probly.representation.sample.axis_tracking.track_axis

probly.representation.sample.axis_tracking.track_axis(index: ToIndices, special_axis: int, ndim: int, torch_indexing: bool = False) AxisTrackingResult | None[source]

Track the new position of a ‘special’ axis after a NumPy-style __getitem__ indexing operation.

Parameters:
  • index – The indexing object used in arr[index]. Can be a slice, int, None, list, ndarray, ellipsis, or a tuple of such.

  • special_axis – Index of the axis to track (0-based) before indexing.

  • ndim – Number of dimensions of the array before indexing.

  • torch_indexing – Whether to apply PyTorch’s mixed basic/advanced indexing rules instead of NumPy’s.

Returns:

The new axis index of the special axis after indexing, or None if the axis is removed (e.g., via integer indexing).