File: pyani_graphics.py

package info (click to toggle)
python-pyani 0.2.13-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 159,944 kB
  • sloc: python: 3,106; makefile: 86; sh: 30
file content (400 lines) | stat: -rw-r--r-- 14,091 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# -*- coding: utf-8 -*-
# (c) The James Hutton Institute 2013-2019
# (c) University of Strathclyde 2019-2020
# Author: Leighton Pritchard
#
# Contact: leighton.pritchard@strath.ac.uk
#
# Leighton Pritchard,
# Strathclyde Institute for Pharmacy and Biomedical Sciences,
# Cathedral Street,
# Glasgow,
# G4 0RE
# Scotland,
# UK
#
# The MIT License
#
# Copyright (c) 2013-2019 The James Hutton Institute
# Copyright (c) 2019-2020 University of Strathclyde
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""Module to implement graphics output for ANI analyses."""

# Force matplotlib NOT to use an Xwindows backend on *nix, so that
# _tkinter.TclError is avoided when there is no $DISPLAY env: this can occur
# when running the package/script via ssh
# See http://stackoverflow.com/questions/2801882/\
#            generating-a-png-with-matplotlib-when-display-is-undefined
# This needs to be done before importing pyplot

from math import floor, log10
import warnings

import matplotlib

# Specify matplotlib backend
matplotlib.use("Agg")

import matplotlib.pyplot as plt  # noqa: E402
import matplotlib.gridspec as gridspec  # noqa: E402

import numpy as np  # noqa: E402

import scipy.cluster.hierarchy as sch  # noqa: E402
import scipy.spatial.distance as distance  # noqa: E402

import seaborn as sns  # noqa: E402
import pandas as pd  # noqa: E402

from . import pyani_config  # noqa: E402


# Register Matplotlib colourmaps
plt.colormaps.register(cmap=pyani_config.CMAP_SPBND_BURD)
plt.colormaps.register(cmap=pyani_config.CMAP_HADAMARD_BURD)
plt.colormaps.register(cmap=pyani_config.CMAP_BURD)


# Convenience class to hold heatmap graphics parameters
class Params(object):  # pylint: disable=too-few-public-methods
    """Convenience class to hold heatmap rendering parameters."""

    def __init__(self, params, labels=None, classes=None):
        """Instantiate Params object."""
        self.cmap = plt.get_cmap(params[0])
        self.vmin = params[1]
        self.vmax = params[2]
        self.labels = labels
        self.classes = classes

    @property
    def vdiff(self):
        """Returns difference between max and min values for presentation."""
        return max(0.01, self.vmax - self.vmin)


# helper for cleaning up matplotlib axes by removing ticks etc.
def clean_axis(axis):
    """Remove ticks, tick labels, and frame from axis."""
    axis.get_xaxis().set_ticks([])
    axis.get_yaxis().set_ticks([])
    for spine in list(axis.spines.values()):
        spine.set_visible(False)


# Add classes colorbar to Seaborn plot
def get_seaborn_colorbar(dfr, classes):
    """Return a colorbar representing classes, for a Seaborn plot.

    The aim is to get a pd.Series for the passed dataframe columns,
    in the form:
    0    colour for class in col 0
    1    colour for class in col 1
    ...  colour for class in col ...
    n    colour for class in col n
    """
    levels = sorted(list(set(classes.values())))
    paldict = {
        lvl: pal
        for (lvl, pal) in zip(
            levels,
            sns.cubehelix_palette(len(levels), light=0.9, dark=0.1, reverse=True, start=1, rot=-2),
        )
    }
    lvl_pal = {cls: paldict[lvl] for (cls, lvl) in list(classes.items())}
    col_cb = pd.Series(dfr.index).map(lvl_pal)
    # The col_cb Series index now has to match the dfr.index, but
    # we don't create the Series with this (and if we try, it
    # fails) - so change it with this line
    col_cb.index = dfr.index
    return col_cb


# Get safe Seaborn labels
def get_safe_seaborn_labels(dfr, labels):
    """Returns labels guaranteed to correspond to the dataframe."""
    if labels is not None:
        return [labels.get(str(i), str(i)) for i in dfr.index]
    return [str(i) for i in dfr.index]


# Return a clustermap
def get_seaborn_clustermap(dfr, params, title=None, annot=True):
    """Returns a Seaborn clustermap."""
    fig = sns.clustermap(
        dfr,
        cmap=params.cmap,
        vmin=params.vmin,
        vmax=params.vmax,
        col_colors=params.colorbar,
        row_colors=params.colorbar,
        figsize=(params.figsize, params.figsize),
        linewidths=params.linewidths,
        xticklabels=params.labels,
        yticklabels=params.labels,
        annot=annot,
    )
    fig.cax.yaxis.set_label_position("left")
    if title:
        fig.cax.set_ylabel(title)

    # Rotate ticklabels
    fig.ax_heatmap.set_xticklabels(fig.ax_heatmap.get_xticklabels(), rotation=90)
    fig.ax_heatmap.set_yticklabels(fig.ax_heatmap.get_yticklabels(), rotation=0)

    # Return clustermap
    return fig


# Generate Seaborn heatmap output
def heatmap_seaborn(dfr, outfilename=None, title=None, params=None):
    """Returns seaborn heatmap with cluster dendrograms.

    - dfr - pandas DataFrame with relevant data
    - outfilename - path to output file (indicates output format)
    """
    # Decide on figure layout size: a minimum size is required for
    # aesthetics, and a maximum to avoid core dumps on rendering.
    # If we hit the maximum size, we should modify font size.
    maxfigsize = 120
    calcfigsize = dfr.shape[0] * 1.1
    figsize = min(max(8, calcfigsize), maxfigsize)
    if figsize == maxfigsize:
        scale = maxfigsize / calcfigsize
        sns.set_context("notebook", font_scale=scale)

    # Add a colorbar?
    if params.classes is None:
        col_cb = None
    else:
        col_cb = get_seaborn_colorbar(dfr, params.classes)

    # Labels are defined before we build the clustering
    # If a label mapping is missing, use the key text as fall back
    params.labels = get_safe_seaborn_labels(dfr, params.labels)

    # Add attributes to parameter object, and draw heatmap
    params.colorbar = col_cb
    params.figsize = figsize
    params.linewidths = 0.25
    fig = get_seaborn_clustermap(dfr, params, title=title)

    # Save to file
    if outfilename:
        fig.savefig(outfilename)

    # Return clustermap
    return fig


# Add dendrogram and axes to passed figure
def add_mpl_dendrogram(dfr, fig, heatmap_gs, orientation="col"):
    """Return a dendrogram and corresponding gridspec, attached to the fig.

    Modifies the fig in-place. Orientation is either 'row' or 'col' and
    determines location and orientation of the rendered dendrogram.
    """
    # Row or column axes?
    if orientation == "row":
        dists = distance.squareform(distance.pdist(dfr))
        spec = heatmap_gs[1, 0]
        orient = "left"
        nrows, ncols = 1, 2
        height_ratios = [1]
    else:  # Column dendrogram
        dists = distance.squareform(distance.pdist(dfr.T))
        spec = heatmap_gs[0, 1]
        orient = "top"
        nrows, ncols = 2, 1
        height_ratios = [1, 0.15]

    # Create row dendrogram axis
    gspec = gridspec.GridSpecFromSubplotSpec(
        nrows,
        ncols,
        subplot_spec=spec,
        wspace=0.0,
        hspace=0.1,
        height_ratios=height_ratios,
    )
    dend_axes = fig.add_subplot(gspec[0, 0])
    dend = sch.dendrogram(
        sch.linkage(distance.squareform(dists), method="complete"),
        color_threshold=np.inf,
        orientation=orient,
    )
    clean_axis(dend_axes)
    return {"dendrogram": dend, "gridspec": gspec}


# Create heatmap axes for Matplotlib output
def get_mpl_heatmap_axes(dfr, fig, heatmap_gs):
    """Return axis for Matplotlib heatmap."""
    # Create heatmap axis
    heatmap_axes = fig.add_subplot(heatmap_gs[1, 1])
    heatmap_axes.set_xticks(np.linspace(0, dfr.shape[0] - 1, dfr.shape[0]))
    heatmap_axes.set_yticks(np.linspace(0, dfr.shape[0] - 1, dfr.shape[0]))
    heatmap_axes.grid(False)
    heatmap_axes.xaxis.tick_bottom()
    heatmap_axes.yaxis.tick_right()
    return heatmap_axes


def add_mpl_colorbar(dfr, fig, dend, params, orientation="row"):
    """Add class colorbars to Matplotlib heatmap."""
    for name in dfr.index[dend["dendrogram"]["leaves"]]:
        if name not in params.classes:
            params.classes[name] = name

    # Assign a numerical value to each class, for mpl
    classdict = {cls: idx for (idx, cls) in enumerate(params.classes.values())}

    # colourbar
    cblist = []
    for name in dfr.index[dend["dendrogram"]["leaves"]]:
        try:
            cblist.append(classdict[params.classes[name]])
        except KeyError:
            cblist.append(classdict[name])
    colbar = pd.Series(cblist)

    # Create colourbar axis - could capture if needed
    if orientation == "row":
        cbaxes = fig.add_subplot(dend["gridspec"][0, 1])
        cbaxes.imshow(
            [[cbar] for cbar in colbar.values],
            cmap=plt.get_cmap(pyani_config.MPL_CBAR),
            interpolation="nearest",
            aspect="auto",
            origin="lower",
        )
    else:
        cbaxes = fig.add_subplot(dend["gridspec"][1, 0])
        cbaxes.imshow(
            [colbar],
            cmap=plt.get_cmap(pyani_config.MPL_CBAR),
            interpolation="nearest",
            aspect="auto",
            origin="lower",
        )
    clean_axis(cbaxes)
    return colbar


# Add labels to the heatmap axes
def add_mpl_labels(heatmap_axes, rowlabels, collabels, params):
    """Add labels to Matplotlib heatmap axes, in-place."""
    if params.labels:
        # If a label mapping is missing, use the key text as fall back
        rowlabels = [params.labels.get(lab, lab) for lab in rowlabels]
        collabels = [params.labels.get(lab, lab) for lab in collabels]
    xlabs = heatmap_axes.set_xticklabels(collabels)
    ylabs = heatmap_axes.set_yticklabels(rowlabels)
    for label in xlabs:  # Rotate column labels
        label.set_rotation(90)
    for labset in (xlabs, ylabs):  # Smaller font
        for label in labset:
            label.set_fontsize(8)


# Add colour scale to heatmap
def add_mpl_colorscale(fig, heatmap_gs, ax_map, params, title=None):
    """Add colour scale to heatmap."""
    # Set tick intervals
    cbticks = [params.vmin + e * params.vdiff for e in (0, 0.25, 0.5, 0.75, 1)]
    if params.vmax > 10:
        exponent = int(floor(log10(params.vmax))) - 1
        cbticks = [int(round(e, -exponent)) for e in cbticks]

    scale_subplot = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=heatmap_gs[0, 0], wspace=0.0, hspace=0.0)
    scale_ax = fig.add_subplot(scale_subplot[0, 1])
    cbar = fig.colorbar(ax_map, scale_ax, ticks=cbticks)
    if title:
        cbar.set_label(title, fontsize=6)
    cbar.ax.yaxis.set_ticks_position("left")
    cbar.ax.yaxis.set_label_position("left")
    cbar.ax.tick_params(labelsize=6)
    cbar.outline.set_linewidth(0)
    return cbar


# Generate Matplotlib heatmap output
def heatmap_mpl(dfr, outfilename=None, title=None, params=None):
    """Returns matplotlib heatmap with cluster dendrograms.

    - dfr - pandas DataFrame with relevant data
    - outfilename - path to output file (indicates output format)
    - params - a list of parameters for plotting: [colormap, vmin, vmax]
    - labels - dictionary of alternative labels, keyed by default sequence
               labels
    - classes - dictionary of sequence classes, keyed by default sequence
                labels
    """
    # Layout figure grid and add title
    # Set figure size by the number of rows in the dataframe
    figsize = max(8, dfr.shape[0] * 0.175)
    fig = plt.figure(figsize=(figsize, figsize))
    # if title:
    #     fig.suptitle(title)
    heatmap_gs = gridspec.GridSpec(2, 2, wspace=0.0, hspace=0.0, width_ratios=[0.3, 1], height_ratios=[0.3, 1])

    # Add column and row dendrograms/axes to figure
    coldend = add_mpl_dendrogram(dfr, fig, heatmap_gs, orientation="col")
    rowdend = add_mpl_dendrogram(dfr, fig, heatmap_gs, orientation="row")

    # Add heatmap axes to figure, with rows/columns as in the dendrograms
    heatmap_axes = get_mpl_heatmap_axes(dfr, fig, heatmap_gs)
    ax_map = heatmap_axes.imshow(
        dfr.iloc[rowdend["dendrogram"]["leaves"], coldend["dendrogram"]["leaves"]],
        interpolation="nearest",
        cmap=params.cmap,
        origin="lower",
        vmin=params.vmin,
        vmax=params.vmax,
        aspect="auto",
    )

    # Are there class colourbars to add?
    if params.classes is not None:
        add_mpl_colorbar(dfr, fig, coldend, params, orientation="col")
        add_mpl_colorbar(dfr, fig, rowdend, params, orientation="row")

    # Add heatmap labels
    add_mpl_labels(
        heatmap_axes,
        dfr.index[rowdend["dendrogram"]["leaves"]],
        dfr.index[coldend["dendrogram"]["leaves"]],
        params,
    )

    # Add colour scale
    add_mpl_colorscale(fig, heatmap_gs, ax_map, params, title)

    # Return figure output, and write, if required
    plt.subplots_adjust(top=0.85)  # Leave room for title
    # fig.set_tight_layout(True)
    # We know that there is a UserWarning here about tight_layout and
    # using the Agg renderer on OSX, so catch and ignore it, for cleanliness.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        heatmap_gs.tight_layout(fig, h_pad=0.1, w_pad=0.5)
    if outfilename:
        fig.savefig(outfilename)
    return fig