1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
|
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, make_scorer
from datetime import datetime
import joblib
import sys
import os
from sklearn.model_selection import cross_val_score
from Bio import SeqIO
import pickle
import time
import os.path
# file with lineage assignments
lineage_file = sys.argv[1]
# file with sequences
sequence_file = sys.argv[2]
# how much of the data will be used for testing, instead of training
testing_percentage = 0.0000000001
relevant_positions = pickle.load(open(sys.argv[5], 'rb'))
relevant_positions.add(0)
# the path to the reference file.
# This reference sequence must be the same as is used in the pangolearn script!!
referenceFile = sys.argv[3]
# data storage
dataList = []
# dict for lookup efficiency
indiciesToKeep = dict()
referenceId = "Wuhan/WH04/2020"
referenceSeq = ""
idToLineage = dict()
idToSeq = dict()
mustKeepIds = []
mustKeepLineages = []
# function for handling weird sequence characters
def clean(x, loc):
x = x.upper()
if x == 'T' or x == 'A' or x == 'G' or x == 'C' or x == '-':
return x
if x == 'U':
return 'T'
# otherwise return value from reference
return referenceSeq[loc]
def findReferenceSeq():
with open(referenceFile) as f:
currentSeq = ""
for line in f:
if ">" not in line:
currentSeq = currentSeq + line.strip()
f.close()
return currentSeq
def getDataLine(seqId, seq):
dataLine = []
dataLine.append(seqId)
newSeq = ""
# for each character in the sequence
for index in range(len(seq)):
newSeq = newSeq + clean(seq[index], index)
dataLine.append(newSeq)
return dataLine
def readInAndFormatData():
# add the data line for the reference seq
idToLineage[referenceId] = "A"
dataList.append(getDataLine(referenceId, referenceSeq))
# create a dictionary of sequence ids to their assigned lineages
lineage_designations = pd.read_csv(lineage_file, delimiter=",", dtype=str)
for index, row in lineage_designations.iterrows():
idToLineage[row["sequence_name"]] = row["lineage"]
seq_dict = {rec.id : rec.seq for rec in SeqIO.parse(sequence_file, "fasta")}
print("files read in, now processing")
for key in seq_dict.keys():
if key in idToLineage:
dataList.append(getDataLine(key, seq_dict[key]))
else:
print("unable to find the lineage classification for: " + key)
# find columns in the data list which always have the same value
def findColumnsWithoutSNPs():
# for each index in the length of each sequence
for index in range(len(dataList[0][1])):
keep = False
# loop through all lines
for line in dataList:
# if there is a difference somewhere, then we want to keep it
if dataList[0][1][index] != line[1][index] or index == 0:
keep = True
break
# otherwise, save it
if keep and index in relevant_positions:
indiciesToKeep[index] = True
# remove columns from the data list which don't have any SNPs. We do this because
# these columns won't be relevant for a logistic regression which is trying to use
# differences between sequences to assign lineages
def removeOtherIndices(indiciesToKeep):
# instantiate the final list
finalList = []
indicies = list(indiciesToKeep.keys())
indicies.sort()
# while the dataList isn't empty
while len(dataList) > 0:
# pop the first line
line = dataList.pop(0)
seqId = line.pop(0)
line = line[0]
# initialize the finalLine
finalLine = []
for index in indicies:
if index == 0:
# if its the first index, then that's the lineage assignment, so keep it
finalLine.append(seqId)
else:
# otherwise keep everything at the indices in indiciesToKeep
finalLine.append(line[index])
# save the finalLine to the finalList
finalList.append(finalLine)
# return
return finalList
def allEqual(list):
entries = dict()
for i in list:
if i not in entries:
entries[i] = True
return len(entries) == 1
def removeAmbiguous():
idsToRemove = set()
lineMap = dict()
idMap = dict()
for line in dataList:
keyString = ",".join(line[1:])
if keyString not in lineMap:
lineMap[keyString] = []
idMap[keyString] = []
if line[0] in idToLineage:
lineMap[keyString].append(idToLineage[line[0]])
idMap[keyString].append(line[0])
else:
print("diagnostics")
print(line[0])
print(keyString)
print(line)
for key in lineMap:
if not allEqual(lineMap[key]):
skipRest = False
# see if any protected lineages are contained in the set, if so keep those ids
for lineage in lineMap[key]:
if lineage in mustKeepLineages:
skipRest = True
for i in idMap[key]:
if lineage != idToLineage[i] and i not in mustKeepIds:
idsToRemove.add(i)
# none of the lineages are protected, fire at will
if not skipRest:
lineageToCounts = dict()
aLineage = False
# find most common lineage
for lineage in lineMap[key]:
if lineage not in lineageToCounts:
lineageToCounts[lineage] = 0
lineageToCounts[lineage] = lineageToCounts[lineage] + 1
aLineage = lineage
m = aLineage
for lineage in lineageToCounts:
if lineageToCounts[lineage] > lineageToCounts[m]:
m = lineage
for i in idMap[key]:
if m != idToLineage[i]:
idsToRemove.add(i)
newList = []
print("keeping indicies:")
for line in dataList:
if line[0] not in idsToRemove:
print(line[0])
line[0] = idToLineage[line[0]]
newList.append(line)
return newList
print("reading in data " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
referenceSeq = findReferenceSeq()
readInAndFormatData()
print("processing snps, formatting data " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
findColumnsWithoutSNPs()
dataList = removeOtherIndices(indiciesToKeep)
print("# sequences before blacklisting")
print(len(dataList))
dataList = removeAmbiguous()
print("# sequences after blacklisting")
print(len(dataList))
# headers are the original genome locations
headers = list(indiciesToKeep.keys())
headers[0] = "lineage"
print("setting up training " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
pima = pd.DataFrame(dataList, columns=headers)
# nucleotide symbols which can appear
categories = ['A', 'C', 'G', 'T', '-']
# one hot encoding of all headers other than the first which is the lineage
dummyHeaders = headers[1:]
# add extra rows to ensure all of the categories are represented, as otherwise
# not enough columns will be created when we call get_dummies
for i in categories:
line = [i] * len(dataList[0])
pima.loc[len(pima)] = line
# get one-hot encoding
pima = pd.get_dummies(pima, columns=dummyHeaders)
# get rid of the fake data we just added
pima.drop(pima.tail(len(categories)).index, inplace=True)
feature_cols = list(pima)
print(feature_cols)
# remove the last column from the data frame. This is because we are trying to predict these values.
h = feature_cols.pop(0)
X = pima[feature_cols]
y = pima[h]
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=testing_percentage,random_state=0)
print("training " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
header_out = os.path.join(sys.argv[4],"decisionTreeHeaders_v1.joblib")
joblib.dump(headers, header_out, compress=('lzma', 9))
# instantiate the random forest with 50 trees
dt = DecisionTreeClassifier()
# fit the model
dt.fit(X,y)
print("testing " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
# classify the test data
y_pred=dt.predict(X_test)
print(y_pred)
# get the scores from these predictions
y_scores = dt.predict_proba(X_test)
print("generating statistics " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"), flush=True)
#print the confusion matrix
print("--------------------------------------------")
print("Confusion Matrix")
cnf_matrix = metrics.confusion_matrix(y_test, y_pred)
print(cnf_matrix)
print("--------------------------------------------")
print("Classification report")
print(metrics.classification_report(y_test, y_pred, digits=3))
# save the model files to compressed joblib files
# using joblib instead of pickle because these large files need to be compressed
model_out = os.path.join(sys.argv[4],"decisionTree_v1.joblib")
joblib.dump(dt, model_out, compress=('lzma', 9))
print("model files created", flush=True)
# this method is used below when running 10-fold cross validation. It ensures
# that the per-lineage statistics are generated for each cross-fold
def classification_report_with_accuracy_score(y_true, y_pred):
print("--------------------------------------------")
print("Crossfold Classification Report")
print(metrics.classification_report(y_true, y_pred, digits=3))
return accuracy_score(y_true, y_pred)
# optionally, run 10-fold cross validation (comment this out if not needed as it takes a while to run)
# cross_validation_scores = cross_val_score(dt, X=X, y=y, cv=10, scoring=make_scorer(classification_report_with_accuracy_score))
|