Skip to Content

DIRICHLET_MULTINOMIAL

Overview

The DIRICHLET_MULTINOMIAL function computes properties of the Dirichlet multinomial distribution, which models the probability of observing a set of counts across multiple categories when the underlying probabilities are themselves uncertain and modeled by a Dirichlet prior. This is useful in Bayesian statistics, genetics, and machine learning for modeling overdispersed count data. The Dirichlet multinomial is a compound distribution: probabilities for each category are drawn from a Dirichlet distribution, then counts are drawn from a multinomial distribution using those probabilities.

The probability mass function (pmf) is:

P(X=x)=Γ(n+1)i=1KΓ(xi+1)Γ(i=1Kαi)Γ(n+i=1Kαi)i=1KΓ(xi+αi)Γ(αi)P(X = x) = \frac{\Gamma(n+1)}{\prod_{i=1}^K \Gamma(x_i+1)} \frac{\Gamma(\sum_{i=1}^K \alpha_i)}{\Gamma(n+\sum_{i=1}^K \alpha_i)} \prod_{i=1}^K \frac{\Gamma(x_i+\alpha_i)}{\Gamma(\alpha_i)}

where xix_i are observed counts, αi\alpha_i are concentration parameters, nn is the total number of trials, and KK is the number of categories. For more details, see the scipy.stats.dirichlet_multinomial documentation.

This wrapper exposes the most commonly used methods: pmf, logpmf, mean, var, and cov. The seed parameter from the underlying Python function is not supported. All inputs must be scalars or 2D lists as supported by Excel. This example function is provided as-is without any representation of accuracy.

Usage

To use the function in Excel:

=DIRICHLET_MULTINOMIAL([x], alpha, [n], [method])
  • x (2D list of integers, optional): Table of counts for each category. Required for pmf and logpmf methods. Each row is a set of counts for one distribution.
  • alpha (2D list of floats, required): Table of concentration parameters. Each row is a set of parameters for one distribution.
  • n (list of integers, optional): Number of trials for each distribution. Required for all methods except cov.
  • method (string, optional, default=pmf): Which method to compute: pmf, logpmf, mean, var, or cov.

The function returns a 2D list of results for each input row, or an error message (string) if the input is invalid. For pmf and logpmf, the result is a probability or log-probability. For mean and var, the result is a list of values for each category. For cov, the result is a covariance matrix.

Examples

Example 1: Basic PMF Calculation

Inputs:

xalphanmethod
2351.01.01.010pmf

Excel formula:

=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "pmf")

Expected output:

Result
0.015

Example 2: Log-PMF Calculation

Inputs:

xalphanmethod
2351.01.01.010logpmf

Excel formula:

=DIRICHLET_MULTINOMIAL({2,3,5}, {1.0,1.0,1.0}, {10}, "logpmf")

Expected output:

Result
-4.190

Example 3: Mean Calculation

Inputs:

alphanmethod
2.03.05.010mean

Excel formula:

=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, {10}, "mean")

Expected output:

Result
2.0003.0005.000

Example 4: Covariance Matrix Calculation

Inputs:

alphamethod
2.03.05.0cov

Excel formula:

=DIRICHLET_MULTINOMIAL(, {2.0,3.0,5.0}, , "cov")

Expected output:

Result
0.160-0.060-0.100
-0.0600.210-0.150
-0.100-0.1500.250

Python Code

from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial from typing import List, Optional, Union def dirichlet_multinomial( x: Optional[List[List[int]]] = None, alpha: List[List[float]] = None, n: Optional[List[int]] = None, method: str = 'pmf' ) -> Union[List[List[Optional[float]]], str]: """ Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution. Args: x: 2D list of integer counts for each category. Required for 'pmf' and 'logpmf' methods. alpha: 2D list of float concentration parameters. Each row is a set of parameters for one distribution. n: List of integers, number of trials for each distribution. Required for all methods except 'cov'. method: Which method to compute (str): 'pmf', 'logpmf', 'mean', 'var', 'cov'. Default is 'pmf'. Returns: 2D list of results for each input, or an error message (str) if input is invalid. This example function is provided as-is without any representation of accuracy. """ # Validate method valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'} if method not in valid_methods: return f"Invalid method: {method}. Must be one of {sorted(valid_methods)}." # Validate alpha if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha): return "Invalid input: alpha must be a 2D list of floats." if len(alpha) < 1: return "Invalid input: alpha must be a 2D list with at least one row." # Validate n if method != 'cov': if n is None or not isinstance(n, list) or len(n) != len(alpha): return "Invalid input: n must be a list of integers with the same length as alpha." try: n = [int(val) for val in n] except Exception: return "Invalid input: n must contain integers." if any(val < 0 for val in n): return "Invalid input: n must contain non-negative integers." # Validate x if method in {'pmf', 'logpmf'}: if x is None or not isinstance(x, list) or len(x) != len(alpha): return "Invalid input: x must be a 2D list of integers with the same number of rows as alpha." for i, row in enumerate(x): if not isinstance(row, list) or len(row) != len(alpha[0]): return "Invalid input: each row of x must have the same length as alpha[0]." try: if any(int(val) < 0 for val in row): return "Invalid input: x must contain non-negative integers." except Exception: return "Invalid input: x must contain integers." # Prepare results results = [] for i, alpha_row in enumerate(alpha): try: # Always require n for instantiation n_val = None if method == 'cov': # Use n=1 if not provided if n is not None and len(n) > i: n_val = n[i] else: n_val = 1 else: n_val = n[i] dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val) if method == 'pmf': res = dist.pmf(x[i]) elif method == 'logpmf': res = dist.logpmf(x[i]) elif method == 'mean': res = dist.mean() elif method == 'var': res = dist.var() elif method == 'cov': res = dist.cov() # Convert output to 2D list if needed if method == 'cov': # Covariance is a 2D array, return as a 2D list (not wrapped) if hasattr(res, 'tolist'): arr = res.tolist() results.append([float(val) for val in arr[0]]) # flatten if 1xN if len(arr) > 1: results[:] = [[float(val) for val in row] for row in arr] else: results[:] = [ [float(val) for val in arr[0]] ] elif isinstance(res, list) and all(isinstance(x, list) for x in res): results[:] = [ [float(val) for val in row] for row in res ] else: results.append([str(res)]) else: if isinstance(res, (float, int)): results.append([float(res)]) elif isinstance(res, list): if all(isinstance(x, (float, int)) for x in res): results.append([float(val) for val in res]) elif all(isinstance(x, list) for x in res): results.append([[float(val) for val in row] for row in res]) else: results.append([str(res)]) elif hasattr(res, 'tolist'): arr = res.tolist() if isinstance(arr[0], list): results.append([[float(val) for val in row] for row in arr]) else: results.append([float(val) for val in arr]) else: results.append([str(res)]) except Exception as e: results.append([f"scipy.stats.dirichlet_multinomial error: {e}"]) return results

Example Workbook

Link to Workbook

Last updated on