File: downsample.py

package info (click to toggle)
python-pangolearn 2022-07-09%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 184,720 kB
  • sloc: python: 801; sh: 77; makefile: 16
file content (203 lines) | stat: -rw-r--r-- 9,280 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python3

import csv
import datetime
import operator
# import argparse
from Bio import SeqIO


# def parse_args():
#     parser = argparse.ArgumentParser(description="""Pick a representative sample for each unique sequence""",
#                                     formatter_class=argparse.RawTextHelpFormatter)
#     parser.add_argument('--in-metadata', dest = 'in_metadata', required=True, help='CSV of containing sequence_name and nucleotide_variants columns, the latter being | separated list of variants')
#     parser.add_argument('--in-fasta', dest = 'in_fasta', required=True, help='FASTA of all input seqs')
#     parser.add_argument('--diff', dest = 'diff', required=True, type=int, help='Samples within distance DIFF of included others may be excluded by the downsampler')
#     parser.add_argument('--out-metadata', dest = 'out_metadata', required=True, help='CSV to write out')
#     parser.add_argument('--out-fasta', dest = 'out_fasta', required=True, help='FASTA to write downsampled seqs')
#     parser.add_argument('--outgroups', dest = 'outgroups', required=False, help='Lineage splits file containing representative outgroups to protect')
#     parser.add_argument('--downsample_date_excluded', action='store_true', help='Downsample from those excluded as outside date window')
#     parser.add_argument('--downsample_included', action='store_true', help='Downsample from all included sequences')
#     parser.add_argument('--downsample_lineage_size', type=int, default=None, help='Min size of lineages to downsample, if unspecified no lineage-aware downsampling')

#     args = parser.parse_args()
#     return args

def parse_outgroups(outgroup_file):
    """
    input is CSV, last column being the representative outgroups:
    """
    outgroups = []
    if not outgroup_file:
        return outgroups
    with open(outgroup_file, "r") as outgroup_handle:
        line = outgroup_handle.readline()
        while line:
            try:
                outgroup = line.strip().split(",")[-1]
                outgroups.append(outgroup)
            except:
                continue
            line = outgroup_handle.readline()
    return(outgroups)

def get_count_dict(in_metadata):
    count_dict = {}
    num_samples = 0
    with open(in_metadata,"r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            num_samples += 1
            for var in row["nucleotide_variants"].split("|"):
                if var in count_dict:
                    count_dict[var] += 1
                else:
                    count_dict[var] = 1
    print("Found", len(count_dict), "variants")

    sorted_tuples = sorted(count_dict.items(), key=operator.itemgetter(1))
    count_dict = {k: v for k, v in sorted_tuples}
    return count_dict, num_samples

def get_lineage_dict(in_metadata, min_size):
    lineage_dict = {}
    if min_size is None:
        return lineage_dict

    with open(in_metadata,"r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if "lineage" in row:
                lin = row["lineage"]
                if lin in lineage_dict:
                    lineage_dict[lin].append(row["sequence_name"])
                else:
                    lineage_dict[lin] = [row["sequence_name"]]
    print("Found", len(lineage_dict), "lineages")

    small_lineages = [lin for lin in lineage_dict if len(lineage_dict[lin]) < min_size]
    for lin in small_lineages:
        del lineage_dict[lin]
    print("Found", len(lineage_dict), "lineages with at least", min_size, "representative sequences")

    return lineage_dict

def get_by_frequency(count_dict, num_samples, band=[0.1,1.0]):
    lower_bound = num_samples*band[0]
    upper_bound = num_samples*band[1]
    most_frequent = [k for k in count_dict if lower_bound < count_dict[k] <= upper_bound]
    print(len(most_frequent), "lie in frequency band", band)
    return most_frequent

def num_unique(muts1, muts2):
    u1 = [m for m in muts1 if m not in muts2]
    u2 = [m for m in muts2 if m not in muts1]
    return len(u1+u2)

def should_downsample_row(row, downsample_date_excluded=True, downsample_included=False, downsample_lineage_size=None, lineage_dict={}):
    if downsample_included and row["why_excluded"] in [None, "None", ""]:
        return True
    if downsample_date_excluded and row["why_excluded"] in [None, "None", ""] and "date_filter" in row and row["date_filter"].startswith("sample_date older than"):
        return True
    if downsample_lineage_size and row["lineage"] in lineage_dict:
        return True
    return False

def downsample(in_metadata, out_metadata, in_fasta, out_fasta, max_diff, outgroup_file, downsample_date_excluded, downsample_included, downsample_lineage_size):
    original_num_seqs = 0
    sample_dict = {}
    var_dict = {}

    count_dict, num_samples = get_count_dict(in_metadata)
    most_frequent = get_by_frequency(count_dict, num_samples, band=[0.05,1.0])
    very_most_frequent = get_by_frequency(count_dict, num_samples, band=[0.5,1.0])

    lineage_dict = get_lineage_dict(in_metadata,downsample_lineage_size)

    outgroups = parse_outgroups(outgroup_file)
    indexed_fasta = SeqIO.index(in_fasta, "fasta")

    with open(in_metadata, 'r', newline = '') as csv_in, \
        open(out_fasta, 'w', newline = '') as fa_out, \
        open(out_metadata, 'w', newline = '') as csv_out:

        reader = csv.DictReader(csv_in, delimiter=",", quotechar='\"', dialect = "unix")
        writer = csv.DictWriter(csv_out, fieldnames = reader.fieldnames, delimiter=",", quotechar='\"', quoting=csv.QUOTE_MINIMAL, dialect = "unix")
        writer.writeheader()

        for row in reader:
            fasta_header = row["sequence_name"]
            if fasta_header not in indexed_fasta:
                continue
            if original_num_seqs % 1000 == 0:
                now = datetime.datetime.now()
                print("%s Downsampled from %i seqs to %i seqs" %(str(now), original_num_seqs, len(sample_dict)))
            original_num_seqs += 1

            if fasta_header in outgroups or not should_downsample_row(row,downsample_date_excluded, downsample_included,
                                                                      downsample_lineage_size,lineage_dict):
                if fasta_header in outgroups:
                    row["why_excluded"]=""
                writer.writerow(row)
                if row["why_excluded"] in [None, "None", ""] and fasta_header in indexed_fasta:
                    seq_rec = indexed_fasta[fasta_header]
                    fa_out.write(">" + seq_rec.id + "\n")
                    fa_out.write(str(seq_rec.seq) + "\n")
                else:
                    print(row["why_excluded"], fasta_header, (fasta_header in indexed_fasta))
                continue

            muts = row["nucleotide_variants"].split("|")
            if len(muts) < max_diff:
                #if not row["why_excluded"]:
                #    row["why_excluded"] = "downsampled with diff threshold %i" %max_diff
                writer.writerow(row)
                continue

            found_close_seq = False

            samples = set()
            low_frequency_muts = [mut for mut in muts if mut not in most_frequent]
            if len(low_frequency_muts) == 0:
                low_frequency_muts = [mut for mut in muts if mut not in very_most_frequent]
                if len(low_frequency_muts) == 0:
                    low_frequency_muts = muts
            if len(low_frequency_muts) > max_diff + 1:
                low_frequency_muts = low_frequency_muts[:max_diff+1]
            for mut in low_frequency_muts:
                if mut in var_dict:
                    samples.update(var_dict[mut])
            if downsample_lineage_size:
                samples = list( samples & set(lineage_dict[row["lineage"]]) )

            for sample in samples:
                if num_unique(muts, sample_dict[sample]) <= max_diff:
                    found_close_seq = True
                    #if not row["why_excluded"]:
                    #    row["why_excluded"] = "downsampled with diff threshold %i" %max_diff
                    writer.writerow(row)
                    break
            if not found_close_seq:
                sample_dict[fasta_header] = muts
                for mut in muts:
                    if mut not in var_dict:
                        var_dict[mut] = [fasta_header]
                    else:
                        var_dict[mut].append(fasta_header)
                row["why_excluded"] = ""
                writer.writerow(row)
                if fasta_header in indexed_fasta:
                    seq_rec = indexed_fasta[fasta_header]
                    fa_out.write(">" + seq_rec.id + "\n")
                    fa_out.write(str(seq_rec.seq) + "\n")

    now = datetime.datetime.now()
    print("%s Downsampled from %i seqs to %i seqs" %(str(now), original_num_seqs, len(sample_dict)))
    # return sample_dict.keys()

# def main():
#     args = parse_args()
#     subsample = downsample(args.in_metadata, args.out_metadata, args.in_fasta, args.out_fasta, args.diff, args.outgroups, args.downsample_date_excluded, args.downsample_included, args.downsample_lineage_size)

# if __name__ == '__main__':
#     main()