package net.sf.javaml.core; import static org.junit.Assert.*; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; import net.sf.javaml.distance.DistanceMeasure; import org.junit.Test; public class TestDefaultDataset { public static class MyDistance implements DistanceMeasure{ private static final long serialVersionUID = 3780817385707462317L; public double measure(Instance x, Instance y) { double dist = 0; for(int i = 0; i < x.noAttributes(); i++){ double v1 = x.value(i); double v2 = y.value(i); dist += (v1 - v2) * (v1 - v2); } return Math.sqrt(dist); } public boolean compare(double x, double y) { return x <= y; } public double getMinValue() { return 0; } public double getMaxValue() { return Double.MAX_VALUE; } } @Test public void test(){ DefaultDataset dd = new DefaultDataset(); assertTrue(dd.isEmpty()); assertTrue(dd.noAttributes() == 0); assertNull(dd.classValue(0)); assertTrue(dd.classIndex(null) < 0); dd.classIndex(new Integer(5)); List instances = new ArrayList(); for(int i = 0; i < 100; i++) instances.add(new DenseInstance(new double[]{1, 2, i})); dd = new DefaultDataset(instances); dd.addAll(instances); dd.clear(); dd.addAll(0, instances); dd.add(new DenseInstance(new double[]{1, 4, 8})); dd.add(0, new DenseInstance(new double[]{1, 4, 8})); assertTrue(dd.instance(0).equals(new DenseInstance(new double[]{1, 4, 8}))); assertNotNull(dd.classes()); MyDistance md = new MyDistance(); assertEquals(3, dd.kNearest(3, new DenseInstance(new double[]{1, 4, 8}), md).size()); Instance i1 = new DenseInstance(new double[]{100, 4, 8}); Instance i2 = new DenseInstance(new double[]{100, 5, 8}); dd.add(i1); dd.add(i2); Set results = dd.kNearest(1, i1, md); for(Instance i : results) assertTrue(i.equals(i2)); assertNotNull(dd.folds(50, new Random())); assertTrue(dd.noAttributes() == 3); dd.classValue(1); assertEquals(dd, dd.copy()); assertEquals(dd.copy(), dd); assertEquals(dd.hashCode(), dd.copy().hashCode()); } }