File: testutils.py

package info (click to toggle)
mrcal 2.5-3
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 8,992 kB
  • sloc: python: 40,651; ansic: 15,632; cpp: 1,754; perl: 303; makefile: 160; sh: 99; lisp: 84
file content (406 lines) | stat: -rw-r--r-- 12,909 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import sys
import numpy as np
import numpysane as nps
import os
import re
from inspect import currentframe
import mrcal

Nchecks = 0
NchecksFailed = 0

# no line breaks. Useful for test reporting. Yes, this sets global state, but
# we're running a test harness. This is fine
np.set_printoptions(linewidth=1e10, suppress=True)



def percentile_compat(*args, **kwargs):
    r'''Wrapper for np.percentile() to handle their API change

In numpy 1.24 the "interpolation" kwarg was renamed to "method". I need to pass
the right thing to work with both old and new numpy. This function tries the
newer method, and if that fails, uses the old one. The test is only done the
first time.

It is assumed that this is called with the old 'interpolation' key.

    '''

    if not 'interpolation' in kwargs or \
       percentile_compat.which == 'interpolation':
        return np.percentile(*args, **kwargs)

    kwargs_no_interpolation = dict(kwargs)
    del kwargs_no_interpolation['interpolation']

    if percentile_compat.which == 'method':
        return np.percentile(*args, **kwargs_no_interpolation,
                             method = kwargs['interpolation'])

    # Need to detect

    try:
        result = np.percentile(*args, **kwargs_no_interpolation,
                               method = kwargs['interpolation'])
        percentile_compat.which = 'method'
        return result
    except:
        percentile_compat.which = 'interpolation'
        return np.percentile(*args, **kwargs)

percentile_compat.which = None


def test_location():
    r'''Reports string describing current location in the test'''


    filename_this = os.path.split( __file__ )[1]

    frame = currentframe().f_back.f_back

    # I keep popping the stack until I leave the testutils file and I'm not in a
    # function called "check"
    while frame:
        if frame.f_back is None or \
           (not frame.f_code.co_filename.endswith(filename_this) and
            frame.f_code.co_name != "check" ):
            break
        frame = frame.f_back

    testfile = os.path.split(frame.f_code.co_filename)[1]
    try:
        return "{}:{}".format(testfile, frame.f_lineno)
    except:
        return ''


def print_red(x):
    """print the message in red"""
    sys.stdout.write("\x1b[31m" + test_location() + ": " + x + "\x1b[0m\n")


def print_green(x):
    """Print the message in green"""
    sys.stdout.write("\x1b[32m" + test_location() + ": " + x + "\x1b[0m\n")

def print_blue(x):
    """Print the message in blue"""
    sys.stdout.write("\x1b[34m" + test_location() + ": " + x + "\x1b[0m\n")



def relative_scale(a,b,
                   *,
                   smooth_radius = None,
                   eps           = 1e-6):
    if smooth_radius is not None and smooth_radius > 0:
        d = smooth_radius*2 + 1
        f = np.ones((d,),) / d
        a = np.convolve(a, f, mode='same')
        b = np.convolve(b, f, mode='same')
    return (np.abs(a) + np.abs(b)) / 2 + eps

def relative_diff(a,b,
                  *,
                  smooth_radius = None,
                  eps           = 1e-6):
    return (a - b) / relative_scale(a,b,
                                    eps           = eps,
                                    smooth_radius = smooth_radius)

def confirm_equal(x, xref,
                  *,
                  msg='',
                  eps=1e-6,
                  reldiff_eps = 1e-6,
                  reldiff_smooth_radius = None,
                  relative=False,
                  worstcase=False,
                  percentile=None,
                  r=False):
    r'''If x is equal to xref, report test success.

    msg identifies this check. eps sets the RMS equality tolerance. The x,xref
    arguments can be given as many different types. This function tries to do
    the right thing.

    if relative: I look at a relative error:
                 err = (a-b) / ((abs(a)+abs(b))/2 + eps)
                 a,b can be smoothed with a kernel of the given smooth_radius
    else:        I look at absolute error:
                 err = a-b

    if worstcase: I look at the worst-case error
                  error = np.max(np.abs(err))
    elif percentile is not None: I look at the given point in the error distribution
                  error = percentile_compat(np.abs(err), percentile)
    else:         RMS error
                  error = np.sqrt(nps.norm2(err) / len(err))

    if r: we are comparing rodrigues rotations. More than one set of r values
          can represent the same rotation

          Let k be an integer. r = th * vaxis.
          Changing th -> th + k*2pi implies the same rotation
          Changing vaxis -> -vaxis and th -> 2pi-th also implies the same rotation
          I normalize the inputs first by finding the rotation with the smallest th

    '''

    if r:
        if not (x.shape[-1] == 3 and xref.shape[-1] == 3):
            raise Exception("confirm_equal(r=True) only makes sense if x and xref have shape (...,3)")

        def normalize_r(r):
            th = nps.mag(r)
            v = r / nps.dummy(th, -1)
            th %= 2.*np.pi
            # th is now in [0,2pi)
            if th > np.pi:
                th = 2.*np.pi - th
                v  *= -1
            # th is in [0,pi)
            return th * v

        x    = normalize_r(x)
        xref = normalize_r(xref)






    global Nchecks
    global NchecksFailed
    Nchecks = Nchecks + 1

    # strip all trailing whitespace in each line, in case these are strings
    if isinstance(x, str):
        x = re.sub('[ \t]+(\n|$)', '\\1', x)
    if isinstance(xref, str):
        xref = re.sub('[ \t]+(\n|$)', '\\1', xref)

    # convert data to numpy if possible
    try:
        xref = np.array(xref)
    except:
        pass
    try:
        x = np.array(x)
    except:
        pass

    try:  # flatten array if possible
        x = x.ravel()
        xref = xref.ravel()
    except:
        pass

    try:
        N = x.shape[0]
    except:
        N = 1
    try:
        Nref = xref.shape[0]
    except:
        Nref = 1

    if N != Nref:

        # Comparing an array to a scalar reference is allowed
        if Nref == 1:
            xref = np.ones((N,), dtype=float) * xref
            Nref = N
        else:
            print_red(("FAILED{}: mismatched array sizes: N = {} but Nref = {}. Arrays: \n" +
                       "x = {}\n" +
                       "xref = {}").
                      format((': ' + msg) if msg else '',
                             N, Nref,
                             x, xref))
            NchecksFailed = NchecksFailed + 1
            return False

    if N != 0:
        try:  # I I can subtract, get the error that way
            if relative:
                diff = relative_diff(x, xref,
                                     eps           = reldiff_eps,
                                     smooth_radius = reldiff_smooth_radius)
            else:
                diff = x - xref

            if worstcase:
                what = 'worst-case'
                err  = np.max(np.abs(diff))
            elif percentile is not None:
                what = f'{percentile}%-percentile'
                err  = percentile_compat(np.abs(diff), percentile, interpolation='higher')
            else:
                what = 'RMS'
                err  = np.sqrt(nps.norm2(diff) / len(diff))

            if not np.all(np.isfinite(err)):
                print_red(f"FAILED{(': ' + msg) if msg else ''}: Some comparison results are NaN or Inf. {what}. error_x_xref =\n{nps.cat(err,x,xref)}")
                NchecksFailed = NchecksFailed + 1
                return False
            if err > eps:
                print_red(f"FAILED{(': ' + msg) if msg else ''}: {what} error = {err}. x_xref_err =\n{nps.cat(x,xref,diff)}")
                NchecksFailed = NchecksFailed + 1
                return False
        except:  # Can't subtract. Do == instead
            if not np.array_equal(x, xref):
                print_red(f"FAILED{(': ' + msg) if msg else ''}: x_xref =\n{nps.cat(x,xref)}")
                NchecksFailed = NchecksFailed + 1
                return False
    print_green("OK" + (': ' + msg) if msg else '')
    return True


def confirm(x, msg=''):
    r'''If x is true, report test success.

    msg identifies this check'''

    global Nchecks
    global NchecksFailed
    Nchecks = Nchecks + 1

    if not x:
        print_red("FAILED{}".format((': ' + msg) if msg else ''))
        NchecksFailed = NchecksFailed + 1
        return False
    print_green("OK{}".format((': ' + msg) if msg else ''))
    return True


def confirm_raises(f, msg=''):
    r'''If f() raises an exception, report test success.

    msg identifies this check'''

    global Nchecks
    global NchecksFailed
    Nchecks = Nchecks + 1

    try:
        f()
        print_red("FAILED{}".format((': ' + msg) if msg else ''))
        NchecksFailed = NchecksFailed + 1
        return False
    except:
        print_green("OK{}".format((': ' + msg) if msg else ''))
        return True


def confirm_does_not_raise(f, msg=''):
    r'''If f() raises an exception, report test failure.

    msg identifies this check'''

    global Nchecks
    global NchecksFailed
    Nchecks = Nchecks + 1

    try:
        f()
        print_green("OK{}".format((': ' + msg) if msg else ''))
        return True
    except:
        print_red("FAILED{}".format((': ' + msg) if msg else ''))
        NchecksFailed = NchecksFailed + 1
        return False


def confirm_covariances_equal(var, var_ref,
                              *,
                              what,
                              # scalar float to use for all the eigenvalues, of
                              # a list of length 3, to use in order from largest
                              # to smallest. None to skip that axis
                              eps_eigenvalues,
                              eps_eigenvectors_deg,
                              check_biggest_eigenvalue_only = False,

                              # In real units, the ellipse radii are of size
                              # sqrt(eigenvalue), so this SHOULD be true. But I
                              # default to False to make the old tests work. New
                              # tests should set this to True
                              check_sqrt_eigenvalue         = False):

    # First, the thing is symmetric, right?
    confirm_equal(nps.transpose(var),
                  var,
                  worstcase = True,
                  msg = f"Var(dq) is symmetric for {what}")


    l_predicted,v_predicted = mrcal.sorted_eig(var)
    l_observed, v_observed  = mrcal.sorted_eig(var_ref)

    eccentricity_threshold = 2.

    if check_sqrt_eigenvalue:
        l_predicted = np.sqrt(l_predicted)
        l_observed  = np.sqrt(l_observed)
        eccentricity_threshold = np.sqrt(eccentricity_threshold)

    # This look at JUST the most dominant modes
    eccentricity_predicted = l_predicted[-1] / l_predicted[-2]

    for i in range(var.shape[-1]):
        # check all the eigenvalues, in order from largest to smallest
        if isinstance(eps_eigenvalues, float):
            eps = eps_eigenvalues
        else:
            eps = eps_eigenvalues[i]
            if eps is None:
                continue

        confirm_equal(l_observed[-1-i],
                      l_predicted[-1-i],
                      eps = eps,
                      worstcase = True,
                      relative  = True,
                      msg = f"Var(dq) largest[{i}] eigenvalue match for {what}")
        if check_biggest_eigenvalue_only:
            break

    # I only check the eigenvector directions if the ellipse is sufficiently
    # non-circular. A circular ellipse has poorly-defined eigenvector directions
    if eccentricity_predicted > eccentricity_threshold:

        # I look at the direction of the largest ellipse axis only
        v0_predicted = v_predicted[:,-1]
        v0_observed  = v_observed [:,-1]

        confirm_equal(np.arccos(np.abs(nps.inner(v0_observed,v0_predicted))) * 180./np.pi,
                      0,
                      eps = eps_eigenvectors_deg,
                      worstcase = True,
                      msg = f"Var(dq) eigenvectors match for {what}")

    # I don't bother checking v1. I already made sure the matrix is
    # symmetric. Thus the eigenvectors are orthogonal, so any angle offset
    # in v0 will be exactly the same in v1



def finish():
    r'''Finalize the executed tests.

    Prints the test summary. Exits successfully iff all the tests passed.

    '''
    if not Nchecks and not NchecksFailed:
        print_red("No tests defined")
        sys.exit(0)

    if NchecksFailed:
        print_red("Some tests failed: {} out of {}".format(NchecksFailed, Nchecks))
        sys.exit(1)

    print_green("All tests passed: {} total".format(Nchecks))
    sys.exit(0)