File: hm.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (166 lines) | stat: -rw-r--r-- 6,763 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import Callable, List, Optional

import torch

from torch_geometric.data import HeteroData, InMemoryDataset


class HM(InMemoryDataset):
    r"""The heterogeneous H&M dataset from the `Kaggle H&M Personalized Fashion
    Recommendations
    <https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations>`_
    challenge.
    The task is to develop product recommendations based on data from previous
    transactions, as well as from customer and product meta data.

    Args:
        root (str): Root directory where the dataset should be saved.
        use_all_tables_as_node_types (bool, optional): If set to :obj:`True`,
            will use the transaction table as a distinct node type.
            (default: :obj:`False`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.HeteroData` object and returns a
            transformed version. The data object will be transformed before
            every access. (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.HeteroData` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)
    """
    url = ('https://www.kaggle.com/competitions/'
           'h-and-m-personalized-fashion-recommendations/data')

    def __init__(
        self,
        root: str,
        use_all_tables_as_node_types: bool = False,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        self.use_all_tables_as_node_types = use_all_tables_as_node_types
        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0], data_cls=HeteroData)

    @property
    def raw_file_names(self) -> List[str]:
        return [
            'customers.csv.zip', 'articles.csv.zip',
            'transactions_train.csv.zip'
        ]

    @property
    def processed_file_names(self) -> str:
        if self.use_all_tables_as_node_types:
            return 'data.pt'
        else:
            return 'data_merged.pt'

    def download(self) -> None:
        raise RuntimeError(
            f"Dataset not found. Please download {self.raw_file_names} from "
            f"'{self.url}' and move it to '{self.raw_dir}'")

    def process(self) -> None:
        import pandas as pd

        data = HeteroData()

        # Process customer data ###############################################
        df = pd.read_csv(self.raw_paths[0], index_col='customer_id')
        customer_map = {idx: i for i, idx in enumerate(df.index)}

        xs = []
        for name in [
                'Active', 'FN', 'club_member_status', 'fashion_news_frequency'
        ]:
            x = pd.get_dummies(df[name]).values
            xs.append(torch.from_numpy(x).to(torch.float))

        x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)
        x = x.nan_to_num(nan=x.nanmean())
        xs.append(x / x.max())

        data['customer'].x = torch.cat(xs, dim=-1)

        # Process article data ################################################
        df = pd.read_csv(self.raw_paths[1], index_col='article_id')
        article_map = {idx: i for i, idx in enumerate(df.index)}

        xs = []
        for name in [  # We drop a few columns here that are high cardinality.
                # 'product_code',  # Drop.
                # 'prod_name',  # Drop.
                'product_type_no',
                'product_type_name',
                'product_group_name',
                'graphical_appearance_no',
                'graphical_appearance_name',
                'colour_group_code',
                'colour_group_name',
                'perceived_colour_value_id',
                'perceived_colour_value_name',
                'perceived_colour_master_id',
                'perceived_colour_master_name',
                # 'department_no',  # Drop.
                # 'department_name',  # Drop.
                'index_code',
                'index_name',
                'index_group_no',
                'index_group_name',
                'section_no',
                'section_name',
                'garment_group_no',
                'garment_group_name',
                # 'detail_desc',  # Drop.
        ]:
            x = pd.get_dummies(df[name]).values
            xs.append(torch.from_numpy(x).to(torch.float))

        data['article'].x = torch.cat(xs, dim=-1)

        # Process transaction data ############################################
        df = pd.read_csv(self.raw_paths[2], parse_dates=['t_dat'])

        x1 = pd.get_dummies(df['sales_channel_id']).values
        x1 = torch.from_numpy(x1).to(torch.float)
        x2 = torch.from_numpy(df['price'].values).to(torch.float).view(-1, 1)
        x = torch.cat([x1, x2], dim=-1)

        time = torch.from_numpy(df['t_dat'].values.astype(int))
        time = time // (60 * 60 * 24 * 10**9)  # Convert nanoseconds to days.

        src = torch.tensor([customer_map[idx] for idx in df['customer_id']])
        dst = torch.tensor([article_map[idx] for idx in df['article_id']])

        if self.use_all_tables_as_node_types:
            data['transaction'].x = x
            data['transaction'].time = time

            edge_index = torch.stack([src, torch.arange(len(df))], dim=0)
            data['customer', 'to', 'transaction'].edge_index = edge_index
            edge_index = edge_index.flip([0])
            data['transaction', 'rev_to', 'customer'].edge_index = edge_index

            edge_index = torch.stack([dst, torch.arange(len(df))], dim=0)
            data['article', 'to', 'transaction'].edge_index = edge_index
            edge_index = edge_index.flip([0])
            data['transaction', 'rev_to', 'article'].edge_index = edge_index
        else:
            edge_index = torch.stack([src, dst], dim=0)
            data['customer', 'to', 'article'].edge_index = edge_index
            data['customer', 'to', 'article'].time = time
            data['customer', 'to', 'article'].edge_attr = x

            edge_index = edge_index.flip([0])
            data['article', 'rev_to', 'customer'].edge_index = edge_index
            data['article', 'rev_to', 'customer'].time = time
            data['article', 'rev_to', 'customer'].edge_attr = x

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        self.save([data], self.processed_paths[0])