1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
from typing import Callable, Optional, Union
import torch
from packaging.version import Version
from ignite.metrics.metric import Metric
class InceptionModel(torch.nn.Module):
r"""Inception Model pre-trained on the ImageNet Dataset.
Args:
return_features: set it to `True` if you want the model to return features from the last pooling
layer instead of prediction probabilities.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
"""
def __init__(self, return_features: bool, device: Union[str, torch.device] = "cpu") -> None:
try:
import torchvision
from torchvision import models
except ImportError:
raise ModuleNotFoundError("This module requires torchvision to be installed.")
super(InceptionModel, self).__init__()
self._device = device
if Version(torchvision.__version__) < Version("0.13.0"):
model_kwargs = {"pretrained": True}
else:
model_kwargs = {"weights": models.Inception_V3_Weights.DEFAULT}
self.model = models.inception_v3(**model_kwargs).to(self._device)
if return_features:
self.model.fc = torch.nn.Identity()
else:
self.model.fc = torch.nn.Sequential(self.model.fc, torch.nn.Softmax(dim=1))
self.model.eval()
@torch.no_grad()
def forward(self, data: torch.Tensor) -> torch.Tensor:
if data.dim() != 4:
raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}")
if data.shape[1] != 3:
raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}")
if data.device != torch.device(self._device):
data = data.to(self._device)
return self.model(data)
class _BaseInceptionMetric(Metric):
def __init__(
self,
num_features: Optional[int],
feature_extractor: Optional[torch.nn.Module],
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
if num_features is None:
raise ValueError("Argument num_features must be provided, if feature_extractor is specified.")
if feature_extractor is None:
feature_extractor = torch.nn.Identity()
if num_features <= 0:
raise ValueError(f"Argument num_features must be greater to zero, got: {num_features}")
if not isinstance(feature_extractor, torch.nn.Module):
raise TypeError(
f"Argument feature_extractor must be of type torch.nn.Module, got {type(self._feature_extractor)}"
)
self._num_features = num_features
self._feature_extractor = feature_extractor.to(device)
super(_BaseInceptionMetric, self).__init__(output_transform=output_transform, device=device)
def _check_feature_shapes(self, samples: torch.Tensor) -> None:
if samples.dim() != 2:
raise ValueError(f"feature_extractor output must be a tensor of dim 2, got: {samples.dim()}")
if samples.shape[0] == 0:
raise ValueError(f"Batch size should be greater than one, got: {samples.shape[0]}")
if samples.shape[1] != self._num_features:
raise ValueError(
f"num_features returned by feature_extractor should be {self._num_features}, got: {samples.shape[1]}"
)
def _extract_features(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = inputs.detach()
if inputs.device != torch.device(self._device):
inputs = inputs.to(self._device)
with torch.no_grad():
outputs = self._feature_extractor(inputs).to(self._device, dtype=torch.float64)
self._check_feature_shapes(outputs)
return outputs
|