File: quad_tree_debug.pyx

package info (click to toggle)
opentsne 1.0.2-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 26,328 kB
  • sloc: python: 4,721; cpp: 1,959; makefile: 20
file content (72 lines) | stat: -rw-r--r-- 2,174 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from openTSNE.quad_tree cimport QuadTree, Node
import numpy as np


def print_tree(QuadTree tree):
    _print_tree(&tree.root)


cdef _print_tree(Node * node, name=None, indent=0):
    """Print the quad tree in a readable textual format."""
    if not node.num_points:
        return

    directions = {0: 'SW', 1: 'NW', 2: 'SE', 3: 'NE'}

    # Print the correct indentation
    print('\t' * indent + '%s: %s (%d) %s' % (
        'Root' if name is None else name,
        ['', '[+]'][not node.is_leaf],
        node.num_points,
        _str_point(<double [:node.n_dims]>node.center_of_mass),
    ))

    if not node.is_leaf:
        for sector in range(1 << node.n_dims):
            _print_tree(&node.children[sector], directions[sector], indent + 1)


def _str_point(double[:] point):
    return '(%s)' % ', '.join('%.4f' % point[i] for i in range(point.shape[0]))


def plot_tree(QuadTree tree, data):
    assert isinstance(data, np.ndarray), '`data` must be np.ndarray'
    if not data.dtype == np.float64:
        data = data.astype(np.float64)
    _plot_tree(&tree.root, data)


cdef _plot_tree(Node * root, double[:, :] data):
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    ax.set_xticks([]), ax.set_yticks([]), ax.axis('off')

    centers = []

    _add_patch(ax, root, centers)
    centers = list(zip(*centers))
    xs = [p[0] for p in data]
    ys = [p[1] for p in data]

    plt.scatter(xs, ys, s=20)
    # plt.scatter(centers[0], centers[1], edgecolors="r", facecolors="none", s=10, linewidths=1)

    plt.savefig("quadtree.png", dpi=80, rasterize=True, transparent=True, bbox_inches="tight")
    plt.show()


cdef _add_patch(ax, Node * node, centers):
    import matplotlib.patches as patches
    min_bounds = np.asarray(<double [:node.n_dims]>node.center) - node.length / 2
    ax.add_patch(patches.Rectangle(
        min_bounds, node.length, node.length, fill=False
    ))
    if not node.is_leaf:
        for i in range(1 << node.n_dims):
            _add_patch(ax, &node.children[i], centers)

    if node.num_points > 0:
        centers.append([node.center_of_mass[0], node.center_of_mass[1]])