File: Snakefile

package info (click to toggle)
snakemake 7.32.4-8.1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 25,836 kB
  • sloc: python: 32,846; javascript: 1,287; makefile: 247; sh: 163; ansic: 57; lisp: 9
file content (50 lines) | stat: -rw-r--r-- 2,454 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
shell.executable("bash")

from itertools import product

src_lang = config['src_lang']
trg_lang = config['trg_lang']

rule all:                        
    input: "data/corpus.txt", f"data/batches.{src_lang}", f"data/batches.{trg_lang}"

checkpoint shard:                    
    output: "data/batches.{lang}"    
    shell: '''                     
        for i in $(seq 0 $(( 2 % 5 + 1))); do # 0 .. 3 -> 4
            for j in $(seq 0 $(( 1 % 2 + 1))); do # 0 .. 2 -> 3
                mkdir -p data/{wildcards.lang}/$i/$j
                echo 'hello {wildcards.lang}' > data/{wildcards.lang}/$i/$j/text
                if [[ "{wildcards.lang}" == "{src_lang}" ]]; then
                    sleep 1
                else
                    sleep 1.5
                fi
            done
        done
        ls -d data/{wildcards.lang}/*/* > {output}  
        '''
                                                                
rule combine:                                              
    input:                                                
        l1='data/{src_lang}/{shard}/{src_batch}/text',
        l2='data/{trg_lang}/{shard}/{trg_batch}/text'
    output: 'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined'
    shell: ''' paste {input.l1} {input.l2} > {output} '''
                                                                                                
def get_batches_pairs(src_lang, trg_lang):                                 
    src_batches = []      
    trg_batches = []                                        
    with checkpoints.shard.get(lang=src_lang).output[0].open() as src_f, \
            checkpoints.shard.get(lang=trg_lang).output[0].open() as trg_f:
        for line in src_f:                      
            src_batches.append(line.strip().split('/')[-2:])                                                                                
        for line in trg_f:
           trg_batches.append(line.strip().split('/')[-2:])
    iterator = product(src_batches, trg_batches)
    return [(src_shard, (src_batch, trg_batch)) for ((src_shard, src_batch), (trg_shard, trg_batch)) in iterator if src_shard == trg_shard]

rule corpus:           
    input: lambda wildcards: [f'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined' for (shard, (src_batch, trg_batch)) in get_batches_pairs(src_lang, trg_lang)]
    output: 'data/corpus.txt'
    shell: ''' cat {input} > {output} '''