1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
from typing import Callable, List, Optional
import torch
from torch_geometric.data import HeteroData, InMemoryDataset
class HM(InMemoryDataset):
r"""The heterogeneous H&M dataset from the `Kaggle H&M Personalized Fashion
Recommendations
<https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations>`_
challenge.
The task is to develop product recommendations based on data from previous
transactions, as well as from customer and product meta data.
Args:
root (str): Root directory where the dataset should be saved.
use_all_tables_as_node_types (bool, optional): If set to :obj:`True`,
will use the transaction table as a distinct node type.
(default: :obj:`False`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.HeteroData` object and returns a
transformed version. The data object will be transformed before
every access. (default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.HeteroData` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
url = ('https://www.kaggle.com/competitions/'
'h-and-m-personalized-fashion-recommendations/data')
def __init__(
self,
root: str,
use_all_tables_as_node_types: bool = False,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
self.use_all_tables_as_node_types = use_all_tables_as_node_types
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0], data_cls=HeteroData)
@property
def raw_file_names(self) -> List[str]:
return [
'customers.csv.zip', 'articles.csv.zip',
'transactions_train.csv.zip'
]
@property
def processed_file_names(self) -> str:
if self.use_all_tables_as_node_types:
return 'data.pt'
else:
return 'data_merged.pt'
def download(self) -> None:
raise RuntimeError(
f"Dataset not found. Please download {self.raw_file_names} from "
f"'{self.url}' and move it to '{self.raw_dir}'")
def process(self) -> None:
import pandas as pd
data = HeteroData()
# Process customer data ###############################################
df = pd.read_csv(self.raw_paths[0], index_col='customer_id')
customer_map = {idx: i for i, idx in enumerate(df.index)}
xs = []
for name in [
'Active', 'FN', 'club_member_status', 'fashion_news_frequency'
]:
x = pd.get_dummies(df[name]).values
xs.append(torch.from_numpy(x).to(torch.float))
x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
x = x.nan_to_num(nan=x.nanmean())
xs.append(x / x.max())
data['customer'].x = torch.cat(xs, dim=-1)
# Process article data ################################################
df = pd.read_csv(self.raw_paths[1], index_col='article_id')
article_map = {idx: i for i, idx in enumerate(df.index)}
xs = []
for name in [ # We drop a few columns here that are high cardinality.
# 'product_code', # Drop.
# 'prod_name', # Drop.
'product_type_no',
'product_type_name',
'product_group_name',
'graphical_appearance_no',
'graphical_appearance_name',
'colour_group_code',
'colour_group_name',
'perceived_colour_value_id',
'perceived_colour_value_name',
'perceived_colour_master_id',
'perceived_colour_master_name',
# 'department_no', # Drop.
# 'department_name', # Drop.
'index_code',
'index_name',
'index_group_no',
'index_group_name',
'section_no',
'section_name',
'garment_group_no',
'garment_group_name',
# 'detail_desc', # Drop.
]:
x = pd.get_dummies(df[name]).values
xs.append(torch.from_numpy(x).to(torch.float))
data['article'].x = torch.cat(xs, dim=-1)
# Process transaction data ############################################
df = pd.read_csv(self.raw_paths[2], parse_dates=['t_dat'])
x1 = pd.get_dummies(df['sales_channel_id']).values
x1 = torch.from_numpy(x1).to(torch.float)
x2 = torch.from_numpy(df['price'].values).to(torch.float).view(-1, 1)
x = torch.cat([x1, x2], dim=-1)
time = torch.from_numpy(df['t_dat'].values.astype(int))
time = time // (60 * 60 * 24 * 10**9) # Convert nanoseconds to days.
src = torch.tensor([customer_map[idx] for idx in df['customer_id']])
dst = torch.tensor([article_map[idx] for idx in df['article_id']])
if self.use_all_tables_as_node_types:
data['transaction'].x = x
data['transaction'].time = time
edge_index = torch.stack([src, torch.arange(len(df))], dim=0)
data['customer', 'to', 'transaction'].edge_index = edge_index
edge_index = edge_index.flip([0])
data['transaction', 'rev_to', 'customer'].edge_index = edge_index
edge_index = torch.stack([dst, torch.arange(len(df))], dim=0)
data['article', 'to', 'transaction'].edge_index = edge_index
edge_index = edge_index.flip([0])
data['transaction', 'rev_to', 'article'].edge_index = edge_index
else:
edge_index = torch.stack([src, dst], dim=0)
data['customer', 'to', 'article'].edge_index = edge_index
data['customer', 'to', 'article'].time = time
data['customer', 'to', 'article'].edge_attr = x
edge_index = edge_index.flip([0])
data['article', 'rev_to', 'customer'].edge_index = edge_index
data['article', 'rev_to', 'customer'].time = time
data['article', 'rev_to', 'customer'].edge_attr = x
if self.pre_transform is not None:
data = self.pre_transform(data)
self.save([data], self.processed_paths[0])
|