1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
import warnings
import torch
from torch_geometric.explain import groundtruth_metrics
from torch_geometric.testing import withPackage
@withPackage('torchmetrics>=0.10.0')
def test_groundtruth_metrics():
pred_mask = torch.rand(10)
target_mask = torch.rand(10)
accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(
pred_mask, target_mask)
assert accuracy >= 0.0 and accuracy <= 1.0
assert recall >= 0.0 and recall <= 1.0
assert precision >= 0.0 and precision <= 1.0
assert f1_score >= 0.0 and f1_score <= 1.0
assert auroc >= 0.0 and auroc <= 1.0
@withPackage('torchmetrics>=0.10.0')
def test_perfect_groundtruth_metrics():
pred_mask = target_mask = torch.rand(10)
accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(
pred_mask, target_mask)
assert round(accuracy, 6) == 1.0
assert round(recall, 6) == 1.0
assert round(precision, 6) == 1.0
assert round(f1_score, 6) == 1.0
assert round(auroc, 6) == 1.0
@withPackage('torchmetrics>=0.10.0')
def test_groundtruth_true_negative():
warnings.filterwarnings('ignore', '.*No positive samples in targets.*')
pred_mask = target_mask = torch.zeros(10)
accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(
pred_mask, target_mask)
assert round(accuracy, 6) == 1.0
assert round(recall, 6) == 0.0
assert round(precision, 6) == 0.0
assert round(f1_score, 6) == 0.0
assert round(auroc, 6) == 0.0
|