DIRICHLET_MULTINOMIAL
Overview
The DIRICHLET_MULTINOMIAL
function computes properties of the Dirichlet multinomial distribution, which models the probability of observing a set of counts across multiple categories when the underlying probabilities are themselves uncertain and modeled by a Dirichlet prior. This is useful in Bayesian statistics, genetics, and machine learning for modeling overdispersed count data. The Dirichlet multinomial is a compound distribution: probabilities for each category are drawn from a Dirichlet distribution, then counts are drawn from a multinomial distribution using those probabilities.
The probability mass function (pmf) is:
where are observed counts, are concentration parameters, is the total number of trials, and is the number of categories. For more details, see the scipy.stats.dirichlet_multinomial documentation .
This wrapper exposes the most commonly used methods: pmf
, logpmf
, mean
, var
, and cov
. The seed
parameter from the underlying Python function is not supported. All inputs must be scalars or 2D lists as supported by Excel. This example function is provided as-is without any representation of accuracy.
Usage
To use the function in Excel:
=DIRICHLET_MULTINOMIAL([x], alpha, [n], [method])
x
(2D list of integers, optional): Table of counts for each category. Required forpmf
andlogpmf
methods. Each row is a set of counts for one distribution.alpha
(2D list of floats, required): Table of concentration parameters. Each row is a set of parameters for one distribution.n
(list of integers, optional): Number of trials for each distribution. Required for all methods exceptcov
.method
(string, optional, default=pmf
): Which method to compute:pmf
,logpmf
,mean
,var
, orcov
.
The function returns a 2D list of results for each input row, or an error message (string) if the input is invalid. For pmf
and logpmf
, the result is a probability or log-probability. For mean
and var
, the result is a list of values for each category. For cov
, the result is a covariance matrix.
Examples
Example 1: Basic PMF Calculation
Inputs:
x | alpha | n | method | ||||
---|---|---|---|---|---|---|---|
2 | 3 | 5 | 1.0 | 1.0 | 1.0 | 10 | pmf |
Excel formula:
=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "pmf")
Expected output:
Result |
---|
0.015 |
Example 2: Log-PMF Calculation
Inputs:
x | alpha | n | method | ||||
---|---|---|---|---|---|---|---|
2 | 3 | 5 | 1.0 | 1.0 | 1.0 | 10 | logpmf |
Excel formula:
=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "logpmf")
Expected output:
Result |
---|
-4.190 |
Example 3: Mean Calculation
Inputs:
alpha | n | method | ||
---|---|---|---|---|
2.0 | 3.0 | 5.0 | 10 | mean |
Excel formula:
=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, {10}, "mean")
Expected output:
Result | ||
---|---|---|
2.000 | 3.000 | 5.000 |
Example 4: Covariance Matrix Calculation
Inputs:
alpha | method | ||
---|---|---|---|
2.0 | 3.0 | 5.0 | cov |
Excel formula:
=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, , "cov")
Expected output:
Result | ||
---|---|---|
0.160 | -0.060 | -0.100 |
-0.060 | 0.210 | -0.150 |
-0.100 | -0.150 | 0.250 |
Python Code
from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial
from typing import List, Optional, Union
def dirichlet_multinomial(
x: Optional[List[List[int]]] = None,
alpha: List[List[float]] = None,
n: Optional[List[int]] = None,
method: str = 'pmf'
) -> Union[List[List[Optional[float]]], str]:
"""
Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.
Args:
x: 2D list of integer counts for each category. Required for 'pmf' and 'logpmf' methods.
alpha: 2D list of float concentration parameters. Each row is a set of parameters for one distribution.
n: List of integers, number of trials for each distribution. Required for all methods except 'cov'.
method: Which method to compute (str): 'pmf', 'logpmf', 'mean', 'var', 'cov'. Default is 'pmf'.
Returns:
2D list of results for each input, or an error message (str) if input is invalid.
This example function is provided as-is without any representation of accuracy.
"""
# Validate method
valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'}
if method not in valid_methods:
return f"Invalid method: {method}. Must be one of {sorted(valid_methods)}."
# Validate alpha
if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha):
return "Invalid input: alpha must be a 2D list of floats."
if len(alpha) < 1:
return "Invalid input: alpha must be a 2D list with at least one row."
# Validate n
if method != 'cov':
if n is None or not isinstance(n, list) or len(n) != len(alpha):
return "Invalid input: n must be a list of integers with the same length as alpha."
try:
n = [int(val) for val in n]
except Exception:
return "Invalid input: n must contain integers."
if any(val < 0 for val in n):
return "Invalid input: n must contain non-negative integers."
# Validate x
if method in {'pmf', 'logpmf'}:
if x is None or not isinstance(x, list) or len(x) != len(alpha):
return "Invalid input: x must be a 2D list of integers with the same number of rows as alpha."
for i, row in enumerate(x):
if not isinstance(row, list) or len(row) != len(alpha[0]):
return "Invalid input: each row of x must have the same length as alpha[0]."
try:
if any(int(val) < 0 for val in row):
return "Invalid input: x must contain non-negative integers."
except Exception:
return "Invalid input: x must contain integers."
# Prepare results
results = []
for i, alpha_row in enumerate(alpha):
try:
# Always require n for instantiation
n_val = None
if method == 'cov':
# Use n=1 if not provided
if n is not None and len(n) > i:
n_val = n[i]
else:
n_val = 1
else:
n_val = n[i]
dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)
if method == 'pmf':
res = dist.pmf(x[i])
elif method == 'logpmf':
res = dist.logpmf(x[i])
elif method == 'mean':
res = dist.mean()
elif method == 'var':
res = dist.var()
elif method == 'cov':
res = dist.cov()
# Convert output to 2D list if needed
if method == 'cov':
# Covariance is a 2D array, return as a 2D list (not wrapped)
if hasattr(res, 'tolist'):
arr = res.tolist()
results.append([float(val) for val in arr[0]]) # flatten if 1xN
if len(arr) > 1:
results[:] = [[float(val) for val in row] for row in arr]
else:
results[:] = [ [float(val) for val in arr[0]] ]
elif isinstance(res, list) and all(isinstance(x, list) for x in res):
results[:] = [ [float(val) for val in row] for row in res ]
else:
results.append([str(res)])
else:
if isinstance(res, (float, int)):
results.append([float(res)])
elif isinstance(res, list):
if all(isinstance(x, (float, int)) for x in res):
results.append([float(val) for val in res])
elif all(isinstance(x, list) for x in res):
results.append([[float(val) for val in row] for row in res])
else:
results.append([str(res)])
elif hasattr(res, 'tolist'):
arr = res.tolist()
if isinstance(arr[0], list):
results.append([[float(val) for val in row] for row in arr])
else:
results.append([float(val) for val in arr])
else:
results.append([str(res)])
except Exception as e:
results.append([f"scipy.stats.dirichlet_multinomial error: {e}"])
return results