DIRICHLET_MULTINOM

Overview

The DIRICHLET_MULTINOM function computes statistical properties of the Dirichlet-multinomial distribution, a compound probability distribution that arises when category probabilities are uncertain. Also known as the Dirichlet compound multinomial (DCM) or multivariate Pólya distribution, it models scenarios where observations follow a multinomial distribution with probabilities drawn from a Dirichlet distribution.

This distribution is constructed by first drawing a probability vector \mathbf{p} from a Dirichlet distribution with concentration parameters \boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_K), then drawing counts from a multinomial distribution with n trials and probability vector \mathbf{p}. The probability mass function is:

P(\mathbf{x} \mid n, \boldsymbol{\alpha}) = \frac{\Gamma(\alpha_0) \Gamma(n+1)}{\Gamma(n + \alpha_0)} \prod_{k=1}^{K} \frac{\Gamma(x_k + \alpha_k)}{\Gamma(\alpha_k) \Gamma(x_k + 1)}

where \alpha_0 = \sum_{k=1}^{K} \alpha_k is the sum of concentration parameters, and \mathbf{x} = (x_1, \ldots, x_K) represents counts in each of K categories with \sum x_k = n.

The expected value for category i is E(X_i) = n \alpha_i / \alpha_0, and the variance is:

\text{Var}(X_i) = n \frac{\alpha_i}{\alpha_0} \left(1 - \frac{\alpha_i}{\alpha_0}\right) \frac{n + \alpha_0}{1 + \alpha_0}

The distribution exhibits overdispersion relative to the multinomial—the variance is inflated by a factor of (n + \alpha_0)/(1 + \alpha_0). This makes it suitable for modeling count data with extra variability, such as word frequencies in documents or allele counts in population genetics. The concentration parameter \alpha_0 controls the degree of overdispersion: smaller values produce greater variability, while larger values make the distribution approach a standard multinomial.

This implementation uses SciPy’s dirichlet_multinomial module and supports computing the PMF, log-PMF, mean, variance, and covariance matrix. For additional theoretical background, see the Wikipedia article on the Dirichlet-multinomial distribution.

This example function is provided as-is without any representation of accuracy.

Excel Usage

=DIRICHLET_MULTINOM(x, alpha, n, dm_method)
  • x (list[list], optional, default: null): 2D list of integer counts for each category. Required for pmf and logpmf methods.
  • alpha (list[list], optional, default: null): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution.
  • n (list[list], optional, default: null): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov.
  • dm_method (str, optional, default: “pmf”): Computation method to use.

Returns (list[list]): 2D list of results, or error message string.

Example 1: Basic PMF calculation with uniform concentration

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 pmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "pmf")

Expected output:

0.0151515

Example 2: Log-PMF calculation for same distribution

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 logpmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "logpmf")

Expected output:

-4.18965

Example 3: Expected mean counts for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 mean

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "mean")

Expected output:

Result
2 3 5
Example 4: Variance for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 var

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "var")

Expected output:

Result
2.90909 3.81818 4.54545
Example 5: Covariance matrix for three categories

Inputs:

alpha dm_method
2 3 5 cov

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, "cov")

Expected output:

Result
0.16 -0.06 -0.1
-0.06 0.21 -0.15
-0.1 -0.15 0.25

Python Code

Show Code
from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial

def dirichlet_multinom(x=None, alpha=None, n=None, dm_method='pmf'):
    """
    Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.

    See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet_multinomial.html

    This example function is provided as-is without any representation of accuracy.

    Args:
        x (list[list], optional): 2D list of integer counts for each category. Required for pmf and logpmf methods. Default is None.
        alpha (list[list], optional): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution. Default is None.
        n (list[list], optional): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov. Default is None.
        dm_method (str, optional): Computation method to use. Valid options: PMF, Log PMF, Mean, Variance, Covariance. Default is 'pmf'.

    Returns:
        list[list]: 2D list of results, or error message string.
    """
    def to2d(val):
        if val is None:
            return None
        return [[val]] if not isinstance(val, list) else val

    def to_float_list(arr):
        if hasattr(arr, "tolist"):
            arr = arr.tolist()
        if isinstance(arr, (float, int)):
            return [float(arr)]
        return [float(v) for v in arr]

    try:
        valid_methods = {"pmf", "logpmf", "mean", "var", "cov"}
        if dm_method not in valid_methods:
            return f"Error: Invalid method '{dm_method}'. Must be one of {sorted(valid_methods)}."

        if alpha is None:
            return "Error: Invalid input: alpha is required."
        alpha = to2d(alpha)
        if not isinstance(alpha, list) or len(alpha) < 1 or not all(
            isinstance(row, list) and len(row) > 0 for row in alpha
        ):
            return "Error: alpha must be a 2D list of positive floats."

        try:
            alpha = [[float(v) for v in row] for row in alpha]
        except (TypeError, ValueError):
            return "Error: alpha must be a 2D list of positive floats."
        if any(any(v <= 0 for v in row) for row in alpha):
            return "Error: alpha must be a 2D list of positive floats."

        # n is required for pmf/logpmf/mean/var; for cov, default to n=1 if omitted
        if dm_method != "cov":
            if n is None:
                return "Error: Invalid input: n is required."
            n = to2d(n)
            if not isinstance(n, list) or len(n) != len(alpha):
                return "Error: n must be a 2D list with the same number of rows as alpha."
            for n_row in n:
                if not isinstance(n_row, list) or len(n_row) != 1:
                    return "Error: Each row of n must contain exactly one integer."
            try:
                n = [[int(val[0])] for val in n]
            except (TypeError, ValueError):
                return "Error: n must contain integers."
            if any(val[0] < 0 for val in n):
                return "Error: n must contain non-negative integers."
        else:
            if n is not None:
                n = to2d(n)
                if not isinstance(n, list) or len(n) != len(alpha):
                    return "Error: n must be a 2D list with the same number of rows as alpha."
                for n_row in n:
                    if not isinstance(n_row, list) or len(n_row) != 1:
                        return "Error: Each row of n must contain exactly one integer."
                try:
                    n = [[int(val[0])] for val in n]
                except (TypeError, ValueError):
                    return "Error: n must contain integers."
                if any(val[0] < 0 for val in n):
                    return "Error: n must contain non-negative integers."

        if dm_method in {"pmf", "logpmf"}:
            if x is None:
                return "Error: Invalid input: x is required for pmf/logpmf."
            x = to2d(x)
            if not isinstance(x, list) or len(x) != len(alpha):
                return "Error: x must be a 2D list with the same number of rows as alpha."
            for row in x:
                if not isinstance(row, list) or len(row) != len(alpha[0]):
                    return "Error: Each row of x must have the same length as alpha rows."
                try:
                    if any(int(val) < 0 for val in row):
                        return "Error: x must contain non-negative integers."
                except (TypeError, ValueError):
                    return "Error: x must contain integers."

        results = []
        for i, alpha_row in enumerate(alpha):
            try:
                if dm_method == "cov":
                    n_val = 1 if n is None else n[i][0]
                else:
                    n_val = n[i][0]

                if dm_method in {"pmf", "logpmf"}:
                    row_sum = sum(int(v) for v in x[i])
                    if row_sum != n_val:
                        return "Error: Invalid input: each row of x must sum to n."

                dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)

                if dm_method == "pmf":
                    res = dist.pmf(x[i])
                elif dm_method == "logpmf":
                    res = dist.logpmf(x[i])
                elif dm_method == "mean":
                    res = dist.mean()
                elif dm_method == "var":
                    res = dist.var()
                elif dm_method == "cov":
                    res = dist.cov()
                else:
                    return "Error: Invalid method."

                if dm_method == "cov":
                    cov_matrix = res.tolist() if hasattr(res, "tolist") else res
                    for row in cov_matrix:
                        results.append([float(val) for val in row])
                else:
                    results.append(to_float_list(res))
            except Exception as e:
                return f"Error: computing {dm_method}: {e}"

        return results
    except Exception as e:
        return f"Error: {str(e)}"

Online Calculator

2D list of integer counts for each category. Required for pmf and logpmf methods.
2D list of concentration parameters (positive floats). Each row represents parameters for one distribution.
2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov.
Computation method to use.