DIRICHLET_MULTINOMIAL
Overview
The DIRICHLET_MULTINOMIAL function computes properties of the Dirichlet multinomial distribution, which models the probability of observing a set of counts across multiple categories when the underlying probabilities are themselves uncertain and modeled by a Dirichlet prior. This is useful in Bayesian statistics, genetics, and machine learning for modeling overdispersed count data. The Dirichlet multinomial is a compound distribution: probabilities for each category are drawn from a Dirichlet distribution, then counts are drawn from a multinomial distribution using those probabilities.
The probability mass function (pmf) is:
where are observed counts, are concentration parameters, is the total number of trials, and is the number of categories. For more details, see the scipy.stats.dirichlet_multinomial documentation .
This wrapper exposes the most commonly used methods: pmf, logpmf, mean, var, and cov. The seed parameter from the underlying Python function is not supported. All inputs must be scalars or 2D lists as supported by Excel. This example function is provided as-is without any representation of accuracy.
Usage
To use the function in Excel:
=DIRICHLET_MULTINOMIAL([x], alpha, [n], [method])x(2D list of integers, optional): Table of counts for each category. Required forpmfandlogpmfmethods. Each row is a set of counts for one distribution.alpha(2D list of floats, required): Table of concentration parameters. Each row is a set of parameters for one distribution.n(list of integers, optional): Number of trials for each distribution. Required for all methods exceptcov.method(string, optional, default=pmf): Which method to compute:pmf,logpmf,mean,var, orcov.
The function returns a 2D list of results for each input row, or an error message (string) if the input is invalid. For pmf and logpmf, the result is a probability or log-probability. For mean and var, the result is a list of values for each category. For cov, the result is a covariance matrix.
Examples
Example 1: Basic PMF Calculation
Inputs:
| x | alpha | n | method | ||||
|---|---|---|---|---|---|---|---|
| 2 | 3 | 5 | 1.0 | 1.0 | 1.0 | 10 | pmf |
Excel formula:
=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "pmf")Expected output:
| Result |
|---|
| 0.015 |
Example 2: Log-PMF Calculation
Inputs:
| x | alpha | n | method | ||||
|---|---|---|---|---|---|---|---|
| 2 | 3 | 5 | 1.0 | 1.0 | 1.0 | 10 | logpmf |
Excel formula:
=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "logpmf")Expected output:
| Result |
|---|
| -4.190 |
Example 3: Mean Calculation
Inputs:
| alpha | n | method | ||
|---|---|---|---|---|
| 2.0 | 3.0 | 5.0 | 10 | mean |
Excel formula:
=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, {10}, "mean")Expected output:
| Result | ||
|---|---|---|
| 2.000 | 3.000 | 5.000 |
Example 4: Covariance Matrix Calculation
Inputs:
| alpha | method | ||
|---|---|---|---|
| 2.0 | 3.0 | 5.0 | cov |
Excel formula:
=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, , "cov")Expected output:
| Result | ||
|---|---|---|
| 0.160 | -0.060 | -0.100 |
| -0.060 | 0.210 | -0.150 |
| -0.100 | -0.150 | 0.250 |
Python Code
from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial
from typing import List, Optional, Union
def dirichlet_multinomial(
x: Optional[List[List[int]]] = None,
alpha: List[List[float]] = None,
n: Optional[List[int]] = None,
method: str = 'pmf'
) -> Union[List[List[Optional[float]]], str]:
"""
Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.
Args:
x: 2D list of integer counts for each category. Required for 'pmf' and 'logpmf' methods.
alpha: 2D list of float concentration parameters. Each row is a set of parameters for one distribution.
n: List of integers, number of trials for each distribution. Required for all methods except 'cov'.
method: Which method to compute (str): 'pmf', 'logpmf', 'mean', 'var', 'cov'. Default is 'pmf'.
Returns:
2D list of results for each input, or an error message (str) if input is invalid.
This example function is provided as-is without any representation of accuracy.
"""
# Validate method
valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'}
if method not in valid_methods:
return f"Invalid method: {method}. Must be one of {sorted(valid_methods)}."
# Validate alpha
if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha):
return "Invalid input: alpha must be a 2D list of floats."
if len(alpha) < 1:
return "Invalid input: alpha must be a 2D list with at least one row."
# Validate n
if method != 'cov':
if n is None or not isinstance(n, list) or len(n) != len(alpha):
return "Invalid input: n must be a list of integers with the same length as alpha."
try:
n = [int(val) for val in n]
except Exception:
return "Invalid input: n must contain integers."
if any(val < 0 for val in n):
return "Invalid input: n must contain non-negative integers."
# Validate x
if method in {'pmf', 'logpmf'}:
if x is None or not isinstance(x, list) or len(x) != len(alpha):
return "Invalid input: x must be a 2D list of integers with the same number of rows as alpha."
for i, row in enumerate(x):
if not isinstance(row, list) or len(row) != len(alpha[0]):
return "Invalid input: each row of x must have the same length as alpha[0]."
try:
if any(int(val) < 0 for val in row):
return "Invalid input: x must contain non-negative integers."
except Exception:
return "Invalid input: x must contain integers."
# Prepare results
results = []
for i, alpha_row in enumerate(alpha):
try:
# Always require n for instantiation
n_val = None
if method == 'cov':
# Use n=1 if not provided
if n is not None and len(n) > i:
n_val = n[i]
else:
n_val = 1
else:
n_val = n[i]
dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)
if method == 'pmf':
res = dist.pmf(x[i])
elif method == 'logpmf':
res = dist.logpmf(x[i])
elif method == 'mean':
res = dist.mean()
elif method == 'var':
res = dist.var()
elif method == 'cov':
res = dist.cov()
# Convert output to 2D list if needed
if method == 'cov':
# Covariance is a 2D array, return as a 2D list (not wrapped)
if hasattr(res, 'tolist'):
arr = res.tolist()
results.append([float(val) for val in arr[0]]) # flatten if 1xN
if len(arr) > 1:
results[:] = [[float(val) for val in row] for row in arr]
else:
results[:] = [ [float(val) for val in arr[0]] ]
elif isinstance(res, list) and all(isinstance(x, list) for x in res):
results[:] = [ [float(val) for val in row] for row in res ]
else:
results.append([str(res)])
else:
if isinstance(res, (float, int)):
results.append([float(res)])
elif isinstance(res, list):
if all(isinstance(x, (float, int)) for x in res):
results.append([float(val) for val in res])
elif all(isinstance(x, list) for x in res):
results.append([[float(val) for val in row] for row in res])
else:
results.append([str(res)])
elif hasattr(res, 'tolist'):
arr = res.tolist()
if isinstance(arr[0], list):
results.append([[float(val) for val in row] for row in arr])
else:
results.append([float(val) for val in arr])
else:
results.append([str(res)])
except Exception as e:
results.append([f"scipy.stats.dirichlet_multinomial error: {e}"])
return results