1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
import torch
import numpy as np
from scipy.special import softmax
from pytorch_tabnet.utils import SparsePredictDataset, PredictDataset, filter_weights
from pytorch_tabnet.abstract_model import TabModel
from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim
from torch.utils.data import DataLoader
import scipy
class TabNetMultiTaskClassifier(TabModel):
def __post_init__(self):
super(TabNetMultiTaskClassifier, self).__post_init__()
self._task = 'classification'
self._default_loss = torch.nn.functional.cross_entropy
self._default_metric = 'logloss'
def prepare_target(self, y):
y_mapped = y.copy()
for task_idx in range(y.shape[1]):
task_mapper = self.target_mapper[task_idx]
y_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y[:, task_idx])
return y_mapped
def compute_loss(self, y_pred, y_true):
"""
Computes the loss according to network output and targets
Parameters
----------
y_pred : list of tensors
Output of network
y_true : LongTensor
Targets label encoded
Returns
-------
loss : torch.Tensor
output of loss function(s)
"""
loss = 0
y_true = y_true.long()
if isinstance(self.loss_fn, list):
# if you specify a different loss for each task
for task_loss, task_output, task_id in zip(
self.loss_fn, y_pred, range(len(self.loss_fn))
):
loss += task_loss(task_output, y_true[:, task_id])
else:
# same loss function is applied to all tasks
for task_id, task_output in enumerate(y_pred):
loss += self.loss_fn(task_output, y_true[:, task_id])
loss /= len(y_pred)
return loss
def stack_batches(self, list_y_true, list_y_score):
y_true = np.vstack(list_y_true)
y_score = []
for i in range(len(self.output_dim)):
score = np.vstack([x[i] for x in list_y_score])
score = softmax(score, axis=1)
y_score.append(score)
return y_true, y_score
def update_fit_params(self, X_train, y_train, eval_set, weights):
output_dim, train_labels = infer_multitask_output(y_train)
for _, y in eval_set:
for task_idx in range(y.shape[1]):
check_output_dim(train_labels[task_idx], y[:, task_idx])
self.output_dim = output_dim
self.classes_ = train_labels
self.target_mapper = [
{class_label: index for index, class_label in enumerate(classes)}
for classes in self.classes_
]
self.preds_mapper = [
{str(index): str(class_label) for index, class_label in enumerate(classes)}
for classes in self.classes_
]
self.updated_weights = weights
filter_weights(self.updated_weights)
def predict(self, X):
"""
Make predictions on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
-------
results : np.array
Predictions of the most probable class
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = {}
for data in dataloader:
data = data.to(self.device).float()
output, _ = self.network(data)
predictions = [
torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1)
.cpu()
.detach()
.numpy()
.reshape(-1)
for task_output in output
]
for task_idx in range(len(self.output_dim)):
results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]]
# stack all task individually
results = [np.hstack(task_res) for task_res in results.values()]
# map all task individually
results = [
np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str))
for task_idx, task_res in enumerate(results)
]
return results
def predict_proba(self, X):
"""
Make predictions for classification on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
-------
res : list of np.ndarray
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = {}
for data in dataloader:
data = data.to(self.device).float()
output, _ = self.network(data)
predictions = [
torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy()
for task_output in output
]
for task_idx in range(len(self.output_dim)):
results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]]
res = [np.vstack(task_res) for task_res in results.values()]
return res
|