Mixture of Decision Trees for Interpretable Machine Learning

A linear gating function together with multiple expert decision trees is trained by expectation-maximization and results in a new, fully interpretable model that works well on several (simple) data sets.

The explainability of predictions from a learned model is a key requirement for many applications of artificial intelligence, especially in safety critical domains like medicine or driving. There are two main approaches: post-hoc explanations for predictions of black-box models (explainable AI), and intrinsically interpretable models (like a linear one). [Bru22M] belongs to the latter camp (see the well-known [Rud19S] for more details on this topic and arguments for preferring interpretability over ad-hoc explanations).

The authors make use of the fact that small decision trees (DT), or stumps, are easily interpretable by humans. To improve the limited predictive power of a single small DT, a Mixture of Decision Trees (MoDT) trains multiple trees together with a gating function, building on classical mixture-of-experts approaches. The gating function decides which DT should be used for prediction for each input. Because it is linear, it can be easily analyzed. Moreover, the authors recommend restricting it to only two features and provide some algorithms for selecting the two most suitable ones in their reference implementation.1 1 Using only two features has the advantage that the decision boundaries of the gating function can be easily visualized, as in Figure 3 Therefore, the prediction made by the model during inference is fully interpretable.

The gating function and the trees are trained simultaneously with an iterative expectation-maximization (EM) scheme. During training the gating function is used in a “soft” way, determining the weights of different regions of the dataset (thus, each tree sees all data during training, but different regions contribute differently to the loss). At inference, only the most probable DT is chosen, in contrast to other tree-ensemble models like Random Forest (RF), where predictions of different trees are averaged.

While not being as powerful as an RF of large depth, MoDT performs surprisingly well across many tasks. It might be a viable model for applications on tabular data which require interpretable decisions, especially if the most informative features are already known. See Table 2 below for a comparison of MoDT on several standard (small) datasets.

Figure 3: [Bru22M] Exemplary decision area of the 2D gating function (on a toy dataset)

Table 2: [Bru22M] Performance of MoDT with full gating function (FG) and gating based on two features (2D), dts and rfs. Best interpretable method: grey background, best overall: bold

Personal comment: While I don’t think that this method can scale well to high-dimensional and complex datasets (due to the low representational power of small decision trees and linear models), it might still be useful for many applications with simple decision boundaries or with small datasets. Due to its simplicity it can be implemented, trained and fine-tuned very quickly. Interpretable models should always be preferred over black box models if their accuracy is comparable, so it is generally worth a try.

Unfortunately, the authors’ implementation mentioned above is not an installable package, and generally seems a bit abandoned at the time of writing. The case of decision trees for regression was not analyzed and is not implemented. I also believe that a simple yet possibly very useful extension would be to allow for expert DTs of different depths in different regions (in this paper, all expert DTs have the same depth). In the end, the complexity of the problem may vary from region to region. It seems that including such experts should be a straightforward extension of the EM scheme used in this paper.


In this series