File: test_git_mol.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (24 lines) | stat: -rw-r--r-- 795 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch

from torch_geometric.llm.models import GITMol
from torch_geometric.testing import withPackage


@withPackage('transformers', 'sentencepiece', 'accelerate')
def test_git_mol():
    model = GITMol()

    x = torch.ones(10, 16, dtype=torch.long)
    edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
    ])
    edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long)
    batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O']
    captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.']
    images = torch.randn(1, 3, 224, 224)

    # Test train:
    loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)
    assert loss >= 0