1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
shell.executable("bash")
from itertools import product
src_lang = config['src_lang']
trg_lang = config['trg_lang']
rule all:
input: "data/corpus.txt", f"data/batches.{src_lang}", f"data/batches.{trg_lang}"
checkpoint shard:
output: "data/batches.{lang}"
shell: '''
for i in $(seq 0 $(( 2 % 5 + 1))); do # 0 .. 3 -> 4
for j in $(seq 0 $(( 1 % 2 + 1))); do # 0 .. 2 -> 3
mkdir -p data/{wildcards.lang}/$i/$j
echo 'hello {wildcards.lang}' > data/{wildcards.lang}/$i/$j/text
if [[ "{wildcards.lang}" == "{src_lang}" ]]; then
sleep 1
else
sleep 1.5
fi
done
done
ls -d data/{wildcards.lang}/*/* > {output}
'''
rule combine:
input:
l1='data/{src_lang}/{shard}/{src_batch}/text',
l2='data/{trg_lang}/{shard}/{trg_batch}/text'
output: 'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined'
shell: ''' paste {input.l1} {input.l2} > {output} '''
def get_batches_pairs(src_lang, trg_lang):
src_batches = []
trg_batches = []
with checkpoints.shard.get(lang=src_lang).output[0].open() as src_f, \
checkpoints.shard.get(lang=trg_lang).output[0].open() as trg_f:
for line in src_f:
src_batches.append(line.strip().split('/')[-2:])
for line in trg_f:
trg_batches.append(line.strip().split('/')[-2:])
iterator = product(src_batches, trg_batches)
return [(src_shard, (src_batch, trg_batch)) for ((src_shard, src_batch), (trg_shard, trg_batch)) in iterator if src_shard == trg_shard]
rule corpus:
input: lambda wildcards: [f'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined' for (shard, (src_batch, trg_batch)) in get_batches_pairs(src_lang, trg_lang)]
output: 'data/corpus.txt'
shell: ''' cat {input} > {output} '''
|