1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
import os.path as osp
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from torch_geometric.data import HeteroData, download_url, extract_zip
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
root = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')
extract_zip(download_url(url, root), root)
movie_path = osp.join(root, 'ml-latest-small', 'movies.csv')
rating_path = osp.join(root, 'ml-latest-small', 'ratings.csv')
def load_node_csv(path, index_col, encoders=None, **kwargs):
df = pd.read_csv(path, index_col=index_col, **kwargs)
mapping = {index: i for i, index in enumerate(df.index.unique())}
x = None
if encoders is not None:
xs = [encoder(df[col]) for col, encoder in encoders.items()]
x = torch.cat(xs, dim=-1)
return x, mapping
def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,
encoders=None, **kwargs):
df = pd.read_csv(path, **kwargs)
src = [src_mapping[index] for index in df[src_index_col]]
dst = [dst_mapping[index] for index in df[dst_index_col]]
edge_index = torch.tensor([src, dst])
edge_attr = None
if encoders is not None:
edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
edge_attr = torch.cat(edge_attrs, dim=-1)
return edge_index, edge_attr
class SequenceEncoder:
# The 'SequenceEncoder' encodes raw column strings into embeddings.
def __init__(self, model_name='all-MiniLM-L6-v2', device=None):
self.device = device
self.model = SentenceTransformer(model_name, device=device)
@torch.no_grad()
def __call__(self, df):
x = self.model.encode(df.values, show_progress_bar=True,
convert_to_tensor=True, device=self.device)
return x.cpu()
class GenresEncoder:
# The 'GenreEncoder' splits the raw column strings by 'sep' and converts
# individual elements to categorical labels.
def __init__(self, sep='|'):
self.sep = sep
def __call__(self, df):
genres = {g for col in df.values for g in col.split(self.sep)}
mapping = {genre: i for i, genre in enumerate(genres)}
x = torch.zeros(len(df), len(mapping))
for i, col in enumerate(df.values):
for genre in col.split(self.sep):
x[i, mapping[genre]] = 1
return x
class IdentityEncoder:
# The 'IdentityEncoder' takes the raw column values and converts them to
# PyTorch tensors.
def __init__(self, dtype=None):
self.dtype = dtype
def __call__(self, df):
return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)
user_x, user_mapping = load_node_csv(rating_path, index_col='userId')
movie_x, movie_mapping = load_node_csv(
movie_path, index_col='movieId', encoders={
'title': SequenceEncoder(),
'genres': GenresEncoder()
})
edge_index, edge_label = load_edge_csv(
rating_path,
src_index_col='userId',
src_mapping=user_mapping,
dst_index_col='movieId',
dst_mapping=movie_mapping,
encoders={'rating': IdentityEncoder(dtype=torch.long)},
)
data = HeteroData()
data['user'].num_nodes = len(user_mapping) # Users do not have any features.
data['movie'].x = movie_x
data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_label
print(data)
# We can now convert `data` into an appropriate format for training a
# graph-based machine learning model:
# 1. Add a reverse ('movie', 'rev_rates', 'user') relation for message passing.
data = ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label.
# 2. Perform a link-level split into training, validation, and test edges.
transform = RandomLinkSplit(
num_val=0.05,
num_test=0.1,
neg_sampling_ratio=0.0,
edge_types=[('user', 'rates', 'movie')],
rev_edge_types=[('movie', 'rev_rates', 'user')],
)
train_data, val_data, test_data = transform(data)
print(train_data)
print(val_data)
print(test_data)
|