File: chunked_call

package info (click to toggle)
vg 1.59.0%2Bds-0.1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 361,528 kB
  • sloc: cpp: 479,590; ansic: 191,648; python: 23,671; javascript: 13,961; sh: 7,025; makefile: 5,577; perl: 3,636; lisp: 293; java: 136
file content (330 lines) | stat: -rwxr-xr-x 14,380 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/python3
"""
Generate a VCF from a GAM and XG by splitting into GAM/VG chunks.
Chunks are then called in series, and the VCFs stitched together.
Any step whose expected output exists is skipped unles --overwrite 
specified.  
"""

import argparse, sys, os, os.path, random, subprocess, shutil, itertools, glob
import json

def parse_args(args):
    parser = argparse.ArgumentParser(description=__doc__, 
        formatter_class=argparse.RawDescriptionHelpFormatter)
        
    # General options
    parser.add_argument("xg_path", type=str,
                        help="input xg file")
    parser.add_argument("gam_path", type=str,
                        help="input alignment")
    parser.add_argument("path_name", type=str,
                        help="name of reference path in graph (ex chr21)")
    parser.add_argument("path_size", type=int,
                        help="size of the reference path in graph")
    parser.add_argument("sample_name", type=str,
                        help="sample name (ex NA12878)")
    parser.add_argument("out_dir", type=str,
                        help="directory where all output will be written")    
    parser.add_argument("--chunk", type=int, default=10000000,
                        help="chunk size")
    parser.add_argument("--overlap", type=int, default=2000,
                        help="amount of overlap between chunks")
    parser.add_argument("--filter_opts", type=str,
                        default="-r 0.9 -fu -s 2 -o 0 -q 15 --defray-ends 999",
                        help="options to pass to vg filter. wrap in \"\"")
    parser.add_argument("--pileup_opts", type=str,
                        default="-q 10",
                        help="options to pass to vg pileup. wrap in \"\"")
    parser.add_argument("--call_opts", type=str,
                        default="",
                        help="options to pass to vg call. wrap in \"\"")
    parser.add_argument("--threads", type=int, default=20,
                        help="number of threads to use in vg call and vg pileup")
    parser.add_argument("--overwrite", action="store_true",
                        help="always overwrite existing files")
                        
    args = args[1:]
        
    return parser.parse_args(args)

def merge_call_opts(contig, offset, length, call_opts, sample_name):
    """ combine input vg call  options with generated options, by adding user offset
    and overriding contigs, sample and sequence lenght"""
    user_opts = call_opts.split()
    user_offset, user_contig, user_ref, user_sample, user_length  = None, None, None, None, None
    for i, uo in enumerate(user_opts):
        if uo in ["-o", "--offset"]:
            user_offset = int(user_opts[i + 1])
            user_opts[i + 1] = str(user_offset + offset)
        elif uo in ["-c", "--contig"]:
            user_contig = user_opts[i + 1]
        elif uo in ["-r", "--ref"]:
            user_ref = user_opts[i + 1]
        elif uo in ["-S", "--sample"]:
            user_sample = user_opts[i + 1]
            user_opts[i + 1] = sample_name
        elif uo in ["-l", "--length"]:
            user_length = user_opts[i + 1]
    opts = " ".join(user_opts)
    if user_offset is None:
        opts += " -o {}".format(offset)
    if user_contig is None:
        opts += " -c {}".format(contig)
    if user_ref is None:
        opts += " -r {}".format(contig)        
    if user_sample is None:
        opts += " -S {}".format(sample_name)
    if user_length is None:
        opts += " -l {}".format(length)
    return opts
    
def run(cmd, proc_stdout = sys.stdout, proc_stderr = sys.stderr,
        check = True):
    """ run command in shell and throw exception if it doesn't work 
    """
    print(cmd)
    proc = subprocess.Popen(cmd, shell=True, bufsize=-1,
                            stdout=proc_stdout, stderr=proc_stderr)
    output, errors = proc.communicate()
    sts = proc.wait()
    if check is True and sts != 0:
        raise RuntimeError("Command: %s exited with non-zero status %i" % (cmd, sts))
    return output, errors

def make_chunks(path_name, path_size, chunk_size, overlap):
    """ compute chunks as BED (0-based) 3-tuples: ie
    (chr1, 0, 10) is the range from 0-9 inclusive of chr1
    """
    assert chunk_size > overlap
    covered = 0
    chunks = []
    while covered < path_size:
        start = max(0, covered - overlap)
        end = min(path_size, start + chunk_size)
        chunks.append((path_name, start, end))
        covered = end
    return chunks

def chunk_base_name(path_name, out_dir, chunk_i = None, tag= ""):
    """ centralize naming of output chunk-related files """
    bn = os.path.join(out_dir, "{}-chunk".format(path_name))
    if chunk_i is not None:
        bn += "-{}".format(chunk_i)
    return "{}{}".format(bn, tag)

def chunk_gam(gam_path, xg_path, path_name, out_dir, chunks, filter_opts, threads, overwrite):
    """ use vg filter to chunk up the gam """
    # make bed chunks
    chunk_path = os.path.join(out_dir, path_name + "_chunks.bed")
    with open(chunk_path, "w") as f:
        for chunk in chunks:
            f.write("{}\t{}\t{}\n".format(chunk[0], chunk[1], chunk[2]))
    # run vg filter on the gam
    if overwrite or not any(
            os.path.isfile(chunk_base_name(path_name, out_dir, i, ".gam")) \
               for i in range(len(chunks))):
        run("vg filter {} -x {} -R {} -B {} {} -t {}".format(
            gam_path, xg_path, chunk_path,
            os.path.join(out_dir, path_name + "-chunk"), filter_opts, threads))

def xg_path_node_id(xg_path, path_name, offset):
    """ use vg find to get the node containing a given path position """
    #NOTE: vg find -p range offsets are 0-based inclusive.  
    stdout, stderr = run("vg find -x {} -p {}:{}-{} | vg mod -o - | vg view -j - | jq .node[0].id".format(
        xg_path, path_name, offset, offset),
                         proc_stdout=subprocess.PIPE)
    return int(stdout)

def xg_path_predecessors(xg_path, path_name, node_id, context = 1):
    """ get nodes before given node in a path. """
    stdout, stderr = run("vg find -x {} -n {} -c {} | vg view -j -".format(
        xg_path, node_id, context), proc_stdout=subprocess.PIPE)

    # get our json graph
    j = json.loads(stdout)
    paths = j["path"]
    path = [x for x in paths if x["name"] == path_name][0]
    mappings = path["mapping"]
    assert len(mappings) > 0
    # check that we have a node_mapping
    assert len([x for x in mappings if x["position"]["node_id"] == node_id]) == 1
    # collect mappings that come before
    out_ids = []
    for mapping in mappings:
        if mapping["position"]["node_id"] == node_id:
            break
        out_ids.append(mapping["position"]["node_id"])
    return out_ids

def chunk_vg(xg_path, path_name, out_dir, chunks, chunk_i, overwrite):
    """ use vg find to make one chunk of the graph """
    chunk = chunks[chunk_i]
    vg_chunk_path = chunk_base_name(chunk[0], out_dir, chunk_i, ".vg")
    if overwrite or not os.path.isfile(vg_chunk_path):
        first_node = xg_path_node_id(xg_path, chunk[0], int(chunk[1]))
        # xg_path query takes 0-based inclusive coordinates, so we
        # subtract 1 below to convert from BED chunk (0-based exlcusive)
        last_node = xg_path_node_id(xg_path, chunk[0], chunk[2] - 1)
        assert first_node > 0 and last_node >= first_node
        # todo: would be cleaner to not have to pad context here
        run("vg find -x {} -r {}:{} -c 1 > {}".format(
            xg_path, first_node, last_node, vg_chunk_path))
        # but because we got a context, manually go in and make sure
        # our path starts at first_node by deleting everything before
        left_path_padding = xg_path_predecessors(xg_path, path_name, first_node,
                                                 context = 1)
        for destroy_id in left_path_padding:
            # destroy should take node list
            run("vg mod -y {} {} | vg mod -o - > {}".format(
                destroy_id, vg_chunk_path, vg_chunk_path + ".destroy"))
            run("mv {} {}".format(
                vg_chunk_path + ".destroy", vg_chunk_path))
                
def xg_path_node_offset(xg_path, path_name, offset):
    """ get the offset of the node containing the given position of a path
    """
    # first we find the node
    node_id = xg_path_node_id(xg_path, path_name, offset)

    # now we find the offset of the beginning of the node
    stdout, stderr = run("vg find -x {} -P {} -n {}".format(
        xg_path, path_name, node_id),
                         proc_stdout=subprocess.PIPE)
    toks = stdout.split()
    # if len > 2 then we have a cyclic path, which we're assuming we don't
    assert len(toks) == 2
    assert toks[0] == str(node_id)
    node_offset = int(toks[1])
    # node_offset must be before
    assert node_offset <= offset
    # sanity check (should really use node size instead of 1000 here)
    assert offset - node_offset < 1000

    return node_offset
    
def sort_vcf(vcf_path, sorted_vcf_path):
    """ from vcflib """
    run("head -10000 {} | grep \"^#\" > {}".format(
        vcf_path, sorted_vcf_path))
    run("grep -v \"^#\" {} | sort -k1,1d -k2,2n >> {}".format(
        vcf_path, sorted_vcf_path))
    
def call_chunk(xg_path, path_name, out_dir, chunks, chunk_i, path_size, overlap,
               pileup_opts, call_options, sample_name, threads,
               overwrite):
    """ create VCF from a given chunk """
    # make the graph chunk
    chunk_vg(xg_path, path_name, out_dir, chunks, chunk_i, overwrite)

    chunk = chunks[chunk_i]
    path_name = chunk[0]
    vg_path = chunk_base_name(path_name, out_dir, chunk_i, ".vg")
    gam_path = chunk_base_name(path_name, out_dir, chunk_i, ".gam")

    # a chunk can be empty if nothing aligns there.
    if not os.path.isfile(gam_path):
        sys.stderr.write("Warning: chunk not found: {}\n".format(gam_path))
        return
    
    # do the augmentation via pileup.  this is the most resource intensive step,
    # especially in terms of mermory used.
    aug_graph_path = chunk_base_name(path_name, out_dir, chunk_i, ".aug")
    support_path = chunk_base_name(path_name, out_dir, chunk_i, ".sup")
    translation_path = chunk_base_name(path_name, out_dir, chunk_i, ".trans")
    
    if overwrite or not all(os.path.isfile(f) for f in [
            aug_graph_path, support_path, translation_path]):
        run("vg augment -a pileup {} {} -t {} {} -Z {} -S {}  > {}".format(
            vg_path, gam_path, threads, pileup_opts, translation_path,
            support_path, aug_graph_path))

    # do the calling.
    vcf_path = chunk_base_name(path_name, out_dir, chunk_i, ".vcf")
    if overwrite or not os.path.isfile(vcf_path + ".gz"):
        offset = xg_path_node_offset(xg_path, chunk[0], chunk[1])
        merged_call_opts = merge_call_opts(chunk[0], offset, path_size,
                                           call_options, sample_name)
        run("vg call {} -b {} -s {} -z {} -t {} {} > {}".format(
            aug_graph_path, vg_path, support_path, translation_path,
            threads, merged_call_opts, vcf_path + ".us"))
        sort_vcf(vcf_path + ".us", vcf_path)
        run("rm {}".format(vcf_path + ".us"))
        run("bgzip {}".format(vcf_path))
        run("tabix -f -p vcf {}".format(vcf_path + ".gz"))

    # do the vcf clip
    left_clip = 0 if chunk_i == 0 else overlap / 2
    right_clip = 0 if chunk_i == len(chunks) - 1 else overlap / 2
    clip_path = chunk_base_name(path_name, out_dir, chunk_i, "_clip.vcf")
    if overwrite or not os.path.isfile(clip_path):
        call_toks = call_options.split()
        offset = 0
        if "-o" in call_toks:
            offset = int(call_toks[call_toks.index("-o") + 1])
        run("bcftools view -r {}:{}-{} {} > {}".format(
            path_name, offset + chunk[1] + left_clip + 1,
            offset + chunk[2] - right_clip, vcf_path + ".gz", clip_path))

            
def merge_vcf_chunks(out_dir, path_name, path_size, chunks, overwrite):
    """ merge a bunch of clipped vcfs created above, taking care to 
    fix up the headers.  everything expected to be sorted already """
    vcf_path = os.path.join(out_dir, path_name + ".vcf")
    if overwrite or not os.path.isfile(vcf_path):
        first = True
        for chunk_i, chunk in enumerate(chunks):
            clip_path = chunk_base_name(path_name, out_dir, chunk_i, "_clip.vcf")
            if os.path.isfile(clip_path):
                if first is True:
                    # copy everything including the header
                    run("cat {} > {}".format(clip_path, vcf_path))
                    first = False
                else:
                    # add on everythin but header
                    run("grep -v \"^#\" {} >> {}".format(clip_path, vcf_path), check=False)
                
    # add a compressed indexed version
    if overwrite or not os.path.isfile(vcf_path + ".gz"):
        run("bgzip -c {} > {}".format(vcf_path, vcf_path + ".gz"))
        run("tabix -f -p vcf {}".format(vcf_path + ".gz"))

def main(args):
    
    options = parse_args(args)

    if not os.path.isdir(options.out_dir):
        os.makedirs(options.out_dir)

    # make things slightly simpler as we split overlap
    # between adjacent chunks
    assert options.overlap % 2 == 0

    # compute overlapping chunks
    chunks = make_chunks(options.path_name, options.path_size,
                options.chunk, options.overlap)

    # split the gam in one go
    chunk_gam(options.gam_path, options.xg_path,
              options.path_name, options.out_dir,
              chunks, options.filter_opts, options.threads,
              options.overwrite)

    # call every chunk in series
    for chunk_i, chunk in enumerate(chunks):
        call_chunk(options.xg_path, options.path_name,
                   options.out_dir, chunks, chunk_i,
                   options.path_size, options.overlap,
                   options.pileup_opts, options.call_opts,
                   options.sample_name, options.threads,
                   options.overwrite)
    
    # stitch together the vcf
    merge_vcf_chunks(options.out_dir, options.path_name,
                     options.path_size,
                     chunks, options.overwrite)
    
if __name__ == "__main__" :
    sys.exit(main(sys.argv))