package uu.mal.skate.view; import java.awt.BasicStroke; import java.awt.Color; import java.awt.Dimension; import java.text.DecimalFormat; import javax.swing.BoxLayout; import javax.swing.JLabel; import javax.swing.JPanel; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.stat.StatUtils; import org.apache.commons.math3.stat.descriptive.summary.Sum; import info.monitorenter.gui.chart.Chart2D; import info.monitorenter.gui.chart.ITracePoint2D; import info.monitorenter.gui.chart.TracePoint2D; import info.monitorenter.gui.chart.labelformatters.LabelFormatterNumber; import info.monitorenter.gui.chart.rangepolicies.RangePolicyFixedViewport; import info.monitorenter.gui.chart.traces.Trace2DLtd; import info.monitorenter.gui.chart.traces.Trace2DSimple; import info.monitorenter.gui.chart.traces.painters.TracePainterLine; import info.monitorenter.util.Range; import uu.mal.skate.model.Options; import uu.mal.skate.model.Simulation; public class ChartPane extends JPanel { private static final long serialVersionUID = 1L; private final Options options; private final Chart2D ratioChart, directionMeanChart, meanChart; private final Trace2DSimple[] ratioChartTraces, directionMeanChartTraces; private final Trace2DLtd meanChartTrace; public ChartPane(Simulation sim) { this.options = sim.getOptions(); setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); setPreferredSize(new Dimension((int) (options.torusHeight / 1.25), options.torusHeight)); String[] chartTitles = sim.getSimType().getChartTitles(); meanChart = new Chart2D(); meanChart.setName(chartTitles[0]); meanChart.getAxisX().getAxisTitle().setTitle("Iteration"); meanChart.getAxisY().getAxisTitle().setTitle("Total sum"); meanChart.getAxisY().setFormatter(new LabelFormatterNumber(new DecimalFormat("#.##"))); meanChartTrace = new Trace2DLtd(200); meanChartTrace.setName(null); meanChartTrace.setTracePainter(new TracePainterLine()); meanChartTrace.setStroke(new BasicStroke(1.5f)); meanChartTrace.setColor(Color.getHSBColor(0.9f, 0.9f, 0.9f)); meanChart.addTrace(meanChartTrace); directionMeanChart = new Chart2D(); directionMeanChart.setName(chartTitles[1]); directionMeanChart.getAxisX().getAxisTitle().setTitle("Iteration"); directionMeanChart.getAxisY().getAxisTitle().setTitle("Mean"); directionMeanChart.getAxisY().setFormatter(new LabelFormatterNumber(new DecimalFormat("#.##"))); directionMeanChartTraces = new Trace2DSimple[options.directionCount]; for(int i = 0; i < directionMeanChartTraces.length; i++){ directionMeanChartTraces[i] = new Trace2DSimple(); directionMeanChartTraces[i].setName(null); directionMeanChartTraces[i].setTracePainter(new TracePainterLine()); directionMeanChartTraces[i].setStroke(new BasicStroke(1.5f)); directionMeanChartTraces[i].setColor(Color.getHSBColor((float)i * (1f/directionMeanChartTraces.length), 0.9f, 0.9f)); directionMeanChart.addTrace(directionMeanChartTraces[i]); } ratioChart = new Chart2D(); ratioChart.setName(chartTitles[2]); ratioChart.getAxisX().getAxisTitle().setTitle("Iteration"); ratioChart.getAxisY().setRangePolicy(new RangePolicyFixedViewport(new Range(0, 1))); ratioChart.getAxisY().getAxisTitle().setTitle("Ratio"); ratioChartTraces = new Trace2DSimple[options.directionCount]; for(int i = 0; i < ratioChartTraces.length; i++){ ratioChartTraces[i] = new Trace2DSimple(); ratioChartTraces[i].setName(null); ratioChartTraces[i].setTracePainter(new TracePainterLine()); ratioChartTraces[i].setStroke(new BasicStroke(1.5f)); ratioChartTraces[i].setColor(Color.getHSBColor((float)i * (1f/ratioChartTraces.length), 0.9f, 0.9f)); ratioChart.addTrace(ratioChartTraces[i]); } add(new JLabel(meanChart.getName())); add(meanChart); add(new JLabel(directionMeanChart.getName())); add(directionMeanChart); add(new JLabel(ratioChart.getName())); add(ratioChart); } public void update(long iteration, RealMatrix rewardMatrix) { Sum totalSum = new Sum(); for(int i = 0; i < rewardMatrix.getRowDimension(); i++) { totalSum.incrementAll(rewardMatrix.getRow(i)); } double allSum = totalSum.getResult(); ITracePoint2D meanPoint = new TracePoint2D(iteration, totalSum.getResult()); meanChartTrace.addPoint(meanPoint); for(int i = 0; i < rewardMatrix.getRowDimension(); i++) { double directionSum = StatUtils.sum(rewardMatrix.getRow(i)); double directionMean = StatUtils.mean(rewardMatrix.getRow(i)); ITracePoint2D ratioPoint = new TracePoint2D(iteration, directionSum / allSum); ratioChartTraces[i].addPoint(ratioPoint); ITracePoint2D directionMeanPoint = new TracePoint2D(iteration, directionMean); directionMeanChartTraces[i].addPoint(directionMeanPoint); } this.invalidate(); } }