File: utils.py

package info (click to toggle)
python-nxtomomill 1.1.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,564 kB
  • sloc: python: 15,970; makefile: 13; sh: 3
file content (452 lines) | stat: -rw-r--r-- 17,378 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
# coding: utf-8

"""An :class:`.Enum` class with additional features."""

from __future__ import annotations
import logging
import os
from datetime import datetime

import numpy
from silx.io.url import DataUrl
from silx.io.utils import get_data
from silx.io.utils import open as open_hdf5
from silx.utils.deprecation import deprecated
from silx.utils.enum import Enum as _Enum
from tomoscan.esrf.scan.utils import cwd_context
from tomoscan.io import HDF5File

from nxtomo.nxobject.nxdetector import ImageKey
from nxtomo.utils.frameappender import FrameAppender
from nxtomo.application.nxtomo import NXtomo

try:
    import hdf5plugin  # noqa F401
except ImportError:
    pass
import uuid
from typing import Iterable

from silx.io.utils import h5py_read_dataset


__all__ = [
    "embed_url",
    "FileExtension",
    "get_file_name",
    "get_tuple_of_keys_from_cmd",
    "is_nx_tomo_entry",
    "add_dark_flat_nx_file",
    "change_image_key_control",
    "str_datetime_to_numpy_datetime64",
    "strip_extension",
]


def embed_url(url: DataUrl, output_file: str) -> DataUrl:
    """
    Create a dataset under duplicate_data and with a random name
    to store it

    :param DataUrl url: dataset to be copied
    :param output_file: where to store the dataset
    :param expected_type: some metadata to put in copied dataset attributes
    :param data: data loaded from url is already loaded
    """
    if not isinstance(url, DataUrl):
        return url
    elif url.file_path() == output_file:
        return url
    else:
        embed_data_path = "/".join(("/duplicate_data", str(uuid.uuid1())))
        with cwd_context(os.path.dirname(os.path.abspath(output_file))):
            with HDF5File(output_file, "a") as h5s:
                h5s[embed_data_path] = get_data(url)
                h5s[embed_data_path].attrs["original_url"] = url.path()
            return DataUrl(
                file_path=output_file, data_path=embed_data_path, scheme="silx"
            )


class FileExtension(_Enum):
    H5 = ".h5"
    HDF5 = ".hdf5"
    NX = ".nx"


def get_file_name(file_name, extension, check=True):
    """
    set the given extension

    :param file_name: name of the file
    :param extension: extension to give
    :param check: if check, already check if the file as one of the
                       '_FileExtension'
    """
    if isinstance(extension, str):
        extension = FileExtension.from_value(extension.lower())
    assert isinstance(extension, FileExtension)
    if check:
        for value in FileExtension.values():
            if file_name.lower().endswith(value):
                return file_name
    return file_name + extension.value()


def get_tuple_of_keys_from_cmd(cmd_value: str) -> tuple:
    """Return a tuple"""
    return tuple(cmd_value.split(","))


def is_nx_tomo_entry(file_path, entry):
    """

    :param file_path: hdf5 file path
    :param entry: entry to check
    :return: True if the entry is an NXTomo entry
    """
    if not os.path.exists(file_path):
        return False
    else:
        with open_hdf5(file_path) as h5s:
            if entry not in h5s:
                return False
            node = h5s[entry]
            return NXtomo.node_is_nxtomo(node)


def add_dark_flat_nx_file(
    file_path: str,
    entry: str,
    darks_start: numpy.ndarray | DataUrl | None = None,
    flats_start: numpy.ndarray | DataUrl | None = None,
    darks_end: numpy.ndarray | DataUrl | None = None,
    flats_end: numpy.ndarray | DataUrl | None = None,
    extras: dict | None = None,
    logger: None | logging.Logger = None,
    embed_data: bool = False,
):
    """
    This will get all data from entry@input_file and patch them with provided
    dark and / or flat(s).
    We consider the sequence as: dark, start_flat, projections, end_flat.

    Behavior regarding data type and target dataset:

    * if dataset at `entry` already exists:
        * if dataset at `entry` is a 'standard' dataset:
            * data will be loaded if necessary and `enrty` will be updated
        * if dataset at `entry` is a virtual dataset:
            * if `data` is a numpy array then we raise an error: the data should
              already be saved somewhere and you should provide a DataUrl
            * if `data` is a DataUrl then the virtual dataset is updated and
              a virtual source pointing to the
              DataUrl.file_path()@DataUrl.data_path() is added to the layout
    * if a new dataset `entry` need to be added:
        * if `data` is a numpy array then we create a new 'standard' Dataset
        * if `data` is a DataUrl then a new virtual dataset will be created

    note: Datasets `image_key`, `image_key_control`, `rotation_angle` and
    `count_time` will be copied each time.

    :param file_path: NXTomo file containing data to be patched
    :param entry: entry to be patched
    :param darks_start: (3D) numpy array containing the first dark serie if any
    :param flats_start: (3D) numpy array containing the first flat if any
    :param darks_end: (3D) numpy array containing dark the second dark serie if
                      any
    :param flats_end: (3D) numpy array containing the second flat if any
    :param extras: dictionary to specify some parameters for flats and dark
                   like rotation angle.
                   valid keys: 'start_dark', 'end_dark', 'start_flag',
                   'end_flag'.
                   Values should be a dictionary of 'NXTomo' keys with
                   values to be set instead of 'default values'.
                   Possible values are:
                   * `count_time`
                   * `rotation_angle`
    :param logger: object for logs
    :param embed_data: if True then each external data will be copy
                            under a 'duplicate_data' folder
    """
    if extras is None:
        extras = {}
    else:
        for key in extras:
            valid_extra_keys = ("darks_start", "darks_end", "flats_start", "flats_end")
            if key not in valid_extra_keys:
                raise ValueError(
                    f"{key} is not recognized. Valid values are {valid_extra_keys}"
                )

    if embed_data is True:
        darks_start = embed_url(darks_start, output_file=file_path)
        darks_end = embed_url(darks_end, output_file=file_path)
        flats_start = embed_url(flats_start, output_file=file_path)
        flats_end = embed_url(flats_end, output_file=file_path)
    else:
        for url in (darks_start, darks_end, flats_start, flats_end):
            if url is not None and isinstance(url, DataUrl):
                if isinstance(url.data_slice(), slice):
                    if url.data_slice().step not in (None, 1):
                        raise ValueError(
                            "When data is not embed slice `step`"
                            "must be None or 1. Other values are"
                            f"not handled. Failing url is {url}"
                        )

    # !!! warning: order of dark / flat treatments import
    data_names = "flats_start", "darks_end", "flats_end", "darks_start"
    datas = flats_start, darks_end, flats_end, darks_start
    keys_value = (
        ImageKey.FLAT_FIELD.value,
        ImageKey.DARK_FIELD.value,
        ImageKey.FLAT_FIELD.value,
        ImageKey.DARK_FIELD.value,
    )
    wheres = "start", "end", "end", "start"  # warning: order import

    for d_n, data, key, where in zip(data_names, datas, keys_value, wheres):
        if data is None:
            continue
        n_frames_to_insert = 1
        if isinstance(data, str):
            data = DataUrl(path=data)
        if isinstance(data, numpy.ndarray) and data.ndim == 3:
            n_frames_to_insert = data.shape[0]
        elif isinstance(data, DataUrl):
            with open_hdf5(data.file_path()) as h5s:
                if data.data_path() not in h5s:
                    raise KeyError(
                        f"Path given ({data.data_path()}) is not in {data.file_path}"
                    )
            data_node = get_data(data)
            if data_node.ndim == 3:
                n_frames_to_insert = data_node.shape[0]
        else:
            raise TypeError(f"{type(data)} as input is not managed")

        if logger is not None:
            logger.info(f"insert {type(data)} frame of type {key} at the {where}")
        # update 'data' dataset
        data_path = os.path.join(entry, "instrument", "detector", "data")
        FrameAppender(
            data, file_path, data_path=data_path, where=where, logger=logger
        ).process()
        # update image-key and image_key_control (we are not managing the
        # 'alignment projection here so values are identical')
        ik_path = os.path.join(entry, "instrument", "detector", "image_key")
        ikc_path = os.path.join(entry, "instrument", "detector", "image_key_control")
        for path in (ik_path, ikc_path):
            FrameAppender(
                [key] * n_frames_to_insert,
                file_path,
                data_path=path,
                where=where,
                logger=logger,
            ).process()

        # add 'other' necessaries key:
        count_time_path = os.path.join(
            entry,
            "instrument",
            "detector",
            "count_time",
        )
        rotation_angle_path = os.path.join(entry, "sample", "rotation_angle")
        x_translation_path = os.path.join(entry, "sample", "x_translation")
        translation_y_path = os.path.join(entry, "sample", "translation_y")
        translation_z_path = os.path.join(entry, "sample", "translation_z")
        control_data_path = os.path.join(entry, "control", "data")
        data_key_paths = (
            count_time_path,
            rotation_angle_path,
            x_translation_path,
            translation_y_path,
            translation_z_path,
            control_data_path,
        )
        mandatory_keys = (
            "count_time",
            "rotation_angle",
        )
        optional_keys = (
            "x_translation",
            "translation_y",
            "translation_z",
            "control/data",
        )

        data_keys = tuple(list(mandatory_keys) + list(optional_keys))

        for data_key, data_key_path in zip(data_keys, data_key_paths):
            data_to_insert = None
            if d_n in extras and data_key in extras[d_n]:
                provided_value = extras[d_n][data_key]
                if isinstance(provided_value, Iterable):
                    if len(provided_value) != n_frames_to_insert:
                        raise ValueError(
                            "Given value to store from extras has"
                            f" incoherent length({len(provided_value)}) compare to "
                            f"the number of frame to save ({n_frames_to_insert})"
                        )
                    else:
                        data_to_insert = provided_value
                else:
                    try:
                        data_to_insert = [provided_value] * n_frames_to_insert
                    except Exception as e:
                        logger.error(f"Fail to create data to insert. Error is {e}")
                        return
            else:
                # get default values
                def get_default_value(location, where_):
                    with open_hdf5(file_path) as h5s:
                        if location not in h5s:
                            return None
                        existing_data = h5s[location]
                        if where_ == "start":
                            return existing_data[0]
                        else:
                            return existing_data[-1]

                try:
                    default_value = get_default_value(data_key_path, where)
                except Exception:
                    default_value = None
                if default_value is None:
                    msg = f"Unable to define a default value for {data_key_path}. Location empty in {file_path}"
                    if data_key in mandatory_keys:
                        raise ValueError(msg)
                    elif logger:
                        logger.warning(msg)
                    continue
                elif logger:
                    logger.debug(
                        f"No value(s) provided for {data_key_path}. Extract some default value ({default_value})."
                    )
                data_to_insert = [default_value] * n_frames_to_insert

            if data_to_insert is not None:
                FrameAppender(
                    data_to_insert,
                    file_path,
                    data_path=data_key_path,
                    where=where,
                    logger=logger,
                ).process()


@deprecated(replacement="_FrameAppender", since_version="0.5.0")
def _insert_frame_data(data, file_path, data_path, where, logger=None):
    """
    This function is used to insert some frame(s) (numpy 2D or 3D to an
    existing dataset. Before the existing array or After.

    :param data:
    :param file_path:
    :param data_path: If the path point to a virtual dataset them this one
                      will be updated but data should be a DataUrl. Of the
                      same shape. Else we will update the data_path by
                      extending the dataset.
    :param where:
    :raises TypeError: In the case the data type and existing data_path are
                       incompatible.
    """
    fa = FrameAppender(
        data=data, file_path=file_path, data_path=data_path, where=where, logger=logger
    )
    return fa.process()


def change_image_key_control(
    file_path: str,
    entry: str,
    frames_indexes: slice | Iterable,
    image_key_control_value: int | ImageKey,
    logger=None,
):
    """
    Will modify image_key and image_key_control values for the requested
    frames.

    :param file_path: path the nexus file
    :param entry: name of the entry to modify
    :param frames_indexes: index of the frame for which we want to modify
                           the image key
    :param image_key_control_value:
    :param logging.Logger logger: logger
    """
    if not isinstance(frames_indexes, (Iterable, slice)):
        raise TypeError("`frame_indexes` should be an instance of Iterable slice")
    if logger:
        logger.info(
            "Update frames {frames_indexes} to"
            "{image_key_control_value} of {entry}@{file_path}"
            "".format(
                frames_indexes=frames_indexes,
                image_key_control_value=image_key_control_value,
                entry=entry,
                file_path=file_path,
            )
        )

    image_key_control_value = ImageKey.from_value(image_key_control_value)
    with HDF5File(file_path, mode="a") as h5s:
        node = h5s[entry]
        image_keys_path = "/".join(("instrument", "detector", "image_key"))
        image_keys = h5py_read_dataset(node[image_keys_path])
        image_keys_control_path = "/".join(
            ("instrument", "detector", "image_key_control")
        )
        image_keys_control = h5py_read_dataset(node[image_keys_control_path])
        # filter frame indexes
        if isinstance(frames_indexes, slice):
            step = frames_indexes.step
            if step is None:
                step = 1
            stop = frames_indexes.stop
            if stop in (None, -1):
                stop = len(image_keys)
            frames_indexes = list(range(frames_indexes.start, stop, step))
        frames_indexes = list(
            filter(lambda x: 0 <= x <= len(image_keys_control), frames_indexes)
        )
        # manage image_key_control
        image_keys_control[frames_indexes] = image_key_control_value.value
        node[image_keys_control_path][:] = image_keys_control
        # manage image_key. In this case we should get rid of Alignment values
        # and replace it by Projection values
        image_key_value = image_key_control_value
        if image_key_value is ImageKey.ALIGNMENT:
            image_key_value = ImageKey.PROJECTION
        image_keys[frames_indexes] = image_key_value.value
        node[image_keys_path][:] = image_keys


def str_datetime_to_numpy_datetime64(my_datetime: str | datetime) -> numpy.datetime64:
    # numpy deprecates time zone awarness conversion to numpy.datetime64.
    # so we remove the time zone info.
    if isinstance(my_datetime, str):
        datetime_as_datetime = datetime.fromisoformat(my_datetime)
    elif isinstance(my_datetime, datetime):
        datetime_as_datetime = my_datetime
    else:
        raise TypeError(
            f"my_datetime is expected to be a str or an instance of datetime. Not {type(my_datetime)}"
        )

    datetime_as_utc_datetime = datetime_as_datetime.astimezone(None)
    tz_free_datetime_as_datetime = datetime_as_utc_datetime.replace(tzinfo=None)
    return numpy.datetime64(tz_free_datetime_as_datetime).astype("<M8[ms]")


def strip_extension(filename, logger=None):
    if filename.endswith((".nx", ".h5")):
        return filename[:-3]
    elif filename.endswith(".hdf5"):
        return filename[:-5]
    else:
        if logger is not None:
            logger.warning(f"Unusual file name {filename} has no known postfix")
        return filename