File: bandpass_plot.py

package info (click to toggle)
sncosmo 2.12.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,628 kB
  • sloc: python: 7,278; cpp: 184; makefile: 130; sh: 1
file content (149 lines) | stat: -rw-r--r-- 4,854 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Helper function to plot a set of bandpasses in sphinx docs."""

import numpy as np
from matplotlib import rc
from matplotlib import pyplot as plt
import sncosmo

cmap = plt.get_cmap('viridis')


def plot_bandpass_set(setname, label_prefix=''):
    """Plot the given set of bandpasses."""

    rc("font", family="serif")

    bandpass_meta = sncosmo.bandpasses._BANDPASSES.get_loaders_metadata()

    fig = plt.figure(figsize=(9, 3))
    ax = plt.axes()

    nbands = 0
    for m in bandpass_meta:
        if (
                m['filterset'] != setname or
                # special case of ZTF position-dependent bandpasses
                'ztf::' in m['name']
        ):
            continue
        b = sncosmo.get_bandpass(m['name'])

        # add zeros on either side of bandpass transmission
        wave = np.zeros(len(b.wave) + 2)
        wave[0] = b.wave[0]
        wave[1:-1] = b.wave
        wave[-1] = b.wave[-1]
        trans = np.zeros(len(b.trans) + 2)
        trans[1:-1] = b.trans

        ax.plot(wave, trans, label=label_prefix + m['name'])
        nbands += 1

    ax.set_xlabel("Wavelength ($\\AA$)")
    ax.set_ylabel("Transmission")

    ncol = 1 + (nbands-1) // 9  # 9 labels per column
    ax.legend(loc='upper right', frameon=False, fontsize='small',
              ncol=ncol)

    # Looks like each legend column takes up about 0.125 of the figure.
    # Make room for the legend.
    xmin, xmax = ax.get_xlim()
    xmax += ncol * 0.125 * (xmax - xmin)
    ax.set_xlim(xmin, xmax)
    plt.tight_layout()
    plt.show()


def plot_bandpass_interpolators(names):

    # we'll figure out min and max wave as we go.
    minwave = float('inf')
    maxwave = 0.

    fig, axes = plt.subplots(nrows=len(names), ncols=1,
                             figsize=(9., 2.5*len(names)), squeeze=True,
                             sharex=True)
    for i in range(len(names)):
        bi = sncosmo.bandpasses._BANDPASS_INTERPOLATORS.retrieve(names[i])

        radii = np.linspace(bi.minpos(), bi.maxpos()-0.000001, 8)

        for r in radii:
            band = bi.at(r)

            # update min,max wave
            minwave = min(minwave, band.minwave())
            maxwave = max(maxwave, band.maxwave())

            wave = np.linspace(band.minwave(), band.maxwave(), 1000)
            trans = band(wave)
            label = ("radius = {:4.1f}cm".format(r)
                     if (r == radii[0] or r == radii[-1])
                     else None)
            axes[i].plot(wave, trans, color=cmap((r - bi.minpos())/
                                                 (bi.maxpos() - bi.minpos())),
                         label=label)


        axes[i].legend(loc='upper right')
        axes[i].set_ylabel("Transmission")
        axes[i].text(0.03, 0.92, names[i], transform=axes[i].transAxes,
                     va='top', ha='left')

    axes[-1].set_xlabel("Wavelength ($\\AA$)")
    plt.tight_layout()
    plt.show()


def plot_general_bandpass_interpolators(name):
    if name == 'hsc':
        return plot_bandpass_set(name, label_prefix='averaged ')

    names = [
        m['name'] for m in sncosmo.bandpasses._BANDPASS_INTERPOLATORS.get_loaders_metadata()
        if m['filterset'] == name]

    # we'll figure out min and max wave as we go.
    minwave = float('inf')
    maxwave = 0.

    fig, axes = plt.subplots(nrows=len(names), ncols=1,
                             figsize=(9., 2.5*len(names)), squeeze=True,
                             sharex=True)
    for i in range(len(names)):
        bi = sncosmo.bandpasses._BANDPASS_INTERPOLATORS.retrieve(names[i])
        b = sncosmo.bandpasses._BANDPASSES.retrieve(names[i])

        # add zeros on either side of bandpass transmission
        wave = np.zeros(len(b.wave) + 2)
        wave[0] = b.wave[0]
        wave[1:-1] = b.wave
        wave[-1] = b.wave[-1]
        trans = np.zeros(len(b.trans) + 2)
        trans[1:-1] = b.trans
        axes[i].plot(wave, trans, label='average')

        x = 1000
        y = 1000
        nsensors = len(bi.transforms._to_focalplane)
        for n, sid in enumerate(np.linspace(1, nsensors-1, 4, dtype=int)):
            band = bi.at(x=x, y=y, sensor_id=sid)

            # update min,max wave
            minwave = min(minwave, band.minwave())
            maxwave = max(maxwave, band.maxwave())

            wave = np.linspace(band.minwave(), band.maxwave(), 1000)
            trans = band(wave)
            label = 'x={}, y={}, sensor_id={}'.format(x, y, sid)
            axes[i].plot(wave, trans, label=label, color=cmap(n / 4), linestyle='dotted')

        axes[i].legend(loc='upper right')
        axes[i].set_ylabel("Transmission")
        axes[i].text(0.03, 0.92, names[i], transform=axes[i].transAxes,
                     va='top', ha='left')

    axes[-1].set_xlabel("Wavelength ($\\AA$)")
    plt.tight_layout()
    plt.show()