File: heatmapper.py

package info (click to toggle)
python-deeptools 3.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 34,624 kB
  • sloc: python: 14,765; xml: 4,090; sh: 38; makefile: 11
file content (1372 lines) | stat: -rw-r--r-- 58,994 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
import sys
import gzip
from collections import OrderedDict
import numpy as np
from copy import deepcopy

import pyBigWig
from deeptools import getScorePerBigWigBin
from deeptools import mapReduce
from deeptools.utilities import toString, toBytes, smartLabels
from deeptools.heatmapper_utilities import getProfileTicks


old_settings = np.seterr(all='ignore')


def chopRegions(exonsInput, left=0, right=0):
    """
    exons is a list of (start, end) tuples. The goal is to chop these into
    separate lists of tuples, to take care or unscaled regions. "left" and
    "right" denote regions of a given size to exclude from the normal binning
    process (unscaled regions).

    This outputs three lists of (start, end) tuples:

    leftBins: 5' unscaled regions
    bodyBins: body bins for scaling
    rightBins: 3' unscaled regions

    In addition are two integers
    padLeft: Number of bases of padding on the left (due to not being able to fulfill "left")
    padRight: As above, but on the right side
    """
    leftBins = []
    rightBins = []
    padLeft = 0
    padRight = 0
    exons = deepcopy(exonsInput)
    while len(exons) > 0 and left > 0:
        width = exons[0][1] - exons[0][0]
        if width <= left:
            leftBins.append(exons[0])
            del exons[0]
            left -= width
        else:
            leftBins.append((exons[0][0], exons[0][0] + left))
            exons[0] = (exons[0][0] + left, exons[0][1])
            left = 0
    if left > 0:
        padLeft = left

    while len(exons) > 0 and right > 0:
        width = exons[-1][1] - exons[-1][0]
        if width <= right:
            rightBins.append(exons[-1])
            del exons[-1]
            right -= width
        else:
            rightBins.append((exons[-1][1] - right, exons[-1][1]))
            exons[-1] = (exons[-1][0], exons[-1][1] - right)
            right = 0
    if right > 0:
        padRight = right

    return leftBins, exons, rightBins[::-1], padLeft, padRight


def chopRegionsFromMiddle(exonsInput, left=0, right=0):
    """
    Like chopRegions(), above, but returns two lists of tuples on each side of
    the center point of the exons.

    The steps are as follow:

     1) Find the center point of the set of exons (e.g., [(0, 200), (300, 400), (800, 900)] would be centered at 200)
       * If a given exon spans the center point then the exon is split
     2) The given number of bases at the end of the left-of-center list are extracted
       * If the set of exons don't contain enough bases, then padLeft is incremented accordingly
     3) As above but for the right-of-center list
     4) A tuple of (#2, #3, pading on the left, and padding on the right) is returned
    """
    leftBins = []
    rightBins = []
    size = sum([x[1] - x[0] for x in exonsInput])
    middle = size // 2
    cumulativeSum = 0
    padLeft = 0
    padRight = 0
    exons = deepcopy(exonsInput)

    # Split exons in half
    for exon in exons:
        size = exon[1] - exon[0]
        if cumulativeSum >= middle:
            rightBins.append(exon)
        elif cumulativeSum + size < middle:
            leftBins.append(exon)
        else:
            # Don't add 0-width exonic bins!
            if exon[0] < exon[1] - cumulativeSum - size + middle:
                leftBins.append((exon[0], exon[1] - cumulativeSum - size + middle))
            if exon[1] - cumulativeSum - size + middle < exon[1]:
                rightBins.append((exon[1] - cumulativeSum - size + middle, exon[1]))
        cumulativeSum += size

    # Trim leftBins/adjust padLeft
    lSum = sum([x[1] - x[0] for x in leftBins])
    if lSum > left:
        lSum = 0
        for i, exon in enumerate(leftBins[::-1]):
            size = exon[1] - exon[0]
            if lSum + size > left:
                leftBins[-i - 1] = (exon[1] + lSum - left, exon[1])
                break
            lSum += size
            if lSum == left:
                break
        i += 1
        if i < len(leftBins):
            leftBins = leftBins[-i:]
    elif lSum < left:
        padLeft = left - lSum

    # Trim rightBins/adjust padRight
    rSum = sum([x[1] - x[0] for x in rightBins])
    if rSum > right:
        rSum = 0
        for i, exon in enumerate(rightBins):
            size = exon[1] - exon[0]
            if rSum + size > right:
                rightBins[i] = (exon[0], exon[1] - rSum - size + right)
                break
            rSum += size
            if rSum == right:
                break
        rightBins = rightBins[:i + 1]
    elif rSum < right:
        padRight = right - rSum

    return leftBins, rightBins, padLeft, padRight


def trimZones(zones, maxLength, binSize, padRight):
    """
    Given a (variable length) list of lists of (start, end) tuples, trim/remove and tuple that extends past maxLength (e.g., the end of a chromosome)

    Returns the trimmed zones and padding
    """
    output = []
    for zone, nbins in zones:
        outZone = []
        changed = False
        for reg in zone:
            if reg[0] >= maxLength:
                changed = True
                padRight += reg[1] - reg[0]
                continue

            if reg[1] > maxLength:
                changed = True
                padRight += reg[1] - maxLength
                reg = (reg[0], maxLength)
            if reg[1] > reg[0]:
                outZone.append(reg)
        if changed:
            nBins = sum(x[1] - x[0] for x in outZone) // binSize
        else:
            nBins = nbins
        output.append((outZone, nBins))
    return output, padRight


def compute_sub_matrix_wrapper(args):
    return heatmapper.compute_sub_matrix_worker(*args)


class heatmapper(object):
    """
    Class to handle the reading and
    plotting of matrices.
    """

    def __init__(self):
        self.parameters = None
        self.lengthDict = None
        self.matrix = None
        self.regions = None
        self.blackList = None
        self.quiet = True
        # These are parameters that were single values in versions <3 but are now internally lists. See issue #614
        self.special_params = set(['unscaled 5 prime', 'unscaled 3 prime', 'body', 'downstream', 'upstream', 'ref point', 'bin size'])

    def getTicks(self, idx):
        """
        This is essentially a wrapper around getProfileTicks to accomdate the fact that each column has its own ticks.
        """
        xticks, xtickslabel = getProfileTicks(self, self.reference_point_label[idx], self.startLabel, self.endLabel, idx)
        return xticks, xtickslabel

    def computeMatrix(self, score_file_list, regions_file, parameters, blackListFileName=None, verbose=False, allArgs=None):
        """
        Splits into
        multiple cores the computation of the scores
        per bin for each region (defined by a hash '#'
        in the regions (BED/GFF) file.
        """
        if parameters['body'] > 0 and \
                parameters['body'] % parameters['bin size'] > 0:
            exit("The --regionBodyLength has to be "
                 "a multiple of --binSize.\nCurrently the "
                 "values are {} {} for\nregionsBodyLength and "
                 "binSize respectively\n".format(parameters['body'],
                                                 parameters['bin size']))

        # the beforeRegionStartLength is extended such that
        # length is a multiple of binSize
        if parameters['downstream'] % parameters['bin size'] > 0:
            exit("Length of region after the body has to be "
                 "a multiple of --binSize.\nCurrent value "
                 "is {}\n".format(parameters['downstream']))

        if parameters['upstream'] % parameters['bin size'] > 0:
            exit("Length of region before the body has to be a multiple of "
                 "--binSize\nCurrent value is {}\n".format(parameters['upstream']))

        if parameters['unscaled 5 prime'] % parameters['bin size'] > 0:
            exit("Length of the unscaled 5 prime region has to be a multiple of "
                 "--binSize\nCurrent value is {}\n".format(parameters['unscaled 5 prime']))

        if parameters['unscaled 3 prime'] % parameters['bin size'] > 0:
            exit("Length of the unscaled 5 prime region has to be a multiple of "
                 "--binSize\nCurrent value is {}\n".format(parameters['unscaled 3 prime']))

        if parameters['unscaled 5 prime'] + parameters['unscaled 3 prime'] > 0 and parameters['body'] == 0:
            exit('Unscaled 5- and 3-prime regions only make sense with the scale-regions subcommand.\n')

        # Take care of GTF options
        transcriptID = "transcript"
        exonID = "exon"
        transcript_id_designator = "transcript_id"
        keepExons = False
        self.quiet = False
        if allArgs is not None:
            allArgs = vars(allArgs)
            transcriptID = allArgs.get("transcriptID", transcriptID)
            exonID = allArgs.get("exonID", exonID)
            transcript_id_designator = allArgs.get("transcript_id_designator", transcript_id_designator)
            keepExons = allArgs.get("keepExons", keepExons)
            self.quiet = allArgs.get("quiet", self.quiet)

        chromSizes, _ = getScorePerBigWigBin.getChromSizes(score_file_list)
        res, labels = mapReduce.mapReduce([score_file_list, parameters],
                                          compute_sub_matrix_wrapper,
                                          chromSizes,
                                          self_=self,
                                          bedFile=regions_file,
                                          blackListFileName=blackListFileName,
                                          numberOfProcessors=parameters['proc number'],
                                          includeLabels=True,
                                          transcriptID=transcriptID,
                                          exonID=exonID,
                                          transcript_id_designator=transcript_id_designator,
                                          keepExons=keepExons,
                                          verbose=verbose)
        # each worker in the pool returns a tuple containing
        # the submatrix data, the regions that correspond to the
        # submatrix, and the number of regions lacking scores
        # Since this is largely unsorted, we need to sort by group

        # merge all the submatrices into matrix
        matrix = np.concatenate([r[0] for r in res], axis=0)
        regions = []
        regions_no_score = 0
        for idx in range(len(res)):
            if len(res[idx][1]):
                regions.extend(res[idx][1])
                regions_no_score += res[idx][2]
        groups = [x[3] for x in regions]
        foo = sorted(zip(groups, list(range(len(regions))), regions))
        sortIdx = [x[1] for x in foo]
        regions = [x[2] for x in foo]
        matrix = matrix[sortIdx]

        # mask invalid (nan) values
        matrix = np.ma.masked_invalid(matrix)

        assert matrix.shape[0] == len(regions), \
            "matrix length does not match regions length"

        if len(regions) == 0:
            sys.stderr.write("\nERROR: Either the BED file does not contain any valid regions or there are none remaining after filtering.\n")
            exit(1)
        if regions_no_score == len(regions):
            exit("\nERROR: None of the BED regions could be found in the bigWig"
                 "file.\nPlease check that the bigwig file is valid and "
                 "that the chromosome names between the BED file and "
                 "the bigWig file correspond to each other\n")

        if regions_no_score > len(regions) * 0.75:
            file_type = 'bigwig' if score_file_list[0].endswith(".bw") else "BAM"
            prcnt = 100 * float(regions_no_score) / len(regions)
            sys.stderr.write(
                "\n\nWarning: {0:.2f}% of regions are *not* associated\n"
                "to any score in the given {1} file. Check that the\n"
                "chromosome names from the BED file are consistent with\n"
                "the chromosome names in the given {2} file and that both\n"
                "files refer to the same species\n\n".format(prcnt,
                                                             file_type,
                                                             file_type))

        self.parameters = parameters

        numcols = matrix.shape[1]
        num_ind_cols = self.get_num_individual_matrix_cols()
        sample_boundaries = list(range(0, numcols + num_ind_cols, num_ind_cols))
        if allArgs is not None and allArgs['samplesLabel'] is not None:
            sample_labels = allArgs['samplesLabel']
        else:
            sample_labels = smartLabels(score_file_list)

        # Determine the group boundaries
        group_boundaries = []
        group_labels_filtered = []
        last_idx = -1
        for x in range(len(regions)):
            if regions[x][3] != last_idx:
                last_idx = regions[x][3]
                group_boundaries.append(x)
                group_labels_filtered.append(labels[last_idx])
        group_boundaries.append(len(regions))

        # check if a given group is too small. Groups that
        # are too small can't be plotted and an exception is thrown.
        group_len = np.diff(group_boundaries)
        if len(group_len) > 1:
            sum_len = sum(group_len)
            group_frac = [float(x) / sum_len for x in group_len]
            if min(group_frac) <= 0.002:
                sys.stderr.write(
                    "One of the groups defined in the bed file is "
                    "too small.\nGroups that are too small can't be plotted. "
                    "\n")

        self.matrix = _matrix(regions, matrix,
                              group_boundaries,
                              sample_boundaries,
                              group_labels_filtered,
                              sample_labels)

        if parameters['skip zeros']:
            self.matrix.removeempty()

    @staticmethod
    def compute_sub_matrix_worker(self, chrom, start, end, score_file_list, parameters, regions):
        """
        Returns
        -------
        numpy matrix
            A numpy matrix that contains per each row the values found per each of the regions given
        """
        if parameters['verbose']:
            sys.stderr.write("Processing {}:{}-{}\n".format(chrom, start, end))

        # read BAM or scores file
        score_file_handles = []
        for sc_file in score_file_list:
            score_file_handles.append(pyBigWig.open(sc_file))

        # determine the number of matrix columns based on the lengths
        # given by the user, times the number of score files
        matrix_cols = len(score_file_list) * \
            ((parameters['downstream'] +
              parameters['unscaled 5 prime'] + parameters['unscaled 3 prime'] +
              parameters['upstream'] + parameters['body']) //
             parameters['bin size'])

        # create an empty matrix to store the values
        sub_matrix = np.zeros((len(regions), matrix_cols))
        sub_matrix[:] = np.NAN

        j = 0
        sub_regions = []
        regions_no_score = 0
        for transcript in regions:
            feature_chrom = transcript[0]
            exons = transcript[1]
            feature_start = exons[0][0]
            feature_end = exons[-1][1]
            feature_name = transcript[2]
            feature_strand = transcript[4]
            padLeft = 0
            padRight = 0
            padLeftNaN = 0
            padRightNaN = 0
            upstream = []
            downstream = []

            # get the body length
            body_length = np.sum([x[1] - x[0] for x in exons]) - parameters['unscaled 5 prime'] - parameters['unscaled 3 prime']

            # print some information
            if parameters['body'] > 0 and \
                    body_length < parameters['bin size']:
                if not self.quiet:
                    sys.stderr.write("A region that is shorter than the bin size (possibly only after accounting for unscaled regions) was found: "
                                     "({0}) {1} {2}:{3}:{4}. Skipping...\n".format((body_length - parameters['unscaled 5 prime'] - parameters['unscaled 3 prime']),
                                                                                   feature_name, feature_chrom,
                                                                                   feature_start, feature_end))
                coverage = np.zeros(matrix_cols)
                if not parameters['missing data as zero']:
                    coverage[:] = np.nan
            else:
                if feature_strand == '-':
                    if parameters['downstream'] > 0:
                        upstream = [(feature_start - parameters['downstream'], feature_start)]
                    if parameters['upstream'] > 0:
                        downstream = [(feature_end, feature_end + parameters['upstream'])]
                    unscaled5prime, body, unscaled3prime, padLeft, padRight = chopRegions(exons, left=parameters['unscaled 3 prime'], right=parameters['unscaled 5 prime'])
                    # bins per zone
                    a = parameters['downstream'] // parameters['bin size']
                    b = parameters['unscaled 3 prime'] // parameters['bin size']
                    d = parameters['unscaled 5 prime'] // parameters['bin size']
                    e = parameters['upstream'] // parameters['bin size']
                else:
                    if parameters['upstream'] > 0:
                        upstream = [(feature_start - parameters['upstream'], feature_start)]
                    if parameters['downstream'] > 0:
                        downstream = [(feature_end, feature_end + parameters['downstream'])]
                    unscaled5prime, body, unscaled3prime, padLeft, padRight = chopRegions(exons, left=parameters['unscaled 5 prime'], right=parameters['unscaled 3 prime'])
                    a = parameters['upstream'] // parameters['bin size']
                    b = parameters['unscaled 5 prime'] // parameters['bin size']
                    d = parameters['unscaled 3 prime'] // parameters['bin size']
                    e = parameters['downstream'] // parameters['bin size']
                c = parameters['body'] // parameters['bin size']

                # build zones (each is a list of tuples)
                #  zone0: region before the region start,
                #  zone1: unscaled 5 prime region
                #  zone2: the body of the region
                #  zone3: unscaled 3 prime region
                #  zone4: the region from the end of the region downstream
                #  the format for each zone is: [(start, end), ...], number of bins
                # Note that for "reference-point", upstream/downstream will go
                # through the exons (if requested) and then possibly continue
                # on the other side (unless parameters['nan after end'] is true)
                if parameters['body'] > 0:
                    zones = [(upstream, a), (unscaled5prime, b), (body, c), (unscaled3prime, d), (downstream, e)]
                elif parameters['ref point'] == 'TES':  # around TES
                    if feature_strand == '-':
                        downstream, body, unscaled3prime, padRight, _ = chopRegions(exons, left=parameters['upstream'])
                        if padRight > 0 and parameters['nan after end'] is True:
                            padRightNaN += padRight
                        elif padRight > 0:
                            downstream.append((downstream[-1][1], downstream[-1][1] + padRight))
                        padRight = 0
                    else:
                        unscale5prime, body, upstream, _, padLeft = chopRegions(exons, right=parameters['upstream'])
                        if padLeft > 0 and parameters['nan after end'] is True:
                            padLeftNaN += padLeft
                        elif padLeft > 0:
                            upstream.insert(0, (upstream[0][0] - padLeft, upstream[0][0]))
                        padLeft = 0
                    e = np.sum([x[1] - x[0] for x in downstream]) // parameters['bin size']
                    a = np.sum([x[1] - x[0] for x in upstream]) // parameters['bin size']
                    zones = [(upstream, a), (downstream, e)]
                elif parameters['ref point'] == 'center':  # at the region center
                    if feature_strand == '-':
                        upstream, downstream, padLeft, padRight = chopRegionsFromMiddle(exons, left=parameters['downstream'], right=parameters['upstream'])
                    else:
                        upstream, downstream, padLeft, padRight = chopRegionsFromMiddle(exons, left=parameters['upstream'], right=parameters['downstream'])
                    if padLeft > 0 and parameters['nan after end'] is True:
                        padLeftNaN += padLeft
                    elif padLeft > 0:
                        if len(upstream) > 0:
                            upstream.insert(0, (upstream[0][0] - padLeft, upstream[0][0]))
                        else:
                            upstream = [(downstream[0][0] - padLeft, downstream[0][0])]
                    padLeft = 0
                    if padRight > 0 and parameters['nan after end'] is True:
                        padRightNaN += padRight
                    elif padRight > 0:
                        downstream.append((downstream[-1][1], downstream[-1][1] + padRight))
                    padRight = 0
                    a = np.sum([x[1] - x[0] for x in upstream]) // parameters['bin size']
                    e = np.sum([x[1] - x[0] for x in downstream]) // parameters['bin size']
                    # It's possible for a/e to be floats or 0 yet upstream/downstream isn't empty
                    if a < 1:
                        upstream = []
                        a = 0
                    if e < 1:
                        downstream = []
                        e = 0
                    zones = [(upstream, a), (downstream, e)]
                else:  # around TSS
                    if feature_strand == '-':
                        unscale5prime, body, upstream, _, padLeft = chopRegions(exons, right=parameters['downstream'])
                        if padLeft > 0 and parameters['nan after end'] is True:
                            padLeftNaN += padLeft
                        elif padLeft > 0:
                            upstream.insert(0, (upstream[0][0] - padLeft, upstream[0][0]))
                        padLeft = 0
                    else:
                        downstream, body, unscaled3prime, padRight, _ = chopRegions(exons, left=parameters['downstream'])
                        if padRight > 0 and parameters['nan after end'] is True:
                            padRightNaN += padRight
                        elif padRight > 0:
                            downstream.append((downstream[-1][1], downstream[-1][1] + padRight))
                        padRight = 0
                    a = np.sum([x[1] - x[0] for x in upstream]) // parameters['bin size']
                    e = np.sum([x[1] - x[0] for x in downstream]) // parameters['bin size']
                    zones = [(upstream, a), (downstream, e)]

                foo = parameters['upstream']
                bar = parameters['downstream']
                if feature_strand == '-':
                    foo, bar = bar, foo
                if padLeftNaN > 0:
                    expected = foo // parameters['bin size']
                    padLeftNaN = int(round(float(padLeftNaN) / parameters['bin size']))
                    if expected - padLeftNaN - a > 0:
                        padLeftNaN += 1
                if padRightNaN > 0:
                    expected = bar // parameters['bin size']
                    padRightNaN = int(round(float(padRightNaN) / parameters['bin size']))
                    if expected - padRightNaN - e > 0:
                        padRightNaN += 1

                coverage = []
                # compute the values for each of the files being processed.
                # "cov" is a numpy array of bins
                for sc_handler in score_file_handles:
                    # We're only supporting bigWig files at this point
                    cov = heatmapper.coverage_from_big_wig(
                        sc_handler, feature_chrom, zones,
                        parameters['bin size'],
                        parameters['bin avg type'],
                        parameters['missing data as zero'],
                        not self.quiet)

                    if padLeftNaN > 0:
                        cov = np.concatenate([[np.nan] * padLeftNaN, cov])
                    if padRightNaN > 0:
                        cov = np.concatenate([cov, [np.nan] * padRightNaN])

                    if feature_strand == "-":
                        cov = cov[::-1]

                    coverage = np.hstack([coverage, cov])

            if coverage is None:
                regions_no_score += 1
                if not self.quiet:
                    sys.stderr.write(
                        "No data was found for region "
                        "{0} {1}:{2}-{3}. Skipping...\n".format(
                            feature_name, feature_chrom,
                            feature_start, feature_end))

                coverage = np.zeros(matrix_cols)
                if not parameters['missing data as zero']:
                    coverage[:] = np.nan

            try:
                temp = coverage.copy()
                temp[np.isnan(temp)] = 0
            except:
                if not self.quiet:
                    sys.stderr.write(
                        "No scores defined for region "
                        "{0} {1}:{2}-{3}. Skipping...\n".format(feature_name,
                                                                feature_chrom,
                                                                feature_start,
                                                                feature_end))
                coverage = np.zeros(matrix_cols)
                if not parameters['missing data as zero']:
                    coverage[:] = np.nan

            if parameters['min threshold'] is not None and coverage.min() <= parameters['min threshold']:
                continue
            if parameters['max threshold'] is not None and coverage.max() >= parameters['max threshold']:
                continue
            if parameters['scale'] != 1:
                coverage = parameters['scale'] * coverage

            sub_matrix[j, :] = coverage

            sub_regions.append(transcript)
            j += 1

        # remove empty rows
        sub_matrix = sub_matrix[0:j, :]
        if len(sub_regions) != len(sub_matrix[:, 0]):
            sys.stderr.write("regions lengths do not match\n")
        return sub_matrix, sub_regions, regions_no_score

    @staticmethod
    def coverage_from_array(valuesArray, zones, binSize, avgType):
        try:
            valuesArray[0]
        except (IndexError, TypeError) as detail:
            sys.stderr.write("{0}\nvalues array value: {1}, zones {2}\n".format(detail, valuesArray, zones))

        cvglist = []
        zoneEnd = 0
        valStart = 0
        valEnd = 0
        for zone, nBins in zones:
            if nBins:
                # linspace is used to more or less evenly partition the data points into the given number of bins
                zoneEnd += nBins
                valStart = valEnd
                valEnd += np.sum([x[1] - x[0] for x in zone])
                counts_list = []

                # Partition the space into bins
                if nBins == 1:
                    pos_array = np.array([valStart])
                else:
                    pos_array = np.linspace(valStart, valEnd, nBins, endpoint=False, dtype=int)
                pos_array = np.append(pos_array, valEnd)

                idx = 0
                while idx < nBins:
                    idxStart = int(pos_array[idx])
                    idxEnd = max(int(pos_array[idx + 1]), idxStart + 1)
                    try:
                        counts_list.append(heatmapper.my_average(valuesArray[idxStart:idxEnd], avgType))
                    except Exception as detail:
                        sys.stderr.write("Exception found: {0}\n".format(detail))
                    idx += 1
                cvglist.append(np.array(counts_list))

        return np.concatenate(cvglist)

    @staticmethod
    def change_chrom_names(chrom):
        """
        Changes UCSC chromosome names to ensembl chromosome names
        and vice versa.
        """
        if chrom.startswith('chr'):
            # remove the chr part from chromosome name
            chrom = chrom[3:]
            if chrom == "M":
                chrom = "MT"
        else:
            # prefix with 'chr' the chromosome name
            chrom = 'chr' + chrom
            if chrom == "chrMT":
                chrom = "chrM"

        return chrom

    @staticmethod
    def coverage_from_big_wig(bigwig, chrom, zones, binSize, avgType, nansAsZeros=False, verbose=True):

        """
        uses pyBigWig
        to query a region define by chrom and zones.
        The output is an array that contains the bigwig
        value per base pair. The summary over bins is
        done in a later step when coverage_from_array is called.
        This method is more reliable than querying the bins
        directly from the bigwig, which should be more efficient.

        By default, any region, even if no chromosome match is found
        on the bigwig file, produces a result. In other words
        no regions are skipped.

        zones: array as follows zone0: region before the region start,
                                zone1: 5' unscaled region (if present)
                                zone2: the body of the region (not always present)
                                zone3: 3' unscaled region (if present)
                                zone4: the region from the end of the region downstream

               each zone is a tuple containing start, end, and number of bins


        This is useful if several matrices wants to be merged
        or if the sorted BED output of one computeMatrix operation
        needs to be used for other cases
        """
        nVals = 0
        for zone, _ in zones:
            for region in zone:
                nVals += region[1] - region[0]

        values_array = np.zeros(nVals)
        if not nansAsZeros:
            values_array[:] = np.nan
        if chrom not in list(bigwig.chroms().keys()):
            unmod_name = chrom
            chrom = heatmapper.change_chrom_names(chrom)
            if chrom not in list(bigwig.chroms().keys()):
                if verbose:
                    sys.stderr.write("Warning: Your chromosome names do not match.\nPlease check that the "
                                     "chromosome names in your BED file\ncorrespond to the names in your "
                                     "bigWig file.\nAn empty line will be added to your heatmap.\nThe problematic "
                                     "chromosome name is {0}\n\n".format(unmod_name))

                # return empty nan array
                return heatmapper.coverage_from_array(values_array, zones, binSize, avgType)

        maxLen = bigwig.chroms(chrom)
        startIdx = 0
        endIdx = 0
        for zone, _ in zones:
            for region in zone:
                startIdx = endIdx
                if region[0] < 0:
                    endIdx += abs(region[0])
                    values_array[startIdx:endIdx] = np.nan
                    startIdx = endIdx
                start = max(0, region[0])
                end = min(maxLen, region[1])
                endIdx += end - start
                if start < end:
                    # This won't be the case if we extend off the front of a chromosome, such as (-100, 0)
                    values_array[startIdx:endIdx] = bigwig.values(chrom, start, end)
                if end < region[1]:
                    startIdx = endIdx
                    endIdx += region[1] - end
                    values_array[startIdx:endIdx] = np.nan

        # replaces nans for zeros
        if nansAsZeros:
            values_array[np.isnan(values_array)] = 0

        return heatmapper.coverage_from_array(values_array, zones,
                                              binSize, avgType)

    @staticmethod
    def my_average(valuesArray, avgType='mean'):
        """
        computes the mean, median, etc but only for those values
        that are not Nan
        """
        valuesArray = np.ma.masked_invalid(valuesArray)
        avg = np.ma.__getattribute__(avgType)(valuesArray)
        if isinstance(avg, np.ma.core.MaskedConstant):
            return np.nan
        else:
            return avg

    def matrix_from_dict(self, matrixDict, regionsDict, parameters):
        self.regionsDict = regionsDict
        self.matrixDict = matrixDict
        self.parameters = parameters
        self.lengthDict = OrderedDict()
        self.matrixAvgsDict = OrderedDict()

    def read_matrix_file(self, matrix_file):
        # reads a bed file containing the position
        # of genomic intervals
        # In case a hash sign '#' is found in the
        # file, this is considered as a delimiter
        # to split the heatmap into groups

        import json
        regions = []
        matrix_rows = []
        current_group_index = 0
        max_group_bound = None

        fh = gzip.open(matrix_file)
        for line in fh:
            line = toString(line).strip()
            # read the header file containing the parameters
            # used
            if line.startswith("@"):
                # the parameters used are saved using
                # json
                self.parameters = json.loads(line[1:].strip())
                max_group_bound = self.parameters['group_boundaries'][1]
                continue

            # split the line into bed interval and matrix values
            region = line.split('\t')
            chrom, start, end, name, score, strand = region[0:6]
            matrix_row = np.ma.masked_invalid(np.fromiter(region[6:], np.float))
            matrix_rows.append(matrix_row)
            starts = start.split(",")
            ends = end.split(",")
            regs = [(int(x), int(y)) for x, y in zip(starts, ends)]
            # get the group index
            if len(regions) >= max_group_bound:
                current_group_index += 1
                max_group_bound = self.parameters['group_boundaries'][current_group_index + 1]
            regions.append([chrom, regs, name, max_group_bound, strand, score])

        matrix = np.vstack(matrix_rows)
        self.matrix = _matrix(regions, matrix, self.parameters['group_boundaries'],
                              self.parameters['sample_boundaries'],
                              group_labels=self.parameters['group_labels'],
                              sample_labels=self.parameters['sample_labels'])

        if 'sort regions' in self.parameters:
            self.matrix.set_sorting_method(self.parameters['sort regions'],
                                           self.parameters['sort using'])

        # Versions of computeMatrix before 3.0 didn't have an entry of these per column, fix that
        nSamples = len(self.matrix.sample_labels)
        h = dict()
        for k, v in self.parameters.items():
            if k in self.special_params and type(v) is not list:
                v = [v] * nSamples
                if len(v) == 0:
                    v = [None] * nSamples
            h[k] = v
        self.parameters = h

        return

    def save_matrix(self, file_name):
        """
        saves the data required to reconstruct the matrix
        the format is:
        A header containing the parameters used to create the matrix
        encoded as:
        @key:value\tkey2:value2 etc...
        The rest of the file has the same first 5 columns of a
        BED file: chromosome name, start, end, name, score and strand,
        all separated by tabs. After the fifth column the matrix
        values are appended separated by tabs.
        Groups are separated by adding a line starting with a hash (#)
        and followed by the group name.

        The file is gzipped.
        """
        import json
        self.parameters['sample_labels'] = self.matrix.sample_labels
        self.parameters['group_labels'] = self.matrix.group_labels
        self.parameters['sample_boundaries'] = self.matrix.sample_boundaries
        self.parameters['group_boundaries'] = self.matrix.group_boundaries

        # Redo the parameters, ensuring things related to ticks and labels are repeated appropriately
        nSamples = len(self.matrix.sample_labels)
        h = dict()
        for k, v in self.parameters.items():
            if type(v) is list and len(v) == 0:
                v = None
            if k in self.special_params and type(v) is not list:
                v = [v] * nSamples
                if len(v) == 0:
                    v = [None] * nSamples
            h[k] = v
        fh = gzip.open(file_name, 'wb')
        params_str = json.dumps(h, separators=(',', ':'))
        fh.write(toBytes("@" + params_str + "\n"))
        score_list = np.ma.masked_invalid(np.mean(self.matrix.matrix, axis=1))
        for idx, region in enumerate(self.matrix.regions):
            # join np_array values
            # keeping nans while converting them to strings
            if not np.ma.is_masked(score_list[idx]):
                np.float(score_list[idx])
            matrix_values = "\t".join(
                np.char.mod('%f', self.matrix.matrix[idx, :]))
            starts = ["{0}".format(x[0]) for x in region[1]]
            ends = ["{0}".format(x[1]) for x in region[1]]
            starts = ",".join(starts)
            ends = ",".join(ends)
            # BEDish format (we don't currently store the score)
            fh.write(
                toBytes('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                        region[0],
                        starts,
                        ends,
                        region[2],
                        region[5],
                        region[4],
                        matrix_values)))
        fh.close()

    def save_tabulated_values(self, file_handle, reference_point_label='TSS', start_label='TSS', end_label='TES', averagetype='mean'):
        """
        Saves the values averaged by col using the avg_type
        given

        Args:
            file_handle: file name to save the file
            reference_point_label: Name of the reference point label
            start_label: Name of the star label
            end_label: Name of the end label
            averagetype: average type (e.g. mean, median, std)

        """
        #  get X labels
        w = self.parameters['bin size']
        b = self.parameters['upstream']
        a = self.parameters['downstream']
        c = self.parameters.get('unscaled 5 prime', 0)
        d = self.parameters.get('unscaled 3 prime', 0)
        m = self.parameters['body']

        xticks = []
        xtickslabel = []
        for idx in range(self.matrix.get_num_samples()):
            if b[idx] < 1e5:
                quotient = 1000
                symbol = 'Kb'
            else:
                quotient = 1e6
                symbol = 'Mb'

            if m[idx] == 0:
                last = 0
                if len(xticks):
                    last = xticks[-1]
                xticks.extend([last + (k / w[idx]) for k in [w[idx], b[idx], b[idx] + a[idx]]])
                xtickslabel.extend(['{0:.1f}{1}'.format(-(float(b[idx]) / quotient), symbol), reference_point_label,
                                    '{0:.1f}{1}'.format(float(a[idx]) / quotient, symbol)])

            else:
                xticks_values = [w[idx]]

                # only if upstream region is set, add a x tick
                if b[idx] > 0:
                    xticks_values.append(b[idx])
                    xtickslabel.append('{0:.1f}{1}'.format(-(float(b[idx]) / quotient), symbol))

                xtickslabel.append(start_label)

                if c[idx] > 0:
                    xticks_values.append(b[idx] + c[idx])
                    xtickslabel.append("")

                if d[idx] > 0:
                    xticks_values.append(b[idx] + c[idx] + m[idx])
                    xtickslabel.append("")

                xticks_values.append(b[idx] + c[idx] + m[idx] + d[idx])
                xtickslabel.append(end_label)

                if a[idx] > 0:
                    xticks_values.append(b[idx] + c[idx] + m[idx] + d[idx] + a[idx])
                    xtickslabel.append('{0:.1f}{1}'.format(float(a[idx]) / quotient, symbol))

                last = 0
                if len(xticks):
                    last = xticks[-1]
                xticks.extend([last + (k / w[idx]) for k in xticks_values])
        x_axis = np.arange(xticks[-1]) + 1
        labs = []
        for x_value in x_axis:
            if x_value in xticks and xtickslabel[xticks.index(x_value)]:
                labs.append(xtickslabel[xticks.index(x_value)])
            elif x_value in xticks:
                labs.append("tick")
            else:
                labs.append("")

        with open(file_handle, 'w') as fh:
            # write labels
            fh.write("bin labels\t\t{}\n".format("\t".join(labs)))
            fh.write('bins\t\t{}\n'.format("\t".join([str(x) for x in x_axis])))

            for sample_idx in range(self.matrix.get_num_samples()):
                for group_idx in range(self.matrix.get_num_groups()):
                    sub_matrix = self.matrix.get_matrix(group_idx, sample_idx)
                    values = [str(x) for x in np.ma.__getattribute__(averagetype)(sub_matrix['matrix'], axis=0)]
                    fh.write("{}\t{}\t{}\n".format(sub_matrix['sample'], sub_matrix['group'], "\t".join(values)))

    def save_matrix_values(self, file_name):
        # print a header telling the group names and their length
        fh = open(file_name, 'wb')
        info = []
        groups_len = np.diff(self.matrix.group_boundaries)
        for i in range(len(self.matrix.group_labels)):
            info.append("{}:{}".format(self.matrix.group_labels[i],
                                       groups_len[i]))
        fh.write(toBytes("#{}\n".format("\t".join(info))))
        # add to header the x axis values
        fh.write(toBytes("#downstream:{}\tupstream:{}\tbody:{}\tbin size:{}\tunscaled 5 prime:{}\tunscaled 3 prime:{}\n".format(
                 self.parameters['downstream'],
                 self.parameters['upstream'],
                 self.parameters['body'],
                 self.parameters['bin size'],
                 self.parameters.get('unscaled 5 prime', 0),
                 self.parameters.get('unscaled 3 prime', 0))))
        sample_len = np.diff(self.matrix.sample_boundaries)
        for i in range(len(self.matrix.sample_labels)):
            info.extend([self.matrix.sample_labels[i]] * sample_len[i])
        fh.write(toBytes("{}\n".format("\t".join(info))))

        fh.close()
        # reopen again using append mode
        fh = open(file_name, 'ab')
        np.savetxt(fh, self.matrix.matrix, fmt="%.4g", delimiter="\t")
        fh.close()

    def save_BED(self, file_handle):
        boundaries = np.array(self.matrix.group_boundaries)
        # Add a header
        file_handle.write("#chrom\tstart\tend\tname\tscore\tstrand\tthickStart\tthickEnd\titemRGB\tblockCount\tblockSizes\tblockStart\tdeepTools_group")
        if self.matrix.silhouette is not None:
            file_handle.write("\tsilhouette")
        file_handle.write("\n")
        for idx, region in enumerate(self.matrix.regions):
            # the label id corresponds to the last boundary
            # that is smaller than the region index.
            # for example for a boundary array = [0, 10, 20]
            # and labels ['a', 'b', 'c'],
            # for index 5, the label is 'a', for
            # index 10, the label is 'b' etc
            label_idx = np.flatnonzero(boundaries <= idx)[-1]
            starts = ["{0}".format(x[0]) for x in region[1]]
            ends = ["{0}".format(x[1]) for x in region[1]]
            starts = ",".join(starts)
            ends = ",".join(ends)
            file_handle.write(
                '{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{1}\t{2}\t0'.format(
                    region[0],
                    region[1][0][0],
                    region[1][-1][1],
                    region[2],
                    region[5],
                    region[4]))
            file_handle.write(
                '\t{0}\t{1}\t{2}\t{3}'.format(
                    len(region[1]),
                    ",".join([str(int(y) - int(x)) for x, y in region[1]]),
                    ",".join([str(int(x) - int(starts[0])) for x, y in region[1]]),
                    self.matrix.group_labels[label_idx]))
            if self.matrix.silhouette is not None:
                file_handle.write("\t{}".format(self.matrix.silhouette[idx]))
            file_handle.write("\n")
        file_handle.close()

    @staticmethod
    def matrix_avg(matrix, avgType='mean'):
        matrix = np.ma.masked_invalid(matrix)
        return np.ma.__getattribute__(avgType)(matrix, axis=0)

    def get_individual_matrices(self, matrix):
        """In case multiple matrices are saved one after the other
        this method splits them appart.
        Returns a list containing the matrices
        """
        num_cols = matrix.shape[1]
        num_ind_cols = self.get_num_individual_matrix_cols()
        matrices_list = []
        for i in range(0, num_cols, num_ind_cols):
            if i + num_ind_cols > num_cols:
                break
            matrices_list.append(matrix[:, i:i + num_ind_cols])
        return matrices_list

    def get_num_individual_matrix_cols(self):
        """
        returns the number of columns  that
        each matrix should have. This is done because
        the final matrix that is plotted can be composed
        of smaller matrices that are merged one after
        the other.
        """
        matrixCols = ((self.parameters['downstream'] + self.parameters['upstream'] + self.parameters['body'] + self.parameters['unscaled 5 prime'] + self.parameters['unscaled 3 prime']) //
                      self.parameters['bin size'])

        return matrixCols


def computeSilhouetteScore(d, idx, labels):
    """
    Given a square distance matrix with NaN diagonals, compute the silhouette score
    of a given row (idx). Each row should have an associated label (labels).
    """
    keep = ~np.isnan(d[idx, ])
    foo = np.bincount(labels[keep], weights=d[idx, ][keep])
    groupSizes = np.bincount(labels[keep])
    intraIdx = labels[idx]
    if groupSizes[intraIdx] == 1:
        return 0
    intra = foo[labels[idx]] / groupSizes[intraIdx]
    interMask = np.arange(len(foo))[np.arange(len(foo)) != labels[idx]]
    inter = np.min(foo[interMask] / groupSizes[interMask])
    return (inter - intra) / max(inter, intra)


class _matrix(object):
    """
    class to hold heatmapper matrices
    The base data is a large matrix
    with definition to know the boundaries for row and col divisions.
    Col divisions represent groups within a subset, e.g. Active and
    inactive from PolII bigwig data.

    Row division represent different samples, for example
    PolII in males vs. PolII in females.

    This is an internal class of the heatmapper class
    """

    def __init__(self, regions, matrix, group_boundaries, sample_boundaries,
                 group_labels=None, sample_labels=None):

        # simple checks
        assert matrix.shape[0] == group_boundaries[-1], \
            "row max do not match matrix shape"
        assert matrix.shape[1] == sample_boundaries[-1], \
            "col max do not match matrix shape"

        self.regions = regions
        self.matrix = matrix
        self.group_boundaries = group_boundaries
        self.sample_boundaries = sample_boundaries
        self.sort_method = None
        self.sort_using = None
        self.silhouette = None

        if group_labels is None:
            self.group_labels = ['group {}'.format(x)
                                 for x in range(len(group_boundaries) - 1)]
        else:
            assert len(group_labels) == len(group_boundaries) - 1, \
                "number of group labels does not match number of groups"
            self.group_labels = group_labels

        if sample_labels is None:
            self.sample_labels = ['sample {}'.format(x)
                                  for x in range(len(sample_boundaries) - 1)]
        else:
            assert len(sample_labels) == len(sample_boundaries) - 1, \
                "number of sample labels does not match number of samples"
            self.sample_labels = sample_labels

    def get_matrix(self, group, sample):
        """
        Returns a sub matrix from the large
        matrix. Group and sample are ids,
        thus, row = 0, col=0 get the first group
        of the first sample.

        Returns
        -------
        dictionary containing the matrix,
        the group label and the sample label
        """
        group_start = self.group_boundaries[group]
        group_end = self.group_boundaries[group + 1]
        sample_start = self.sample_boundaries[sample]
        sample_end = self.sample_boundaries[sample + 1]

        return {'matrix': np.ma.masked_invalid(self.matrix[group_start:group_end, :][:, sample_start:sample_end]),
                'group': self.group_labels[group],
                'sample': self.sample_labels[sample]}

    def get_num_samples(self):
        return len(self.sample_labels)

    def get_num_groups(self):
        return len(self.group_labels)

    def set_group_labels(self, new_labels):
        """ sets new labels for groups
        """
        if len(new_labels) != len(self.group_labels):
            raise ValueError("length new labels != length original labels")
        self.group_labels = new_labels

    def set_sample_labels(self, new_labels):
        """ sets new labels for groups
        """
        if len(new_labels) != len(self.sample_labels):
            raise ValueError("length new labels != length original labels")
        self.sample_labels = new_labels

    def set_sorting_method(self, sort_method, sort_using):
        self.sort_method = sort_method
        self.sort_using = sort_using

    def get_regions(self):
        """Returns the regions per group

        Returns
        ------
        list

            Each element of the list is itself a list
            of dictionaries containing the regions info:
            chrom, start, end, strand, name etc.

            Each element of the list corresponds to each
            of the groups
        """
        regions = []
        for idx in range(len(self.group_labels)):
            start = self.group_boundaries[idx]
            end = self.group_boundaries[idx + 1]
            regions.append(self.regions[start:end])

        return regions

    def sort_groups(self, sort_using='mean', sort_method='no', sample_list=None):
        """
        Sorts and rearranges the submatrices according to the
        sorting method given.
        """
        if sort_method == 'no':
            return

        if (sample_list is not None) and (len(sample_list) > 0):
            # get the ids that correspond to the selected sample list
            idx_to_keep = []
            for sample_idx in sample_list:
                idx_to_keep += range(self.sample_boundaries[sample_idx], self.sample_boundaries[sample_idx + 1])

            matrix = self.matrix[:, idx_to_keep]

        else:
            matrix = self.matrix

        # compute the row average:
        if sort_using == 'region_length':
            matrix_avgs = list()
            for x in self.regions:
                matrix_avgs.append(np.sum([bar[1] - bar[0] for bar in x[1]]))
            matrix_avgs = np.array(matrix_avgs)
        elif sort_using == 'mean':
            matrix_avgs = np.nanmean(matrix, axis=1)
        elif sort_using == 'mean':
            matrix_avgs = np.nanmean(matrix, axis=1)
        elif sort_using == 'median':
            matrix_avgs = np.nanmedian(matrix, axis=1)
        elif sort_using == 'max':
            matrix_avgs = np.nanmax(matrix, axis=1)
        elif sort_using == 'min':
            matrix_avgs = np.nanmin(matrix, axis=1)
        elif sort_using == 'sum':
            matrix_avgs = np.nansum(matrix, axis=1)
        else:
            sys.exit("{} is an unsupported sorting method".format(sort_using))

        # order per group
        _sorted_regions = []
        _sorted_matrix = []
        for idx in range(len(self.group_labels)):
            start = self.group_boundaries[idx]
            end = self.group_boundaries[idx + 1]
            order = matrix_avgs[start:end].argsort()
            if sort_method == 'descend':
                order = order[::-1]
            _sorted_matrix.append(self.matrix[start:end, :][order, :])
            # sort the regions
            _reg = self.regions[start:end]
            for idx in order:
                _sorted_regions.append(_reg[idx])

        self.matrix = np.vstack(_sorted_matrix)
        self.regions = _sorted_regions
        self.set_sorting_method(sort_method, sort_using)

    def hmcluster(self, k, evaluate_silhouette=True, method='kmeans', clustering_samples=None):
        matrix = np.asarray(self.matrix)
        matrix_to_cluster = matrix
        if clustering_samples is not None:
            assert all(i > 0 for i in clustering_samples),\
                "all indices should be bigger than or equal to 1."
            assert all(i <= len(self.sample_labels) for i in
                       clustering_samples),\
                "each index should be smaller than or equal to {}(total "\
                "number of samples.)".format(len(self.sample_labels))

            clustering_samples = np.asarray(clustering_samples) - 1

            samples_cols = []
            for idx in clustering_samples:
                samples_cols += range(self.sample_boundaries[idx],
                                      self.sample_boundaries[idx + 1])

            matrix_to_cluster = matrix_to_cluster[:, samples_cols]
        if np.any(np.isnan(matrix_to_cluster)):
            # replace nans for 0 otherwise kmeans produces a weird behaviour
            sys.stderr.write("*Warning* For clustering nan values have to be replaced by zeros \n")
            matrix_to_cluster[np.isnan(matrix_to_cluster)] = 0

        if method == 'kmeans':
            from scipy.cluster.vq import vq, kmeans

            centroids, _ = kmeans(matrix_to_cluster, k)
            # order the centroids in an attempt to
            # get the same cluster order
            cluster_labels, _ = vq(matrix_to_cluster, centroids)

        if method == 'hierarchical':
            # normally too slow for large data sets
            from scipy.cluster.hierarchy import fcluster, linkage
            Z = linkage(matrix_to_cluster, method='ward', metric='euclidean')
            cluster_labels = fcluster(Z, k, criterion='maxclust')
            # hierarchical clustering labels from 1 .. k
            # while k-means labels 0 .. k -1
            # Thus, for consistency, we subtract 1
            cluster_labels -= 1

        # sort clusters
        _clustered_mean = []
        _cluster_ids_list = []
        for cluster in range(k):
            cluster_ids = np.flatnonzero(cluster_labels == cluster)
            _cluster_ids_list.append(cluster_ids)
            _clustered_mean.append(matrix_to_cluster[cluster_ids, :].mean())

        # reorder clusters based on mean
        cluster_order = np.argsort(_clustered_mean)[::-1]
        # create groups using the clustering
        self.group_labels = []
        self.group_boundaries = [0]
        _clustered_regions = []
        _clustered_matrix = []
        cluster_number = 1
        for cluster in cluster_order:
            self.group_labels.append("cluster_{}".format(cluster_number))
            cluster_number += 1
            cluster_ids = _cluster_ids_list[cluster]
            self.group_boundaries.append(self.group_boundaries[-1] +
                                         len(cluster_ids))
            _clustered_matrix.append(self.matrix[cluster_ids, :])
            for idx in cluster_ids:
                _clustered_regions.append(self.regions[idx])

        self.regions = _clustered_regions
        self.matrix = np.vstack(_clustered_matrix)

        return idx

    def computeSilhouette(self, k):
        if k > 1:
            from scipy.spatial.distance import pdist, squareform

            silhouette = np.repeat(0.0, self.group_boundaries[-1])
            groupSizes = np.subtract(self.group_boundaries[1:], self.group_boundaries[:-1])
            labels = np.repeat(np.arange(k), groupSizes)

            d = pdist(self.matrix)
            d2 = squareform(d)
            np.fill_diagonal(d2, np.nan)  # This excludes the diagonal
            for idx in range(len(labels)):
                silhouette[idx] = computeSilhouetteScore(d2, idx, labels)
            sys.stderr.write("The average silhouette score is: {}\n".format(np.mean(silhouette)))
            self.silhouette = silhouette

    def removeempty(self):
        """
        removes matrix rows containing only zeros or nans
        """
        to_keep = []
        score_list = np.ma.masked_invalid(np.mean(self.matrix, axis=1))
        for idx, region in enumerate(self.regions):
            if np.ma.is_masked(score_list[idx]) or np.float(score_list[idx]) == 0:
                continue
            else:
                to_keep.append(idx)
        self.regions = [self.regions[x] for x in to_keep]
        self.matrix = self.matrix[to_keep, :]
        # adjust sample boundaries
        to_keep = np.array(to_keep)
        self.group_boundaries = [len(to_keep[to_keep < x]) for x in self.group_boundaries]

    def flatten(self):
        """
        flatten and remove nans from matrix. Useful
        to get max and mins from matrix.

        :return flattened matrix
        """
        matrix_flatten = np.asarray(self.matrix.flatten())
        # nans are removed from the flattened array
        matrix_flatten = matrix_flatten[~np.isnan(matrix_flatten)]
        if len(matrix_flatten) == 0:
            num_nan = len(np.flatnonzero(np.isnan(self.matrix.flatten())))
            raise ValueError("matrix only contains nans "
                             "(total nans: {})".format(num_nan))
        return matrix_flatten