Full-joint distributions are particularly attractive representations of (uncertain) knowledge, for they enable a wide variety of reasoning patterns: Causal, diagnostic and mixed inference patterns are all supported within a single framework. Given a full-joint distribution $P(X=x)$ over a set of variables $X$, we can, in particular, reason about conditional probabilities of the form $P(Q=q \mid E=e)$ for any subsets $Q \subseteq X$ and $E \subseteq X$ of the variables. Probabilistic models that represent such distributions have thus long been viewed as valuable resources for AI systems at large.
With joint probability trees (JPTs), Nyga et al. present a novel formalism for the representation of full-joint distributions with some attractive properties [Nyg23J]:
- support for hybrid domains (discrete and continuous random variables)
- tractable learning and inference
- model construction in the absence of domain knowledge (no prior specification of a dependency structure is necessary)
Representation
A joint probability tree represents a full-joint distribution over a set of random variables $X = \{ X_1, \dots, X_N \}$ as a mixture model, where
- the mixture components are obtained by recursively partitioning a training set using a greedy algorithm that seeks to minimise an impurity metric (analogous to the learning of decision trees).
- the representation of the full-joint distribution of each mixture component assumes mutual independence of the random variables, i.e. each component contains representations of $N$ univariate distributions.
Let $\Lambda$ be the set of partitions (leaves) we obtain via the learning algorithm. The full-joint distribution then decomposes as follows, \begin{align} P(X=x) &= \sum_{\lambda \in \Lambda} P(X=x \mid \lambda) P(\lambda) \\ &= \sum_{\lambda \in \Lambda} P(\lambda) \prod_{i=1}^N P(X_i=x_i \mid \lambda), \end{align} where $P(\lambda)$ is the fraction of the samples that ended up in leaf $\lambda$ during training, and $P(X_i=x_i \mid \lambda)$ is computed based on the distribution that is stored for random variable $X_i$ in leaf $\lambda$:
- For a discrete variable $X_i$, $P(X_i=x_i \mid \lambda)$ is simply the relative frequency of value $x_i$ in the training samples that ended up in leaf $\lambda$. We thus store the corresponding multinomial distribution in $\lambda$.
- For a continuous variable $X_i$, we store a cumulative distribution function (CDF), which we conveniently represent as a piecewise linear function, such that both density values of the form $P(X_i=x_i \mid \lambda)$ and range queries of the form $P(x_i^\bot \leq X_i \leq x_i^\top \mid \lambda)$ can straightforwardly be supported. To achieve compactness, the authors use a recursive optimisation scheme in order to induce the CDF representation from the training samples in $\lambda$.
Learning
JPTs use a recursive partitioning scheme that is inspired by the greedy
algorithms typically used for the induction of decision trees, which we shall
briefly review:
When inducing a (binary) decision tree, the fundamental operation is to split
the training dataset into two subsets, such that the average impurity in the two
subsets is lower than the impurity of the set we started with.
The split condition can be any condition that compares one of the variables
against a constant (e.g. $X_i < c$ for a continuous variable $X_i$ and some
real-valued constant $c$) and we greedily choose, among all candidate splits,
the one that yields the greatest reduction in impurity.
The two subsets are then recursively split in the same way until a termination
criterion is met (e.g. until impurity is sufficiently low or until the sets have
become too small).
The reduction in impurity means that after we have checked the condition, we
are closer to a unique answer than we were previously.
For classification, a common impurity metric is the entropy of the class
variable’s distribution, and for regression, one often uses the variance of the
regression target variable, which corresponds to the mean squared error (MSE)
when using the mean as the predictor.
In JPTs, there is no target variable; the goal is to obtain subsets where the
impurity of all variables is low, and the authors thus minimise the (weighted)
sum of the impurity values pertaining to the individual variables.
In order to reconcile the semantics of impurity values of discrete and
continuous variables (at various scales), relative improvements in impurity for
each variable are considered in order to determine the best splits.
Figure 1 illustrates JPT learning and the resulting representation
based on the example of a hybrid distribution involving three variables.
The authors’ evaluation indicates promising performance in practice,
showing that generative JPTs can perform on par with discriminative models (that
were specifically trained to solve a specific classification or regression
problem) whilst retaining a much higher degree of flexibility.
Moreover, although the authors do not explicitly mention this, the JPT learning
algorithm can also be viewed as a clustering method which falls into the
category of divisive hierarchical clustering algorithms.
When combined with appropriate, clustering-specific termination criteria, it
could prove to be a highly effective automatic clustering method where no prior
knowledge on the number of clusters is required.
Discussion
The independence assumption that is made in each leaf may seem strong, yet the
impurity minimisation that is at the heart of the learning algorithm seeks to
ensure that there are ultimately few interactions between the variables,
reducing co-information and thus weakening the assumption.
By varying the number/size of partitions, we can essentially control the degree
to which the independence assumption constitutes an approximation.
However, with the current incarnation of JPTs as proposed by the authors,
we cannot arbitrarily reduce leaf size, because doing so may result in a
potentially problematic case of overfitting where the probability of the
evidence becomes zero for an increasingly large number of out-of-sample
predictions.
Fortunately, we can conceive of straightforward solutions to this problem:
We could, for instance, assume that the observed values of continuous variables
are but noisy observations of the true values and apply the principles of
kernel density estimators in order to induce the CDF representations.
For discrete variables, we could, in a similar spirit, assume pseudo-counts.
The authors provide an
open-source implementation,
so little stands in the way of experimentation.