File: test_tools.py

package info (click to toggle)
pdb2pqr 2.1.1%2Bdfsg-7%2Bdeb11u1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 47,044 kB
  • sloc: python: 44,152; cpp: 9,847; xml: 9,092; sh: 79; makefile: 55; ansic: 36
file content (231 lines) | stat: -rw-r--r-- 7,587 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from itertools import zip_longest
from glob import glob
from os import path
import csv

def isAtomLine(line):
    try:
        return line[:4] == 'ATOM' or line[:6] == 'HETATM'
    except IndexError:
        return False

def parsePQRAtomLine(line, has_chain):
    #Parses ATOM line into a more comparable tuple
    #First peel off the element type
    #This will keep us from running into problems with tests
    # that have enough elements so that the serial runs into
    # the record type.
    recordType = line[:6].strip()
    sLine = line[6:].split()

    if has_chain:
        strings = (recordType, sLine[1], sLine[2],sLine[3])
        ints = (int(sLine[0]), int(sLine[4]))
        floats = tuple(float(x) for x in sLine[5:])
    else:
        strings = (recordType, sLine[1], sLine[2])
        ints = (int(sLine[0]), int(sLine[3]))
        floats = tuple(float(x) for x in sLine[4:])


    return strings,ints,floats

def compareParsedAtoms(atom1, atom2):
    return atom1[0:1] == atom2[0:1] and all(abs(x-y)<0.1 for x,y in zip(atom1[2],atom2[2]))

def ComparePQRAction(outputFileName, testFileName, correctFileName, has_chain=False):
    failure = False
    results = []
    with open(testFileName) as testFile:
        with open(correctFileName) as correctFile:
            testAtoms = filter(isAtomLine, testFile)
            correctAtoms = filter(isAtomLine, correctFile)
            check_error = False
            correct_total = 0.0
            test_total = 0.0
            for testAtom, correctAtom in zip_longest(testAtoms, correctAtoms, fillvalue=None):
                if testAtom is None or correctAtom is None:
                    results.append('TEST ERROR: Result file is the wrong length!\n')
                    failure = True
                    break
                parsedTest = parsePQRAtomLine(testAtom, has_chain)
                parsedCorrect = parsePQRAtomLine(correctAtom, has_chain)

                if not compareParsedAtoms(parsedTest,parsedCorrect):
                    results.append('WARNING: Mismatch ->\n%s%s\n' % (testAtom, correctAtom))
                    check_error = True

                test_total += sum(parsedTest[1]) + sum(parsedTest[2])
                correct_total += sum(parsedCorrect[1]) + sum(parsedCorrect[2])

            if check_error and abs(test_total - correct_total) > 20.0:
                results.append('TEST ERROR: Result file does not match target close enough!\n')
                failure = True

    with open(outputFileName, 'w') as outputFile:
        for line in results:
            outputFile.write(line)

        outputFile.write('FAILURE!' if failure else 'SUCCESS!')

    return failure


def get_csv_data(in_file):
    reader = csv.reader(in_file)
    results = [(float(pH),float(titr)) for pH, titr in reader]
    return results

def get_curve_data(input_path):
    scatter_data = {}
    for file_name in glob(input_path+'/titration_curves/*.csv'):
        base_name = path.basename(file_name)

        name = base_name.rsplit('.', 1)[0]

        with open(file_name, 'rb') as in_file:
            file_data = get_csv_data(in_file)

        scatter_data[name] = file_data

    return scatter_data

def merge_curves(curve1, curve2):
    combined = dict((ph, [value, None]) for (ph, value) in curve1)

    for ph, value in curve2:
        if ph not in combined:
            combined[ph]=[None, value]
        else:
            combined[ph][1] = value

    keys = list(combined.keys())
    keys.sort()

    combined_list = [(combined[ph][0], combined[ph][1]) for ph in keys]
    return combined_list

def CompareTitCurvesAction(outputFileName, testDirName, correctDirName):
    results = []
    EPSILON = 0.025

    test_data = get_curve_data(testDirName)
    correct_data = get_curve_data(correctDirName)

    for name in correct_data:
        correct_curve = correct_data[name]
        test_curve = test_data.pop(name, None)

        if test_curve is None:
            results.append("ERROR: test results missing curve for residue " + name + "\n")
            continue

        combined = merge_curves(correct_curve, test_curve)

        report_extra_data = False
        report_missing_data = False
        total_error = 0.0
        bad_point_count = 0

        for correct_value, test_value in combined:
            if correct_value is None:
                report_extra_data = True
            elif test_value is None:
                report_missing_data = True
            else:
                diff = abs(correct_value - test_value)
                if diff > EPSILON:
                    bad_point_count += 1
                    total_error += diff


        if report_extra_data:
            results.append("ERROR: test curve " + name +" has extra data\n")

        if report_missing_data:
            results.append("ERROR: test curve " + name +" has missing data\n")

        if bad_point_count > 0:
            results.append("ERROR: test curve {name} has {count}"
                           " bad points with {total} cumulative error.\n".format(name=name, count=bad_point_count, total=total_error))


    for name in test_data:
        results.append("ERROR: extra curve for residue" + name + " in test results\n")

    failure = bool(results)

    with open(outputFileName, 'w') as outputFile:
        for line in results:
            outputFile.write(line)

        outputFile.write('FAILURE!' if failure else 'SUCCESS!')

    return failure



def getSummaryLines(sourceFile):
    while 'SUMMARY' not in next(sourceFile):
        pass

    #Skip header
    next(sourceFile)
    line = next(sourceFile)

    results = []

    while '-----------' not in line:
        sLine = line.split()


        strings = tuple(sLine[:3])
        floats = tuple(float(x) for x in sLine[3:])

        results.append((strings,floats))

        line = next(sourceFile)

    return results

def ComparePROPKAAction(outputFileName, testFileName, correctFileName):
    failure = False
    results = []
    with open(testFileName) as testFile:
        with open(correctFileName) as correctFile:
            testPKAs = getSummaryLines(testFile)
            correctPKAs = getSummaryLines(correctFile)
            correct_total = 0.0
            test_total = 0.0
            for testPKA, correctPKA in zip_longest(testPKAs, correctPKAs, fillvalue=None):
                if testPKA is None or correctPKA is None:
                    results.append('TEST ERROR: Result file is the wrong length!')
                    failure = True
                    break

                test_total += sum(testPKA[1])
                correct_total += sum(correctPKA[1])

            if abs(test_total - correct_total) > 20.0:
                results.append('TEST ERROR: Result file does not match target close enough!')
                failure = True

    with open(outputFileName, 'w') as outputFile:
        for line in results:
            outputFile.write(line + '\n')

        outputFile.write('FAILURE!\n' if failure else 'SUCCESS!\n')

    return failure

def CompareStringFunc(outputFileName, targetfile, sourcefile, has_chain=None):
    return 'Comparing files ("%s", "%s") -> %s' % (targetfile, sourcefile, outputFileName)

def CompareDirectoryFunc(outputFileName, targetdir, sourcedir):
    return 'Comparing directories ("%s", "%s") -> %s' % (targetdir, sourcedir, outputFileName)