1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
|
# coding: utf-8
from __future__ import annotations
import os
import h5py
import h5py._hl.selections as selection
import numpy
from h5py import h5s as h5py_h5s
from silx.io.url import DataUrl
from silx.io.utils import get_data, h5py_read_dataset
from tomoscan.esrf.scan.utils import cwd_context
from tomoscan.io import HDF5File
from silx.io.utils import open as open_hdf5
from nxtomo.io import to_target_rel_path
from nxtomomill.utils.h5pyutils import from_data_url_to_virtual_source
from nxtomomill.utils.hdf5 import DatasetReader
__all__ = [
"FrameAppender",
]
class FrameAppender:
"""
Class to insert 2D frame(s) to an existing dataset
"""
def __init__(
self,
data: numpy.ndarray | DataUrl,
file_path,
data_path,
where,
logger=None,
):
if where not in ("start", "end"):
raise ValueError("`where` should be `start` or `end`")
if not isinstance(
data, (DataUrl, numpy.ndarray, list, tuple, h5py.VirtualSource)
):
raise TypeError(
f"data should be an instance of DataUrl or a numpy array not {type(data)}"
)
self.data = data
self.file_path = os.path.abspath(file_path)
self.data_path = data_path
self.where = where
self.logger = logger
def process(self) -> None:
"""
main function. Will start the insertion of frame(s)
"""
with HDF5File(self.file_path, mode="a") as h5s:
if self.data_path in h5s:
self._add_to_existing_dataset(h5s)
else:
self._create_new_dataset(h5s)
if self.logger:
self.logger.info(f"data added to {self.data_path}@{self.file_path}")
def _add_to_existing_virtual_dataset(self, h5s):
if (
h5py.version.hdf5_version_tuple[0] <= 1
and h5py.version.hdf5_version_tuple[1] < 12
):
if self.logger:
self.logger.warning(
"You are working on virtual dataset"
"with a hdf5 version < 12. Frame "
"you want to change might be "
"modified depending on the working "
"directory without notifying."
"See https://github.com/silx-kit/silx/issues/3277"
)
if isinstance(self.data, h5py.VirtualSource):
self.__insert_virtual_source_in_vds(h5s=h5s, new_virtual_source=self.data)
elif isinstance(self.data, DataUrl):
if self.logger is not None:
self.logger.debug(
f"Update virtual dataset: {self.data_path}@{self.file_path}"
)
# store DataUrl in the current virtual dataset
url = self.data
def check_dataset(dataset_frm_url):
data_need_reshape = False
"""check if the dataset is valid or might need a reshape"""
if dataset_frm_url.ndim not in (2, 3):
raise ValueError(f"{url.path()} should point to 2D or 3D dataset ")
if dataset_frm_url.ndim == 2:
new_shape = 1, dataset_frm_url.shape[0], dataset_frm_url.shape[1]
if self.logger is not None:
self.logger.info(
f"reshape provided data to 3D (from {dataset_frm_url.shape} to {new_shape})"
)
data_need_reshape = True
return data_need_reshape
loaded_dataset = None
if url.data_slice() is None:
# case we can avoid to load the data in memory
with DatasetReader(url) as data_frm_url:
data_need_reshape = check_dataset(data_frm_url)
# FIXME: avoid keeping some file open. not clear why this is needed
data_frm_url = None
else:
data_frm_url = get_data(url)
data_need_reshape = check_dataset(data_frm_url)
loaded_dataset = data_frm_url
if url.data_slice() is None and not data_need_reshape:
# case we can avoid to load the data in memory
with DatasetReader(self.data) as data_frm_url:
self.__insert_url_in_vds(h5s, url, data_frm_url)
# FIXME: avoid keeping some file open. not clear why this is needed
data_frm_url = None
else:
if loaded_dataset is None:
data_frm_url = get_data(url)
else:
data_frm_url = loaded_dataset
self.__insert_url_in_vds(h5s, url, data_frm_url)
else:
raise TypeError(
"Provided data is a numpy array when given"
"dataset path is a virtual dataset. "
"You must store the data somewhere else "
"and provide a DataUrl"
)
def __insert_url_in_vds(self, h5s, url, data_frm_url):
if data_frm_url.ndim == 2:
dim_2, dim_1 = data_frm_url.shape
data_frm_url = data_frm_url.reshape(1, dim_2, dim_1)
elif data_frm_url.ndim == 3:
_, dim_2, dim_1 = data_frm_url.shape
else:
raise ValueError("data to had is expected to be 2 or 3 d")
new_virtual_source = h5py.VirtualSource(
path_or_dataset=url.file_path(),
name=url.data_path(),
shape=data_frm_url.shape,
)
if url.data_slice() is not None:
# in the case we have to process to a FancySelection
with open_hdf5(os.path.abspath(url.file_path())) as h5sd:
dst = h5sd[url.data_path()]
sel = selection.select(
h5sd[url.data_path()].shape, url.data_slice(), dst
)
new_virtual_source.sel = sel
self.__insert_virtual_source_in_vds(
h5s=h5s, new_virtual_source=new_virtual_source, relative_path=True
)
def __insert_virtual_source_in_vds(
self, h5s, new_virtual_source: h5py.VirtualSource, relative_path=True
):
if not isinstance(new_virtual_source, h5py.VirtualSource):
raise TypeError(
f"{new_virtual_source} is expected to be an instance of h5py.VirtualSource and not {type(new_virtual_source)}"
)
if not len(new_virtual_source.shape) == 3:
raise ValueError(
f"virtual source shape is expected to be 3D and not {len(new_virtual_source.shape)}D."
)
# preprocess virtualSource to insure having a relative path
if relative_path:
vds_file_path = to_target_rel_path(new_virtual_source.path, self.file_path)
new_virtual_source_sel = new_virtual_source.sel
new_virtual_source = h5py.VirtualSource(
path_or_dataset=vds_file_path,
name=new_virtual_source.name,
shape=new_virtual_source.shape,
dtype=new_virtual_source.dtype,
)
new_virtual_source.sel = new_virtual_source_sel
virtual_sources_len = []
virtual_sources = []
# we need to recreate the VirtualSource they are not
# store or available from the API
for vs_info in h5s[self.data_path].virtual_sources():
length, vs = self._recreate_vs(vs_info=vs_info, vds_file=self.file_path)
virtual_sources.append(vs)
virtual_sources_len.append(length)
n_frames = h5s[self.data_path].shape[0] + new_virtual_source.shape[0]
data_type = h5s[self.data_path].dtype
if self.where == "start":
virtual_sources.insert(0, new_virtual_source)
virtual_sources_len.insert(0, new_virtual_source.shape[0])
else:
virtual_sources.append(new_virtual_source)
virtual_sources_len.append(new_virtual_source.shape[0])
# create the new virtual dataset
layout = h5py.VirtualLayout(
shape=(
n_frames,
new_virtual_source.shape[-2],
new_virtual_source.shape[-1],
),
dtype=data_type,
)
last = 0
for v_source, vs_len in zip(virtual_sources, virtual_sources_len):
layout[last : vs_len + last] = v_source
last += vs_len
if self.data_path in h5s:
del h5s[self.data_path]
h5s.create_virtual_dataset(self.data_path, layout)
def _add_to_existing_none_virtual_dataset(self, h5s):
"""
for now when we want to add data *to a none virtual dataset*
we always duplicate data if provided from a DataUrl.
We could create a virtual dataset as well but seems to complicated for
a use case that we don't really have at the moment.
:param h5s:
"""
if self.logger is not None:
self.logger.debug("Update dataset: {entry}@{file_path}")
if isinstance(self.data, (numpy.ndarray, list, tuple)):
new_data = self.data
else:
url = self.data
new_data = get_data(url)
if isinstance(new_data, numpy.ndarray):
if not new_data.shape[1:] == h5s[self.data_path].shape[1:]:
raise ValueError(
f"Data shapes are incoherent: {new_data.shape} vs {h5s[self.data_path].shape}"
)
new_shape = (
new_data.shape[0] + h5s[self.data_path].shape[0],
new_data.shape[1],
new_data.shape[2],
)
data_to_store = numpy.empty(new_shape)
if self.where == "start":
data_to_store[: new_data.shape[0]] = new_data
data_to_store[new_data.shape[0] :] = h5py_read_dataset(
h5s[self.data_path]
)
else:
data_to_store[: h5s[self.data_path].shape[0]] = h5py_read_dataset(
h5s[self.data_path]
)
data_to_store[h5s[self.data_path].shape[0] :] = new_data
else:
assert isinstance(
self.data, (list, tuple)
), f"Unmanaged data type {type(self.data)}"
o_data = h5s[self.data_path]
o_data = list(h5py_read_dataset(o_data))
if self.where == "start":
new_data.extend(o_data)
data_to_store = numpy.asarray(new_data)
else:
o_data.extend(new_data)
data_to_store = numpy.asarray(o_data)
del h5s[self.data_path]
h5s[self.data_path] = data_to_store
def _add_to_existing_dataset(self, h5s):
"""Add the frame to an existing dataset"""
if h5s[self.data_path].is_virtual:
self._add_to_existing_virtual_dataset(h5s=h5s)
else:
self._add_to_existing_none_virtual_dataset(h5s=h5s)
def _create_new_dataset(self, h5s):
"""
needs to create a new dataset. In this case the policy is:
- if a DataUrl is provided then we create a virtual dataset
- if a numpy array is provided then we create a 'standard' dataset
"""
if isinstance(self.data, DataUrl):
url = self.data
url_file_path = to_target_rel_path(url.file_path(), self.file_path)
url = DataUrl(
file_path=url_file_path,
data_path=url.data_path(),
scheme=url.scheme(),
data_slice=url.data_slice(),
)
with cwd_context(os.path.dirname(self.file_path)):
vs, vs_shape, data_type = from_data_url_to_virtual_source(url)
layout = h5py.VirtualLayout(shape=vs_shape, dtype=data_type)
layout[:] = vs
h5s.create_virtual_dataset(self.data_path, layout)
elif isinstance(self.data, h5py.VirtualSource):
virtual_source = self.data
layout = h5py.VirtualLayout(
shape=virtual_source.shape,
dtype=virtual_source.dtype,
)
vds_file_path = to_target_rel_path(virtual_source.path, self.file_path)
virtual_source_rel_path = h5py.VirtualSource(
path_or_dataset=vds_file_path,
name=virtual_source.name,
shape=virtual_source.shape,
dtype=virtual_source.dtype,
)
virtual_source_rel_path.sel = virtual_source.sel
layout[:] = virtual_source_rel_path
# convert path to relative
h5s.create_virtual_dataset(self.data_path, layout)
elif not isinstance(self.data, numpy.ndarray):
raise TypeError(
f"self.data should be an instance of DataUrl, a numpy array or a VirtualSource. Not {type(self.data)}"
)
else:
h5s[self.data_path] = self.data
@staticmethod
def _recreate_vs(vs_info, vds_file):
"""Simple util to retrieve a h5py.VirtualSource from virtual source
information.
to understand clearly this function you might first have a look at
the use case exposed in issue:
https://gitlab.esrf.fr/tomotools/nxtomomill/-/issues/40
"""
with cwd_context(os.path.dirname(vds_file)):
dataset_file_path = vs_info.file_name
# in case the virtual source is in the same file
if dataset_file_path == ".":
dataset_file_path = vds_file
with open_hdf5(dataset_file_path) as vs_node:
dataset = vs_node[vs_info.dset_name]
select_bounds = vs_info.vspace.get_select_bounds()
left_bound = select_bounds[0]
right_bound = select_bounds[1]
length = right_bound[0] - left_bound[0] + 1
# warning: for now step is not managed with virtual
# dataset
virtual_source = h5py.VirtualSource(
vs_info.file_name,
vs_info.dset_name,
shape=dataset.shape,
)
# here we could provide dataset but we won't to
# insure file path will be relative.
type_code = vs_info.src_space.get_select_type()
# check for unlimited selections in case where selection is regular
# hyperslab, which is the only allowed case for h5s.UNLIMITED to be
# in the selection
if (
type_code == h5py_h5s.SEL_HYPERSLABS
and vs_info.src_space.is_regular_hyperslab()
):
(
source_start,
stride,
count,
block,
) = vs_info.src_space.get_regular_hyperslab()
source_end = source_start[0] + length
sel = selection.select(
dataset.shape,
slice(source_start[0], source_end),
dataset=dataset,
)
virtual_source.sel = sel
return (
length,
virtual_source,
)
|