1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
import copy
from unittest.mock import patch
import torch
import torchtext
from torch.nn import functional as torch_F
from ..common.torchtext_test_case import TorchtextTestCase
class TestModels(TorchtextTestCase):
def test_roberta_bundler_build_model(self) -> None:
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle
dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)
# case: user provide encoder checkpoint state dict
dummy_encoder = RobertaModel(dummy_encoder_conf)
model = RobertaBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict())
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())
# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
# case: user provide classifier checkpoint state dict when head is given and override_head is set True
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
override_checkpoint_head=True,
)
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())
# case: user provide only encoder checkpoint state dict when head is given
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
encoder_state_dict = {}
for k, v in dummy_classifier.encoder.state_dict().items():
encoder_state_dict["encoder." + k] = v
model = torchtext.models.RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict
)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
def test_roberta_bundler_train(self) -> None:
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle
dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)
from torch.optim import SGD
def _train(model):
optim = SGD(model.parameters(), lr=1)
model_input = torch.tensor([[0, 1, 2, 3, 4, 5]])
target = torch.tensor([0])
logits = model(model_input)
loss = torch_F.cross_entropy(logits, target)
loss.backward()
optim.step()
# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
checkpoint=dummy_classifier.state_dict(),
)
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())
_train(model)
self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)
# freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=dummy_classifier.state_dict(),
)
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())
_train(model)
self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)
@patch("logging.Logger.warning")
def test_roberta_bundler_get_model(self, mock):
from torchtext.models import RobertaEncoderConf, RobertaBundle
dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)
model_bundle = RobertaBundle(dummy_encoder_conf)
model_bundle.get_model(load_weights=False, freeze_encoder=True)
mock.assert_called_with(
"The encoder is not loaded with pre-trained weights. Setting freeze_encoder to True will hinder encoder from learning appropriate weights."
)
def test_roberta_bundler_raise_checkpoint(self) -> None:
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaBundle
with self.assertRaises(TypeError):
dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=1,
)
def test_roberta_bundler_encode_conf_property(self) -> None:
from torchtext.models import RobertaEncoderConf, RobertaBundle
dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)
model_bundle = RobertaBundle(dummy_encoder_conf)
self.assertTrue(isinstance(model_bundle.encoderConf, RobertaEncoderConf))
|