File: hmmplot.py

package info (click to toggle)
nanopolish 0.14.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 11,760 kB
  • sloc: cpp: 22,200; ansic: 1,478; python: 814; makefile: 210; sh: 43; perl: 17
file content (118 lines) | stat: -rw-r--r-- 4,744 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#! /usr/bin/env python3
"""
Plot a random segmentation from a dataset.

Usage:
  $ python3 polya.out.tsv reads.fastq.readdb.index
"""
import h5py
import pandas as pd
import numpy as np
import argparse
import os
from random import choice
from collections import OrderedDict

# plotting libraries:
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import seaborn as sns


def load_fast5_signal(read_path):
    """Load a fast5 signal from read path; return as numpy array."""
    read_h5 = h5py.File(read_path, 'r')

    # get scaling parameters:
    offset = read_h5['UniqueGlobalKey']['channel_id'].attrs['offset']
    digitisation = read_h5['UniqueGlobalKey']['channel_id'].attrs['digitisation']
    read_range = read_h5['UniqueGlobalKey']['channel_id'].attrs['range']

    # get raw integer-encoded signal:
    rn = list(read_h5['Raw']['Reads'].keys())[0]
    signal = (read_range / digitisation) * (np.array(read_h5['Raw']['Reads'][rn]['Signal']) + offset)

    # close hdf object and return numpy signal:
    read_h5.close()
    return signal


def get_state_names(header):
    """Return a list of state-start columns in the header. E.g., `[leader_start, adapter_start, ..., transcript_start]`."""
    return list(filter(lambda name: (name[-6:] == '_start'), header))


def generate_color_palette(num_colors):
    """Generate a list (of length `num_colors`) of color IDs for matplotlib."""
    # TODO(this is a hack-ish solution. Generate it mathematically!!)
    colors = ['cyan','yellow','red','green','blue', 'orange', 'green']
    return colors[:num_colors]


def main(args):
    """Filter-in PASS-ing segmentations and plot a random segmented read to file."""
    # load dataframes:
    polya = pd.read_csv(args.polya_tsv, sep='\t')
    readdb = pd.read_csv(args.readdb, sep='\t', header=None, names=['readname','location'])

    # get the names of all state-index columns:
    state_starts = get_state_names(polya.columns.values.tolist())

    # get a random read, its segmentation, and its location:
    if (args.read is None):
        row_values  = choice(polya[polya['qc_tag'] == 'PASS'][['readname', *state_starts]].values).tolist()
        read_id = row_values.pop(0)
        state_start_indices = OrderedDict()
        for k in range(len(state_starts)):
            state_start_indices[state_starts[k]] = row_values[k]
        read_path = readdb[readdb['readname'] == read_id].values[0][1]
    else:
        try:
            read_df = polya[polya['readname'] == args.read]
            row_values = choice(read_df[read_df['qc_tag'] == 'PASS'][['readname', *state_starts]].values).tolist()
            read_id = row_values.pop(0)
            state_start_indices = OrderedDict()
            for k in range(len(state_starts)):
                state_start_indices[state_starts[k]] = row_values[k]
            read_path = readdb[readdb['readname'] == read_id].values[0][1]
        except:
            raise Exception("[hmmplot.py] read id={} could not be resolved".format(args.read))

    # load fast5 file:
    signal = load_fast5_signal(read_path)

    # create dictionary of start-stop indices for each region:
    start_stop_indices = {}
    stop_idxs = [state_start_indices[name] for name in state_starts[1:]] + [signal.shape[0]]
    colors = generate_color_palette(len(state_start_indices))
    for n, (name, start_idx) in enumerate(state_start_indices.items()):
        start_stop_indices[name] = ( start_idx, stop_idxs[n], colors[n] )

    # make segmentation plot:
    plt.figure(figsize=(18,6))
    plt.plot(signal)
    for k, v in start_stop_indices.items():
        plt.axvspan(v[0], v[1], color=v[2], alpha=0.35, label=k[:-6])
    plt.legend(loc='best')
    plt.xlim(0, signal.shape[0])
    plt.title("Segmentation: {}".format(read_id))
    plt.xlabel("Sample Index (3' to 5')")
    plt.ylabel("Current (pA)")
    if (args.out is None):
        plt.savefig("segmentation.{}.png".format(read_id))
    else:
        plt.savefig(args.out)
    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Plot a random passing segmentation from a polya output file.")
    parser.add_argument("polya_tsv", help="Output TSV of `nanopolish polya {...}`")
    parser.add_argument("readdb", help="ReadDB index file from `nanopolish index {...}`")
    parser.add_argument("--out", default=None, help="Where to put the output file. [./segmentation.<READ_ID>.png]")
    parser.add_argument("--read", default=None, help="Visualize a specific read. [random read]")
    args = parser.parse_args()
    assert(os.path.exists(args.polya_tsv)), "[ERR] {} does not exist".format(args.polya_tsv)
    assert(os.path.exists(args.readdb)), "[ERR] {} does not exist".format(args.readdb)
    main(args)