1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
|
from collections import defaultdict
from Bio import SeqIO
import scipy.stats as ss
import numpy as np
#import cPickle
import pickle #_pickle as cPickle
import sys
import re
import os
base_comps = {'A':'T','C':'G','T':'A','G':'C','N':'N','M':'M'}
#@profile
def comp(seq,base_comps=base_comps):
return ''.join([base_comps[nt] for nt in list(seq)])
#@profile
def revcomp(seq,rev=True):
if not rev:
return seq
else:
return ''.join(list(comp(seq))[::-1])
#@profile
def strand(rev):
if rev:
return '-'
else:
return '+'
#find positions of motifs (eg. CG bases) in reference sequence and change to M
#@profile
def methylate_motifs(ref_seq,motif,meth_base,meth_position=None):
if meth_position:
meth_motif = motif[:meth_position]+'M'
if meth_position < len(motif)-1:
meth_motif = meth_position+motif[meth_position+1:]
else:
meth_motif = 'M'.join(motif.split(meth_base))
meth_seq = ref_seq.replace(motif,meth_motif)
return meth_seq
#change specified positions to M in reference sequence
#@profile
def methylate_positions(ref_seq,positions,meth_base):
meth_seq = ref_seq
count = 0
for pos in positions: #changed to 0-based - else have to subtract from pos
if meth_seq[pos] == meth_base or meth_seq[pos] == 'M':
meth_seq = meth_seq[:pos]+'M'+meth_seq[pos+1:]
count+=1
else:
print('Base {} does not correspond to methylated base - check reference positions are 0-based - quitting thread now'.format(pos))
sys.exit(0)
#print(count, 'positions methylated in one strand')
return meth_seq
#extract signals around methylated positions from tsv
#@profile
def methylate_references(ref_seq,base,motif=None,positions=None,train=False,contig=None):
if not positions and motif:
meth_fwd = methylate_motifs(ref_seq,motif,base)
meth_rev = methylate_motifs(ref_seq,revcomp(motif),base_comps[base])
#print(len(meth_fwd.split('M')),'Ms in methylated sequence')
elif positions:
fwd_pos = [int(pos.split()[1]) for pos in open(positions,'r').read().split('\n') if len(pos.split()) > 1 and pos.split()[2] == '+' and pos.split()[0] == contig]
rev_pos = [int(pos.split()[1]) for pos in open(positions,'r').read().split('\n') if len(pos.split()) > 1 and pos.split()[2] == '-' and pos.split()[0] == contig]
meth_fwd = methylate_positions(ref_seq,fwd_pos,base)
meth_rev = methylate_positions(ref_seq,rev_pos,base_comps[base])
else:
print('no motifs or positions specified')
sys.exit(0)
return meth_fwd,meth_rev
#@profile
def find_and_methylate(refname,contigname,base,motif,positions_list):
for ref in SeqIO.parse(refname,"fasta"):
contigid = ref.id
if contigid == contigname:
meth_fwd,meth_rev = methylate_references(str(ref.seq).upper(),base,motif=motif,positions=positions_list,contig=contigname)
return meth_fwd,meth_rev
def writefi(data,fi):
with open(fi,'a') as outfi:
for entry in data:
outfi.write('\t'.join(entry)+'\n')
def adjust_scores(context_dict,context,diffs,prob,k):
if context in context_dict['m6A']:
hmm_score = 1-(1/np.prod([ss.norm(context_dict['m6A'][context]['mean'][i],context_dict['m6A'][context]['sd'][i]).pdf(diffs[i]) for i in range(k)]))
correlation_score = ss.stats.pearsonr(context_dict['m6A'][context]['mean'],diffs)[0]
if context in context_dict['A']:
correlation_diff = correlation_score - ss.stats.pearsonr(context_dict['A'][context]['mean'],diffs)[0]
frac_meth = context_dict['m6A'][context]['num']*1./context_dict['A'][context]['num']
else:
frac_meth = 1
representation_score = prob + 1 - frac_meth #increases score for contexts not included in methylation training set
def base_models(base,twobase=False):
if base == 'A' and twobase:
base_model = {'MG':'MG','MC':'MH','MA':'MH','MT':'MH','MM':'MH','MH':'MH','AT':'MH','AC':'MH','AG':'MG','AT':'MH','AA':'MH','AM':'MH'} #TODO: fix error where sites not methylated
else:
base_model = {'M'+nextb:'general' for nextb in ['A','C','G','T','M']}
base_model.update({'A'+nextb:'general' for nextb in ['A','C','G','T','M']})
base_model.update({'T'+nextb:'general' for nextb in ['A','C','G','T','M']}) #shouldn't be necessary
return(base_model)
#determine difference between measurements and model for bases surrounding methylated positions
#@profile
def extract_features(tsv_input,fasta_input,read2qual,k,skip_thresh,qual_thresh,modelfile,classifier,startline,endline=None,train=False,pos_label=None,base=None,motif=None,positions_list=None):
#set position variables
last_read,last_pos,last_pos_in_kmer,last_read_num = '',0,k,0
last_contig = None
#set count variables
num_observations,w_skips,skipped_skips,pos_set,multi_meth_pos_set,read_set = 0,set(),set(),set(),set(),set()
#set tracking variables for observation
mpos = None
diff_col = [[] for xi in range(k)]
if not train:
tsv_output = '.'.join(tsv_input.split('.')[:-1])+'.diffs.'+str(k)+'.tmp'+str(startline)
modfi = open(modelfile,'rb')
model = pickle.load(modfi,encoding='latin')
modfi.close()
if type(model) != dict:
model = {'general':model} #for compatibility with previously trained model
twobase = False
else:
twobase = True
base_model = base_models(base,twobase)
else:
base_model = base_models(base,False) #or set to False?
tsv_output = '.'.join(tsv_input.split('.')[:-1])+'.diffs.'+str(k)+'.train.tmp'+str(startline)
signals,contexts = {bm:{} for bm in base_model.values()},{bm:{} for bm in base_model.values()}
towrite = []
#save only one set of adjoining methylated positions at a time - once the set complete, write the positions to a file
#tsv format: ecoli 805 CGCCAT cc1da58e-3db3-4a4b-93c2-c78e1dbe6aba:1D_000:template t 1 102.16 0.963 0.00175 CGCCAT 102.23 1.93 -0.03 101.973,100.037,102.403,101.758,104.338,102.618,101.973
with open(tsv_input,'r') as tsv:
tsv.seek(max(startline-500,0))
linepos = max(startline-500,0)
#startline, endline, and linepos are in characters -- previously used tsv.tell(), but incompatible with python3
while linepos <= endline-500:
#print('current position',linepos)
lines = tsv.readlines(8000000) #TODO: why 8M? reasonable size for memory consumption, but could change
for line in lines:
linepos += len(line)
try:
chrom, read_pos, read_kmer, read_name, x, read_ind, event_current, event_sd, y, ref_kmer, model_current, ref_sd = line.split()[:12]
except ValueError:
continue
if chrom != last_contig:
try:
meth_fwd,meth_rev = find_and_methylate(fasta_input,chrom,base,motif,positions_list)
last_contig = chrom
except TypeError: #ValueError
print('Error: could not find sequence for reference contig',chrom)
continue
if read_name != last_read:
first_read_ind = int(read_ind)
try:
qual = read2qual[read_name]
except KeyError:
qual = read2qual[read_name.split(':')[0].split('_')[0]]
if (qual < qual_thresh) or ref_kmer == 'NNNNNN':
continue
if (read_name != last_read and read_kmer == ref_kmer) or (read_name == last_read and int(read_ind) > first_read_ind): #takes into account complementary palindromes - temporarily sets new reads to positive strand
rev = False
meth_ref = meth_fwd
else:
rev = True
meth_ref = meth_rev
read_pos = int(read_pos)
reference_kmer = meth_ref[read_pos:read_pos+k]
#if finished context for previous potentially modified position, save and reset
if mpos and ((read_pos >= mpos+1 and read_name == last_read) or (read_name != last_read)):
#write to file
num_skips = len([x for x in diff_col if x == []])
if num_skips <= skip_thresh: #accept max number of skips within an observation
if num_skips > 0:
w_skips.add((last_read,mpos))
diffs = [np.mean(kmer_pos) if kmer_pos!=[] else 0 for kmer_pos in diff_col]
if not last_rev:
diffs = diffs[::-1]
try:
last_qual = read2qual[last_read]
except KeyError:
last_qual = read2qual[last_read.split(':')[0].split('_')[0]]
diffs = diffs+[last_qual]
context = revcomp(last_ref[mpos-k+1:mpos+k],last_rev)
if context[int(len(context)/2)] == 'M':
try:
twobase_model = base_model[context[int(len(context)/2):int(len(context)/2)+2]]
if not train:
mod_prob = model[twobase_model].predict_proba([diffs]) #TODO: call model only when batch ready to write
if mod_prob[0][1] >= 0.5:
if base == 'A':
label = 'm6A' #TODO: ensure correct direction + label unmeth/meth as appropriate
else:
label = 'm'+base
else:
label = base
label = label+'\t'+str(np.round(mod_prob[0][1],2))
else:
mod_prob = ''
label = pos_label[(chrom,mpos,strand(last_rev))]
if label not in signals[twobase_model]:
signals[twobase_model][label] = []
contexts[twobase_model][label] = []
signals[twobase_model][label].append(diffs)
contexts[twobase_model][label].append(context)
towrite.append([chrom,last_read,str(mpos),context,','.join([str(diff) for diff in diffs]),strand(last_rev),label])
last_info = last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev)
except (IndexError,KeyError) as e:
print(last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev),'- Index or Key Error')
print(model.keys(), base_model.keys(), context[int(len(context)/2):int(len(context)/2)+2])
print(e)
print(model[twobase_model].predict_proba([diffs]))
sys.exit(0)
else:
print(last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev))
print(read_name, rev, last_read, last_rev, last_first)
print(read_kmer,reference_kmer, ref_kmer, last_pos_in_kmer, mspacing, pos_in_kmer)
sys.exit(0)
num_observations += 1
if num_observations%5000 == 0:
writefi(towrite,tsv_output)
towrite = []
pos_set.add(mpos)
read_set.add(last_read)
if len(read_set)%1000 == 0 and len(read_set) > last_read_num:
#print(len(read_set), 'reads examined')
last_read_num = len(read_set)
else:
skipped_skips.add((last_read,mpos))
#reset variables
if len(reference_kmer.split('M')) < 2 or read_name != last_read or read_pos > mpos+skip_thresh+1: #allow no more than skip_thresh skips
diff_col = [[] for i in range(k)]
mpos = None
last_pos_in_kmer = k
else:
if reference_kmer[0] != 'M':
multi_meth_pos_set.add((last_read,mpos))
last_mpos = mpos
pos_in_kmer = len(reference_kmer.split('M')[0])
mpos = read_pos + pos_in_kmer
mspacing = min(k,mpos - last_mpos)
last_pos_in_kmer = pos_in_kmer
last_diff_col = diff_col
diffs = [[] for i in range(mspacing)] + diff_col[:-mspacing]
diff_col = diffs
if len(diff_col) != k:
try:
print(last_info,'- n diffs off')
except:
pass
#GGCGCM 613883 613878 False 2289b392-746e-4fa0-8226-d3ac661c9620_Basecall_2D_template 2289b392-746e-4fa0-8226-d3ac661c9620_Basecall_2D_template [[], [], [], [], [], [], []] 7
print(reference_kmer,last_mpos,mpos,mspacing,read_pos,read_pos-last_mpos,read_name,last_read,diff_col,mspacing, last_diff_col, last_diff_col[:-mspacing])
diff_col = [[] for i in range(k)]
sys.exit(0)
#if modified base in reference, save surrounding context to call that position
if 'M' in set(list(reference_kmer)):
pos_in_kmer = [i for i,x in enumerate(list(reference_kmer)) if x == 'M'][0]
#if new read, reset differences variable and proceed
if mpos:
if read_name != last_read:
mpos = None
diff_col = [[] for i in range(k)]
elif rev != last_rev:
mpos = None
#if new read or new position
if not mpos:
mpos = read_pos+pos_in_kmer
last_pos_in_kmer = pos_in_kmer
last_read = read_name
last_rev = rev
last_first = first_read_ind
last_ref = meth_ref
diff_col[pos_in_kmer].append(np.round(float(event_current)-float(model_current),4))
last_pos = read_pos
elif mpos:
mpos = None
diff_col = [[] for i in range(k)]
writefi(towrite,tsv_output)
print('thread finished processing...:')
print('%d observations' %num_observations)
num_pos = len(pos_set)
print('%d positions' %num_pos)
print('%d regions with multiple methylated bases' %len(multi_meth_pos_set))
print('%d observations with skips included' %len(w_skips))
print('%d observations with too many skips' %len(skipped_skips))
if train:
return signals, contexts
|