Machine learning: Decision tree

Finally I have implemented decision tree and I want to share it with you! Decision tree concept is deceptively simple and I thought it will be very easy to implement. However it was not as easy as I thought it will be. But first things first. What is decision tree?

Decision tree

It is one of the most popular and effective machine learning algorithms. It can do classification, regression, ranking, probability estimation, clustering. Decision tree is not a black box and its results is easily interpretable. Data does not need to be normalized or specifically prepared. It is intuitive and even a child can understand the basics of how decision trees work. They are also widely applied in practice especially with ensembles (multiple decision trees, for example random forests). All of this makes them perfect algorithm to start with machine learning.


How decision trees work

As you can see in the picture above this particular decision tree tries to classify if titanic passenger survived or died during the catastrophe. For example if he is a male and older than 9.5 years then it is most likely he did not survive (this specific example also contains probability of correct answer). There is basically nothing much to add. Once you have a built decision tree and want to classify some data sample you simply start at the root node and check if your data has a feature from the node. You do that until you reach a leaf node which already has an answer you were seeking. I really recommend watching this udacity course on decision trees to understand them better and get some intuitions on how tree is build. They explain it so much better than me.

How decision tree is built

To build a decision tree we take a set of possible features. Then we take one feature create tree node for it and split training data. Once training data is split into 2 (or n) sublists same thing is repeated on those sublists with recursion until whole tree is built.
Now of course you may have some questions about this simplified explanation. First of all how do we select a feature to do a split on? Then how do we stop the recursion?
To select best feature for a split we check which split gives us most information gain (or has least impurity as in pseudo code later). Lets say one split returns 80 entries with TRUE label and 20 with FALSE label. It is obviously better split as 50 TRUE labels and 50 FALSE labels. So we check each feature and select best one. We stop when there are no more features left, or data list is already fully homogenous (all entries with same label) or tree is already reached maximum depth (parameter set).
And once again I recommend to check out that udacity course.

Decision tree training pseudocode

I used pseudocode from this book to implement my decision tree. It goes like this:

GrowTree(D, F) – grow a feature tree from training data.

Input : data D; set of features F.
Output : feature tree T with labelled leaves.
 if Homogeneous(D) then return Label(D) ;
 S = BestSplit(D, F) ; 
 split D into subsets Di according to the literals in S;
 for each i do
 if Di not empty then Ti = GrowTree(Di, F) else Ti is a leaf labelled with Label(D);
 return a tree whose root is labelled with S and whose children are Ti

BestSplit(D, F) – find the best split for a decision tree.

Input : data D; set of features F.
Output : feature f to split on.
 Imin  =1;
 for each f ∈ F do
 split D into subsets D1 ,..., Dl according to the values Vj of f;
 if Impurity({D1 ,..., Dl}) < Imin then
 Imin = Impurity({D1 ,..., Dl});
 fbest =  f;
 return fbest

Java implementation

First of all lets think how we are going to represent Feature. My Feature must know if it does belong to a data sample. This wont handle categorical features, but its OK for now. So here is Interface which defines a feature:

public interface Feature {
    boolean belongsTo(DataSample dataSample);

    default List<List<DataSample>> split(List<DataSample> data) {
        List<List<DataSample>> result = Lists.newArrayList();
        Map<Boolean, List<DataSample>> split = data.parallelStream().collect(partitioningBy(dataSample -> belongsTo(dataSample)));
        if (split.get(true).size() > 0) {
        } else {
        if (split.get(false).size() > 0) {
        } else {
        return result;


As you can see it has belongsTo() method which returns true if data sample contains the feature and also default method for splitting data into 2 sub lists – one which contains feature and one which doesn’t. I happily use java 8 streams and even though I am sure it could be contracted to one liner it still saves a lot of time. God bless java 8.
Now here is my single implementation for Feature interface – PredicateFeature. It takes column name and java 8 Predicate function (which allows it to accept lambda expressions!). So it simply checks value of data sample on provided column and checks it with Predicate function. If it returns true – data has that feature. For example Predicate feature with column “age” and predicate age -> age > 10 will return true to all data samples that has age more than 10 in “age” column. You can also check out some commonly used predicates in P class.

public class PredicateFeature<T> implements Feature {
    /** Data column used by feature. */
    private String column;

    /** Predicate used for splitting. */
    private Predicate<T> predicate;
    /** Feature Label used for visualization and testing the tree. */
    private String label;

    private PredicateFeature(String column, Predicate<T> predicate, String label) {
        this.column = column;
        this.predicate = predicate;
        this.label = label;

    public boolean belongsTo(DataSample dataSample) {
        Optional<Object> optionalValue = dataSample.getValue(column);
        return optionalValue.isPresent() ? predicate.test((T)optionalValue.get()) : false;
    public static <T> Feature newFeature(String column, T featureValue) {
        return new PredicateFeature<T>(column, P.isEqual(featureValue), String.format("%s = %s", column, featureValue));

    public static <T> Feature newFeature(String column, Predicate<T> predicate, String predicateString) {
        return new PredicateFeature<T>(column, predicate, String.format("%s %s", column, predicateString));

As we already introduced DataSamples lets check my simple implementation:

public class SimpleDataSample implements DataSample {
    private Map<String, Object> values = Maps.newHashMap();
    /** Column name which contains data labels. */
    private String labelColumn;
    private SimpleDataSample(String labelColumn, String[] header, Object... dataValues) {
        this.labelColumn = labelColumn;
        for (int i = 0; i < header.length; i++) {
            this.values.put(header[i], dataValues[i]);

    public Optional<Object> getValue(String column) {
        return Optional.ofNullable(values.get(column));
    public Label getLabel() {
        return (Label)values.get(labelColumn);

    public static SimpleDataSample newClassificationDataSample(String[] header, Object... values) {
        Preconditions.checkArgument(header.length == values.length);
        return new SimpleDataSample(null, header, values);

    public static SimpleDataSample newSimpleDataSample(String labelColumn, String[] header, Object... values) {
        Preconditions.checkArgument(header.length == values.length);
        return new SimpleDataSample(labelColumn, header, values);

It is simple Hashmap with Map key being column name. It also has one more property – “labelColumn”. This is a column name which contains Labels of training data. Value of label column must be of type Label.

OK now we are ready to implement GrowTree and BestSplit routines from previously mentioned pseudocode:

    public void train(List<DataSample> trainingData, List<Feature> features) {
        root = growTree(trainingData, features, 1);

     protected Node growTree(List<DataSample> trainingData, List<Feature> features, int currentDepth) {

        Label currentNodeLabel = null;
        // if dataset already homogeneous enough (has label assigned) make this node a leaf
        if ((currentNodeLabel = getLabel(trainingData)) != null) {
            log.debug("New leaf is created because data is homogeneous: {}", currentNodeLabel.getName());
            return Node.newLeafNode(currentNodeLabel);
        boolean stoppingCriteriaReached = features.isEmpty() || currentDepth >= maxDepth;
        if (stoppingCriteriaReached) {
            Label majorityLabel = getMajorityLabel(trainingData);
            log.debug("New leaf is created because stopping criteria reached: {}", majorityLabel.getName());
            return Node.newLeafNode(majorityLabel);

        Feature bestSplit = findBestSplitFeature(trainingData, features); // get best set of literals
        log.debug("Best split found: {}", bestSplit.toString());
        List<List<DataSample>> splitData = bestSplit.split(trainingData);
        log.debug("Data is split into sublists of sizes: {}",;

        // remove best split from features (TODO check if it is not slow)
        List<Feature> newFeatures = -> !p.equals(bestSplit)).collect(toList());
        Node node = Node.newNode(bestSplit);
        for (List<DataSample> subsetTrainingData : splitData) { // add children to current node according to split
            if (subsetTrainingData.isEmpty()) {
                // if subset data is empty add a leaf with label calculated from initial data
            } else {
                // grow tree further recursively
                node.addChild(growTree(subsetTrainingData, newFeatures, currentDepth + 1));

        return node;

And finding best Feature to split:

    protected Feature findBestSplitFeature(List<DataSample> data, List<Feature> features) {
        double currentImpurity = 1;
        Feature bestSplitFeature = null; // rename split to feature

        for (Feature feature : features) {
            List<List<DataSample>> splitData = feature.split(data);
            // totalSplitImpurity = sum(singleLeafImpurities) / nbOfLeafs
            // in other words splitImpurity is average of leaf impurities
            double calculatedSplitImpurity = splitData.parallelStream().filter(list -> !list.isEmpty()).mapToDouble(list -> impurityCalculationMethod.calculateImpurity(list)).average().getAsDouble();
            if (calculatedSplitImpurity < currentImpurity) {
                currentImpurity = calculatedSplitImpurity;
                bestSplitFeature = feature;

        return bestSplitFeature;

Code pretty much follows pseudocode and is pretty self explanatory. One thing you might wonder is impurityCalculationMethod. It is simple interface to choose a formula for impurity calculation. Right now I have 2 – Gini and Entropy and you can check them out here.


So how we can use this decision tree? Here is simple example:

public static void main(String[] args) throws FileNotFoundException, IOException {
        List<DataSample> trainingData = readData(true);
        DecisionTree tree = new DecisionTree();
        List<Feature> features = getFeatures();
        tree.train(trainingData, features);
        // print tree after training
        // read test data
        List<DataSample> testingData = readData(false);
        // classify all test data
        for (DataSample dataSample : testingData) {

Here is an example how I solved kaggle titanic competition using this decision tree. Only 0.7655 but still nice to do it with your own decision tree implementation:)

As always you can find source code on my github.
Happy coding!