1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
# to support Cell type inside Cell
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Iterator
import numpy as np
from .point import Func, Point, ValuedPoint
def vertices_from_extremes(dim: int, pmin: Point, pmax: Point, fn: Func) -> list[ValuedPoint]:
"""Requires pmin.x ≤ pmax.x, pmin.y ≤ pmax.y"""
w = pmax - pmin
return [
ValuedPoint(np.array([pmin[d] + (i >> d & 1) * w[d] for d in range(dim)])).calc(fn) for i in range(1 << dim)
]
@dataclass
class MinimalCell:
dim: int
# In 2 dimensions, vertices = [bottom-left, bottom-right, top-left, top-right] points
vertices: list[ValuedPoint]
def get_subcell(self, axis: int, dir: int) -> MinimalCell:
"""Given an n-cell, this returns an (n-1)-cell (with half the vertices)"""
m = 1 << axis
return MinimalCell(self.dim - 1, [v for i, v in enumerate(self.vertices) if (i & m > 0) == dir])
def get_dual(self, fn: Func) -> ValuedPoint:
return ValuedPoint.midpoint(self.vertices[0], self.vertices[-1], fn)
@dataclass
class Cell(MinimalCell):
depth: int
# Children go in same order: bottom-left, bottom-right, top-left, top-right
children: list[Cell]
parent: Cell
child_direction: int
def compute_children(self, fn: Func) -> None:
assert self.children == []
for i, vertex in enumerate(self.vertices):
pmin = (self.vertices[0].pos + vertex.pos) / 2
pmax = (self.vertices[-1].pos + vertex.pos) / 2
vertices = vertices_from_extremes(self.dim, pmin, pmax, fn)
new_quad = Cell(self.dim, vertices, self.depth + 1, [], self, i)
self.children.append(new_quad)
def get_leaves_in_direction(self, axis: int, dir: int) -> Iterator[Cell]:
"""
Axis = 0,1,2,etc for x,y,z,etc.
Dir = 0 for -x, 1 for +x.
"""
if self.children:
m = 1 << axis
for i in range(1 << self.dim):
if (i & m > 0) == dir:
yield from self.children[i].get_leaves_in_direction(axis, dir)
else:
yield self
def walk_in_direction(self, axis: int, dir: int) -> Cell | None:
"""
Same arguments as get_leaves_in_direction.
Returns the quad (with depth <= self.depth) that shares a (dim-1)-cell
with self, where that (dim-1)-cell is the side of self defined by
axis and dir.
"""
m = 1 << axis
if (self.child_direction & m > 0) == dir:
# on the right side of the parent cell and moving right (or analagous)
# so need to go up through the parent's parent
if self.parent is None:
return None
parent_walked = self.parent.walk_in_direction(axis, dir)
if parent_walked and parent_walked.children:
# end at same depth
return parent_walked.children[self.child_direction ^ m]
else:
# end at lesser depth
return parent_walked
else:
if self.parent is None:
return None
return self.parent.children[self.child_direction ^ m]
def walk_leaves_in_direction(self, axis: int, dir: int) -> Iterator[Cell | None]:
walked = self.walk_in_direction(axis, dir)
if walked is not None:
yield from walked.get_leaves_in_direction(axis, dir)
else:
yield None
def should_descend_deep_cell(cell: Cell, tol: np.ndarray) -> bool:
if np.all(cell.vertices[-1].pos - cell.vertices[0].pos < 10 * tol):
# too small of a cell to be worth descending
# We compare to 10*tol instead of tol because the simplices are smaller than the quads
# The factor 10 itself is arbitrary.
return False
elif all(np.isnan(v.val) for v in cell.vertices):
# in a region where the function is undefined
return False
elif any(np.isnan(v.val) for v in cell.vertices):
# straddling defined and undefined
return True
else:
# simple approach: only descend if we cross the isoline
# TODO: This could very much be improved, e.g. by incorporating gradient or second-derivative
# tests, etc., to cancel descending in approximately linear regions
return any(np.sign(v.val) != np.sign(cell.vertices[0].val) for v in cell.vertices[1:])
def build_tree(
dim: int,
fn: Func,
pmin: Point,
pmax: Point,
min_depth: int,
max_cells: int,
tol: np.ndarray,
) -> Cell:
branching_factor = 1 << dim
# min_depth takes precedence over max_quads
max_cells = max(branching_factor**min_depth, max_cells)
vertices = vertices_from_extremes(dim, pmin, pmax, fn)
# root's childDirection is 0, even though none is reasonable
current_quad = root = Cell(dim, vertices, 0, [], None, 0)
quad_queue = deque([root])
leaf_count = 1
while len(quad_queue) > 0 and leaf_count < max_cells:
current_quad = quad_queue.popleft()
if current_quad.depth < min_depth or should_descend_deep_cell(current_quad, tol):
current_quad.compute_children(fn)
quad_queue.extend(current_quad.children)
# add 4 for the new quads, subtract 1 for the old quad not being a leaf anymore
leaf_count += branching_factor - 1
return root
|