package net.sf.javaml.classification.tree;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.SortedSet;
import java.util.Vector;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

/* loaded from: input_file:javaml-0.1.7.jar:net/sf/javaml/classification/tree/RandomTree.class */
public class RandomTree implements Classifier {
    private static final long serialVersionUID = -6421557885832628441L;
    private int noSplitAttributes;
    private Random rg;
    private float[] rightCenter;
    private float[] leftCenter;
    private Object finalClass;
    private RandomTree leftChild;
    private RandomTree rightChild;
    private Vector<Integer> splitAttributes;
    private SortedSet<Object> parentClasses;
    static final /* synthetic */ boolean $assertionsDisabled;

    private RandomTree(int i, Random random, SortedSet<Object> sortedSet) {
        this.noSplitAttributes = -1;
        this.rg = null;
        this.rightCenter = null;
        this.leftCenter = null;
        this.finalClass = null;
        this.leftChild = null;
        this.rightChild = null;
        this.splitAttributes = null;
        this.parentClasses = null;
        this.rg = random;
        this.noSplitAttributes = i;
        this.parentClasses = sortedSet;
    }

    public RandomTree(int i, Random random) {
        this(i, random, null);
    }

    @Override // net.sf.javaml.classification.Classifier
    public void buildClassifier(Dataset dataset) {
        if (this.parentClasses == null) {
            this.parentClasses = dataset.classes();
        }
        if (dataset.classes().size() == 1) {
            this.finalClass = dataset.classes().first();
            dataset.clear();
            return;
        }
        DefaultDataset defaultDataset = null;
        DefaultDataset defaultDataset2 = null;
        boolean z = false;
        int i = 0;
        while (!z) {
            i++;
            this.splitAttributes = new Vector<>();
            for (int i2 = 0; i2 < dataset.noAttributes(); i2++) {
                this.splitAttributes.add(Integer.valueOf(i2));
            }
            while (this.splitAttributes.size() / (i * i) > this.noSplitAttributes) {
                this.splitAttributes.remove(this.rg.nextInt(this.splitAttributes.size()));
            }
            int i3 = 0;
            int i4 = 0;
            this.leftCenter = new float[this.splitAttributes.size()];
            this.rightCenter = new float[this.splitAttributes.size()];
            for (Instance instance : dataset) {
                if (dataset.classIndex(instance.classValue()) == 0) {
                    i3++;
                    for (int i5 = 0; i5 < this.splitAttributes.size(); i5++) {
                        this.leftCenter[i5] = (float) (r0[r1] + instance.value(this.splitAttributes.get(i5).intValue()));
                    }
                } else {
                    i4++;
                    for (int i6 = 0; i6 < this.splitAttributes.size(); i6++) {
                        this.rightCenter[i6] = (float) (r0[r1] + instance.value(this.splitAttributes.get(i6).intValue()));
                    }
                }
            }
            for (int i7 = 0; i7 < this.splitAttributes.size(); i7++) {
                float[] fArr = this.leftCenter;
                int i8 = i7;
                fArr[i8] = fArr[i8] / i3;
                float[] fArr2 = this.rightCenter;
                int i9 = i7;
                fArr2[i9] = fArr2[i9] / i4;
            }
            double[] dArr = new double[this.splitAttributes.size()];
            defaultDataset = new DefaultDataset();
            defaultDataset2 = new DefaultDataset();
            for (Instance instance2 : dataset) {
                for (int i10 = 0; i10 < this.splitAttributes.size(); i10++) {
                    dArr[i10] = instance2.value(this.splitAttributes.get(i10).intValue());
                }
                if (dist(dArr, this.leftCenter) > dist(dArr, this.rightCenter)) {
                    defaultDataset2.add(instance2);
                } else {
                    defaultDataset.add(instance2);
                }
            }
            z = (defaultDataset.size() == 0 || defaultDataset2.size() == 0) ? false : true;
            if (!z && i * i * this.noSplitAttributes > dataset.noAttributes()) {
                Vector vector = new Vector();
                vector.addAll(dataset.classes());
                this.finalClass = vector.get(this.rg.nextInt(vector.size()));
                dataset.clear();
                return;
            }
        }
        this.leftChild = new RandomTree(this.noSplitAttributes, this.rg, this.parentClasses);
        this.leftChild.buildClassifier(defaultDataset);
        this.rightChild = new RandomTree(this.noSplitAttributes, this.rg, this.parentClasses);
        this.rightChild.buildClassifier(defaultDataset2);
    }

    private double dist(double[] dArr, float[] fArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.abs(dArr[i] - fArr[i]);
        }
        return d;
    }

    @Override // net.sf.javaml.classification.Classifier
    public Object classify(Instance instance) {
        if (this.finalClass != null) {
            return this.finalClass;
        }
        if (!$assertionsDisabled && this.rightCenter == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.leftCenter == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.leftChild == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.rightChild == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.splitAttributes == null) {
            throw new AssertionError();
        }
        double[] dArr = new double[this.noSplitAttributes];
        for (int i = 0; i < this.noSplitAttributes; i++) {
            dArr[i] = instance.value(this.splitAttributes.get(i).intValue());
        }
        return dist(dArr, this.leftCenter) > dist(dArr, this.rightCenter) ? this.rightChild.classify(instance) : this.leftChild.classify(instance);
    }

    @Override // net.sf.javaml.classification.Classifier
    public Map<Object, Double> classDistribution(Instance instance) {
        HashMap hashMap = new HashMap();
        Iterator<Object> it = this.parentClasses.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), Double.valueOf(0.0d));
        }
        hashMap.put(classify(instance), Double.valueOf(1.0d));
        return hashMap;
    }

    static {
        $assertionsDisabled = !RandomTree.class.desiredAssertionStatus();
    }
}
