1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
from abc import abstractmethod
from typing import Tuple
import torch
from ignite.metrics.metric import Metric, reinit__is_reduced
def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1
if not (y_pred.ndimension() == 1 or c1):
raise ValueError(f"Input y_pred should have shape (N,) or (N, 1), but given {y_pred.shape}")
c2 = y.ndimension() == 2 and y.shape[1] == 1
if not (y.ndimension() == 1 or c2):
raise ValueError(f"Input y should have shape (N,) or (N, 1), but given {y.shape}")
if y_pred.shape != y.shape:
raise ValueError(f"Input data shapes should be the same, but given {y_pred.shape} and {y.shape}")
def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
if y_pred.dtype not in (torch.float16, torch.float32, torch.float64):
raise TypeError(f"Input y_pred dtype should be float 16, 32 or 64, but given {y_pred.dtype}")
if y.dtype not in (torch.float16, torch.float32, torch.float64):
raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}")
def _torch_median(output: torch.Tensor) -> float:
output = output.view(-1)
len_ = len(output)
if len_ % 2 == 0:
return float((torch.kthvalue(output, len_ // 2)[0] + torch.kthvalue(output, len_ // 2 + 1)[0]) / 2)
else:
return float(torch.kthvalue(output, len_ // 2 + 1)[0])
class _BaseRegression(Metric):
# Base class for all regression metrics
# `update` method check the shapes and call internal overloaded
# method `_update`.
@reinit__is_reduced
def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
_check_output_shapes(output)
_check_output_types(output)
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
if y.ndimension() == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)
self._update((y_pred, y))
@abstractmethod
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
pass
|