1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
# The point of this test is to make sure that the infrastructure for supporting
# custom attributes, like title in Hist, is working.
from __future__ import annotations
import pytest
import boost_histogram as bh
# First, make a new family to identify your library
CUSTOM_FAMILY = object()
# Add named axes
class NamedAxesTuple(bh.axis.AxesTuple):
__slots__ = ()
def _get_index_by_name(self, name):
if not isinstance(name, str):
return name
for i, ax in enumerate(self):
if ax.name == name:
return i
raise KeyError(f"{name} not found in axes")
def __getitem__(self, item):
if isinstance(item, slice):
item = slice(
self._get_index_by_name(item.start),
self._get_index_by_name(item.stop),
self._get_index_by_name(item.step),
)
else:
item = self._get_index_by_name(item)
return super().__getitem__(item)
@property
def name(self):
"""
The names of the axes. May be empty strings.
"""
return tuple(ax.name for ax in self)
@name.setter
def name(self, values):
for ax, val in zip(self, values, strict=False):
ax._ax.raw_metadata["name"] = f"test: {val}"
# When you subclass Histogram or an Axes, you should register your family so
# boost-histogram will know what to convert C++ objects into.
class AxesMixin:
__slots__ = ()
# Only required for placing the Mixin first
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@property
def name(self):
"""
Get the name for the Regular axis
"""
return self._ax.raw_metadata.get("name", "")
# The order of the mixin is important here - it must be first
# if it needs to override bh.axis.Regular, otherwise, last is simpler,
# as it doesn't need to forward __init_subclass__ kwargs then.
class Regular(bh.axis.Regular, AxesMixin, family=CUSTOM_FAMILY):
__slots__ = ()
def __init__(self, bins, start, stop, name):
super().__init__(bins, start, stop)
self._ax.raw_metadata["name"] = name
class Integer(AxesMixin, bh.axis.Integer, family=CUSTOM_FAMILY):
__slots__ = ()
def __init__(self, start, stop, name):
super().__init__(start, stop)
self._ax.raw_metadata["name"] = name
class CustomHist(bh.Histogram, family=CUSTOM_FAMILY):
def _generate_axes_(self):
return NamedAxesTuple(self._axis(i) for i in range(self.ndim))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
valid_names = [ax.name for ax in self.axes if ax.name]
if len(valid_names) != len(set(valid_names)):
msg = f"{self.__class__.__name__} instance cannot contain axes with duplicated names"
raise KeyError(msg)
def test_hist_creation():
hist_1 = CustomHist(Regular(10, 0, 1, name="a"), Integer(0, 4, name="b"))
assert hist_1.axes[0].name == "a"
assert hist_1.axes[1].name == "b"
hist_2 = CustomHist(Regular(10, 0, 1, name=""), Regular(20, 0, 4, name=""))
assert not hist_2.axes[0].name
assert not hist_2.axes[1].name
with pytest.raises(KeyError):
CustomHist(Regular(10, 0, 1, name="a"), Regular(20, 0, 4, name="a"))
def test_hist_index():
hist_1 = CustomHist(Regular(10, 0, 1, name="a"), Regular(20, 0, 4, name="b"))
assert hist_1.axes[0].name == "a"
assert hist_1.axes[1].name == "b"
def test_hist_convert():
hist_1 = CustomHist(Regular(10, 0, 1, name="a"), Integer(0, 4, name="b"))
hist_bh = bh.Histogram(hist_1)
assert type(hist_bh.axes[0]) is bh.axis.Regular
assert type(hist_bh.axes[1]) is bh.axis.Integer
assert hist_bh.axes[0].name == "a"
assert hist_bh.axes[1].name == "b"
hist_2 = CustomHist(hist_bh)
assert type(hist_2.axes[0]) is Regular
assert type(hist_2.axes[1]) is Integer
assert hist_2.axes[0].name == "a"
assert hist_2.axes[1].name == "b"
# Just verify no-op status
hist_3 = CustomHist(hist_1)
assert type(hist_3.axes[0]) is Regular
assert type(hist_3.axes[1]) is Integer
assert hist_3.axes[0].name == "a"
assert hist_3.axes[1].name == "b"
def test_access():
hist = CustomHist(Regular(10, 0, 1, name="a"), Regular(20, 0, 4, name="b"))
assert hist.axes["a"] == hist.axes[0]
assert hist.axes["b"] == hist.axes[1]
from_bh = bh.Histogram(bh.axis.Regular(10, 0, 1), bh.axis.Regular(20, 0, 4))
from_bh.axes.name = "a", "b"
hist_conv = CustomHist(from_bh)
assert hist_conv.axes["a"] == hist_conv.axes[0]
assert hist_conv.axes["b"] == hist_conv.axes[1]
def test_hist_name_set():
hist_1 = CustomHist(Regular(10, 0, 1, name="a"), Regular(20, 0, 4, name="b"))
hist_1.axes.name = ("c", "d")
assert hist_1.axes.name == ("test: c", "test: d")
with pytest.raises(AttributeError):
hist_1.axes[0].name = "a"
hist_1.axes.label = ("one", "two")
assert hist_1.axes.label == ("one", "two")
with pytest.raises(ValueError):
hist_1.axes.label = ("one",)
with pytest.raises(ValueError):
hist_1.axes.label = ("one", "two", "three")
|