File: model.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (22 lines) | stat: -rw-r--r-- 748 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.nn as nn
from transformers import AutoConfig, AutoModelForSequenceClassification


class TransformerModel(nn.Module):
    def __init__(self, model_name, model_dir, dropout, n_fc, n_classes):
        super(TransformerModel, self).__init__()
        self.config = AutoConfig.from_pretrained(
            model_name,
            num_labels=n_classes,
            output_hidden_states=n_fc,
            classifier_dropout=dropout,
            output_attentions=True,
        )
        self.transformer = AutoModelForSequenceClassification.from_pretrained(
            model_name, cache_dir=model_dir, config=self.config
        )

    def forward(self, inputs):
        output = self.transformer(**inputs)["logits"]

        return output