"""Public API to plot 2d, 3d or multidimensional data as credal sets."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import mpltern # noqa: F401, required for ternary projection, do not remove
import numpy as np
from probly.visualization.credal.input_handling import dispatch_plot
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
else:
import matplotlib.pyplot as plt
[docs]
def create_credal_plot(
input_data: np.ndarray,
labels: list[str] | None = None,
title: str | None = None,
choice: str | None = None,
minmax: bool | None = None,
*,
show: bool = True,
) -> Axes:
"""Public API for credal sets; refers to the correct plotting method via input_handling.
Args:
input_data: NumPy array with probabilities.
labels: List of labels corresponding to the input data.
title: Custom or predefined title.
choice: Either "MLE", "Credal", "Probability" or None.
minmax: Enables to show the Min/Max lines only for ternary plots.
show: Enables the user to decide whether to show the plot or not.
"""
plot = dispatch_plot(
input_data,
labels=labels,
title=title,
choice=choice,
minmax=minmax,
)
if show:
plt.show()
return plot
[docs]
def simplex_plot(probs: np.ndarray) -> tuple[Figure, Axes]:
"""Plot probability distributions on the simplex.
Args:
probs: numpy.ndarray of shape (n_instances, n_classes)
Returns:
fig: matplotlib figure
ax: matplotlib axes
"""
fig = plt.figure()
ax = fig.add_subplot(projection="ternary")
ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2])
return fig, ax
[docs]
def credal_set_plot(probs: np.ndarray) -> tuple[Figure, Axes]:
"""Plot credal sets based on intervals of lower and upper probabilities.
Args:
probs: numpy.ndarray of shape (n_samples, n_classes)
Returns:
fig: matplotlib figure
ax: matplotlib axes
"""
fig = plt.figure()
ax = fig.add_subplot(projection="ternary")
lower_probs = np.min(probs, axis=0)
upper_probs = np.max(probs, axis=0)
lower_idxs = np.argmin(probs, axis=0)
upper_idxs = np.argmax(probs, axis=0)
edge_probs = np.vstack((probs[lower_idxs], probs[upper_idxs]))
vertices_ = []
for i, j, k in [(0, 1, 2), (1, 2, 0), (0, 2, 1)]:
for x in [lower_probs[i], upper_probs[i]]:
for y in [lower_probs[j], upper_probs[j]]:
z = 1 - x - y
if lower_probs[k] <= z <= upper_probs[k]:
prob = [0, 0, 0]
prob[i] = x
prob[j] = y
prob[k] = z
vertices_.append(prob)
vertices = np.array(vertices_)
if len(vertices) > 0:
center = np.mean(vertices, axis=0)
angles = np.arctan2(vertices[:, 1] - center[1], vertices[:, 0] - center[0])
vertices = vertices[np.argsort(angles)]
ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2])
vertices_closed = np.vstack([vertices, vertices[0]])
ax.fill(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2])
ax.plot(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2])
ax.scatter(edge_probs[:, 0], edge_probs[:, 1], edge_probs[:, 2])
else:
msg = "The set of vertices is empty. Please check the probabilities in the credal set."
raise ValueError(msg)
return fig, ax