File: plotEnrichment.py

package info (click to toggle)
python-deeptools 3.5.6%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 34,456 kB
  • sloc: python: 14,503; xml: 4,212; sh: 33; makefile: 5
file content (588 lines) | stat: -rwxr-xr-x 25,241 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
#!/usr/bin/python3
# -*- coding: utf-8 -*-

import sys
import argparse
import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
from deeptools import cm  # noqa: F401
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import plotly.offline as py
import plotly.graph_objs as go

from deeptools.mapReduce import mapReduce, getUserRegion, blSubtract
from deeptools.getFragmentAndReadSize import get_read_and_fragment_length
from deeptools.utilities import getCommonChrNames, mungeChromosome, getTLen, smartLabels
from deeptools.bamHandler import openBam
from deeptoolsintervals import Enrichment, GTF
from deeptools.countReadsPerBin import CountReadsPerBin as cr
from deeptools import parserCommon


old_settings = np.seterr(all='ignore')


def parse_arguments(args=None):
    basic_args = plot_enrichment_args()

    # --region, --blackListFileName, -p and -v
    parent_parser = parserCommon.getParentArgParse(binSize=False)

    # --extend reads and such
    read_options = parserCommon.read_options()

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""
Tool for calculating and plotting the signal enrichment in either regions in BED
format or feature types (column 3) in GTF format. The underlying datapoints can also be output.
Metrics are plotted as a fraction of total reads. Regions in a BED file are assigned to the 'peak' feature.

detailed help:

  plotEnrichment -h

""",
        epilog='example usages:\n'
               'plotEnrichment -b file1.bam file2.bam --BED peaks.bed -o enrichment.png\n\n'
               ' \n\n',
        parents=[basic_args, parent_parser, read_options],
        usage='plotEnrichment -b sample1.bam sample2.bam --BED peaks.bed -o enrichment.png\n'
        'help: plotEnrichment -h / plotEnrichment --help\n')

    return parser


def plot_enrichment_args():
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    # define the arguments
    required.add_argument('--bamfiles', '-b',
                          metavar='file1.bam file2.bam',
                          help='List of indexed bam files separated by spaces.',
                          nargs='+',
                          required=True)

    required.add_argument('--BED',
                          help='Limits the enrichment analysis to '
                          'the regions specified in these BED/GTF files. Enrichment '
                          'is calculated as the number of reads overlapping each '
                          'feature type. The feature type is column 3 in a GTF file '
                          'and "peak" for BED files.',
                          metavar='FILE1.bed FILE2.bed',
                          nargs='+',
                          required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument('--plotFile', '-o',
                          help='File to save the plot to. The file extension determines the format, '
                          'so heatmap.pdf will save the heatmap in PDF format. '
                          'The available formats are: .png, '
                          '.eps, .pdf and .svg.',
                          type=parserCommon.writableFile,
                          metavar='FILE')

    optional.add_argument('--attributeKey',
                          help='Instead of deriving labels from the feature column in a GTF file, '
                          'use the given attribute key, such as gene_biotype. For BED files or '
                          'entries without the attribute key, None is used as the label.')

    optional.add_argument('--labels', '-l',
                          metavar='sample1 sample2',
                          help='User defined labels instead of default labels from '
                          'file names. '
                          'Multiple labels have to be separated by spaces, e.g. '
                          '--labels sample1 sample2 sample3',
                          nargs='+')

    optional.add_argument('--smartLabels',
                          action='store_true',
                          help='Instead of manually specifying labels for the input '
                          'BAM/BED/GTF files, this causes deepTools to use the file name '
                          'after removing the path and extension. For BED/GTF files, the '
                          'eventual region name will be overriden if specified inside '
                          'the file.')

    optional.add_argument('--regionLabels',
                          metavar="region1 region2",
                          help="For BED files, the label given to its region is "
                          "the file name, but this can be overridden by providing "
                          "a custom label. For GTF files this is ignored. Note "
                          "that if you provide labels, you MUST provide one for each "
                          "BED/GTF file, even though it will be ignored for GTF files.",
                          nargs='+')

    optional.add_argument('--plotTitle', '-T',
                          help='Title of the plot, to be printed on top of '
                          'the generated image. Leave blank for no title. (Default: %(default)s)',
                          default='')

    optional.add_argument('--plotFileFormat',
                          metavar='FILETYPE',
                          help='Image format type. If given, this option '
                          'overrides the image format based on the plotFile '
                          'ending. The available options are: png, '
                          'eps, pdf, plotly and svg.',
                          choices=['png', 'pdf', 'svg', 'eps', 'plotly'])

    optional.add_argument('--outRawCounts',
                          help='Save the counts per region to a tab-delimited file.',
                          type=parserCommon.writableFile,
                          metavar='FILE')

    optional.add_argument('--perSample',
                          help='Group the plots by sample, rather than by feature type (the default).',
                          action='store_true')

    optional.add_argument('--variableScales',
                          help='By default, the y-axes are always 0-100. This allows the axis range to be restricted.',
                          action='store_true')

    optional.add_argument('--plotHeight',
                          help='Plot height in cm. (Default: %(default)s)',
                          type=float,
                          default=20)

    optional.add_argument('--plotWidth',
                          help='Plot width in cm. The minimum value is 1 cm. (Default: %(default)s)',
                          type=float,
                          default=20)

    optional.add_argument('--colors',
                          help='List of colors to use '
                          'for the plotted lines. Color names '
                          'and html hex strings (e.g., #eeff22) '
                          'are accepted. The color names should '
                          'be space separated. For example, '
                          '--colors red blue green ',
                          nargs='+')

    optional.add_argument('--numPlotsPerRow',
                          help='Number of plots per row (Default: %(default)s)',
                          type=int,
                          default=4)

    optional.add_argument('--alpha',
                          default=0.9,
                          type=parserCommon.check_float_0_1,
                          help='The alpha channel (transparency) to use for the bars. '
                          'The default is 0.9 and values must be between 0 and 1.')

    optional.add_argument('--Offset',
                          help='Uses this offset inside of each read as the signal. This is useful in '
                          'cases like RiboSeq or GROseq, where the signal is 12, 15 or 0 bases past the '
                          'start of the read. This can be paired with the --filterRNAstrand option. '
                          'Note that negative values indicate offsets from the end of each read. A value '
                          'of 1 indicates the first base of the alignment (taking alignment orientation '
                          'into account). Likewise, a value of -1 is the last base of the alignment. An '
                          'offset of 0 is not permitted. If two values are specified, then they will be '
                          'used to specify a range of positions. Note that specifying something like '
                          '--Offset 5 -1 will result in the 5th through last position being used, which '
                          'is equivalent to trimming 4 bases from the 5-prime end of alignments.',
                          metavar='INT',
                          type=int,
                          nargs='+',
                          required=False)

    bed12 = parser.add_argument_group('BED12 arguments')

    bed12.add_argument('--keepExons',
                       help="For BED12 files, use each exon as a region, rather than columns 2/3",
                       action="store_true")

    return parser


def getBAMBlocks(read, defaultFragmentLength, centerRead, offset=None):
    """
    This is basically get_fragment_from_read from countReadsPerBin
    """
    blocks = None
    maxPairedFragmentLength = 0
    if defaultFragmentLength != "read length":
        maxPairedFragmentLength = 4 * defaultFragmentLength

    if defaultFragmentLength == 'read length':
        blocks = read.get_blocks()
    else:
        if cr.is_proper_pair(read, maxPairedFragmentLength):
            if read.is_reverse:
                fragmentStart = read.next_reference_start
                fragmentEnd = read.reference_end
            else:
                fragmentStart = read.reference_start
                # the end of the fragment is defined as
                # the start of the forward read plus the insert length
                fragmentEnd = read.reference_start + abs(read.template_length)
        # Extend using the default fragment length
        else:
            if read.is_reverse:
                fragmentStart = read.reference_end - defaultFragmentLength
                fragmentEnd = read.reference_end
            else:
                fragmentStart = read.reference_start
                fragmentEnd = read.reference_start + defaultFragmentLength
        if centerRead:
            fragmentCenter = fragmentEnd - (fragmentEnd - fragmentStart) / 2
            fragmentStart = fragmentCenter - read.infer_query_length(always=False) / 2
            fragmentEnd = fragmentStart + read.infer_query_length(always=False)

        assert fragmentStart < fragmentEnd, "fragment start greater than fragment" \
                                            "end for read {}".format(read.query_name)
        blocks = [(int(fragmentStart), int(fragmentEnd))]

    # Handle read offsets, if needed
    if offset is not None:
        rv = [(None, None)]
        if len(offset) > 1:
            if offset[0] > 0:
                offset[0] -= 1
            if offset[1] < 0:
                offset[1] += 1
        else:
            if offset[0] > 0:
                offset[0] -= 1
                offset = [offset[0], offset[0] + 1]
            else:
                offset = [offset[0], None]
        if offset[1] == 0:
            # -1 gets switched to 0, which screws things up
            offset = (offset[0], None)

        stretch = []
        # For the sake of simplicity, convert [(10, 20), (30, 40)] to [10, 11, 12, 13, ..., 40]
        # Then subset accordingly
        for block in blocks:
            stretch.extend(range(block[0], block[1]))
        if read.is_reverse:
            stretch = stretch[::-1]
        try:
            foo = stretch[offset[0]:offset[1]]
        except:
            return rv

        if len(foo) == 0:
            return rv
        if read.is_reverse:
            foo = foo[::-1]
        # Convert the stretch back to a list of tuples
        foo = np.array(foo)
        d = foo[1:] - foo[:-1]
        idx = np.argwhere(d > 1).flatten().tolist()  # This now holds the interval bounds as a list
        idx.append(-1)
        last = 0
        blocks = []
        for i in idx:
            blocks.append((foo[last].astype("int"), foo[i].astype("int") + 1))
            last = i + 1
    return blocks


def getEnrichment_worker(arglist):
    """
    This is the worker function of plotEnrichment.

    In short, given a region, iterate over all reads **starting** in it.
    Filter/extend them as requested and check each for an overlap with
    findOverlaps. For each overlap, increment the counter for that feature.
    """
    chrom, start, end, args, defaultFragmentLength = arglist
    if args.verbose:
        sys.stderr.write("Processing {}:{}-{}\n".format(chrom, start, end))

    olist = []
    total = [0] * len(args.bamfiles)
    for idx, f in enumerate(args.bamfiles):
        odict = dict()
        for x in gtf.features:
            odict[x] = 0
        fh = openBam(f)

        chrom = mungeChromosome(chrom, fh.references)

        lpos = None
        prev_pos = set()
        for read in fh.fetch(chrom, start, end):
            # Filter
            if read.pos < start:
                # Ensure that a given alignment is processed only once
                continue
            if read.flag & 4:
                continue
            if args.minMappingQuality and read.mapq < args.minMappingQuality:
                continue
            if args.samFlagInclude and read.flag & args.samFlagInclude != args.samFlagInclude:
                continue
            if args.samFlagExclude and read.flag & args.samFlagExclude != 0:
                continue
            tLen = getTLen(read)
            if args.minFragmentLength > 0 and tLen < args.minFragmentLength:
                continue
            if args.maxFragmentLength > 0 and tLen > args.maxFragmentLength:
                continue
            if args.ignoreDuplicates:
                # Assuming more or less concordant reads, use the fragment bounds, otherwise the start positions
                if tLen >= 0:
                    s = read.pos
                    e = s + tLen
                else:
                    s = read.pnext
                    e = s - tLen
                if read.reference_id != read.next_reference_id:
                    e = read.pnext
                if lpos is not None and lpos == read.reference_start \
                        and (s, e, read.next_reference_id, read.is_reverse) in prev_pos:
                    continue
                if lpos != read.reference_start:
                    prev_pos.clear()
                lpos = read.reference_start
                prev_pos.add((s, e, read.next_reference_id, read.is_reverse))
            total[idx] += 1

            # Get blocks, possibly extending
            features = gtf.findOverlaps(chrom, getBAMBlocks(read, defaultFragmentLength, args.centerReads, args.Offset))

            if features is not None and len(features) > 0:
                for x in features:
                    odict[x] += 1
        olist.append(odict)
    return olist, gtf.features, total


def plotEnrichment(args, featureCounts, totalCounts, features):
    # get the number of rows and columns
    if args.perSample:
        totalPlots = len(args.bamfiles)
        barsPerPlot = len(features)
    else:
        totalPlots = len(features)
        barsPerPlot = len(args.bamfiles)
    cols = min(args.numPlotsPerRow, totalPlots)
    rows = np.ceil(totalPlots / float(args.numPlotsPerRow)).astype(int)

    # Handle the colors
    if not args.colors:
        cmap_plot = plt.get_cmap('jet')
        args.colors = cmap_plot(np.arange(barsPerPlot, dtype=float) / float(barsPerPlot))
        if args.plotFileFormat == 'plotly':
            args.colors = range(barsPerPlot)
    elif len(args.colors) < barsPerPlot:
        sys.exit("Error: {0} colors were requested, but {1} were needed!".format(len(args.colors), barsPerPlot))

    data = []
    if args.plotFileFormat == 'plotly':
        fig = go.Figure()
        fig['layout'].update(title=args.plotTitle)
        domainWidth = .9 / cols
        domainHeight = .9 / rows
        bufferHeight = 0.0
        if rows > 1:
            bufferHeight = 0.1 / (rows - 1)
        bufferWidth = 0.0
        if cols > 1:
            bufferWidth = 0.1 / (cols - 1)
    else:
        grids = gridspec.GridSpec(rows, cols)
        plt.rcParams['font.size'] = 10.0

        # convert cm values to inches
        fig = plt.figure(figsize=(args.plotWidth / 2.54, args.plotHeight / 2.54))
        fig.suptitle(args.plotTitle, y=(1 - (0.06 / args.plotHeight)))

    for i in range(totalPlots):
        col = i % cols
        row = np.floor(i / float(args.numPlotsPerRow)).astype(int)

        if args.perSample:
            xlabels = features
            ylabel = "% alignments in {0}".format(args.labels[i])
            vals = [featureCounts[i][foo] for foo in features]
            vals = 100 * np.array(vals, dtype='float64') / totalCounts[i]
        else:
            xlabels = args.labels
            ylabel = "% {0}".format(features[i])
            vals = [foo[features[i]] for foo in featureCounts]
            vals = 100 * np.array(vals, dtype='float64') / np.array(totalCounts, dtype='float64')

        if args.plotFileFormat == 'plotly':
            xanchor = 'x{}'.format(i + 1)
            yanchor = 'y{}'.format(i + 1)
            base = row * (domainHeight + bufferHeight)
            domain = [base, base + domainHeight]
            fig['layout']['xaxis{}'.format(i + 1)] = {'domain': domain, 'anchor': yanchor}
            base = col * (domainWidth + bufferWidth)
            domain = [base, base + domainWidth]
            fig['layout']['yaxis{}'.format(i + 1)] = {'domain': domain, 'anchor': xanchor, 'title': ylabel}
            if args.variableScales is False:
                fig['layout']['yaxis{}'.format(i + 1)].update(range=[0, 100])
            trace = go.Bar(x=xlabels,
                           y=vals,
                           opacity=args.alpha,
                           orientation='v',
                           showlegend=False,
                           xaxis=xanchor,
                           yaxis=yanchor,
                           name=ylabel,
                           marker={'color': args.colors, 'line': {'color': args.colors}})
            data.append(trace)
        else:
            ax = plt.subplot(grids[row, col])
            ax.bar(np.arange(vals.shape[0]), vals, width=1.0, bottom=0.0, align='center', color=args.colors, edgecolor=args.colors, alpha=args.alpha)
            ax.set_ylabel(ylabel)
            ax.set_xticks(np.arange(vals.shape[0]))
            ax.set_xticklabels(xlabels, rotation='vertical')
            if args.variableScales is False:
                ax.set_ylim(0.0, 100.0)

    if args.plotFileFormat == 'plotly':
        fig.add_traces(data)
        py.plot(fig, filename=args.plotFile, auto_open=False)
        # colors
    else:
        plt.subplots_adjust(wspace=0.05, hspace=0.3, bottom=0.15, top=0.80)
        plt.tight_layout()
        plt.savefig(args.plotFile, dpi=200, format=args.plotFileFormat)
        plt.close()


def getChunkLength(args, chromSize):
    """
    There's no point in parsing the GTF time over and over again needlessly.
    Emprically, it seems that adding ~4x the number of workers is ideal, since
    coverage is non-uniform. This is a heuristic way of approximating that.

    Note that if there are MANY small contigs and a few large ones (e.g., the
    max and median lengths are >10x different, then it's best to take a
    different tack.
    """

    if args.region:
        chromSize, region_start, region_end, genomeChunkLength = getUserRegion(chromSize, args.region)
        rv = np.ceil((region_start - region_end) / float(4 * args.numberOfProcessors)).astype(int)
        return max(1, rv)

    bl = None
    if args.blackListFileName:
        bl = GTF(args.blackListFileName)

    lengths = []
    for k, v in chromSize:
        regs = blSubtract(bl, k, [0, v])
        for reg in regs:
            lengths.append(reg[1] - reg[0])

    if len(lengths) >= 4 * args.numberOfProcessors:
        rv = np.median(lengths).astype(int)
        # In cases like dm6 or GRCh38, there are a LOT of really small contigs, which will cause the median to be small and performance to tank
        if np.max(lengths) >= 10 * rv:
            rv = np.ceil(np.sum(lengths) / (4.0 * args.numberOfProcessors)).astype(int)
    else:
        rv = np.ceil(np.sum(lengths) / (4.0 * args.numberOfProcessors)).astype(int)

    return max(1, rv)


def main(args=None):

    args = parse_arguments().parse_args(args)

    if not args.outRawCounts and not args.plotFile:
        sys.exit("Error: You need to specify at least one of --plotFile or --outRawCounts!\n")

    if args.labels is None:
        args.labels = args.bamfiles
    if args.smartLabels:
        args.labels = smartLabels(args.bamfiles)
    if len(args.labels) != len(args.bamfiles):
        sys.exit("Error: The number of labels ({0}) does not match the number of BAM files ({1})!".format(len(args.labels), len(args.bamfiles)))

    # Ensure that if we're given an attributeKey that it's not empty
    if args.attributeKey and args.attributeKey == "":
        args.attributeKey = None

    global gtf
    if not args.regionLabels and args.smartLabels:
        args.regionLabels = smartLabels(args.BED)
    gtf = Enrichment(args.BED, keepExons=args.keepExons, labels=args.regionLabels, attributeKey=args.attributeKey)

    # Get fragment size and chromosome dict
    fhs = [openBam(x) for x in args.bamfiles]
    chromSize, non_common_chr = getCommonChrNames(fhs, verbose=args.verbose)
    for fh in fhs:
        fh.close()

    frag_len_dict, read_len_dict = get_read_and_fragment_length(args.bamfiles[0],
                                                                return_lengths=False,
                                                                blackListFileName=args.blackListFileName,
                                                                numberOfProcessors=args.numberOfProcessors,
                                                                verbose=args.verbose)
    if args.extendReads:
        if args.extendReads is True:
            # try to guess fragment length if the bam file contains paired end reads
            if frag_len_dict:
                defaultFragmentLength = frag_len_dict['median']
            else:
                sys.exit("*ERROR*: library is not paired-end. Please provide an extension length.")
            if args.verbose:
                print("Fragment length based on paired en data "
                      "estimated to be {0}".format(frag_len_dict['median']))
        elif args.extendReads < read_len_dict['median']:
            sys.stderr.write("*WARNING*: read extension is smaller than read length (read length = {}). "
                             "Reads will not be extended.\n".format(int(read_len_dict['median'])))
            defaultFragmentLength = 'read length'
        elif args.extendReads > 2000:
            sys.exit("*ERROR*: read extension must be smaller that 2000. Value give: {} ".format(args.extendReads))
        else:
            defaultFragmentLength = args.extendReads
    else:
        defaultFragmentLength = 'read length'

    # Get the chunkLength
    chunkLength = getChunkLength(args, chromSize)

    # Map reduce to get the counts/file/feature
    res = mapReduce([args, defaultFragmentLength],
                    getEnrichment_worker,
                    chromSize,
                    genomeChunkLength=chunkLength,
                    region=args.region,
                    blackListFileName=args.blackListFileName,
                    numberOfProcessors=args.numberOfProcessors,
                    verbose=args.verbose)

    features = res[0][1]
    featureCounts = []
    for i in list(range(len(args.bamfiles))):
        d = dict()
        for x in features:
            d[x] = 0
        featureCounts.append(d)

    # res is a list, with each element a list (length len(args.bamfiles)) of dicts
    totalCounts = [0] * len(args.bamfiles)
    for x in res:
        for i, y in enumerate(x[2]):
            totalCounts[i] += y
        for i, y in enumerate(x[0]):
            for k, v in y.items():
                featureCounts[i][k] += v

    # Make a plot
    if args.plotFile:
        plotEnrichment(args, featureCounts, totalCounts, features)

    # Raw counts
    if args.outRawCounts:
        of = open(args.outRawCounts, "w")
        of.write("file\tfeatureType\tpercent\tfeatureReadCount\ttotalReadCount\n")
        for i, x in enumerate(args.labels):
            for k, v in featureCounts[i].items():
                of.write("{0}\t{1}\t{2:5.2f}\t{3}\t{4}\n".format(x, k, (100.0 * v) / totalCounts[i], v, totalCounts[i]))
        of.close()