Joint Probability Trees

Joint probability trees (JPTs) are a novel formalism for the representation of full-joint distributions over sets of random variables in hybrid domains. The learning algorithm fundamentally builds on the principles well-known from decision tree learning, decomposing the representation into tractable mixture components based on the notion of distribution impurity.

Full-joint distributions are particularly attractive representations of (uncertain) knowledge, for they enable a wide variety of reasoning patterns: Causal, diagnostic and mixed inference patterns are all supported within a single framework. Given a full-joint distribution P(X=x) over a set of variables X, we can, in particular, reason about conditional probabilities of the form P(Q=qE=e) for any subsets QX and EX of the variables. Probabilistic models that represent such distributions have thus long been viewed as valuable resources for AI systems at large.

With joint probability trees (JPTs), Nyga et al. present a novel formalism for the representation of full-joint distributions with some attractive properties [Nyg23J]:

  • support for hybrid domains (discrete and continuous random variables)
  • tractable learning and inference
  • model construction in the absence of domain knowledge (no prior specification of a dependency structure is necessary)

Representation

A joint probability tree represents a full-joint distribution over a set of random variables X={X1,,XN} as a mixture model, where

  • the mixture components are obtained by recursively partitioning a training set using a greedy algorithm that seeks to minimise an impurity metric (analogous to the learning of decision trees).
  • the representation of the full-joint distribution of each mixture component assumes mutual independence of the random variables, i.e. each component contains representations of N univariate distributions.

Let Λ be the set of partitions (leaves) we obtain via the learning algorithm. The full-joint distribution then decomposes as follows, P(X=x)=λΛP(X=xλ)P(λ)=λΛP(λ)i=1NP(Xi=xiλ), where P(λ) is the fraction of the samples that ended up in leaf λ during training, and P(Xi=xiλ) is computed based on the distribution that is stored for random variable Xi in leaf λ:

  • For a discrete variable Xi, P(Xi=xiλ) is simply the relative frequency of value xi in the training samples that ended up in leaf λ. We thus store the corresponding multinomial distribution in λ.
  • For a continuous variable Xi, we store a cumulative distribution function (CDF), which we conveniently represent as a piecewise linear function, such that both density values of the form P(Xi=xiλ) and range queries of the form P(xiXixiλ) can straightforwardly be supported. To achieve compactness, the authors use a recursive optimisation scheme in order to induce the CDF representation from the training samples in λ.

Learning

JPTs use a recursive partitioning scheme that is inspired by the greedy algorithms typically used for the induction of decision trees, which we shall briefly review: When inducing a (binary) decision tree, the fundamental operation is to split the training dataset into two subsets, such that the average impurity in the two subsets is lower than the impurity of the set we started with. The split condition can be any condition that compares one of the variables against a constant (e.g. Xi<c for a continuous variable Xi and some real-valued constant c) and we greedily choose, among all candidate splits, the one that yields the greatest reduction in impurity. The two subsets are then recursively split in the same way until a termination criterion is met (e.g. until impurity is sufficiently low or until the sets have become too small).
The reduction in impurity means that after we have checked the condition, we are closer to a unique answer than we were previously. For classification, a common impurity metric is the entropy of the class variable’s distribution, and for regression, one often uses the variance of the regression target variable, which corresponds to the mean squared error (MSE) when using the mean as the predictor.
In JPTs, there is no target variable; the goal is to obtain subsets where the impurity of all variables is low, and the authors thus minimise the (weighted) sum of the impurity values pertaining to the individual variables. In order to reconcile the semantics of impurity values of discrete and continuous variables (at various scales), relative improvements in impurity for each variable are considered in order to determine the best splits. Figure 1 illustrates JPT learning and the resulting representation based on the example of a hybrid distribution involving three variables.

Figure 1. Nyga et al. illustrate JPT learning and the resulting representation based on a distribution involving three variables [Nyg23J]: Two variables (X and Y) are continuous and the third is binary, representing a class. The ground truth distribution is a mixture of Gaussians (a) where the class variable determines the mixture component. Based on a number of samples from the distribution (b), a JPT is trained, colours indicating the value of the class variable and the additional lines indicating the split boundaries that were determined by the learning algorithm in order to ultimately arrive at the final JPT representation (d). Each of the seven leaves contains a representation of a CDF for X and Y (line plots) as well as a categorical distribution for the class variable (bar plot). Notice that the classes are well-separated, as each leaf is almost exclusively associated with one of the classes. Furthermore, the marginal distribution over X and Y that is recovered by the JPT (c) is fairly close to the ground truth distribution (a).

Figure 1d. Larger view of the joint probability tree.

The authors’ evaluation indicates promising performance in practice, showing that generative JPTs can perform on par with discriminative models (that were specifically trained to solve a specific classification or regression problem) whilst retaining a much higher degree of flexibility.
Moreover, although the authors do not explicitly mention this, the JPT learning algorithm can also be viewed as a clustering method which falls into the category of divisive hierarchical clustering algorithms. When combined with appropriate, clustering-specific termination criteria, it could prove to be a highly effective automatic clustering method where no prior knowledge on the number of clusters is required.

Discussion

The independence assumption that is made in each leaf may seem strong, yet the impurity minimisation that is at the heart of the learning algorithm seeks to ensure that there are ultimately few interactions between the variables, reducing co-information and thus weakening the assumption. By varying the number/size of partitions, we can essentially control the degree to which the independence assumption constitutes an approximation.
However, with the current incarnation of JPTs as proposed by the authors, we cannot arbitrarily reduce leaf size, because doing so may result in a potentially problematic case of overfitting where the probability of the evidence becomes zero for an increasingly large number of out-of-sample predictions. Fortunately, we can conceive of straightforward solutions to this problem: We could, for instance, assume that the observed values of continuous variables are but noisy observations of the true values and apply the principles of kernel density estimators in order to induce the CDF representations. For discrete variables, we could, in a similar spirit, assume pseudo-counts.
The authors provide an open-source implementation, so little stands in the way of experimentation.

References

In this series