| | import numpy as np |
| | import torch |
| | import dgl |
| | import dgl.function as fn |
| | import dgl.nn as dglnn |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class RGCN(nn.Module): |
| | def __init__(self, in_feats, hid_feats, out_feats, rel_names): |
| | super().__init__() |
| | |
| | self.conv1 = dglnn.HeteroGraphConv({ |
| | rel: dglnn.GraphConv(in_feats[rel], hid_feats) |
| | for rel in rel_names}, aggregate='sum') |
| | self.conv2 = dglnn.HeteroGraphConv({ |
| | rel: dglnn.GraphConv(hid_feats, out_feats) |
| | for rel in rel_names}, aggregate='sum') |
| |
|
| | def forward(self, graph, inputs): |
| | |
| | h = self.conv1(graph, inputs) |
| | h = {k: F.relu(v) for k, v in h.items()} |
| | h = self.conv2(graph, h) |
| | return h |
| |
|
| | class HeteroDotProductPredictor(nn.Module): |
| | def forward(self, graph, h, etype): |
| | |
| | with graph.local_scope(): |
| | graph.ndata['h'] = h |
| | graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype) |
| | return graph.edges[etype].data['score'] |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, in_features, hidden_features, out_features, rel_names): |
| | super().__init__() |
| | self.sage = RGCN(in_features, hidden_features, out_features, rel_names) |
| | self.pred = HeteroDotProductPredictor() |
| | def forward(self, g, neg_g, x, etype): |
| | h = self.sage(g, x) |
| | return self.pred(g, h, etype), self.pred(neg_g, h, etype) |
| |
|
| |
|
| | def construct_negative_graph(graph, k, etype): |
| | utype, _, vtype = etype |
| | src, dst = graph.edges(etype=etype) |
| | neg_src = src.repeat_interleave(k) |
| | neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,)) |
| | return dgl.heterograph( |
| | {etype: (neg_src, neg_dst)}, |
| | num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}) |
| |
|
| |
|
| | def compute_loss(pos_score, neg_score): |
| | |
| | n_edges = pos_score.shape[0] |
| | return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean() |
| |
|
| |
|
| |
|
| | n_users = 1000 |
| | n_items = 500 |
| | n_follows = 3000 |
| | n_clicks = 5000 |
| | n_dislikes = 500 |
| | n_hetero_features_user = 10 |
| | n_hetero_features_item = 5 |
| | n_user_classes = 5 |
| | n_max_clicks = 10 |
| |
|
| | follow_src = np.random.randint(0, n_users, n_follows) |
| | follow_dst = np.random.randint(0, n_users, n_follows) |
| | click_src = np.random.randint(0, n_users, n_clicks) |
| | click_dst = np.random.randint(0, n_items, n_clicks) |
| | dislike_src = np.random.randint(0, n_users, n_dislikes) |
| | dislike_dst = np.random.randint(0, n_items, n_dislikes) |
| |
|
| | hetero_graph = dgl.heterograph({ |
| | ('user', 'follow', 'user'): (follow_src, follow_dst), |
| | ('user', 'followed-by', 'user'): (follow_dst, follow_src), |
| | ('user', 'click', 'item'): (click_src, click_dst), |
| | ('item', 'clicked-by', 'user'): (click_dst, click_src), |
| | ('user', 'dislike', 'item'): (dislike_src, dislike_dst), |
| | ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)}) |
| |
|
| | hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features_user) |
| | hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features_item) |
| | hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,)) |
| | hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float() |
| | |
| | hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6) |
| | hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6) |
| |
|
| | |
| | hetero_features_dims = { |
| | 'follow': n_hetero_features_user, |
| | 'followed-by': n_hetero_features_user, |
| | 'click': n_hetero_features_user, |
| | 'clicked-by': n_hetero_features_item, |
| | 'dislike': n_hetero_features_user, |
| | 'disliked-by': n_hetero_features_item |
| | } |
| |
|
| | k = 5 |
| | model = Model(hetero_features_dims, 20, 5, hetero_graph.etypes) |
| | user_feats = hetero_graph.nodes['user'].data['feature'] |
| | item_feats = hetero_graph.nodes['item'].data['feature'] |
| | node_features = {'user': user_feats, 'item': item_feats} |
| | opt = torch.optim.Adam(model.parameters()) |
| | for epoch in range(10): |
| | negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item')) |
| | pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item')) |
| | loss = compute_loss(pos_score, neg_score) |
| | opt.zero_grad() |
| | loss.backward() |
| | opt.step() |
| | print(loss.item()) |