Dirichlet Multinomial
Overview
The Dirichlet-Multinomial framework models uncertainty in categorical probabilities and updates that uncertainty as count data arrives. It combines the Dirichlet distribution as a prior over simplex-valued probability vectors with multinomial observations, yielding closed-form Bayesian updates. This makes it a practical foundation for classification calibration, language modeling, A/B allocation with many arms, and any workflow that tracks proportions under limited data. In statistical engineering, this category matters because it provides fast, interpretable posterior inference without requiring iterative sampling for routine updates.
Core Concepts: The shared structure across these tools is conjugacy, normalization on the simplex, and log-domain numerical stability. If prior concentration parameters are \boldsymbol{\alpha} and observed counts are \mathbf{n}, posterior parameters are \boldsymbol{\alpha}'=\boldsymbol{\alpha}+\mathbf{n}, and one-step predictive probabilities are p_i=\alpha_i'/\sum_j \alpha_j'. Dirichlet density calculations depend on the multivariate beta normalizer B(\boldsymbol{\alpha}), so robust implementations often compute in log space via \log\Gamma(\cdot) and log-sum-exp transforms. Together, these ideas support both probabilistic interpretation (means, variances, credible intervals) and stable computation in real datasets.
Implementation: The category is implemented with SciPy, primarily scipy.stats.dirichlet for Dirichlet distribution summaries and scipy.special routines such as gammaln and logsumexp for stable log-domain math. SciPy provides production-grade numerical routines for Bayesian probability models, while custom wrappers align those routines to spreadsheet-style inputs and outputs for analysis workflows.
The posterior updating and prediction workflow centers on DM_POST_UPDATE and DM_PREDICTIVE. DM_POST_UPDATE applies the conjugate update rule to convert prior pseudo-counts and observed counts into posterior hyperparameters, then returns posterior means as immediately usable category probabilities. DM_PREDICTIVE isolates the predictive step, mapping a posterior Dirichlet vector directly to one-step-ahead probabilities that sum to 1. In practice, these two tools support sequential learning pipelines where new observations are folded in repeatedly and predictions are refreshed after each batch.
Uncertainty quantification and distribution diagnostics are handled by DM_CRED_INT and DM_DIRICHLET_SUM. DM_CRED_INT computes marginal credible intervals for each category probability using the Beta marginals implied by a Dirichlet posterior, which is essential when teams need interval-based risk communication rather than point estimates. DM_DIRICHLET_SUM provides a compact distribution report: category means, category variances, and density/log-density at a user-specified simplex point. These outputs are useful for model checking, prior elicitation reviews, and comparing how concentrated alternative priors or posteriors are around specific probability vectors.
Normalization and evidence-oriented computations are supported by DM_LOGBETA and DM_LOGSUM_NORM. DM_LOGBETA evaluates the log multivariate beta function, the key normalizing term that appears in Dirichlet and Dirichlet-multinomial log-likelihood expressions. DM_LOGSUM_NORM converts unnormalized log-scores into normalized probabilities and returns the log normalizer, a standard step in Bayesian model comparison and stable posterior weight calculations. Used together, these functions reduce overflow/underflow risk and make large-scale categorical inference pipelines numerically reliable.
DM_CRED_INT
This function computes marginal credible intervals for each category probability under a Dirichlet posterior. Each category marginal follows a Beta distribution with parameters derived from the full Dirichlet vector.
For posterior parameters \boldsymbol{\alpha}' and category i, the marginal is:
\theta_i \sim \mathrm{Beta}\left(\alpha_i', \alpha_0' - \alpha_i'\right), \qquad \alpha_0' = \sum_{j=1}^K \alpha_j'
For credibility level c, the lower and upper bounds are the quantiles at (1-c)/2 and 1-(1-c)/2.
Excel Usage
=DM_CRED_INT(alpha_posterior, cred_level)
alpha_posterior(list[list], required): Posterior Dirichlet hyperparameters as a 2D range of positive values.cred_level(float, optional, default: 0.95): Central credibility level in (0, 1), such as 0.95.
Returns (list[list]): 2D array with lower bounds in the first row and upper bounds in the second row.
Example 1: Default 95 percent intervals for three-category posterior
Inputs:
| alpha_posterior | ||
|---|---|---|
| 6 | 3 | 4 |
Excel formula:
=DM_CRED_INT({6,3,4})
Expected output:
| Result | ||
|---|---|---|
| 0.210945 | 0.0548606 | 0.0992461 |
| 0.72333 | 0.484138 | 0.571858 |
Example 2: 90 percent intervals for skewed posterior
Inputs:
| alpha_posterior | cred_level | ||
|---|---|---|---|
| 15 | 2 | 1 | 0.9 |
Excel formula:
=DM_CRED_INT({15,2,1}, 0.9)
Expected output:
| Result | ||
|---|---|---|
| 0.673807 | 0.0213176 | 0.00301271 |
| 0.950102 | 0.250124 | 0.161566 |
Example 3: Matrix-shaped posterior parameters for interval computation
Inputs:
| alpha_posterior | cred_level | |
|---|---|---|
| 2 | 5 | 0.8 |
| 3 | 4 |
Excel formula:
=DM_CRED_INT({2,5;3,4}, 0.8)
Expected output:
| Result | |||
|---|---|---|---|
| 0.041691 | 0.200502 | 0.0879964 | 0.141611 |
| 0.267836 | 0.523429 | 0.359776 | 0.444263 |
Example 4: Intervals under weaker posterior concentration
Inputs:
| alpha_posterior | cred_level | |||
|---|---|---|---|---|
| 0.9 | 1.1 | 0.8 | 1.2 | 0.85 |
Excel formula:
=DM_CRED_INT({0.9,1.1,0.8,1.2}, 0.85)
Expected output:
| Result | |||
|---|---|---|---|
| 0.0179992 | 0.0346402 | 0.0117339 | 0.0449046 |
| 0.54483 | 0.60985 | 0.509197 | 0.639744 |
Python Code
Show Code
from scipy.special import betaincinv as scipy_betaincinv
def dm_cred_int(alpha_posterior, cred_level=0.95):
"""
Compute category-wise credible intervals from posterior Dirichlet parameters.
See: https://en.wikipedia.org/wiki/Dirichlet_distribution#Marginal_distributions
This example function is provided as-is without any representation of accuracy.
Args:
alpha_posterior (list[list]): Posterior Dirichlet hyperparameters as a 2D range of positive values.
cred_level (float, optional): Central credibility level in (0, 1), such as 0.95. Default is 0.95.
Returns:
list[list]: 2D array with lower bounds in the first row and upper bounds in the second row.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
alpha_posterior = to2d(alpha_posterior)
cred_level = float(cred_level)
if not isinstance(alpha_posterior, list) or not all(isinstance(row, list) for row in alpha_posterior):
return "Error: alpha_posterior must be a 2D list"
if cred_level <= 0 or cred_level >= 1:
return "Error: cred_level must be strictly between 0 and 1"
alpha_flat = []
for row in alpha_posterior:
for val in row:
try:
alpha_flat.append(float(val))
except (TypeError, ValueError):
continue
if len(alpha_flat) < 2:
return "Error: alpha_posterior must contain at least two positive values"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha_posterior values must be positive"
alpha_total = float(sum(alpha_flat))
tail = (1.0 - cred_level) / 2.0
lower = []
upper = []
for ai in alpha_flat:
bi = alpha_total - ai
if bi <= 0:
return "Error: alpha_posterior must contain at least two categories"
lower.append(float(scipy_betaincinv(ai, bi, tail)))
upper.append(float(scipy_betaincinv(ai, bi, 1.0 - tail)))
return [lower, upper]
except Exception as e:
return f"Error: {str(e)}"Online Calculator
DM_DIRICHLET_SUM
This function summarizes a Dirichlet distribution for Bayesian categorical modeling. It computes the Dirichlet mean and variance for each category and evaluates both the density and log-density at a supplied probability vector on the simplex.
For concentration parameters \boldsymbol{\alpha}=(\alpha_1,\ldots,\alpha_K), the mean and variance of each category probability are:
\mathbb{E}[\theta_i]=\frac{\alpha_i}{\alpha_0}, \qquad \mathrm{Var}(\theta_i)=\frac{\alpha_i(\alpha_0-\alpha_i)}{\alpha_0^2(\alpha_0+1)}
where \alpha_0=\sum_{i=1}^K \alpha_i. The input probability vector must satisfy \sum_i x_i=1 with each x_i\in(0,1).
Excel Usage
=DM_DIRICHLET_SUM(alpha, x)
alpha(list[list], required): Dirichlet concentration parameters as a 2D range of positive values.x(list[list], required): Category probability vector as a 2D range with entries in (0, 1) that sum to 1.
Returns (list[list]): 2D array with mean row, variance row, and a row containing density and log-density.
Example 1: Symmetric three-category Dirichlet summary
Inputs:
| alpha | x | ||||
|---|---|---|---|---|---|
| 2 | 2 | 2 | 0.3 | 0.4 | 0.3 |
Excel formula:
=DM_DIRICHLET_SUM({2,2,2}, {0.3,0.4,0.3})
Expected output:
| Result | ||
|---|---|---|
| 0.333333 | 0.333333 | 0.333333 |
| 0.031746 | 0.031746 | 0.031746 |
| 4.32 | 1.46326 |
Example 2: Skewed four-category concentration with valid simplex point
Inputs:
| alpha | x | ||||||
|---|---|---|---|---|---|---|---|
| 5 | 2 | 3 | 4 | 0.4 | 0.1 | 0.2 | 0.3 |
Excel formula:
=DM_DIRICHLET_SUM({5,2,3,4}, {0.4,0.1,0.2,0.3})
Expected output:
| Result | |||
|---|---|---|---|
| 0.357143 | 0.142857 | 0.214286 | 0.285714 |
| 0.0153061 | 0.00816327 | 0.0112245 | 0.0136054 |
| 59.7794 | 4.09066 |
Example 3: Matrix-shaped input ranges are flattened correctly
Inputs:
| alpha | x | ||
|---|---|---|---|
| 3 | 1 | 0.25 | 0.1 |
| 2 | 4 | 0.15 | 0.5 |
Excel formula:
=DM_DIRICHLET_SUM({3,1;2,4}, {0.25,0.1;0.15,0.5})
Expected output:
| Result | |||
|---|---|---|---|
| 0.3 | 0.1 | 0.2 | 0.4 |
| 0.0190909 | 0.00818182 | 0.0145455 | 0.0218182 |
| 35.4375 | 3.56777 |
Example 4: Concentrated prior with dominant category probability
Inputs:
| alpha | x | ||||
|---|---|---|---|---|---|
| 10 | 2 | 1 | 0.75 | 0.15 | 0.1 |
Excel formula:
=DM_DIRICHLET_SUM({10,2,1}, {0.75,0.15,0.1})
Expected output:
| Result | ||
|---|---|---|
| 0.769231 | 0.153846 | 0.0769231 |
| 0.0126796 | 0.00929839 | 0.00507185 |
| 14.8668 | 2.69913 |
Python Code
Show Code
import numpy as np
from scipy.stats import dirichlet as scipy_dirichlet
def dm_dirichlet_sum(alpha, x):
"""
Compute Dirichlet density and moments for a category-probability vector.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html
This example function is provided as-is without any representation of accuracy.
Args:
alpha (list[list]): Dirichlet concentration parameters as a 2D range of positive values.
x (list[list]): Category probability vector as a 2D range with entries in (0, 1) that sum to 1.
Returns:
list[list]: 2D array with mean row, variance row, and a row containing density and log-density.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
def flatten_numeric(mat):
flat = []
for row in mat:
if not isinstance(row, list):
return None
for val in row:
try:
flat.append(float(val))
except (TypeError, ValueError):
continue
return flat
alpha = to2d(alpha)
x = to2d(x)
alpha_flat = flatten_numeric(alpha)
x_flat = flatten_numeric(x)
if alpha_flat is None or x_flat is None:
return "Error: alpha and x must be 2D lists"
if len(alpha_flat) < 2:
return "Error: alpha must contain at least two positive values"
if len(alpha_flat) != len(x_flat):
return "Error: alpha and x must have the same number of elements"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha values must be positive"
if any((xi <= 0) or (xi >= 1) for xi in x_flat):
return "Error: x values must be strictly between 0 and 1"
x_sum = float(sum(x_flat))
if abs(x_sum - 1.0) > 1e-8:
return "Error: x values must sum to 1"
alpha_arr = np.asarray(alpha_flat, dtype=float)
x_arr = np.asarray(x_flat, dtype=float)
mean = scipy_dirichlet.mean(alpha_arr).tolist()
var = scipy_dirichlet.var(alpha_arr).tolist()
pdf_value = float(scipy_dirichlet.pdf(x_arr, alpha_arr))
logpdf_value = float(scipy_dirichlet.logpdf(x_arr, alpha_arr))
width = len(alpha_flat)
summary_row = [pdf_value, logpdf_value] + [""] * max(0, width - 2)
return [mean, var, summary_row]
except Exception as e:
return f"Error: {str(e)}"Online Calculator
DM_LOGBETA
This function computes the log of the multivariate beta function for a Dirichlet concentration vector. The result is the normalization term used in Dirichlet and Dirichlet-Multinomial log-density and log-evidence calculations.
For concentration parameters \boldsymbol{\alpha}=(\alpha_1,\ldots,\alpha_K), the log-normalization term is:
\log B(\boldsymbol{\alpha}) = \sum_{i=1}^K \log\Gamma(\alpha_i) - \log\Gamma\!\left(\sum_{i=1}^K \alpha_i\right)
Using \log\Gamma(\cdot) directly provides stable computation for large or small positive parameters.
Excel Usage
=DM_LOGBETA(alpha)
alpha(list[list], required): Dirichlet concentration parameters as a 2D range of positive values.
Returns (float): Log of the multivariate beta function for the provided concentration parameters.
Example 1: Symmetric three-category concentration vector
Inputs:
| alpha | ||
|---|---|---|
| 2 | 2 | 2 |
Excel formula:
=DM_LOGBETA({2,2,2})
Expected output:
-4.78749
Example 2: Uneven four-category concentration vector
Inputs:
| alpha | |||
|---|---|---|---|
| 5 | 1.5 | 3 | 2.5 |
Excel formula:
=DM_LOGBETA({5,1.5,3,2.5})
Expected output:
-13.4672
Example 3: Matrix-shaped concentration input is flattened
Inputs:
| alpha | |
|---|---|
| 1.2 | 2.8 |
| 3.1 | 4.4 |
Excel formula:
=DM_LOGBETA({1.2,2.8;3.1,4.4})
Expected output:
-12.7572
Example 4: Larger concentration values remain stable in log-space
Inputs:
| alpha | ||
|---|---|---|
| 25 | 40 | 35 |
Excel formula:
=DM_LOGBETA({25,40,35})
Expected output:
-109.137
Python Code
Show Code
import numpy as np
from scipy.special import gammaln as scipy_gammaln
def dm_logbeta(alpha):
"""
Compute the Dirichlet log-normalization term using log-gamma values.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.gammaln.html
This example function is provided as-is without any representation of accuracy.
Args:
alpha (list[list]): Dirichlet concentration parameters as a 2D range of positive values.
Returns:
float: Log of the multivariate beta function for the provided concentration parameters.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
alpha = to2d(alpha)
if not isinstance(alpha, list) or not all(isinstance(row, list) for row in alpha):
return "Error: alpha must be a 2D list"
alpha_flat = []
for row in alpha:
for val in row:
try:
alpha_flat.append(float(val))
except (TypeError, ValueError):
continue
if len(alpha_flat) < 2:
return "Error: alpha must contain at least two positive values"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha values must be positive"
alpha_arr = np.asarray(alpha_flat, dtype=float)
return float(np.sum(scipy_gammaln(alpha_arr)) - scipy_gammaln(np.sum(alpha_arr)))
except Exception as e:
return f"Error: {str(e)}"Online Calculator
DM_LOGSUM_NORM
This function uses a log-domain normalization step to transform unnormalized log-scores into categorical probabilities. It is useful for Bayesian workflows where posterior quantities are accumulated in log space.
Given log-values \ell_1,\ldots,\ell_K, it computes:
\log Z = \log\left(\sum_{i=1}^K e^{\ell_i}\right), \qquad p_i = \frac{e^{\ell_i}}{\sum_{j=1}^K e^{\ell_j}}
Computing \log Z with a stable log-sum-exp routine prevents overflow and underflow.
Excel Usage
=DM_LOGSUM_NORM(log_values)
log_values(list[list], required): Unnormalized log-scores as a 2D numeric range.
Returns (list[list]): 2D array with one row for log normalizer and one row for normalized probabilities.
Example 1: Three-category log-scores normalize to probabilities
Inputs:
| log_values | ||
|---|---|---|
| 0 | -1 | -2 |
Excel formula:
=DM_LOGSUM_NORM({0,-1,-2})
Expected output:
| Result | ||
|---|---|---|
| 0.407606 | ||
| 0.665241 | 0.244728 | 0.0900306 |
Example 2: Mixed-scale scores remain numerically stable
Inputs:
| log_values | |||
|---|---|---|---|
| 15 | 2 | -5 | -10 |
Excel formula:
=DM_LOGSUM_NORM({15,2,-5,-10})
Expected output:
| Result | |||
|---|---|---|---|
| 15 | |||
| 0.999998 | 0.00000226032 | 2.06115e-9 | 1.38879e-11 |
Example 3: Matrix-shaped log-values are flattened for normalization
Inputs:
| log_values | |
|---|---|
| 1.2 | 0.4 |
| -0.6 | -2.5 |
Excel formula:
=DM_LOGSUM_NORM({1.2,0.4;-0.6,-2.5})
Expected output:
| Result | |||
|---|---|---|---|
| 1.6943 | |||
| 0.609997 | 0.274089 | 0.100832 | 0.0150813 |
Example 4: Very negative log-values avoid underflow issues
Inputs:
| log_values | ||
|---|---|---|
| -900 | -901 | -903 |
Excel formula:
=DM_LOGSUM_NORM({-900,-901,-903})
Expected output:
| Result | ||
|---|---|---|
| -899.651 | ||
| 0.705385 | 0.259496 | 0.035119 |
Python Code
Show Code
import numpy as np
from scipy.special import logsumexp as scipy_logsumexp
def dm_logsum_norm(log_values):
"""
Compute a stable log normalizer and normalized probabilities from log-values.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
This example function is provided as-is without any representation of accuracy.
Args:
log_values (list[list]): Unnormalized log-scores as a 2D numeric range.
Returns:
list[list]: 2D array with one row for log normalizer and one row for normalized probabilities.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
log_values = to2d(log_values)
if not isinstance(log_values, list) or not all(isinstance(row, list) for row in log_values):
return "Error: log_values must be a 2D list"
flat = []
for row in log_values:
for val in row:
try:
flat.append(float(val))
except (TypeError, ValueError):
continue
if len(flat) == 0:
return "Error: log_values must contain at least one numeric value"
arr = np.asarray(flat, dtype=float)
log_z = float(scipy_logsumexp(arr))
probs = np.exp(arr - log_z).tolist()
width = len(probs)
first_row = [log_z] + [""] * max(0, width - 1)
return [first_row, probs]
except Exception as e:
return f"Error: {str(e)}"Online Calculator
DM_POST_UPDATE
This function performs conjugate Bayesian updating for categorical data. A Dirichlet prior combined with multinomial counts yields a Dirichlet posterior with category-wise hyperparameters incremented by the observed counts.
If prior hyperparameters are \boldsymbol{\alpha} and observed counts are \mathbf{n}, then:
\alpha_i' = \alpha_i + n_i
The function returns both the posterior hyperparameters and posterior predictive means \alpha_i' / \sum_j \alpha_j'.
Excel Usage
=DM_POST_UPDATE(alpha_prior, counts)
alpha_prior(list[list], required): Prior Dirichlet hyperparameters as a 2D range of positive values.counts(list[list], required): Observed category counts as a 2D range of nonnegative integer values.
Returns (list[list]): 2D array with posterior hyperparameters in the first row and posterior predictive means in the second row.
Example 1: Three-category posterior update with moderate counts
Inputs:
| alpha_prior | counts | ||||
|---|---|---|---|---|---|
| 1 | 1 | 1 | 5 | 2 | 3 |
Excel formula:
=DM_POST_UPDATE({1,1,1}, {5,2,3})
Expected output:
| Result | ||
|---|---|---|
| 6 | 3 | 4 |
| 0.461538 | 0.230769 | 0.307692 |
Example 2: Informative prior combined with larger sample counts
Inputs:
| alpha_prior | counts | ||||
|---|---|---|---|---|---|
| 10 | 4 | 6 | 12 | 8 | 5 |
Excel formula:
=DM_POST_UPDATE({10,4,6}, {12,8,5})
Expected output:
| Result | ||
|---|---|---|
| 22 | 12 | 11 |
| 0.488889 | 0.266667 | 0.244444 |
Example 3: Matrix-shaped prior and counts are flattened consistently
Inputs:
| alpha_prior | counts | ||
|---|---|---|---|
| 2 | 3 | 1 | 0 |
| 4 | 5 | 6 | 2 |
Excel formula:
=DM_POST_UPDATE({2,3;4,5}, {1,0;6,2})
Expected output:
| Result | |||
|---|---|---|---|
| 3 | 3 | 10 | 7 |
| 0.130435 | 0.130435 | 0.434783 | 0.304348 |
Example 4: Sparse counts preserve influence of prior
Inputs:
| alpha_prior | counts | ||||||
|---|---|---|---|---|---|---|---|
| 0.8 | 1.5 | 2 | 0.7 | 0 | 1 | 0 | 2 |
Excel formula:
=DM_POST_UPDATE({0.8,1.5,2,0.7}, {0,1,0,2})
Expected output:
| Result | |||
|---|---|---|---|
| 0.8 | 2.5 | 2 | 2.7 |
| 0.1 | 0.3125 | 0.25 | 0.3375 |
Python Code
Show Code
import numpy as np
def dm_post_update(alpha_prior, counts):
"""
Update Dirichlet posterior parameters from prior hyperparameters and observed counts.
See: https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
This example function is provided as-is without any representation of accuracy.
Args:
alpha_prior (list[list]): Prior Dirichlet hyperparameters as a 2D range of positive values.
counts (list[list]): Observed category counts as a 2D range of nonnegative integer values.
Returns:
list[list]: 2D array with posterior hyperparameters in the first row and posterior predictive means in the second row.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
def flatten_numeric(mat):
if not isinstance(mat, list) or not all(isinstance(row, list) for row in mat):
return None
flat = []
for row in mat:
for val in row:
try:
flat.append(float(val))
except (TypeError, ValueError):
continue
return flat
alpha_prior = to2d(alpha_prior)
counts = to2d(counts)
alpha_flat = flatten_numeric(alpha_prior)
counts_flat = flatten_numeric(counts)
if alpha_flat is None or counts_flat is None:
return "Error: alpha_prior and counts must be 2D lists"
if len(alpha_flat) < 2:
return "Error: alpha_prior must contain at least two positive values"
if len(alpha_flat) != len(counts_flat):
return "Error: alpha_prior and counts must have the same number of elements"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha_prior values must be positive"
clean_counts = []
for c in counts_flat:
c_int = int(round(c))
if abs(c - c_int) > 1e-9 or c_int < 0:
return "Error: counts must contain nonnegative integers"
clean_counts.append(c_int)
posterior = (np.asarray(alpha_flat, dtype=float) + np.asarray(clean_counts, dtype=float)).tolist()
total_post = float(sum(posterior))
means = [p / total_post for p in posterior]
return [posterior, means]
except Exception as e:
return f"Error: {str(e)}"Online Calculator
DM_PREDICTIVE
This function returns the one-step-ahead posterior predictive probabilities under a Dirichlet-Multinomial model. For category counts updated into posterior hyperparameters, each category probability is the posterior mean.
Given posterior Dirichlet parameters \boldsymbol{\alpha}', the predictive probability for category i is:
p_i = \frac{\alpha_i'}{\sum_{j=1}^K \alpha_j'}
These probabilities sum to 1 and provide the expected category frequencies for the next draw.
Excel Usage
=DM_PREDICTIVE(alpha_posterior)
alpha_posterior(list[list], required): Posterior Dirichlet hyperparameters as a 2D range of positive values.
Returns (list[list]): Single-row 2D array of posterior predictive probabilities for each category.
Example 1: Balanced posterior produces balanced predictive probabilities
Inputs:
| alpha_posterior | ||
|---|---|---|
| 3 | 3 | 3 |
Excel formula:
=DM_PREDICTIVE({3,3,3})
Expected output:
| Result | ||
|---|---|---|
| 0.333333 | 0.333333 | 0.333333 |
Example 2: Skewed posterior emphasizes dominant category
Inputs:
| alpha_posterior | ||
|---|---|---|
| 12 | 2 | 1 |
Excel formula:
=DM_PREDICTIVE({12,2,1})
Expected output:
| Result | ||
|---|---|---|
| 0.8 | 0.133333 | 0.0666667 |
Example 3: Matrix-shaped posterior parameters are flattened correctly
Inputs:
| alpha_posterior | |
|---|---|
| 4 | 1 |
| 2 | 3 |
Excel formula:
=DM_PREDICTIVE({4,1;2,3})
Expected output:
| Result | |||
|---|---|---|---|
| 0.4 | 0.1 | 0.2 | 0.3 |
Example 4: Weak posterior still returns normalized probabilities
Inputs:
| alpha_posterior | |||
|---|---|---|---|
| 0.6 | 0.9 | 1.2 | 0.8 |
Excel formula:
=DM_PREDICTIVE({0.6,0.9,1.2,0.8})
Expected output:
| Result | |||
|---|---|---|---|
| 0.171429 | 0.257143 | 0.342857 | 0.228571 |
Python Code
Show Code
def dm_predictive(alpha_posterior):
"""
Compute posterior predictive category probabilities from Dirichlet parameters.
See: https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
This example function is provided as-is without any representation of accuracy.
Args:
alpha_posterior (list[list]): Posterior Dirichlet hyperparameters as a 2D range of positive values.
Returns:
list[list]: Single-row 2D array of posterior predictive probabilities for each category.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
alpha_posterior = to2d(alpha_posterior)
if not isinstance(alpha_posterior, list) or not all(isinstance(row, list) for row in alpha_posterior):
return "Error: alpha_posterior must be a 2D list"
alpha_flat = []
for row in alpha_posterior:
for val in row:
try:
alpha_flat.append(float(val))
except (TypeError, ValueError):
continue
if len(alpha_flat) < 2:
return "Error: alpha_posterior must contain at least two positive values"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha_posterior values must be positive"
total_alpha = float(sum(alpha_flat))
probs = [a / total_alpha for a in alpha_flat]
return [probs]
except Exception as e:
return f"Error: {str(e)}"Online Calculator