Dirichlet distribution tutorial¶
What are Dirichlet distributions?¶
The Dirichlet distribution is a multivariate probability distribution defined over the probability simplex. It is commonly used to model uncertainty over categorical probability vectors, i.e. vectors that are non-negative and sum to one.
In the 3-dimensional case, the simplex can be visualized as an equilateral triangle, making ternary plots a natural and intuitive way to understand Dirichlet behavior.
1. Imports:¶
First of all we import the necessary packages.
import numpy as np
from probly.visualization.dirichlet.dirichlet_visualization import create_dirichlet_plot
print("Imports successfully loaded.")
2. Test Dirichlet distribution plot:¶
Let’s generate a Dirichlet plot with the alpha values (8, 2, 2). Where the values corespond to the likelyhood of an outcome.
alpha_values = np.array([8, 2, 2])
labels = ("Red", "Green", "Blue")
title = f"Distribution of likelyhood if balls in a Urn are: {labels}"
create_dirichlet_plot(alpha=alpha_values, labels=labels, title=title)
3. How to read the Dirichlet distribution plot:¶
Each corner corresponds to full probability mass on one category.
Each point inside the triangle represents a valid probability vector.
Higher density regions indicate more likely probability configurations.
Larger
αvalues lead to more concentrated (confident) distributions.Smaller
αvalues produce flatter, more uncertain distributions.
In this example, the distribution is most concentrated toward Red, reflecting the larger α value.
4. Understanding the implementation plot_dirichlet.py¶
The visualization pipeline is organized into a single class:
DirichletTernaryVisualizer.
This class is responsible for:
Defining the ternary geometry
Converting between coordinate systems
Evaluating the Dirichlet PDF
Rendering contours and annotations
4.1 Ternary Triangle Geometry¶
def triangle_corners(self) -> np.ndarray:
return np.array(
[
[0.0, 0.0],
[1.0, 0.0],
[0.5, np.sqrt(3) / 2],
]
)
This method defines the corners of an equilateral triangle in cartesian space. These corners represent the pure categorical outcomes:
(1, 0, 0)
(0, 1, 0)
(0, 0, 1)
mapped into 2D coordinates.
4.2 Cartesian to Barycentric coordinate conversion¶
def xy_to_barycentric(self, xy: np.ndarray, tol: float = 1e-4) -> np.ndarray:
"""Convert Cartesian coordinates to barycentric coordinates.
Args:
xy: Cartesian coordinates inside the triangle.
tol: Numerical tolerance to avoid simplex boundaries.
return: Barycentric coordinates.
"""
corners = self.triangle_corners()
def to3(v: np.ndarray) -> np.ndarray:
"""Promote 2D vector to 3D."""
return np.array([v[0], v[1], 0.0])
area = float(
0.5
* np.linalg.norm(
np.cross(
to3(corners[1] - corners[0]),
to3(corners[2] - corners[0]),
)
)
)
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
def tri_area(point: np.ndarray, pair: np.ndarray) -> float:
area = 0.5 * np.linalg.norm(
np.cross(
to3(pair[0] - point),
to3(pair[1] - point),
)
)
return float(area)
coords = np.array([tri_area(xy, p) for p in pairs]) / area
clipped_coords = np.clip(coords, tol, 1.0 - tol)
return clipped_coords
This function converts plot coordinates into Dirichlet-compatible inputs. While the plotting is done in cartesian space, the PDF is calculated through probabilty vectors i.e. barycentric coordinates.
How it works:¶
compute the area of the triangle
compute three sub-triangle areas
divide sub-triangle area by (big) triangle area which normalizes our coordinates
clip values to avoid instability near edges
4.3 Dirichlet distribution and PDF evaluation¶
class Dirichlet:
"""Dirichlet distribution."""
def __init__(self, alpha: np.ndarray) -> None:
"""Initialize the distribution.
Args:
alpha: Dirichlet concentration parameters.
"""
self.alpha = np.asarray(alpha)
self.coef = gamma(np.sum(self.alpha)) / np.prod([gamma(a) for a in self.alpha])
def pdf(self, x: np.ndarray) -> float:
"""Compute the Dirichlet pdf.
Args:
x: Barycentric coordinates.
return: Pdf value.
"""
return float(self.coef * np.prod([xx ** (aa - 1) for xx, aa in zip(x, self.alpha, strict=False)]))
This inner class implements the Dirichlet PDF.
4.4 Rendering the Dirichlet plot with contours¶
def dirichlet_plot( # noqa: D417
self,
alpha: np.ndarray,
labels: list[str],
title: str,
ax: Axes | None = None,
subdiv: int = 7,
nlevels: int = 200,
cmap: str = "viridis",
**contour_kwargs: Any, # noqa: ANN401
) -> Axes:
"""Plot Dirichlet pdf contours on a ternary simplex.
Args:
alpha: Dirichlet concentration parameters.
labels: the labels of the ternary corners.
title: title of the plot.
ax: matplotlib axes.Axes to plot on.
subdiv: triangular mesh subdivision depth.
nlevels: number of contour levels.
cmap: matplotlib colormap.
returns: Ternary plot with Dirichlet contours.
"""
corners = self.triangle_corners()
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=subdiv)
dist = self.Dirichlet(alpha)
pvals = [dist.pdf(self.xy_to_barycentric(np.array(xy))) for xy in zip(trimesh.x, trimesh.y, strict=False)]
if ax is None:
fig, ax = plt.subplots(figsize=(6, 6))
fig.subplots_adjust(bottom=0.25)
ax.tricontourf(
trimesh,
pvals,
nlevels,
cmap=cmap,
**contour_kwargs,
)
ax.plot(
[corners[0, 0], corners[1, 0], corners[2, 0], corners[0, 0]],
[corners[0, 1], corners[1, 1], corners[2, 1], corners[0, 1]],
color=cfg.BLACK,
)
self.label_corners_and_vertices(ax, labels)
ax.set_aspect("equal", "box")
ax.set_xlim(-0.1, 1.1)
ax.set_ylim(-0.1, np.sqrt(3) / 2)
ax.axis("off")
ax.set_title(title, pad=40)
return ax
Here everything comes together.
Generate simplex mesh
Refine mesh for smooth contours
Convert mesh points -> barycentric coords
Evaluate Dirichlet PDF
Render filled contours
Overlay geometry and annotations
The result: We generate a visually interpretable Dirichlet plot.
5. ONE MODULE TO ~~RULE THEM ALL~~ generate
the completed Dirichlet distribution plot dirichlet_visualization.py¶
Here we have the code:
def create_dirichlet_plot(
alpha: np.ndarray,
labels: list[str] | None = None,
title: str | None = None,
*,
show: bool = True,
) -> Axes | None:
"""Create a ternary Dirichlet distribution plot.
Args:
alpha: Dirichlet concentration parameters.
labels: List of labels corresponding to the simplex corners.
title: Custom plot title.
show: Enables the user to decide whether to show the plot or not.
"""
alpha = np.asarray(alpha)
if alpha.shape != (3,):
msg = "Dirichlet plot requires exactly three alpha values."
raise ValueError(msg)
if labels is None:
labels = ["θ₁", "θ₂", "θ₃"]
if title is None:
title = f"Dirichlet Distribution (α = {alpha.tolist()})" # noqa: RUF001
visualizer = DirichletTernaryVisualizer()
ax = visualizer.dirichlet_plot(
alpha=alpha,
labels=labels,
title=title,
)
if show:
plt.show()
return ax
Here we can pass through the following arguments:
alphalabelstitle
Although only alpha has to be passed through in the form of an array of three.
labels and title get auto-populated with text if nothing is passed through.
6. So how do alpha values affect the plot?¶
If we were to give:
bigger alpha values or
small alpha values.
bigAlphaValues = np.array([20, 80, 20]) # noqa: N816
create_dirichlet_plot(alpha=bigAlphaValues)
Bigger alpha values lead to a concentrated distribution.
small_alpha_values = np.array([3, 8, 2])
create_dirichlet_plot(small_alpha_values)
Smaller alpha values lead to a more unconcetrated distribution.