File: binary_partition_utils.py

package info (click to toggle)
meep-openmpi 1.25.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 64,556 kB
  • sloc: cpp: 32,214; python: 27,958; lisp: 1,225; makefile: 505; sh: 249; ansic: 131; javascript: 5
file content (310 lines) | stat: -rw-r--r-- 11,384 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import warnings
from typing import Dict, Generator, List, Tuple

import numpy as onp

import meep as mp


def is_leaf_node(partition: mp.BinaryPartition) -> bool:
    """Returns True if the partition has no children.

    Args:
      partition: the BinaryPartition node

    Returns:
      A boolean indicating whether partition is a leaf node.
    """
    return partition.left is None and partition.right is None


def enumerate_leaf_nodes(
    partition: mp.BinaryPartition,
) -> Generator[mp.BinaryPartition, None, None]:
    """Enumerates all leaf nodes of a partition.

    Args:
      partition: the BinaryPartition node

    Yields:
      The leaf nodes contained within the partition.
    """
    if is_leaf_node(partition):
        yield partition
    else:
        yield from enumerate_leaf_nodes(partition.left)
        yield from enumerate_leaf_nodes(partition.right)


def partition_has_duplicate_proc_ids(partition: mp.BinaryPartition) -> bool:
    """Returns True if the partition has more than one node with the same proc_id.

    Args:
      partition: the BinaryPartition node

    Returns:
      A boolean indicating if the partition contains duplicate proc_ids.
    """
    proc_ids = [node.proc_id for node in enumerate_leaf_nodes(partition)]
    return len(set(proc_ids)) != len(proc_ids)


def get_total_weight(
    partition: mp.BinaryPartition, weights_by_proc_id: List[float]
) -> float:
    """Computes the total weights contained within a BinaryPartition subtree.

    Args:
      partition: a BinaryPartition subtree to compute the associated weights for
      weights_by_proc_id: a list of weights associated with each proc_id

    Returns:
      The sum of all weights for each proc_id encountered in the subtree.

    Raises:
      ValueError: if sim.chunk_layout includes nodes with duplicate proc_ids
    """
    if partition_has_duplicate_proc_ids(partition):
        raise ValueError("Duplicate proc_ids found in chunk_layout!")
    if partition.proc_id is not None:
        return weights_by_proc_id[partition.proc_id]
    elif partition.left is not None and partition.right is not None:
        left_weight = get_total_weight(partition.left, weights_by_proc_id)
        right_weight = get_total_weight(partition.right, weights_by_proc_id)
        return left_weight + right_weight
    else:
        raise ValueError("Partition missing proc_id or left, right attributes!")


def get_left_right_total_weights(
    partition: mp.BinaryPartition, weights_by_proc_id: List[float]
) -> Tuple[float, float]:
    """Computes the total weights contained in left and right subtrees.

    Args:
      partition: a BinaryPartition tree to compute the associated weights for
      weights_by_proc_id: a list of weights associated with each proc_id

    Returns:
      The sum of weights for each proc_id encountered in the left and right
      subtrees.

    Raises:
      ValueError: if the BinaryPartition is a leaf node or improperly formatted.
    """
    if partition.left is not None and partition.right is not None:
        left_weight = get_total_weight(partition.left, weights_by_proc_id)
        right_weight = get_total_weight(partition.right, weights_by_proc_id)
        return (left_weight, right_weight)
    else:
        raise ValueError("Partition missing left, right attributes!")


def pixel_volume(grid_volume: mp.grid_volume) -> int:
    """Computes the number of pixels contained in a 2D or 3D grid_volume object.

    NOTE: This assumes that z=0 means 2D simulation and z>0 means 3D.

    Args:
      grid_volume: a meep grid_volume object

    Returns:
      The 2D or 3D number of pixels in the grid_volume.
    """
    if grid_volume.nz() > 0:  # we're in a 3D simulation
        return grid_volume.nx() * grid_volume.ny() * grid_volume.nz()
    else:  # 2D simulation
        return grid_volume.nx() * grid_volume.ny()


def get_total_volume(
    partition: mp.BinaryPartition,
    chunk_volumes: Tuple[mp.grid_volume],
    chunk_owners: onp.ndarray,
) -> float:
    """Computes the total pixel volume in a subtree from associated chunk volumes.

    NOTE: If multiple chunks are owned by the same process, this function may
    over-count the volume, since all provided grid volumes associated with a
    given process are counted.

    Args:
      partition: a BinaryPartition subtree to compute the associated volumes for
      chunk_volumes: associated grid volumes from a simulation, obtained by
        calling sim.structure.get_chunk_volumes()
      chunk_owners: list of processes associated with each grid volume, obtained
        by calling sim.structure.get_chunk_owners()

    Returns:
      The total pixel volume occupied by all chunks owned by the partition.
    """
    my_grid_volumes = get_grid_volumes_in_tree(partition, chunk_volumes, chunk_owners)
    return sum(pixel_volume(vol) for vol in my_grid_volumes)


def get_left_right_total_volumes(
    partition: mp.BinaryPartition,
    chunk_volumes: Tuple[mp.grid_volume],
    chunk_owners: onp.ndarray,
) -> Tuple[float, float]:
    """Computes the total pixel volume in left and right subtrees.

    Args:
      partition: a BinaryPartition subtree to compute the associated volumes for
      chunk_volumes: associated grid volumes from a simulation, obtained by
        calling sim.structure.get_chunk_volumes()
      chunk_owners: list of processes associated with each grid volume, obtained
        by calling sim.structure.get_chunk_owners()

    Returns:
      A tuple for left and right subtreees, where each entry is a list of the
      total pixel volume occupied by all chunks owned by each process.

    Raises:
      ValueError: if the BinaryPartition is a leaf node or improperly formatted.
    """
    if partition.left is not None and partition.right is not None:
        left_volume = get_total_volume(partition.left, chunk_volumes, chunk_owners)
        right_volume = get_total_volume(partition.right, chunk_volumes, chunk_owners)
        return (left_volume, right_volume)
    else:
        raise ValueError("Partition missing left, right attributes!")


def get_grid_volumes_in_tree(
    partition: mp.BinaryPartition,
    chunk_volumes: Tuple[mp.grid_volume],
    chunk_owners: onp.ndarray,
) -> List[mp.grid_volume]:
    """Fetches a list of grid_volumes contained in a BinaryPartition subtree.

    NOTE: If multiple chunks are owned by the same process, this function may
    over-count the volume, since all provided grid volumes associated with a
    given process are counted.

    Args:
      partition: a BinaryPartition subtree to find grid_volumes for
      chunk_volumes: associated grid volumes from a simulation, obtained by
        calling sim.structure.get_chunk_volumes()
      chunk_owners: list of processes associated with each grid volume, obtained
        by calling sim.structure.get_chunk_owners()

    Returns:
      A list of all grid_volume objects associated with any proc_id contained in
      the partition. The list is not necessarily ordered by proc_id values.
    """
    if partition_has_duplicate_proc_ids(partition):
        warnings.warn("Partition has duplicate proc_ids, overcounting possible!")

    my_proc_ids = [node.proc_id for node in enumerate_leaf_nodes(partition)]

    return [
        chunk_volumes[i]
        for (i, owner) in enumerate(chunk_owners)
        if owner in my_proc_ids
    ]


def get_total_volume_per_process(
    partition: mp.BinaryPartition,
    chunk_volumes: Tuple[mp.grid_volume],
    chunk_owners: onp.ndarray,
) -> Dict[int, float]:
    """Computes the total pixel volume per process contained in a BinaryPartition.

    Args:
      partition: a BinaryPartition subtree to compute the associated volumes for
      chunk_volumes: associated grid volumes from a simulation, obtained by
        calling sim.structure.get_chunk_volumes()
      chunk_owners: list of processes associated with each grid volume, obtained
        by calling sim.structure.get_chunk_owners()

    Returns:
      A dictionary of the total pixel volume occupied by all chunks owned by each
      process.
    """
    volumes_per_process = {}
    leaf_nodes_in_tree = enumerate_leaf_nodes(partition)
    for leaf in leaf_nodes_in_tree:
        total_volume = sum(
            pixel_volume(chunk_volumes[i])
            for i, owner in enumerate(chunk_owners)
            if owner == leaf.proc_id
        )

        volumes_per_process[leaf.proc_id] = total_volume
    return volumes_per_process


def get_box_ranges(
    partition: mp.BinaryPartition,
    chunk_volumes: Tuple[mp.grid_volume],
    chunk_owners: onp.ndarray,
) -> Tuple[float, float, float, float, float, float]:
    """Gets the max and min x, y, z dimensions spanned by a partition.

    Args:
      partition: a BinaryPartition subtree to compute the range for
      chunk_volumes: associated grid volumes from a simulation, obtained by
        calling sim.structure.get_chunk_volumes()
      chunk_owners: list of processes associated with each grid volume, obtained
        by calling sim.structure.get_chunk_owners()

    Returns:
      A 6-tuple which enumerates:
      (
        xmin: the min x value of any grid_volume associated with the partition,
        xmax: the max x value of any grid_volume associated with the partition,
        ymin: the min y value of any grid_volume associated with the partition,
        ymax: the max y value of any grid_volume associated with the partition,
        zmin: the min z value of any grid_volume associated with the partition,
        zmax: the max z value of any grid_volume associated with the partition
      )
    """
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    zmins = []
    zmaxs = []
    for leaf in enumerate_leaf_nodes(partition):
        # generate a list of all volumes owned by the process
        owned_volumes = [
            chunk_volumes[i]
            for (i, owner) in enumerate(chunk_owners)
            if owner == leaf.proc_id
        ]
        # add the corners to the lists
        for vol in owned_volumes:
            xmins.append(vol.surroundings().get_min_corner().x())
            ymins.append(vol.surroundings().get_min_corner().y())
            zmins.append(vol.surroundings().get_min_corner().z())
            xmaxs.append(vol.surroundings().get_max_corner().x())
            ymaxs.append(vol.surroundings().get_max_corner().y())
            zmaxs.append(vol.surroundings().get_max_corner().z())
    return (min(xmins), max(xmaxs), min(ymins), max(ymaxs), min(zmins), max(zmaxs))


def partitions_are_equal(bp1: mp.BinaryPartition, bp2: mp.BinaryPartition) -> bool:
    """Determines if two partitions have all nodes with identical attributes.

    Args:
      bp1: a BinaryPartition object to compare equality for
      bp2: the other BinaryPartition object to compare equality for

    Returns:
      A bool if all nodes in the partitions have equal attributes
    """
    if is_leaf_node(bp1) and is_leaf_node(bp2):
        return bp1.proc_id == bp2.proc_id
    elif (not is_leaf_node(bp1)) and (not is_leaf_node(bp2)):
        return all(
            [
                bp1.split_dir == bp2.split_dir,
                bp1.split_pos == bp2.split_pos,
                partitions_are_equal(bp1.left, bp2.left),
                partitions_are_equal(bp1.right, bp2.right),
            ]
        )
    else:
        return False