File: igmc_dataset.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (128 lines) | stat: -rw-r--r-- 4,611 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os.path as osp
from typing import Callable, Optional

import torch
from torch import Tensor

from torch_geometric.data import HeteroData, InMemoryDataset, download_url


class IGMCDataset(InMemoryDataset):
    r"""The user-item heterogeneous rating datasets :obj:`"Douban"`,
    :obj:`"Flixster"` and :obj:`"Yahoo-Music"` from the `"Inductive Matrix
    Completion Based on Graph Neural Networks"
    <https://arxiv.org/abs/1904.12058>`_ paper.

    Nodes represent users and items.
    Edges and features between users and items represent a (training) rating of
    the item given by the user.

    Args:
        root (str): Root directory where the dataset should be saved.
        name (str): The name of the dataset (:obj:`"Douban"`,
            :obj:`"Flixster"`, :obj:`"Yahoo-Music"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.HeteroData` object and returns a
            transformed version. The data object will be transformed before
            every access. (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.HeteroData` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)
    """
    url = 'https://github.com/muhanzhang/IGMC/raw/master/raw_data'

    def __init__(
        self,
        root: str,
        name: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        self.name = name.lower().replace('-', '_')
        assert self.name in ['flixster', 'douban', 'yahoo_music']

        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0], data_cls=HeteroData)

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self) -> str:
        return 'training_test_dataset.mat'

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self) -> None:
        path = f'{self.url}/{self.name}/training_test_dataset.mat'
        download_url(path, self.raw_dir)

    @staticmethod
    def load_matlab_file(path_file: str, name: str) -> Tensor:
        import h5py
        import numpy as np

        db = h5py.File(path_file, 'r')
        out = torch.from_numpy(np.asarray(db[name])).to(torch.float).t()
        db.close()

        return out

    def process(self) -> None:
        data = HeteroData()

        M = self.load_matlab_file(self.raw_paths[0], 'M')

        if self.name == 'flixster':
            user_x = self.load_matlab_file(self.raw_paths[0], 'W_users')
            item_x = self.load_matlab_file(self.raw_paths[0], 'W_movies')
        elif self.name == 'douban':
            user_x = self.load_matlab_file(self.raw_paths[0], 'W_users')
            item_x = torch.eye(M.size(1))
        elif self.name == 'yahoo_music':
            user_x = torch.eye(M.size(0))
            item_x = self.load_matlab_file(self.raw_paths[0], 'W_tracks')

        data['user'].x = user_x
        data['item'].x = item_x

        train_mask = self.load_matlab_file(self.raw_paths[0], 'Otraining')
        train_mask = train_mask.to(torch.bool)

        edge_index = train_mask.nonzero().t()
        rating = M[edge_index[0], edge_index[1]]

        data['user', 'rates', 'item'].edge_index = edge_index
        data['user', 'rates', 'item'].rating = rating

        data['item', 'rated_by', 'user'].edge_index = edge_index.flip([0])
        data['item', 'rated_by', 'user'].rating = rating

        test_mask = self.load_matlab_file(self.raw_paths[0], 'Otest')
        test_mask = test_mask.to(torch.bool)

        edge_label_index = test_mask.nonzero().t()
        edge_label = M[edge_label_index[0], edge_label_index[1]]

        data['user', 'rates', 'item'].edge_label_index = edge_label_index
        data['user', 'rates', 'item'].edge_label = edge_label

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(name={self.name})'