File: metrics.py

package info (click to toggle)
python-polsarpro 2026.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 17,024 kB
  • sloc: python: 3,830; xml: 293; sh: 91; javascript: 18; makefile: 3
file content (134 lines) | stat: -rw-r--r-- 4,104 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
import xarray
import pandas as pd
from collections import OrderedDict
import warnings

# these are convenience function used only for development


# returns a dict with various error metrics for two arrays (predicted and reference)
def compute_error_metrics(arr_pred, arr_ref):
    stats = OrderedDict()

    eps = 1e-30

    # simple error
    err = arr_pred - arr_ref

    # squared and absolute errors
    se = err**2
    ae = abs(err)

    # normalized absolute error (between 0 and 1)
    naerr = ae / (eps + abs(arr_pred) + abs(arr_ref))

    # ---- Global statistics

    dm = np.nanmean(arr_pred) - np.nanmean(arr_ref)
    item = {"value": dm, "descr": "Difference of Means"}
    stats["dm"] = item

    rmse = np.sqrt(np.nanmean(se) + eps)
    item = {"value": rmse, "descr": "Root Mean Square Error"}
    stats["rmse"] = item

    bias = np.nanmean(err)
    item = {"value": bias, "descr": "Mean Error (bias)"}
    stats["bias"] = item

    mae = np.nanmean(ae)
    item = {"value": mae, "descr": "Mean Absolute Error"}
    stats["mae"] = item

    mnae = np.nanmean(naerr)
    item = {"value": mnae, "descr": "Mean Normalized Absolute Error"}
    stats["mnae"] = item

    p99 = np.nanpercentile(naerr, q=99)
    item = {"value": p99, "descr": "99% percentile of the Normalized Absolute Error"}
    stats["p99"] = item

    p90 = np.nanpercentile(naerr, q=90)
    item = {"value": p90, "descr": "90% percentile of the Normalized Absolute Error"}
    stats["p90"] = item

    p50 = np.nanpercentile(naerr, q=50)
    item = {"value": p50, "descr": "50% percentile of the Normalized Absolute Error"}
    stats["p50"] = item

    return stats


# returns a dataframe that summarizes statistics for all variables
def summarize_metrics(
    out_py: xarray.Dataset,
    out_c: dict,
    short_titles: bool = False,
    verbose: bool = True,
):

    data_vars = list(out_py.data_vars)

    # dataframes used to collect results and to merge
    dfs = []

    for var in data_vars:
        # check that python variable appears in C data
        if var not in out_c:
            warnings.warn(f"Skipping variable '{var}'! Not found in C outputs.")
            continue
        else:
            if verbose:
                print(f"Computing variable '{var}'.")
        stats = compute_error_metrics(out_py[var], out_c[var])
        if short_titles:
            tmp = {k: [stats[k]["value"]] for k in stats}
        else:
            tmp = {stats[k]["descr"]: [stats[k]["value"]] for k in stats}
        df = pd.DataFrame.from_dict(tmp, orient="columns")
        df.index = [var]
        dfs.append(df)

    return pd.concat(dfs)

# helper function to visualize errors
def visualize_errors(out_py, out_c, clip=True):
    import matplotlib.pyplot as plt

    for var in out_py.data_vars:
        if var not in out_c:
            Warning(f"Skipping variable '{var}'! Not found in C outputs.")
        if np.iscomplexobj(out_py[var]):
            img_c = np.abs(out_c[var])
            img_py = np.abs(out_py[var])
            err = abs(out_py[var] - out_c[var])
            t_err = "Absolute difference: python - C"
        else:
            img_c = out_c[var]
            img_py = out_py[var]
            err = out_py[var] - out_c[var]
            t_err = "Difference: python - C"

        if clip:
            m = 0.5 * (np.nanmean(img_py) + np.nanmean(img_py))
            img_c = img_c.clip(0, 2*m) 
            img_py = img_py.clip(0, 2*m) 

        plt.figure(figsize=(10, 6))
        plt.suptitle(var)
        plt.subplot(131)
        plt.imshow(err[::8], interpolation="none")
        plt.title(t_err)
        plt.axis("off")
        plt.colorbar(fraction=0.046, pad=0.04, location="bottom")
        plt.subplot(132)
        plt.imshow(img_c[::8], interpolation="none")
        plt.title("C")
        plt.axis("off")
        plt.colorbar(fraction=0.046, pad=0.04, location="bottom")
        plt.subplot(133)
        plt.imshow(img_py[::8], interpolation="none")
        plt.colorbar(fraction=0.046, pad=0.04, location="bottom")
        plt.title("python")
        plt.axis("off")