File: _base.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (63 lines) | stat: -rw-r--r-- 2,242 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from abc import abstractmethod
from typing import Tuple

import torch

from ignite.metrics.metric import Metric, reinit__is_reduced


def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
    y_pred, y = output
    c1 = y_pred.ndimension() == 2 and y_pred.shape[1] == 1
    if not (y_pred.ndimension() == 1 or c1):
        raise ValueError(f"Input y_pred should have shape (N,) or (N, 1), but given {y_pred.shape}")

    c2 = y.ndimension() == 2 and y.shape[1] == 1
    if not (y.ndimension() == 1 or c2):
        raise ValueError(f"Input y should have shape (N,) or (N, 1), but given {y.shape}")

    if y_pred.shape != y.shape:
        raise ValueError(f"Input data shapes should be the same, but given {y_pred.shape} and {y.shape}")


def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
    y_pred, y = output
    if y_pred.dtype not in (torch.float16, torch.float32, torch.float64):
        raise TypeError(f"Input y_pred dtype should be float 16, 32 or 64, but given {y_pred.dtype}")

    if y.dtype not in (torch.float16, torch.float32, torch.float64):
        raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}")


def _torch_median(output: torch.Tensor) -> float:
    output = output.view(-1)
    len_ = len(output)

    if len_ % 2 == 0:
        return float((torch.kthvalue(output, len_ // 2)[0] + torch.kthvalue(output, len_ // 2 + 1)[0]) / 2)
    else:
        return float(torch.kthvalue(output, len_ // 2 + 1)[0])


class _BaseRegression(Metric):
    # Base class for all regression metrics
    # `update` method check the shapes and call internal overloaded
    # method `_update`.

    @reinit__is_reduced
    def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        _check_output_shapes(output)
        _check_output_types(output)
        y_pred, y = output[0].detach(), output[1].detach()

        if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
            y_pred = y_pred.squeeze(dim=-1)

        if y.ndimension() == 2 and y.shape[1] == 1:
            y = y.squeeze(dim=-1)

        self._update((y_pred, y))

    @abstractmethod
    def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        pass