File: _common.py

package info (click to toggle)
python-matplotlib-venn 1.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 340 kB
  • sloc: python: 1,514; makefile: 8
file content (142 lines) | stat: -rw-r--r-- 4,687 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Venn diagram plotting routines.
Functionality, common to venn2 and venn3.

Copyright 2012-2024, Konstantin Tretyakov.
http://kt.era.ee/

Licensed under MIT license.
"""

from typing import Optional, Sequence
import numpy as np
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from matplotlib.text import Text

from matplotlib_venn._math import Point2D


class VennDiagram:
    """
    A container for a set of patches and patch labels and set labels, which make up the rendered venn diagram.
    This object is returned by a venn2 or venn3 function call.
    """

    id2idx = {
        "10": 0,
        "01": 1,
        "11": 2,
        "100": 0,
        "010": 1,
        "110": 2,
        "001": 3,
        "101": 4,
        "011": 5,
        "111": 6,
        "A": 0,
        "B": 1,
        "C": 2,
    }

    def __init__(
        self,
        patches: Sequence[Patch],
        subset_labels: Sequence[Text],
        set_labels: Sequence[Text],
        centers: Sequence[Point2D],
        radii: Sequence[float],
    ):
        self.patches = patches
        self.subset_labels = subset_labels
        self.set_labels = set_labels
        self.centers = centers
        self.radii = radii

    def get_patch_by_id(self, id: str) -> Patch:
        """Returns a patch by a "region id".
        A region id is a string '10', '01' or '11' for 2-circle diagram or a
        string like '001', '010', etc, for 3-circle diagram."""
        return self.patches[self.id2idx[id]]

    def get_label_by_id(self, id: str) -> Text:
        """
        Returns a subset label by a "region id".
        A region id is a string '10', '01' or '11' for 2-circle diagram or a
        string like '001', '010', etc, for 3-circle diagram.
        Alternatively, if the string 'A', 'B'  (or 'C' for 3-circle diagram) is given, the label of the
        corresponding set is returned (or None)."""
        if len(id) == 1:
            return (
                self.set_labels[self.id2idx[id]]
                if self.set_labels is not None
                else None
            )
        else:
            return self.subset_labels[self.id2idx[id]]

    def get_circle_center(self, id: int) -> Point2D:
        """
        Returns the coordinates of the center of a circle as a numpy array (x,y)
        id must be 0, 1 or 2 (corresponding to the first, second, or third circle).
        This is a getter-only (i.e. changing this value does not affect the diagram)
        """
        return self.centers[id]

    def get_circle_radius(self, id: int) -> float:
        """
        Returns the radius of circle id (where id is 0, 1 or 2).
        This is a getter-only (i.e. changing this value does not affect the diagram)
        """
        return self.radii[id]

    def hide_zeroes(self) -> None:
        """
        Sometimes it makes sense to hide the labels for subsets whose size is zero.
        This utility method does this.
        """
        for v in self.subset_labels:
            if v is not None and v.get_text() == "0":
                v.set_visible(False)


def mix_colors(
    col1: np.ndarray, col2: np.ndarray, col3: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Mixes two or three colors to compute a "mixed" color (for purposes of computing
    colors of the intersection regions based on the colors of the sets.
    Note that we do not simply compute averages of given colors as those seem
    too dark for some default configurations. Thus, we lighten the combination up a bit.

    Inputs are (up to) three RGB triples of floats 0.0-1.0 given as numpy arrays.

    >>> mix_colors(np.array([1.0, 0., 0.]), np.array([1.0, 0., 0.])).tolist()
    [1.0, 0.0, 0.0]
    >>> np.round(mix_colors(np.array([1.0, 1., 0.]), np.array([1.0, 0.9, 0.]), np.array([1.0, 0.8, 0.1])), 3).tolist()
    [1.0, 1.0, 0.04]
    """
    if col3 is None:
        mix_color = 0.7 * (col1 + col2)
    else:
        mix_color = 0.4 * (col1 + col2 + col3)
    mix_color = np.min([mix_color, [1.0, 1.0, 1.0]], 0)
    return mix_color


def prepare_venn_axes(
    ax: Axes, centers: Sequence[Point2D], radii: Sequence[float]
) -> None:
    """
    Sets properties of the axis object to suit venn plotting. I.e. hides ticks, makes proper xlim/ylim.
    """
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    min_x = min(c.x - r for (c, r) in zip(centers, radii))
    max_x = max(c.x + r for (c, r) in zip(centers, radii))
    min_y = min(c.y - r for (c, r) in zip(centers, radii))
    max_y = max(c.y + r for (c, r) in zip(centers, radii))
    ax.set_xlim([min_x - 0.1, max_x + 0.1])
    ax.set_ylim([min_y - 0.1, max_y + 0.1])
    ax.set_axis_off()