-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
53 changed files
with
4,551 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import recbole | ||
from recbole.config.configurator import Config as RecBole_Config | ||
from recbole.utils import ModelType as RecBoleModelType | ||
|
||
from recbole_gnn.utils import get_model, ModelType | ||
|
||
|
||
class Config(RecBole_Config): | ||
def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None): | ||
""" | ||
Args: | ||
model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config | ||
will search the parameter 'model' from the external input as the model name or model class. | ||
dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset' | ||
from the external input as the dataset name. | ||
config_file_list (list of str): the external config file, it allows multiple config files, default is None. | ||
config_dict (dict): the external parameter dictionaries, default is None. | ||
""" | ||
if recbole.__version__ == "1.1.1": | ||
self.compatibility_settings() | ||
super(Config, self).__init__(model, dataset, config_file_list, config_dict) | ||
|
||
def compatibility_settings(self): | ||
import numpy as np | ||
np.bool = np.bool_ | ||
np.int = np.int_ | ||
np.float = np.float_ | ||
np.complex = np.complex_ | ||
np.object = np.object_ | ||
np.str = np.str_ | ||
np.long = np.int_ | ||
np.unicode = np.unicode_ | ||
|
||
def _get_model_and_dataset(self, model, dataset): | ||
|
||
if model is None: | ||
try: | ||
model = self.external_config_dict['model'] | ||
except KeyError: | ||
raise KeyError( | ||
'model need to be specified in at least one of the these ways: ' | ||
'[model variable, config file, config dict, command line] ' | ||
) | ||
if not isinstance(model, str): | ||
final_model_class = model | ||
final_model = model.__name__ | ||
else: | ||
final_model = model | ||
final_model_class = get_model(final_model) | ||
|
||
if dataset is None: | ||
try: | ||
final_dataset = self.external_config_dict['dataset'] | ||
except KeyError: | ||
raise KeyError( | ||
'dataset need to be specified in at least one of the these ways: ' | ||
'[dataset variable, config file, config dict, command line] ' | ||
) | ||
else: | ||
final_dataset = dataset | ||
|
||
return final_model, final_model_class, final_dataset | ||
|
||
def _load_internal_config_dict(self, model, model_class, dataset): | ||
super()._load_internal_config_dict(model, model_class, dataset) | ||
current_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_init_file = os.path.join(current_path, './properties/model/' + model + '.yaml') | ||
quick_start_config_path = os.path.join(current_path, './properties/quick_start_config/') | ||
sequential_base_init = os.path.join(quick_start_config_path, 'sequential_base.yaml') | ||
social_base_init = os.path.join(quick_start_config_path, 'social_base.yaml') | ||
|
||
if os.path.isfile(model_init_file): | ||
config_dict = self._update_internal_config_dict(model_init_file) | ||
|
||
self.internal_config_dict['MODEL_TYPE'] = model_class.type | ||
if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL: | ||
self._update_internal_config_dict(sequential_base_init) | ||
if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL: | ||
self._update_internal_config_dict(social_base_init) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from recbole.model.abstract_recommender import GeneralRecommender | ||
from recbole.utils import ModelType as RecBoleModelType | ||
|
||
from recbole_gnn.utils import ModelType | ||
|
||
|
||
class GeneralGraphRecommender(GeneralRecommender): | ||
"""This is an abstract general graph recommender. All the general graph models should implement in this class. | ||
The base general graph recommender class provide the basic U-I graph dataset and parameters information. | ||
""" | ||
type = RecBoleModelType.GENERAL | ||
|
||
def __init__(self, config, dataset): | ||
super(GeneralGraphRecommender, self).__init__(config, dataset) | ||
self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"]) | ||
self.use_sparse = config["enable_sparse"] and dataset.is_sparse | ||
if self.use_sparse: | ||
self.edge_index, self.edge_weight = self.edge_index.to(self.device), None | ||
else: | ||
self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device) | ||
|
||
|
||
class SocialRecommender(GeneralRecommender): | ||
"""This is an abstract social recommender. All the social graph model should implement this class. | ||
The base social recommender class provide the basic social graph dataset and parameters information. | ||
""" | ||
type = ModelType.SOCIAL | ||
|
||
def __init__(self, config, dataset): | ||
super(SocialRecommender, self).__init__(config, dataset) |
10 changes: 10 additions & 0 deletions
10
Recbole_GNN/recbole_gnn/model/general_recommender/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from recbole_gnn.model.general_recommender.lightgcn import LightGCN | ||
from recbole_gnn.model.general_recommender.hmlet import HMLET | ||
from recbole_gnn.model.general_recommender.ncl import NCL | ||
from recbole_gnn.model.general_recommender.ngcf import NGCF | ||
from recbole_gnn.model.general_recommender.sgl import SGL | ||
from recbole_gnn.model.general_recommender.lightgcl import LightGCL | ||
from recbole_gnn.model.general_recommender.simgcl import SimGCL | ||
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL | ||
from recbole_gnn.model.general_recommender.directau import DirectAU | ||
from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC |
120 changes: 120 additions & 0 deletions
120
Recbole_GNN/recbole_gnn/model/general_recommender/directau.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# r""" | ||
# DiretAU | ||
# ################################################ | ||
# Reference: | ||
# Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022. | ||
|
||
# Reference code: | ||
# https://github.com/THUwangcy/DirectAU | ||
# """ | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from recbole.model.init import xavier_normal_initialization | ||
from recbole.utils import InputType | ||
from recbole.model.general_recommender import BPR | ||
from recbole_gnn.model.general_recommender import LightGCN | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
|
||
|
||
class DirectAU(GeneralGraphRecommender): | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(DirectAU, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.embedding_size = config['embedding_size'] | ||
self.gamma = config['gamma'] | ||
self.encoder_name = config['encoder'] | ||
|
||
# define encoder | ||
if self.encoder_name == 'MF': | ||
self.encoder = MFEncoder(config, dataset) | ||
elif self.encoder_name == 'LightGCN': | ||
self.encoder = LGCNEncoder(config, dataset) | ||
else: | ||
raise ValueError('Non-implemented Encoder.') | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_normal_initialization) | ||
|
||
def forward(self, user, item): | ||
user_e, item_e = self.encoder(user, item) | ||
return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1) | ||
|
||
@staticmethod | ||
def alignment(x, y, alpha=2): | ||
return (x - y).norm(p=2, dim=1).pow(alpha).mean() | ||
|
||
@staticmethod | ||
def uniformity(x, t=2): | ||
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() | ||
|
||
def calculate_loss(self, interaction): | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_e, item_e = self.forward(user, item) | ||
align = self.alignment(user_e, item_e) | ||
uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2 | ||
|
||
return align, uniform | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
user_e = self.user_embedding(user) | ||
item_e = self.item_embedding(item) | ||
return torch.mul(user_e, item_e).sum(dim=1) | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.encoder_name == 'LightGCN': | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings() | ||
user_e = self.restore_user_e[user] | ||
all_item_e = self.restore_item_e | ||
else: | ||
user_e = self.encoder.user_embedding(user) | ||
all_item_e = self.encoder.item_embedding.weight | ||
score = torch.matmul(user_e, all_item_e.transpose(0, 1)) | ||
return score.view(-1) | ||
|
||
|
||
class MFEncoder(BPR): | ||
def __init__(self, config, dataset): | ||
super(MFEncoder, self).__init__(config, dataset) | ||
|
||
def forward(self, user_id, item_id): | ||
return super().forward(user_id, item_id) | ||
|
||
def get_all_embeddings(self): | ||
user_embeddings = self.user_embedding.weight | ||
item_embeddings = self.item_embedding.weight | ||
return user_embeddings, item_embeddings | ||
|
||
|
||
class LGCNEncoder(LightGCN): | ||
def __init__(self, config, dataset): | ||
super(LGCNEncoder, self).__init__(config, dataset) | ||
|
||
def forward(self, user_id, item_id): | ||
user_all_embeddings, item_all_embeddings = self.get_all_embeddings() | ||
u_embed = user_all_embeddings[user_id] | ||
i_embed = item_all_embeddings[item_id] | ||
return u_embed, i_embed | ||
|
||
def get_all_embeddings(self): | ||
return super().forward() |
229 changes: 229 additions & 0 deletions
229
Recbole_GNN/recbole_gnn/model/general_recommender/hmlet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
# @Time : 2022/3/21 | ||
# @Author : Yupeng Hou | ||
# @Email : houyupeng@ruc.edu.cn | ||
|
||
r""" | ||
HMLET | ||
################################################ | ||
Reference: | ||
Taeyong Kong et al. "Linear, or Non-Linear, That is the Question!." in WSDM 2022. | ||
Reference code: | ||
https://github.com/qbxlvnf11/HMLET | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole.model.loss import BPRLoss, EmbLoss | ||
from recbole.model.layers import activation_layer | ||
from recbole.utils import InputType | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
from recbole_gnn.model.layers import LightGCNConv | ||
|
||
|
||
class Gating_Net(nn.Module): | ||
def __init__(self, embedding_dim, mlp_dims, dropout_p): | ||
super(Gating_Net, self).__init__() | ||
self.embedding_dim = embedding_dim | ||
|
||
fc_layers = [] | ||
for i in range(len(mlp_dims)): | ||
if i == 0: | ||
fc = nn.Linear(embedding_dim*2, mlp_dims[i]) | ||
fc_layers.append(fc) | ||
else: | ||
fc = nn.Linear(mlp_dims[i-1], mlp_dims[i]) | ||
fc_layers.append(fc) | ||
if i != len(mlp_dims) - 1: | ||
fc_layers.append(nn.BatchNorm1d(mlp_dims[i])) | ||
fc_layers.append(nn.Dropout(p=dropout_p)) | ||
fc_layers.append(nn.ReLU(inplace=True)) | ||
self.mlp = nn.Sequential(*fc_layers) | ||
|
||
def gumbel_softmax(self, logits, temperature, hard): | ||
"""Sample from the Gumbel-Softmax distribution and optionally discretize. | ||
Args: | ||
logits: [batch_size, n_class] unnormalized log-probs | ||
temperature: non-negative scalar | ||
hard: if True, take argmax, but differentiate w.r.t. soft sample y | ||
Returns: | ||
[batch_size, n_class] sample from the Gumbel-Softmax distribution. | ||
If hard=True, then the returned sample will be one-hot, otherwise it will | ||
be a probabilitiy distribution that sums to 1 across classes | ||
""" | ||
y = self.gumbel_softmax_sample(logits, temperature) ## (0.6, 0.2, 0.1,..., 0.11) | ||
if hard: | ||
k = logits.size(1) # k is numb of classes | ||
# y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) ## (1, 0, 0, ..., 0) | ||
y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y) | ||
y = (y_hard - y).detach() + y | ||
return y | ||
|
||
def gumbel_softmax_sample(self, logits, temperature): | ||
""" Draw a sample from the Gumbel-Softmax distribution""" | ||
noise = self.sample_gumbel(logits) | ||
y = (logits + noise) / temperature | ||
return F.softmax(y, dim=1) | ||
|
||
def sample_gumbel(self, logits): | ||
"""Sample from Gumbel(0, 1)""" | ||
noise = torch.rand(logits.size()) | ||
eps = 1e-20 | ||
noise.add_(eps).log_().neg_() | ||
noise.add_(eps).log_().neg_() | ||
return torch.Tensor(noise.float()).to(logits.device) | ||
|
||
def forward(self, feature, temperature, hard): | ||
x = self.mlp(feature) | ||
out = self.gumbel_softmax(x, temperature, hard) | ||
out_value = out.unsqueeze(2) | ||
gating_out = out_value.repeat(1, 1, self.embedding_dim) | ||
return gating_out | ||
|
||
|
||
class HMLET(GeneralGraphRecommender): | ||
r"""HMLET combines both linear and non-linear propagation layers for general recommendation and yields better performance. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(HMLET, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN | ||
self.n_layers = config['n_layers'] # int type:the layer num of lightGCN | ||
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization | ||
self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization | ||
self.gate_layer_ids = config['gate_layer_ids'] # list type: layer ids for non-linear gating | ||
self.gating_mlp_dims = config['gating_mlp_dims'] # list type: list of mlp dimensions in gating module | ||
self.dropout_ratio = config['dropout_ratio'] # dropout ratio for mlp in gating module | ||
self.gum_temp = config['ori_temp'] | ||
self.logger.info(f'Model initialization, gumbel softmax temperature: {self.gum_temp}') | ||
|
||
# define layers and loss | ||
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) | ||
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) | ||
self.gcn_conv = LightGCNConv(dim=self.latent_dim) | ||
self.activation = nn.ELU() if config['activation_function'] == 'elu' else activation_layer(config['activation_function']) | ||
self.gating_nets = nn.ModuleList([ | ||
Gating_Net(self.latent_dim, self.gating_mlp_dims, self.dropout_ratio) for _ in range(len(self.gate_layer_ids)) | ||
]) | ||
|
||
self.mf_loss = BPRLoss() | ||
self.reg_loss = EmbLoss() | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e', 'gum_temp'] | ||
|
||
for gating in self.gating_nets: | ||
self._gating_freeze(gating, False) | ||
|
||
def _gating_freeze(self, model, freeze_flag): | ||
for name, child in model.named_children(): | ||
for param in child.parameters(): | ||
param.requires_grad = freeze_flag | ||
|
||
def __choosing_one(self, features, gumbel_out): | ||
feature = torch.sum(torch.mul(features, gumbel_out), dim=1) # batch x embedding_dim (or batch x embedding_dim x layer_num) | ||
return feature | ||
|
||
def __where(self, idx, lst): | ||
for i in range(len(lst)): | ||
if lst[i] == idx: | ||
return i | ||
raise ValueError(f'{idx} not in {lst}.') | ||
|
||
def get_ego_embeddings(self): | ||
r"""Get the embedding of users and items and combine to an embedding matrix. | ||
Returns: | ||
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] | ||
""" | ||
user_embeddings = self.user_embedding.weight | ||
item_embeddings = self.item_embedding.weight | ||
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) | ||
return ego_embeddings | ||
|
||
def forward(self): | ||
all_embeddings = self.get_ego_embeddings() | ||
embeddings_list = [all_embeddings] | ||
non_lin_emb_list = [all_embeddings] | ||
|
||
for layer_idx in range(self.n_layers): | ||
linear_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) | ||
if layer_idx not in self.gate_layer_ids: | ||
all_embeddings = linear_embeddings | ||
else: | ||
non_lin_id = self.__where(layer_idx, self.gate_layer_ids) | ||
last_non_lin_emb = non_lin_emb_list[non_lin_id] | ||
non_lin_embeddings = self.activation(self.gcn_conv(last_non_lin_emb, self.edge_index, self.edge_weight)) | ||
stack_embeddings = torch.stack([linear_embeddings, non_lin_embeddings], dim=1) | ||
concat_embeddings = torch.cat((linear_embeddings, non_lin_embeddings), dim=-1) | ||
gumbel_out = self.gating_nets[non_lin_id](concat_embeddings, self.gum_temp, not self.training) | ||
all_embeddings = self.__choosing_one(stack_embeddings, gumbel_out) | ||
non_lin_emb_list.append(all_embeddings) | ||
embeddings_list.append(all_embeddings) | ||
hmlet_all_embeddings = torch.stack(embeddings_list, dim=1) | ||
hmlet_all_embeddings = torch.mean(hmlet_all_embeddings, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(hmlet_all_embeddings, [self.n_users, self.n_items]) | ||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
neg_item = interaction[self.NEG_ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
u_embeddings = user_all_embeddings[user] | ||
pos_embeddings = item_all_embeddings[pos_item] | ||
neg_embeddings = item_all_embeddings[neg_item] | ||
|
||
# calculate BPR Loss | ||
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) | ||
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) | ||
mf_loss = self.mf_loss(pos_scores, neg_scores) | ||
|
||
# calculate regularization Loss | ||
u_ego_embeddings = self.user_embedding(user) | ||
pos_ego_embeddings = self.item_embedding(pos_item) | ||
neg_ego_embeddings = self.item_embedding(neg_item) | ||
|
||
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) | ||
loss = mf_loss + self.reg_weight * reg_loss | ||
|
||
return loss | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
|
||
u_embeddings = user_all_embeddings[user] | ||
i_embeddings = item_all_embeddings[item] | ||
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
# get user embedding from storage variable | ||
u_embeddings = self.restore_user_e[user] | ||
|
||
# dot with all item embedding to accelerate | ||
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) | ||
|
||
return scores.view(-1) |
226 changes: 226 additions & 0 deletions
226
Recbole_GNN/recbole_gnn/model/general_recommender/lightgcl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2023/04/12 | ||
# @Author : Wanli Yang | ||
# @Email : 2013774@mail.nankai.edu.cn | ||
|
||
r""" | ||
LightGCL | ||
################################################ | ||
Reference: | ||
Xuheng Cai et al. "LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation" in ICLR 2023. | ||
Reference code: | ||
https://github.com/HKUDS/LightGCL | ||
""" | ||
|
||
import numpy as np | ||
import scipy.sparse as sp | ||
import torch | ||
import torch.nn as nn | ||
from recbole.model.abstract_recommender import GeneralRecommender | ||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole.model.loss import EmbLoss | ||
from recbole.utils import InputType | ||
import torch.nn.functional as F | ||
|
||
|
||
class LightGCL(GeneralRecommender): | ||
r"""LightGCL is a GCN-based recommender model. | ||
LightGCL guides graph augmentation by singular value decomposition (SVD) to not only | ||
distill the useful information of user-item interactions but also inject the global | ||
collaborative context into the representation alignment of contrastive learning. | ||
We implement the model following the original author with a pairwise training mode. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(LightGCL, self).__init__(config, dataset) | ||
self._user = dataset.inter_feat[dataset.uid_field] | ||
self._item = dataset.inter_feat[dataset.iid_field] | ||
|
||
# load parameters info | ||
self.embed_dim = config["embedding_size"] | ||
self.n_layers = config["n_layers"] | ||
self.dropout = config["dropout"] | ||
self.temp = config["temp"] | ||
self.lambda_1 = config["lambda1"] | ||
self.lambda_2 = config["lambda2"] | ||
self.q = config["q"] | ||
self.act = nn.LeakyReLU(0.5) | ||
self.reg_loss = EmbLoss() | ||
|
||
# get the normalized adjust matrix | ||
self.adj_norm = self.coo2tensor(self.create_adjust_matrix()) | ||
|
||
# perform svd reconstruction | ||
svd_u, s, svd_v = torch.svd_lowrank(self.adj_norm, q=self.q) | ||
self.u_mul_s = svd_u @ (torch.diag(s)) | ||
self.v_mul_s = svd_v @ (torch.diag(s)) | ||
del s | ||
self.ut = svd_u.T | ||
self.vt = svd_v.T | ||
|
||
self.E_u_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_users, self.embed_dim))) | ||
self.E_i_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_items, self.embed_dim))) | ||
self.E_u_list = [None] * (self.n_layers + 1) | ||
self.E_i_list = [None] * (self.n_layers + 1) | ||
self.E_u_list[0] = self.E_u_0 | ||
self.E_i_list[0] = self.E_i_0 | ||
self.Z_u_list = [None] * (self.n_layers + 1) | ||
self.Z_i_list = [None] * (self.n_layers + 1) | ||
self.G_u_list = [None] * (self.n_layers + 1) | ||
self.G_i_list = [None] * (self.n_layers + 1) | ||
self.G_u_list[0] = self.E_u_0 | ||
self.G_i_list[0] = self.E_i_0 | ||
|
||
self.E_u = None | ||
self.E_i = None | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
def create_adjust_matrix(self): | ||
r"""Get the normalized interaction matrix of users and items. | ||
Returns: | ||
coo_matrix of the normalized interaction matrix. | ||
""" | ||
ratings = np.ones_like(self._user, dtype=np.float32) | ||
matrix = sp.csr_matrix( | ||
(ratings, (self._user, self._item)), | ||
shape=(self.n_users, self.n_items), | ||
).tocoo() | ||
rowD = np.squeeze(np.array(matrix.sum(1)), axis=1) | ||
colD = np.squeeze(np.array(matrix.sum(0)), axis=0) | ||
for i in range(len(matrix.data)): | ||
matrix.data[i] = matrix.data[i] / pow(rowD[matrix.row[i]] * colD[matrix.col[i]], 0.5) | ||
return matrix | ||
|
||
def coo2tensor(self, matrix: sp.coo_matrix): | ||
r"""Convert coo_matrix to tensor. | ||
Args: | ||
matrix (scipy.coo_matrix): Sparse matrix to be converted. | ||
Returns: | ||
torch.sparse.FloatTensor: Transformed sparse matrix. | ||
""" | ||
indices = torch.from_numpy( | ||
np.vstack((matrix.row, matrix.col)).astype(np.int64)) | ||
values = torch.from_numpy(matrix.data) | ||
shape = torch.Size(matrix.shape) | ||
x = torch.sparse.FloatTensor(indices, values, shape).coalesce().to(self.device) | ||
return x | ||
|
||
def sparse_dropout(self, matrix, dropout): | ||
if dropout == 0.0: | ||
return matrix | ||
indices = matrix.indices() | ||
values = F.dropout(matrix.values(), p=dropout) | ||
size = matrix.size() | ||
return torch.sparse.FloatTensor(indices, values, size) | ||
|
||
def forward(self): | ||
for layer in range(1, self.n_layers + 1): | ||
# GNN propagation | ||
self.Z_u_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout), | ||
self.E_i_list[layer - 1]) | ||
self.Z_i_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout).transpose(0, 1), | ||
self.E_u_list[layer - 1]) | ||
# aggregate | ||
self.E_u_list[layer] = self.Z_u_list[layer] | ||
self.E_i_list[layer] = self.Z_i_list[layer] | ||
|
||
# aggregate across layer | ||
self.E_u = sum(self.E_u_list) | ||
self.E_i = sum(self.E_i_list) | ||
|
||
return self.E_u, self.E_i | ||
|
||
def calculate_loss(self, interaction): | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user_list = interaction[self.USER_ID] | ||
pos_item_list = interaction[self.ITEM_ID] | ||
neg_item_list = interaction[self.NEG_ITEM_ID] | ||
E_u_norm, E_i_norm = self.forward() | ||
bpr_loss = self.calc_bpr_loss(E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list) | ||
ssl_loss = self.calc_ssl_loss(E_u_norm, E_i_norm, user_list, pos_item_list) | ||
total_loss = bpr_loss + ssl_loss | ||
return total_loss | ||
|
||
def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list): | ||
r"""Calculate the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss. | ||
Args: | ||
E_u_norm (torch.Tensor): Ego embedding of all users after forwarding. | ||
E_i_norm (torch.Tensor): Ego embedding of all items after forwarding. | ||
user_list (torch.Tensor): List of the user. | ||
pos_item_list (torch.Tensor): List of positive examples. | ||
neg_item_list (torch.Tensor): List of negative examples. | ||
Returns: | ||
torch.Tensor: Loss of BPR tasks and parameter regularization. | ||
""" | ||
u_e = E_u_norm[user_list] | ||
pi_e = E_i_norm[pos_item_list] | ||
ni_e = E_i_norm[neg_item_list] | ||
pos_scores = torch.mul(u_e, pi_e).sum(dim=1) | ||
neg_scores = torch.mul(u_e, ni_e).sum(dim=1) | ||
loss1 = -(pos_scores - neg_scores).sigmoid().log().mean() | ||
|
||
# reg loss | ||
loss_reg = 0 | ||
for param in self.parameters(): | ||
loss_reg += param.norm(2).square() | ||
loss_reg *= self.lambda_2 | ||
return loss1 + loss_reg | ||
|
||
def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list): | ||
r"""Calculate the loss of self-supervised tasks. | ||
Args: | ||
E_u_norm (torch.Tensor): Ego embedding of all users in the original graph after forwarding. | ||
E_i_norm (torch.Tensor): Ego embedding of all items in the original graph after forwarding. | ||
user_list (torch.Tensor): List of the user. | ||
pos_item_list (torch.Tensor): List of positive examples. | ||
Returns: | ||
torch.Tensor: Loss of self-supervised tasks. | ||
""" | ||
# calculate G_u_norm&G_i_norm | ||
for layer in range(1, self.n_layers + 1): | ||
# svd_adj propagation | ||
vt_ei = self.vt @ self.E_i_list[layer - 1] | ||
self.G_u_list[layer] = self.u_mul_s @ vt_ei | ||
ut_eu = self.ut @ self.E_u_list[layer - 1] | ||
self.G_i_list[layer] = self.v_mul_s @ ut_eu | ||
|
||
# aggregate across layer | ||
G_u_norm = sum(self.G_u_list) | ||
G_i_norm = sum(self.G_i_list) | ||
|
||
neg_score = torch.log(torch.exp(G_u_norm[user_list] @ E_u_norm.T / self.temp).sum(1) + 1e-8).mean() | ||
neg_score += torch.log(torch.exp(G_i_norm[pos_item_list] @ E_i_norm.T / self.temp).sum(1) + 1e-8).mean() | ||
pos_score = (torch.clamp((G_u_norm[user_list] * E_u_norm[user_list]).sum(1) / self.temp, -5.0, 5.0)).mean() + ( | ||
torch.clamp((G_i_norm[pos_item_list] * E_i_norm[pos_item_list]).sum(1) / self.temp, -5.0, 5.0)).mean() | ||
ssl_loss = -pos_score + neg_score | ||
return self.lambda_1 * ssl_loss | ||
|
||
def predict(self, interaction): | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
user = self.restore_user_e[interaction[self.USER_ID]] | ||
item = self.restore_item_e[interaction[self.ITEM_ID]] | ||
return torch.sum(user * item, dim=1) | ||
|
||
def full_sort_predict(self, interaction): | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
user = self.restore_user_e[interaction[self.USER_ID]] | ||
return user.matmul(self.restore_item_e.T) |
133 changes: 133 additions & 0 deletions
133
Recbole_GNN/recbole_gnn/model/general_recommender/lightgcn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# @Time : 2022/3/8 | ||
# @Author : Lanling Xu | ||
# @Email : xulanling_sherry@163.com | ||
|
||
r""" | ||
LightGCN | ||
################################################ | ||
Reference: | ||
Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020. | ||
Reference code: | ||
https://github.com/kuandeng/LightGCN | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole.model.loss import BPRLoss, EmbLoss | ||
from recbole.utils import InputType | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
from recbole_gnn.model.layers import LightGCNConv | ||
|
||
|
||
class LightGCN(GeneralGraphRecommender): | ||
r"""LightGCN is a GCN-based recommender model, implemented via PyG. | ||
LightGCN includes only the most essential component in GCN — neighborhood aggregation — for | ||
collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly | ||
propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings | ||
learned at all layers as the final embedding. | ||
We implement the model following the original author with a pairwise training mode. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(LightGCN, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN | ||
self.n_layers = config['n_layers'] # int type:the layer num of lightGCN | ||
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization | ||
self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization | ||
|
||
# define layers and loss | ||
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) | ||
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) | ||
self.gcn_conv = LightGCNConv(dim=self.latent_dim) | ||
self.mf_loss = BPRLoss() | ||
self.reg_loss = EmbLoss() | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
def get_ego_embeddings(self): | ||
r"""Get the embedding of users and items and combine to an embedding matrix. | ||
Returns: | ||
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] | ||
""" | ||
user_embeddings = self.user_embedding.weight | ||
item_embeddings = self.item_embedding.weight | ||
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) | ||
return ego_embeddings | ||
|
||
def forward(self): | ||
all_embeddings = self.get_ego_embeddings() | ||
embeddings_list = [all_embeddings] | ||
|
||
for layer_idx in range(self.n_layers): | ||
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) | ||
embeddings_list.append(all_embeddings) | ||
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) | ||
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) | ||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
neg_item = interaction[self.NEG_ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
u_embeddings = user_all_embeddings[user] | ||
pos_embeddings = item_all_embeddings[pos_item] | ||
neg_embeddings = item_all_embeddings[neg_item] | ||
|
||
# calculate BPR Loss | ||
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) | ||
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) | ||
mf_loss = self.mf_loss(pos_scores, neg_scores) | ||
|
||
# calculate regularization Loss | ||
u_ego_embeddings = self.user_embedding(user) | ||
pos_ego_embeddings = self.item_embedding(pos_item) | ||
neg_ego_embeddings = self.item_embedding(neg_item) | ||
|
||
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) | ||
loss = mf_loss + self.reg_weight * reg_loss | ||
|
||
return loss | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
|
||
u_embeddings = user_all_embeddings[user] | ||
i_embeddings = item_all_embeddings[item] | ||
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
# get user embedding from storage variable | ||
u_embeddings = self.restore_user_e[user] | ||
|
||
# dot with all item embedding to accelerate | ||
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) | ||
|
||
return scores.view(-1) |
222 changes: 222 additions & 0 deletions
222
Recbole_GNN/recbole_gnn/model/general_recommender/ncl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
NCL | ||
################################################ | ||
Reference: | ||
Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022. | ||
""" | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole.model.loss import BPRLoss, EmbLoss | ||
from recbole.utils import InputType | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
from recbole_gnn.model.layers import LightGCNConv | ||
|
||
|
||
class NCL(GeneralGraphRecommender): | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(NCL, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.latent_dim = config['embedding_size'] # int type: the embedding size of the base model | ||
self.n_layers = config['n_layers'] # int type: the layer num of the base model | ||
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization | ||
|
||
self.ssl_temp = config['ssl_temp'] | ||
self.ssl_reg = config['ssl_reg'] | ||
self.hyper_layers = config['hyper_layers'] | ||
|
||
self.alpha = config['alpha'] | ||
|
||
self.proto_reg = config['proto_reg'] | ||
self.k = config['num_clusters'] | ||
|
||
# define layers and loss | ||
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim) | ||
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim) | ||
self.gcn_conv = LightGCNConv(dim=self.latent_dim) | ||
self.mf_loss = BPRLoss() | ||
self.reg_loss = EmbLoss() | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
self.user_centroids = None | ||
self.user_2cluster = None | ||
self.item_centroids = None | ||
self.item_2cluster = None | ||
|
||
def e_step(self): | ||
user_embeddings = self.user_embedding.weight.detach().cpu().numpy() | ||
item_embeddings = self.item_embedding.weight.detach().cpu().numpy() | ||
self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings) | ||
self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings) | ||
|
||
def run_kmeans(self, x): | ||
"""Run K-means algorithm to get k clusters of the input tensor x | ||
""" | ||
import faiss | ||
kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True) | ||
kmeans.train(x) | ||
cluster_cents = kmeans.centroids | ||
|
||
_, I = kmeans.index.search(x, 1) | ||
|
||
# convert to cuda Tensors for broadcast | ||
centroids = torch.Tensor(cluster_cents).to(self.device) | ||
centroids = F.normalize(centroids, p=2, dim=1) | ||
|
||
node2cluster = torch.LongTensor(I).squeeze().to(self.device) | ||
return centroids, node2cluster | ||
|
||
def get_ego_embeddings(self): | ||
r"""Get the embedding of users and items and combine to an embedding matrix. | ||
Returns: | ||
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim] | ||
""" | ||
user_embeddings = self.user_embedding.weight | ||
item_embeddings = self.item_embedding.weight | ||
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) | ||
return ego_embeddings | ||
|
||
def forward(self): | ||
all_embeddings = self.get_ego_embeddings() | ||
embeddings_list = [all_embeddings] | ||
for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)): | ||
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) | ||
embeddings_list.append(all_embeddings) | ||
|
||
lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1) | ||
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) | ||
return user_all_embeddings, item_all_embeddings, embeddings_list | ||
|
||
def ProtoNCE_loss(self, node_embedding, user, item): | ||
user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items]) | ||
|
||
user_embeddings = user_embeddings_all[user] # [B, e] | ||
norm_user_embeddings = F.normalize(user_embeddings) | ||
|
||
user2cluster = self.user_2cluster[user] # [B,] | ||
user2centroids = self.user_centroids[user2cluster] # [B, e] | ||
pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1) | ||
pos_score_user = torch.exp(pos_score_user / self.ssl_temp) | ||
ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1)) | ||
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1) | ||
|
||
proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum() | ||
|
||
item_embeddings = item_embeddings_all[item] | ||
norm_item_embeddings = F.normalize(item_embeddings) | ||
|
||
item2cluster = self.item_2cluster[item] # [B, ] | ||
item2centroids = self.item_centroids[item2cluster] # [B, e] | ||
pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1) | ||
pos_score_item = torch.exp(pos_score_item / self.ssl_temp) | ||
ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1)) | ||
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1) | ||
proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum() | ||
|
||
proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item) | ||
return proto_nce_loss | ||
|
||
def ssl_layer_loss(self, current_embedding, previous_embedding, user, item): | ||
current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items]) | ||
previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items]) | ||
|
||
current_user_embeddings = current_user_embeddings[user] | ||
previous_user_embeddings = previous_user_embeddings_all[user] | ||
norm_user_emb1 = F.normalize(current_user_embeddings) | ||
norm_user_emb2 = F.normalize(previous_user_embeddings) | ||
norm_all_user_emb = F.normalize(previous_user_embeddings_all) | ||
pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1) | ||
ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1)) | ||
pos_score_user = torch.exp(pos_score_user / self.ssl_temp) | ||
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1) | ||
|
||
ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum() | ||
|
||
current_item_embeddings = current_item_embeddings[item] | ||
previous_item_embeddings = previous_item_embeddings_all[item] | ||
norm_item_emb1 = F.normalize(current_item_embeddings) | ||
norm_item_emb2 = F.normalize(previous_item_embeddings) | ||
norm_all_item_emb = F.normalize(previous_item_embeddings_all) | ||
pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1) | ||
ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1)) | ||
pos_score_item = torch.exp(pos_score_item / self.ssl_temp) | ||
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1) | ||
|
||
ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum() | ||
|
||
ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item) | ||
return ssl_loss | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
neg_item = interaction[self.NEG_ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings, embeddings_list = self.forward() | ||
|
||
center_embedding = embeddings_list[0] | ||
context_embedding = embeddings_list[self.hyper_layers * 2] | ||
|
||
ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item) | ||
proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item) | ||
|
||
u_embeddings = user_all_embeddings[user] | ||
pos_embeddings = item_all_embeddings[pos_item] | ||
neg_embeddings = item_all_embeddings[neg_item] | ||
|
||
# calculate BPR Loss | ||
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) | ||
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) | ||
|
||
mf_loss = self.mf_loss(pos_scores, neg_scores) | ||
|
||
u_ego_embeddings = self.user_embedding(user) | ||
pos_ego_embeddings = self.item_embedding(pos_item) | ||
neg_ego_embeddings = self.item_embedding(neg_item) | ||
|
||
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings) | ||
|
||
return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings, embeddings_list = self.forward() | ||
|
||
u_embeddings = user_all_embeddings[user] | ||
i_embeddings = item_all_embeddings[item] | ||
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e, embedding_list = self.forward() | ||
# get user embedding from storage variable | ||
u_embeddings = self.restore_user_e[user] | ||
|
||
# dot with all item embedding to accelerate | ||
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) | ||
|
||
return scores.view(-1) |
149 changes: 149 additions & 0 deletions
149
Recbole_GNN/recbole_gnn/model/general_recommender/ngcf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# @Time : 2022/3/8 | ||
# @Author : Changxin Tian | ||
# @Email : cx.tian@outlook.com | ||
r""" | ||
NGCF | ||
################################################ | ||
Reference: | ||
Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019. | ||
Reference code: | ||
https://github.com/xiangwang1223/neural_graph_collaborative_filtering | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch_geometric.utils import dropout_adj | ||
|
||
from recbole.model.init import xavier_normal_initialization | ||
from recbole.model.loss import BPRLoss, EmbLoss | ||
from recbole.utils import InputType | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
from recbole_gnn.model.layers import BiGNNConv | ||
|
||
|
||
class NGCF(GeneralGraphRecommender): | ||
r"""NGCF is a model that incorporate GNN for recommendation. | ||
We implement the model following the original author with a pairwise training mode. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(NGCF, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.embedding_size = config['embedding_size'] | ||
self.hidden_size_list = config['hidden_size_list'] | ||
self.hidden_size_list = [self.embedding_size] + self.hidden_size_list | ||
self.node_dropout = config['node_dropout'] | ||
self.message_dropout = config['message_dropout'] | ||
self.reg_weight = config['reg_weight'] | ||
|
||
# define layers and loss | ||
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size) | ||
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size) | ||
self.GNNlayers = torch.nn.ModuleList() | ||
for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]): | ||
self.GNNlayers.append(BiGNNConv(input_size, output_size)) | ||
self.mf_loss = BPRLoss() | ||
self.reg_loss = EmbLoss() | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_normal_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
def get_ego_embeddings(self): | ||
r"""Get the embedding of users and items and combine to an embedding matrix. | ||
Returns: | ||
Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim) | ||
""" | ||
user_embeddings = self.user_embedding.weight | ||
item_embeddings = self.item_embedding.weight | ||
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0) | ||
return ego_embeddings | ||
|
||
def forward(self): | ||
if self.node_dropout == 0: | ||
edge_index, edge_weight = self.edge_index, self.edge_weight | ||
else: | ||
edge_index, edge_weight = self.edge_index, self.edge_weight | ||
if self.use_sparse: | ||
row, col, edge_weight = edge_index.t().coo() | ||
edge_index = torch.stack([row, col], 0) | ||
edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, | ||
p=self.node_dropout, training=self.training) | ||
from torch_sparse import SparseTensor | ||
edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight, | ||
sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items)) | ||
edge_index = edge_index.t() | ||
edge_weight = None | ||
else: | ||
edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight, | ||
p=self.node_dropout, training=self.training) | ||
|
||
all_embeddings = self.get_ego_embeddings() | ||
embeddings_list = [all_embeddings] | ||
for gnn in self.GNNlayers: | ||
all_embeddings = gnn(all_embeddings, edge_index, edge_weight) | ||
all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings) | ||
all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings) | ||
all_embeddings = F.normalize(all_embeddings, p=2, dim=1) | ||
embeddings_list += [all_embeddings] # storage output embedding of each layer | ||
ngcf_all_embeddings = torch.cat(embeddings_list, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items]) | ||
|
||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
neg_item = interaction[self.NEG_ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
u_embeddings = user_all_embeddings[user] | ||
pos_embeddings = item_all_embeddings[pos_item] | ||
neg_embeddings = item_all_embeddings[neg_item] | ||
|
||
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) | ||
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) | ||
mf_loss = self.mf_loss(pos_scores, neg_scores) # calculate BPR Loss | ||
|
||
reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) # L2 regularization of embeddings | ||
|
||
return mf_loss + self.reg_weight * reg_loss | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings = self.forward() | ||
|
||
u_embeddings = user_all_embeddings[user] | ||
i_embeddings = item_all_embeddings[item] | ||
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
# get user embedding from storage variable | ||
u_embeddings = self.restore_user_e[user] | ||
|
||
# dot with all item embedding to accelerate | ||
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) | ||
|
||
return scores.view(-1) |
240 changes: 240 additions & 0 deletions
240
Recbole_GNN/recbole_gnn/model/general_recommender/sgl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2022/3/8 | ||
# @Author : Changxin Tian | ||
# @Email : cx.tian@outlook.com | ||
r""" | ||
SGL | ||
################################################ | ||
Reference: | ||
Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021. | ||
Reference code: | ||
https://github.com/wujcan/SGL | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch_geometric.utils import degree | ||
from torch_geometric.nn.conv.gcn_conv import gcn_norm | ||
|
||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole.model.loss import EmbLoss | ||
from recbole.utils import InputType | ||
|
||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
from recbole_gnn.model.layers import LightGCNConv | ||
|
||
|
||
class SGL(GeneralGraphRecommender): | ||
r"""SGL is a GCN-based recommender model. | ||
SGL supplements the classical supervised task of recommendation with an auxiliary | ||
self supervised task, which reinforces node representation learning via self- | ||
discrimination.Specifically,SGL generates multiple views of a node, maximizing the | ||
agreement between different views of the same node compared to that of other nodes. | ||
SGL devises three operators to generate the views — node dropout, edge dropout, and | ||
random walk — that change the graph structure in different manners. | ||
We implement the model following the original author with a pairwise training mode. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(SGL, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.latent_dim = config["embedding_size"] | ||
self.n_layers = int(config["n_layers"]) | ||
self.aug_type = config["type"] | ||
self.drop_ratio = config["drop_ratio"] | ||
self.ssl_tau = config["ssl_tau"] | ||
self.reg_weight = config["reg_weight"] | ||
self.ssl_weight = config["ssl_weight"] | ||
|
||
self._user = dataset.inter_feat[dataset.uid_field] | ||
self._item = dataset.inter_feat[dataset.iid_field] | ||
self.dataset = dataset | ||
|
||
# define layers and loss | ||
self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim) | ||
self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim) | ||
self.gcn_conv = LightGCNConv(dim=self.latent_dim) | ||
self.reg_loss = EmbLoss() | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
def train(self, mode: bool = True): | ||
r"""Override train method of base class. The subgraph is reconstructed each time it is called. | ||
""" | ||
T = super().train(mode=mode) | ||
if mode: | ||
self.graph_construction() | ||
return T | ||
|
||
def graph_construction(self): | ||
r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node. | ||
""" | ||
if self.aug_type == "ND" or self.aug_type == "ED": | ||
self.sub_graph1 = [self.random_graph_augment()] * self.n_layers | ||
self.sub_graph2 = [self.random_graph_augment()] * self.n_layers | ||
elif self.aug_type == "RW": | ||
self.sub_graph1 = [self.random_graph_augment() for _ in range(self.n_layers)] | ||
self.sub_graph2 = [self.random_graph_augment() for _ in range(self.n_layers)] | ||
|
||
def random_graph_augment(self): | ||
def rand_sample(high, size=None, replace=True): | ||
return np.random.choice(np.arange(high), size=size, replace=replace) | ||
|
||
if self.aug_type == "ND": | ||
drop_user = rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False) | ||
drop_item = rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False) | ||
|
||
mask = np.isin(self._user.numpy(), drop_user) | ||
mask |= np.isin(self._item.numpy(), drop_item) | ||
keep = np.where(~mask) | ||
|
||
row = self._user[keep] | ||
col = self._item[keep] + self.n_users | ||
|
||
elif self.aug_type == "ED" or self.aug_type == "RW": | ||
keep = rand_sample(len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False) | ||
row = self._user[keep] | ||
col = self._item[keep] + self.n_users | ||
|
||
edge_index1 = torch.stack([row, col]) | ||
edge_index2 = torch.stack([col, row]) | ||
edge_index = torch.cat([edge_index1, edge_index2], dim=1) | ||
edge_weight = torch.ones(edge_index.size(1)) | ||
num_nodes = self.n_users + self.n_items | ||
|
||
if self.use_sparse: | ||
adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes) | ||
adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False) | ||
return adj_t.to(self.device), None | ||
|
||
edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False) | ||
|
||
return edge_index.to(self.device), edge_weight.to(self.device) | ||
|
||
def forward(self, graph=None): | ||
all_embeddings = torch.cat([self.user_embedding.weight, self.item_embedding.weight]) | ||
embeddings_list = [all_embeddings] | ||
|
||
if graph is None: # for the original graph | ||
for _ in range(self.n_layers): | ||
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight) | ||
embeddings_list.append(all_embeddings) | ||
else: # for the augmented graph | ||
for graph_edge_index, graph_edge_weight in graph: | ||
all_embeddings = self.gcn_conv(all_embeddings, graph_edge_index, graph_edge_weight) | ||
embeddings_list.append(all_embeddings) | ||
|
||
embeddings_list = torch.stack(embeddings_list, dim=1) | ||
embeddings_list = torch.mean(embeddings_list, dim=1, keepdim=False) | ||
user_all_embeddings, item_all_embeddings = torch.split(embeddings_list, [self.n_users, self.n_items], dim=0) | ||
|
||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list): | ||
r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss. | ||
Args: | ||
user_emd (torch.Tensor): Ego embedding of all users after forwarding. | ||
item_emd (torch.Tensor): Ego embedding of all items after forwarding. | ||
user_list (torch.Tensor): List of the user. | ||
pos_item_list (torch.Tensor): List of positive examples. | ||
neg_item_list (torch.Tensor): List of negative examples. | ||
Returns: | ||
torch.Tensor: Loss of BPR tasks and parameter regularization. | ||
""" | ||
u_e = user_emd[user_list] | ||
pi_e = item_emd[pos_item_list] | ||
ni_e = item_emd[neg_item_list] | ||
p_scores = torch.mul(u_e, pi_e).sum(dim=1) | ||
n_scores = torch.mul(u_e, ni_e).sum(dim=1) | ||
|
||
l1 = torch.sum(-F.logsigmoid(p_scores - n_scores)) | ||
|
||
u_e_p = self.user_embedding(user_list) | ||
pi_e_p = self.item_embedding(pos_item_list) | ||
ni_e_p = self.item_embedding(neg_item_list) | ||
|
||
l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p) | ||
|
||
return l1 + l2 * self.reg_weight | ||
|
||
def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2): | ||
r"""Calculate the loss of self-supervised tasks. | ||
Args: | ||
user_list (torch.Tensor): List of the user. | ||
pos_item_list (torch.Tensor): List of positive examples. | ||
user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding. | ||
user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding. | ||
item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding. | ||
item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding. | ||
Returns: | ||
torch.Tensor: Loss of self-supervised tasks. | ||
""" | ||
|
||
u_emd1 = F.normalize(user_sub1[user_list], dim=1) | ||
u_emd2 = F.normalize(user_sub2[user_list], dim=1) | ||
all_user2 = F.normalize(user_sub2, dim=1) | ||
v1 = torch.sum(u_emd1 * u_emd2, dim=1) | ||
v2 = u_emd1.matmul(all_user2.T) | ||
v1 = torch.exp(v1 / self.ssl_tau) | ||
v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1) | ||
ssl_user = -torch.sum(torch.log(v1 / v2)) | ||
|
||
i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1) | ||
i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1) | ||
all_item2 = F.normalize(item_sub2, dim=1) | ||
v3 = torch.sum(i_emd1 * i_emd2, dim=1) | ||
v4 = i_emd1.matmul(all_item2.T) | ||
v3 = torch.exp(v3 / self.ssl_tau) | ||
v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1) | ||
ssl_item = -torch.sum(torch.log(v3 / v4)) | ||
|
||
return (ssl_item + ssl_user) * self.ssl_weight | ||
|
||
def calculate_loss(self, interaction): | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user_list = interaction[self.USER_ID] | ||
pos_item_list = interaction[self.ITEM_ID] | ||
neg_item_list = interaction[self.NEG_ITEM_ID] | ||
|
||
user_emd, item_emd = self.forward() | ||
user_sub1, item_sub1 = self.forward(self.sub_graph1) | ||
user_sub2, item_sub2 = self.forward(self.sub_graph2) | ||
|
||
total_loss = self.calc_bpr_loss(user_emd, item_emd, user_list, pos_item_list, neg_item_list) + \ | ||
self.calc_ssl_loss(user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2) | ||
return total_loss | ||
|
||
def predict(self, interaction): | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
|
||
user = self.restore_user_e[interaction[self.USER_ID]] | ||
item = self.restore_item_e[interaction[self.ITEM_ID]] | ||
return torch.sum(user * item, dim=1) | ||
|
||
def full_sort_predict(self, interaction): | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward() | ||
|
||
user = self.restore_user_e[interaction[self.USER_ID]] | ||
return user.matmul(self.restore_item_e.T) |
60 changes: 60 additions & 0 deletions
60
Recbole_GNN/recbole_gnn/model/general_recommender/simgcl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
SimGCL | ||
################################################ | ||
Reference: | ||
Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022. | ||
""" | ||
|
||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from recbole_gnn.model.general_recommender import LightGCN | ||
|
||
|
||
class SimGCL(LightGCN): | ||
def __init__(self, config, dataset): | ||
super(SimGCL, self).__init__(config, dataset) | ||
|
||
self.cl_rate = config['lambda'] | ||
self.eps = config['eps'] | ||
self.temperature = config['temperature'] | ||
|
||
def forward(self, perturbed=False): | ||
all_embs = self.get_ego_embeddings() | ||
embeddings_list = [] | ||
|
||
for layer_idx in range(self.n_layers): | ||
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight) | ||
if perturbed: | ||
random_noise = torch.rand_like(all_embs, device=all_embs.device) | ||
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps | ||
embeddings_list.append(all_embs) | ||
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) | ||
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) | ||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calculate_cl_loss(self, x1, x2): | ||
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) | ||
pos_score = (x1 * x2).sum(dim=-1) | ||
pos_score = torch.exp(pos_score / self.temperature) | ||
ttl_score = torch.matmul(x1, x2.transpose(0, 1)) | ||
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1) | ||
return -torch.log(pos_score / ttl_score).sum() | ||
|
||
def calculate_loss(self, interaction): | ||
loss = super().calculate_loss(interaction) | ||
|
||
user = torch.unique(interaction[self.USER_ID]) | ||
pos_item = torch.unique(interaction[self.ITEM_ID]) | ||
|
||
perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True) | ||
perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True) | ||
|
||
user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user]) | ||
item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item]) | ||
|
||
return loss + self.cl_rate * (user_cl_loss + item_cl_loss) |
163 changes: 163 additions & 0 deletions
163
Recbole_GNN/recbole_gnn/model/general_recommender/ssl4rec.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
r""" | ||
SSL4REC | ||
################################################ | ||
Reference: | ||
Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021. | ||
Reference code: | ||
https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from recbole.model.loss import EmbLoss | ||
from recbole.utils import InputType | ||
|
||
from recbole.model.init import xavier_uniform_initialization | ||
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender | ||
|
||
|
||
class SSL4REC(GeneralGraphRecommender): | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(SSL4REC, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.tau = config["tau"] | ||
self.reg_weight = config["reg_weight"] | ||
self.cl_rate = config["ssl_weight"] | ||
self.require_pow = config["require_pow"] | ||
|
||
self.reg_loss = EmbLoss() | ||
|
||
self.encoder = DNN_Encoder(config, dataset) | ||
|
||
# storage variables for full sort evaluation acceleration | ||
self.restore_user_e = None | ||
self.restore_item_e = None | ||
|
||
# parameters initialization | ||
self.apply(xavier_uniform_initialization) | ||
self.other_parameter_name = ['restore_user_e', 'restore_item_e'] | ||
|
||
def forward(self, user, item): | ||
user_e, item_e = self.encoder(user, item) | ||
return user_e, item_e | ||
|
||
def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature): | ||
user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1) | ||
pos_score = (user_emb * item_emb).sum(dim=-1) | ||
pos_score = torch.exp(pos_score / temperature) | ||
ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1)) | ||
ttl_score = torch.exp(ttl_score / temperature).sum(dim=1) | ||
loss = -torch.log(pos_score / ttl_score + 10e-6) | ||
return torch.mean(loss) | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
|
||
user_embeddings, item_embeddings = self.forward(user, pos_item) | ||
|
||
rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau) | ||
cl_loss = self.encoder.calculate_cl_loss(pos_item) | ||
reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow) | ||
|
||
loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss | ||
|
||
return loss | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
item = interaction[self.ITEM_ID] | ||
|
||
user_embeddings, item_embeddings = self.forward(user, item) | ||
|
||
u_embeddings = user_embeddings[user] | ||
i_embeddings = item_embeddings[item] | ||
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1) | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID] | ||
if self.restore_user_e is None or self.restore_item_e is None: | ||
self.restore_user_e, self.restore_item_e = self.forward(torch.arange( | ||
self.n_users, device=self.device), torch.arange(self.n_items, device=self.device)) | ||
# get user embedding from storage variable | ||
u_embeddings = self.restore_user_e[user] | ||
|
||
# dot with all item embedding to accelerate | ||
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1)) | ||
|
||
return scores.view(-1) | ||
|
||
|
||
class DNN_Encoder(nn.Module): | ||
def __init__(self, config, dataset): | ||
super(DNN_Encoder, self).__init__() | ||
|
||
self.emb_size = config["embedding_size"] | ||
self.drop_ratio = config["drop_ratio"] | ||
self.tau = config["tau"] | ||
|
||
self.USER_ID = config["USER_ID_FIELD"] | ||
self.ITEM_ID = config["ITEM_ID_FIELD"] | ||
self.n_users = dataset.num(self.USER_ID) | ||
self.n_items = dataset.num(self.ITEM_ID) | ||
|
||
self.user_tower = nn.Sequential( | ||
nn.Linear(self.emb_size, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 128), | ||
nn.Tanh() | ||
) | ||
self.item_tower = nn.Sequential( | ||
nn.Linear(self.emb_size, 1024), | ||
nn.ReLU(True), | ||
nn.Linear(1024, 128), | ||
nn.Tanh() | ||
) | ||
self.dropout = nn.Dropout(self.drop_ratio) | ||
|
||
self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size) | ||
self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
nn.init.xavier_uniform_(self.initial_user_emb.weight) | ||
nn.init.xavier_uniform_(self.initial_item_emb.weight) | ||
|
||
def forward(self, q, x): | ||
q_emb = self.initial_user_emb(q) | ||
i_emb = self.initial_item_emb(x) | ||
|
||
q_emb = self.user_tower(q_emb) | ||
i_emb = self.item_tower(i_emb) | ||
|
||
return q_emb, i_emb | ||
|
||
def item_encoding(self, x): | ||
i_emb = self.initial_item_emb(x) | ||
i1_emb = self.dropout(i_emb) | ||
i2_emb = self.dropout(i_emb) | ||
|
||
i1_emb = self.item_tower(i1_emb) | ||
i2_emb = self.item_tower(i2_emb) | ||
|
||
return i1_emb, i2_emb | ||
|
||
def calculate_cl_loss(self, idx): | ||
x1, x2 = self.item_encoding(idx) | ||
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) | ||
pos_score = (x1 * x2).sum(dim=-1) | ||
pos_score = torch.exp(pos_score / self.tau) | ||
ttl_score = torch.matmul(x1, x2.transpose(0, 1)) | ||
ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1) | ||
return -torch.log(pos_score / ttl_score).mean() |
90 changes: 90 additions & 0 deletions
90
Recbole_GNN/recbole_gnn/model/general_recommender/xsimgcl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
XSimGCL | ||
################################################ | ||
Reference: | ||
Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023. | ||
Reference code: | ||
https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py | ||
""" | ||
|
||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from recbole_gnn.model.general_recommender import LightGCN | ||
|
||
|
||
class XSimGCL(LightGCN): | ||
def __init__(self, config, dataset): | ||
super(XSimGCL, self).__init__(config, dataset) | ||
|
||
self.cl_rate = config['lambda'] | ||
self.eps = config['eps'] | ||
self.temperature = config['temperature'] | ||
self.layer_cl = config['layer_cl'] | ||
|
||
def forward(self, perturbed=False): | ||
all_embs = self.get_ego_embeddings() | ||
all_embs_cl = all_embs | ||
embeddings_list = [] | ||
|
||
for layer_idx in range(self.n_layers): | ||
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight) | ||
if perturbed: | ||
random_noise = torch.rand_like(all_embs, device=all_embs.device) | ||
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps | ||
embeddings_list.append(all_embs) | ||
if layer_idx == self.layer_cl - 1: | ||
all_embs_cl = all_embs | ||
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1) | ||
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1) | ||
|
||
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items]) | ||
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items]) | ||
if perturbed: | ||
return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl | ||
return user_all_embeddings, item_all_embeddings | ||
|
||
def calculate_cl_loss(self, x1, x2): | ||
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1) | ||
pos_score = (x1 * x2).sum(dim=-1) | ||
pos_score = torch.exp(pos_score / self.temperature) | ||
ttl_score = torch.matmul(x1, x2.transpose(0, 1)) | ||
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1) | ||
return -torch.log(pos_score / ttl_score).mean() | ||
|
||
def calculate_loss(self, interaction): | ||
# clear the storage variable when training | ||
if self.restore_user_e is not None or self.restore_item_e is not None: | ||
self.restore_user_e, self.restore_item_e = None, None | ||
|
||
user = interaction[self.USER_ID] | ||
pos_item = interaction[self.ITEM_ID] | ||
neg_item = interaction[self.NEG_ITEM_ID] | ||
|
||
user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True) | ||
u_embeddings = user_all_embeddings[user] | ||
pos_embeddings = item_all_embeddings[pos_item] | ||
neg_embeddings = item_all_embeddings[neg_item] | ||
|
||
# calculate BPR Loss | ||
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1) | ||
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1) | ||
mf_loss = self.mf_loss(pos_scores, neg_scores) | ||
|
||
# calculate regularization Loss | ||
u_ego_embeddings = self.user_embedding(user) | ||
pos_ego_embeddings = self.item_embedding(pos_item) | ||
neg_ego_embeddings = self.item_embedding(neg_item) | ||
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow) | ||
|
||
user = torch.unique(interaction[self.USER_ID]) | ||
pos_item = torch.unique(interaction[self.ITEM_ID]) | ||
|
||
# calculate CL Loss | ||
user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user]) | ||
item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item]) | ||
|
||
return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch_geometric.nn import MessagePassing | ||
from torch_sparse import matmul | ||
|
||
|
||
class LightGCNConv(MessagePassing): | ||
def __init__(self, dim): | ||
super(LightGCNConv, self).__init__(aggr='add') | ||
self.dim = dim | ||
|
||
def forward(self, x, edge_index, edge_weight): | ||
return self.propagate(edge_index, x=x, edge_weight=edge_weight) | ||
|
||
def message(self, x_j, edge_weight): | ||
return edge_weight.view(-1, 1) * x_j | ||
|
||
def message_and_aggregate(self, adj_t, x): | ||
return matmul(adj_t, x, reduce=self.aggr) | ||
|
||
def __repr__(self): | ||
return '{}({})'.format(self.__class__.__name__, self.dim) | ||
|
||
|
||
class BipartiteGCNConv(MessagePassing): | ||
def __init__(self, dim): | ||
super(BipartiteGCNConv, self).__init__(aggr='add') | ||
self.dim = dim | ||
|
||
def forward(self, x, edge_index, edge_weight, size): | ||
return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) | ||
|
||
def message(self, x_j, edge_weight): | ||
return edge_weight.view(-1, 1) * x_j | ||
|
||
def __repr__(self): | ||
return '{}({})'.format(self.__class__.__name__, self.dim) | ||
|
||
|
||
class BiGNNConv(MessagePassing): | ||
r"""Propagate a layer of Bi-interaction GNN | ||
.. math:: | ||
output = (L+I)EW_1 + LE \otimes EW_2 | ||
""" | ||
|
||
def __init__(self, in_channels, out_channels): | ||
super().__init__(aggr='add') | ||
self.in_channels, self.out_channels = in_channels, out_channels | ||
self.lin1 = torch.nn.Linear(in_features=in_channels, out_features=out_channels) | ||
self.lin2 = torch.nn.Linear(in_features=in_channels, out_features=out_channels) | ||
|
||
def forward(self, x, edge_index, edge_weight): | ||
x_prop = self.propagate(edge_index, x=x, edge_weight=edge_weight) | ||
x_trans = self.lin1(x_prop + x) | ||
x_inter = self.lin2(torch.mul(x_prop, x)) | ||
return x_trans + x_inter | ||
|
||
def message(self, x_j, edge_weight): | ||
return edge_weight.view(-1, 1) * x_j | ||
|
||
def message_and_aggregate(self, adj_t, x): | ||
return matmul(adj_t, x, reduce=self.aggr) | ||
|
||
def __repr__(self): | ||
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels) | ||
|
||
|
||
class SRGNNConv(MessagePassing): | ||
def __init__(self, dim): | ||
# mean aggregation to incorporate weight naturally | ||
super(SRGNNConv, self).__init__(aggr='mean') | ||
|
||
self.lin = torch.nn.Linear(dim, dim) | ||
|
||
def forward(self, x, edge_index): | ||
x = self.lin(x) | ||
return self.propagate(edge_index, x=x) | ||
|
||
|
||
class SRGNNCell(nn.Module): | ||
def __init__(self, dim): | ||
super(SRGNNCell, self).__init__() | ||
|
||
self.dim = dim | ||
self.incomming_conv = SRGNNConv(dim) | ||
self.outcomming_conv = SRGNNConv(dim) | ||
|
||
self.lin_ih = nn.Linear(2 * dim, 3 * dim) | ||
self.lin_hh = nn.Linear(dim, 3 * dim) | ||
|
||
self._reset_parameters() | ||
|
||
def forward(self, hidden, edge_index): | ||
input_in = self.incomming_conv(hidden, edge_index) | ||
reversed_edge_index = torch.flip(edge_index, dims=[0]) | ||
input_out = self.outcomming_conv(hidden, reversed_edge_index) | ||
inputs = torch.cat([input_in, input_out], dim=-1) | ||
|
||
gi = self.lin_ih(inputs) | ||
gh = self.lin_hh(hidden) | ||
i_r, i_i, i_n = gi.chunk(3, -1) | ||
h_r, h_i, h_n = gh.chunk(3, -1) | ||
reset_gate = torch.sigmoid(i_r + h_r) | ||
input_gate = torch.sigmoid(i_i + h_i) | ||
new_gate = torch.tanh(i_n + reset_gate * h_n) | ||
hy = (1 - input_gate) * hidden + input_gate * new_gate | ||
return hy | ||
|
||
def _reset_parameters(self): | ||
stdv = 1.0 / np.sqrt(self.dim) | ||
for weight in self.parameters(): | ||
weight.data.uniform_(-stdv, stdv) |
7 changes: 7 additions & 0 deletions
7
Recbole_GNN/recbole_gnn/model/sequential_recommender/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN | ||
from recbole_gnn.model.sequential_recommender.gcsan import GCSAN | ||
from recbole_gnn.model.sequential_recommender.lessr import LESSR | ||
from recbole_gnn.model.sequential_recommender.niser import NISER | ||
from recbole_gnn.model.sequential_recommender.sgnnhn import SGNNHN | ||
from recbole_gnn.model.sequential_recommender.srgnn import SRGNN | ||
from recbole_gnn.model.sequential_recommender.tagnn import TAGNN |
277 changes: 277 additions & 0 deletions
277
Recbole_GNN/recbole_gnn/model/sequential_recommender/gcegnn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
# @Time : 2022/3/22 | ||
# @Author : Yupeng Hou | ||
# @Email : houyupeng@ruc.edu.cn | ||
|
||
r""" | ||
GCE-GNN | ||
################################################ | ||
Reference: | ||
Ziyang Wang et al. "Global Context Enhanced Graph Neural Networks for Session-based Recommendation." in SIGIR 2020. | ||
Reference code: | ||
https://github.com/CCIIPLab/GCE-GNN | ||
""" | ||
|
||
import numpy as np | ||
from tqdm import tqdm | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch_geometric.nn import MessagePassing | ||
from torch_geometric.utils import softmax | ||
from recbole.model.loss import BPRLoss | ||
from recbole.model.abstract_recommender import SequentialRecommender | ||
|
||
|
||
class LocalAggregator(MessagePassing): | ||
def __init__(self, dim, alpha): | ||
super().__init__(aggr='add') | ||
self.edge_emb = nn.Embedding(4, dim) | ||
self.leakyrelu = nn.LeakyReLU(alpha) | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
return self.propagate(edge_index, x=x, edge_attr=edge_attr) | ||
|
||
def message(self, x_j, x_i, edge_attr, index, ptr, size_i): | ||
x = x_j * x_i | ||
a = self.edge_emb(edge_attr) | ||
e = (x * a).sum(dim=-1) | ||
e = self.leakyrelu(e) | ||
e = softmax(e, index, ptr, size_i) | ||
return e.unsqueeze(-1) * x_j | ||
|
||
|
||
class GlobalAggregator(nn.Module): | ||
def __init__(self, dim, dropout, act=torch.relu): | ||
super(GlobalAggregator, self).__init__() | ||
self.dropout = dropout | ||
self.act = act | ||
self.dim = dim | ||
|
||
self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim)) | ||
self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1)) | ||
self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim)) | ||
self.bias = nn.Parameter(torch.Tensor(self.dim)) | ||
|
||
def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None): | ||
if extra_vector is not None: | ||
alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1) | ||
alpha = F.leaky_relu(alpha, negative_slope=0.2) | ||
alpha = torch.matmul(alpha, self.w_2).squeeze(-1) | ||
alpha = torch.softmax(alpha, -1).unsqueeze(-1) | ||
neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2) | ||
else: | ||
neighbor_vector = torch.mean(neighbor_vector, dim=2) | ||
# self_vectors = F.dropout(self_vectors, 0.5, training=self.training) | ||
output = torch.cat([self_vectors, neighbor_vector], -1) | ||
output = F.dropout(output, self.dropout, training=self.training) | ||
output = torch.matmul(output, self.w_3) | ||
output = output.view(batch_size, -1, self.dim) | ||
output = self.act(output) | ||
return output | ||
|
||
|
||
class GCEGNN(SequentialRecommender): | ||
def __init__(self, config, dataset): | ||
super(GCEGNN, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.embedding_size = config['embedding_size'] | ||
self.leakyrelu_alpha = config['leakyrelu_alpha'] | ||
self.dropout_local = config['dropout_local'] | ||
self.dropout_global = config['dropout_global'] | ||
self.dropout_gcn = config['dropout_gcn'] | ||
self.device = config['device'] | ||
self.loss_type = config['loss_type'] | ||
self.build_global_graph = config['build_global_graph'] | ||
self.sample_num = config['sample_num'] | ||
self.hop = config['hop'] | ||
self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ] | ||
|
||
# global graph construction | ||
self.global_graph = None | ||
if self.build_global_graph: | ||
self.global_adj, self.global_weight = self.construct_global_graph(dataset) | ||
|
||
# item embedding | ||
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) | ||
self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size) | ||
|
||
# define layers and loss | ||
# Aggregator | ||
self.local_agg = LocalAggregator(self.embedding_size, self.leakyrelu_alpha) | ||
global_agg_list = [] | ||
for i in range(self.hop): | ||
global_agg_list.append(GlobalAggregator(self.embedding_size, self.dropout_gcn)) | ||
self.global_agg = nn.ModuleList(global_agg_list) | ||
|
||
self.w_1 = nn.Linear(2 * self.embedding_size, self.embedding_size, bias=False) | ||
self.w_2 = nn.Linear(self.embedding_size, 1, bias=False) | ||
self.glu1 = nn.Linear(self.embedding_size, self.embedding_size) | ||
self.glu2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False) | ||
if self.loss_type == 'BPR': | ||
self.loss_fct = BPRLoss() | ||
elif self.loss_type == 'CE': | ||
self.loss_fct = nn.CrossEntropyLoss() | ||
else: | ||
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") | ||
|
||
self.reset_parameters() | ||
self.other_parameter_name = ['global_adj', 'global_weight'] | ||
|
||
def reset_parameters(self): | ||
stdv = 1.0 / np.sqrt(self.embedding_size) | ||
for weight in self.parameters(): | ||
weight.data.uniform_(-stdv, stdv) | ||
|
||
def _add_edge(self, graph, sid, tid): | ||
if tid not in graph[sid]: | ||
graph[sid][tid] = 0 | ||
graph[sid][tid] += 1 | ||
|
||
def construct_global_graph(self, dataset): | ||
self.logger.info('Constructing global graphs.') | ||
item_id_list = dataset.inter_feat['item_id_list'] | ||
src_item_ids = item_id_list[:,:4].tolist() | ||
tgt_itme_id = dataset.inter_feat['item_id'].tolist() | ||
global_graph = [{} for _ in range(self.n_items)] | ||
for i in tqdm(range(len(tgt_itme_id)), desc='Converting: '): | ||
tid = tgt_itme_id[i] | ||
for sid in src_item_ids[i]: | ||
if sid > 0: | ||
self._add_edge(global_graph, tid, sid) | ||
self._add_edge(global_graph, sid, tid) | ||
global_adj = [[] for _ in range(self.n_items)] | ||
global_weight = [[] for _ in range(self.n_items)] | ||
for i in tqdm(range(self.n_items), desc='Sorting: '): | ||
sorted_out_edges = [v for v in sorted(global_graph[i].items(), reverse=True, key=lambda x: x[1])] | ||
global_adj[i] = [v[0] for v in sorted_out_edges[:self.sample_num]] | ||
global_weight[i] = [v[1] for v in sorted_out_edges[:self.sample_num]] | ||
if len(global_adj[i]) < self.sample_num: | ||
for j in range(self.sample_num - len(global_adj[i])): | ||
global_adj[i].append(0) | ||
global_weight[i].append(0) | ||
return torch.LongTensor(global_adj).to(self.device), torch.FloatTensor(global_weight).to(self.device) | ||
|
||
def fusion(self, hidden, mask): | ||
batch_size = hidden.shape[0] | ||
length = hidden.shape[1] | ||
pos_emb = self.pos_embedding.weight[:length] | ||
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) | ||
|
||
hs = torch.sum(hidden * mask, -2) / torch.sum(mask, 1) | ||
hs = hs.unsqueeze(-2).expand(-1, length, -1) | ||
nh = self.w_1(torch.cat([pos_emb, hidden], -1)) | ||
nh = torch.tanh(nh) | ||
nh = torch.sigmoid(self.glu1(nh) + self.glu2(hs)) | ||
beta = self.w_2(nh) | ||
beta = beta * mask | ||
final_h = torch.sum(beta * hidden, 1) | ||
return final_h | ||
|
||
def forward(self, x, edge_index, edge_attr, alias_inputs, item_seq_len): | ||
batch_size = alias_inputs.shape[0] | ||
mask = alias_inputs.gt(0).unsqueeze(-1) | ||
h = self.item_embedding(x) | ||
|
||
# local | ||
h_local = self.local_agg(h, edge_index, edge_attr) | ||
|
||
# global | ||
item_neighbors = [F.pad(x[alias_inputs], (0, self.max_seq_length - x[alias_inputs].shape[1]), "constant", 0)] | ||
weight_neighbors = [] | ||
support_size = self.max_seq_length | ||
|
||
for i in range(self.hop): | ||
item_sample_i, weight_sample_i = self.global_adj[item_neighbors[-1].view(-1)], self.global_weight[item_neighbors[-1].view(-1)] | ||
support_size *= self.sample_num | ||
item_neighbors.append(item_sample_i.view(batch_size, support_size)) | ||
weight_neighbors.append(weight_sample_i.view(batch_size, support_size)) | ||
|
||
entity_vectors = [self.item_embedding(i) for i in item_neighbors] | ||
weight_vectors = weight_neighbors | ||
|
||
session_info = [] | ||
item_emb = h[alias_inputs] * mask | ||
|
||
# mean | ||
sum_item_emb = torch.sum(item_emb, 1) / torch.sum(mask.float(), 1) | ||
|
||
# sum | ||
# sum_item_emb = torch.sum(item_emb, 1) | ||
|
||
sum_item_emb = sum_item_emb.unsqueeze(-2) | ||
for i in range(self.hop): | ||
session_info.append(sum_item_emb.repeat(1, entity_vectors[i].shape[1], 1)) | ||
|
||
for n_hop in range(self.hop): | ||
entity_vectors_next_iter = [] | ||
shape = [batch_size, -1, self.sample_num, self.embedding_size] | ||
for hop in range(self.hop - n_hop): | ||
aggregator = self.global_agg[n_hop] | ||
vector = aggregator(self_vectors=entity_vectors[hop], | ||
neighbor_vector=entity_vectors[hop + 1].view(shape), | ||
masks=None, | ||
batch_size=batch_size, | ||
neighbor_weight=weight_vectors[hop].view(batch_size, -1, self.sample_num), | ||
extra_vector=session_info[hop]) | ||
entity_vectors_next_iter.append(vector) | ||
entity_vectors = entity_vectors_next_iter | ||
|
||
h_global = entity_vectors[0].view(batch_size, self.max_seq_length, self.embedding_size) | ||
h_global = h_global[:,:alias_inputs.shape[1],:] | ||
|
||
h_local = F.dropout(h_local, self.dropout_local, training=self.training) | ||
h_global = F.dropout(h_global, self.dropout_global, training=self.training) | ||
h_local = h_local[alias_inputs] | ||
|
||
h_session = h_local + h_global | ||
h_session = self.fusion(h_session, mask) | ||
return h_session | ||
|
||
def calculate_loss(self, interaction): | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
edge_attr = interaction['edge_attr'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len) | ||
pos_items = interaction[self.POS_ITEM_ID] | ||
if self.loss_type == 'BPR': | ||
neg_items = interaction[self.NEG_ITEM_ID] | ||
pos_items_emb = self.item_embedding(pos_items) | ||
neg_items_emb = self.item_embedding(neg_items) | ||
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] | ||
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] | ||
loss = self.loss_fct(pos_score, neg_score) | ||
return loss | ||
else: # self.loss_type = 'CE' | ||
test_item_emb = self.item_embedding.weight | ||
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) | ||
loss = self.loss_fct(logits, pos_items) | ||
return loss | ||
|
||
def predict(self, interaction): | ||
test_item = interaction[self.ITEM_ID] | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
edge_attr = interaction['edge_attr'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len) | ||
test_item_emb = self.item_embedding(test_item) | ||
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
edge_attr = interaction['edge_attr'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len) | ||
test_items_emb = self.item_embedding.weight | ||
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] | ||
return scores |
165 changes: 165 additions & 0 deletions
165
Recbole_GNN/recbole_gnn/model/sequential_recommender/gcsan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# @Time : 2022/3/7 | ||
# @Author : Yupeng Hou | ||
# @Email : houyupeng@ruc.edu.cn | ||
|
||
r""" | ||
GCSAN | ||
################################################ | ||
Reference: | ||
Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019. | ||
""" | ||
|
||
import torch | ||
from torch import nn | ||
from recbole.model.layers import TransformerEncoder | ||
from recbole.model.loss import EmbLoss, BPRLoss | ||
from recbole.model.abstract_recommender import SequentialRecommender | ||
|
||
from recbole_gnn.model.layers import SRGNNCell | ||
|
||
|
||
class GCSAN(SequentialRecommender): | ||
r"""GCSAN captures rich local dependencies via graph neural network, | ||
and learns long-range dependencies by applying the self-attention mechanism. | ||
Note: | ||
In the original paper, the attention mechanism in the self-attention layer is a single head, | ||
for the reusability of the project code, we use a unified transformer component. | ||
According to the experimental results, we only applied regularization to embedding. | ||
""" | ||
|
||
def __init__(self, config, dataset): | ||
super(GCSAN, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.n_layers = config['n_layers'] | ||
self.n_heads = config['n_heads'] | ||
self.hidden_size = config['hidden_size'] # same as embedding_size | ||
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer | ||
self.hidden_dropout_prob = config['hidden_dropout_prob'] | ||
self.attn_dropout_prob = config['attn_dropout_prob'] | ||
self.hidden_act = config['hidden_act'] | ||
self.layer_norm_eps = config['layer_norm_eps'] | ||
|
||
self.step = config['step'] | ||
self.device = config['device'] | ||
self.weight = config['weight'] | ||
self.reg_weight = config['reg_weight'] | ||
self.loss_type = config['loss_type'] | ||
self.initializer_range = config['initializer_range'] | ||
|
||
# item embedding | ||
self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0) | ||
|
||
# define layers and loss | ||
self.gnncell = SRGNNCell(self.hidden_size) | ||
self.self_attention = TransformerEncoder( | ||
n_layers=self.n_layers, | ||
n_heads=self.n_heads, | ||
hidden_size=self.hidden_size, | ||
inner_size=self.inner_size, | ||
hidden_dropout_prob=self.hidden_dropout_prob, | ||
attn_dropout_prob=self.attn_dropout_prob, | ||
hidden_act=self.hidden_act, | ||
layer_norm_eps=self.layer_norm_eps | ||
) | ||
self.reg_loss = EmbLoss() | ||
if self.loss_type == 'BPR': | ||
self.loss_fct = BPRLoss() | ||
elif self.loss_type == 'CE': | ||
self.loss_fct = nn.CrossEntropyLoss() | ||
else: | ||
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") | ||
|
||
# parameters initialization | ||
self.apply(self._init_weights) | ||
|
||
def _init_weights(self, module): | ||
""" Initialize the weights """ | ||
if isinstance(module, (nn.Linear, nn.Embedding)): | ||
# Slightly different from the TF version which uses truncated_normal for initialization | ||
# cf https://github.com/pytorch/pytorch/pull/5617 | ||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | ||
elif isinstance(module, nn.LayerNorm): | ||
module.bias.data.zero_() | ||
module.weight.data.fill_(1.0) | ||
if isinstance(module, nn.Linear) and module.bias is not None: | ||
module.bias.data.zero_() | ||
|
||
def get_attention_mask(self, item_seq): | ||
"""Generate left-to-right uni-directional attention mask for multi-head attention.""" | ||
attention_mask = (item_seq > 0).long() | ||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 | ||
# mask for left-to-right unidirectional | ||
max_len = attention_mask.size(-1) | ||
attn_shape = (1, max_len, max_len) | ||
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 | ||
subsequent_mask = (subsequent_mask == 0).unsqueeze(1) | ||
subsequent_mask = subsequent_mask.long().to(item_seq.device) | ||
|
||
extended_attention_mask = extended_attention_mask * subsequent_mask | ||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
return extended_attention_mask | ||
|
||
def forward(self, x, edge_index, alias_inputs, item_seq_len): | ||
hidden = self.item_embedding(x) | ||
for i in range(self.step): | ||
hidden = self.gnncell(hidden, edge_index) | ||
|
||
seq_hidden = hidden[alias_inputs] | ||
# fetch the last hidden state of last timestamp | ||
ht = self.gather_indexes(seq_hidden, item_seq_len - 1) | ||
|
||
attention_mask = self.get_attention_mask(alias_inputs) | ||
outputs = self.self_attention(seq_hidden, attention_mask, output_all_encoded_layers=True) | ||
output = outputs[-1] | ||
at = self.gather_indexes(output, item_seq_len - 1) | ||
seq_output = self.weight * at + (1 - self.weight) * ht | ||
return seq_output | ||
|
||
def calculate_loss(self, interaction): | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) | ||
pos_items = interaction[self.POS_ITEM_ID] | ||
if self.loss_type == 'BPR': | ||
neg_items = interaction[self.NEG_ITEM_ID] | ||
pos_items_emb = self.item_embedding(pos_items) | ||
neg_items_emb = self.item_embedding(neg_items) | ||
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] | ||
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] | ||
loss = self.loss_fct(pos_score, neg_score) | ||
else: # self.loss_type = 'CE' | ||
test_item_emb = self.item_embedding.weight | ||
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) | ||
loss = self.loss_fct(logits, pos_items) | ||
reg_loss = self.reg_loss(self.item_embedding.weight) | ||
total_loss = loss + self.reg_weight * reg_loss | ||
return total_loss | ||
|
||
def predict(self, interaction): | ||
test_item = interaction[self.ITEM_ID] | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) | ||
test_item_emb = self.item_embedding(test_item) | ||
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
x = interaction['x'] | ||
edge_index = interaction['edge_index'] | ||
alias_inputs = interaction['alias_inputs'] | ||
item_seq_len = interaction[self.ITEM_SEQ_LEN] | ||
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len) | ||
test_items_emb = self.item_embedding.weight | ||
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] | ||
return scores |
257 changes: 257 additions & 0 deletions
257
Recbole_GNN/recbole_gnn/model/sequential_recommender/lessr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
# @Time : 2022/3/11 | ||
# @Author : Yupeng Hou | ||
# @Email : houyupeng@ruc.edu.cn | ||
|
||
r""" | ||
LESSR | ||
################################################ | ||
Reference: | ||
Tianwen Chen and Raymond Chi-Wing Wong. "Handling Information Loss of Graph Neural Networks for Session-based Recommendation." in KDD 2020. | ||
Reference code: | ||
https://github.com/twchen/lessr | ||
""" | ||
|
||
import torch | ||
from torch import nn | ||
from torch_geometric.utils import softmax | ||
from torch_geometric.nn import global_add_pool | ||
from recbole.model.abstract_recommender import SequentialRecommender | ||
|
||
|
||
class EOPA(nn.Module): | ||
def __init__( | ||
self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None | ||
): | ||
super().__init__() | ||
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None | ||
self.feat_drop = nn.Dropout(feat_drop) | ||
self.gru = nn.GRU(input_dim, input_dim, batch_first=True) | ||
self.fc_self = nn.Linear(input_dim, output_dim, bias=False) | ||
self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False) | ||
self.activation = activation | ||
|
||
def reducer(self, nodes): | ||
m = nodes.mailbox['m'] # (num_nodes, deg, d) | ||
# m[i]: the messages passed to the i-th node with in-degree equal to 'deg' | ||
# the order of messages follows the order of incoming edges | ||
# since the edges are sorted by occurrence time when the EOP multigraph is built | ||
# the messages are in the order required by EOPA | ||
_, hn = self.gru(m) # hn: (1, num_nodes, d) | ||
return {'neigh': hn.squeeze(0)} | ||
|
||
def forward(self, mg, feat): | ||
import dgl.function as fn | ||
|
||
with mg.local_scope(): | ||
if self.batch_norm is not None: | ||
feat = self.batch_norm(feat) | ||
mg.ndata['ft'] = self.feat_drop(feat) | ||
if mg.number_of_edges() > 0: | ||
mg.update_all(fn.copy_u('ft', 'm'), self.reducer) | ||
neigh = mg.ndata['neigh'] | ||
rst = self.fc_self(feat) + self.fc_neigh(neigh) | ||
else: | ||
rst = self.fc_self(feat) | ||
if self.activation is not None: | ||
rst = self.activation(rst) | ||
return rst | ||
|
||
|
||
class SGAT(nn.Module): | ||
def __init__( | ||
self, | ||
input_dim, | ||
hidden_dim, | ||
output_dim, | ||
batch_norm=True, | ||
feat_drop=0.0, | ||
activation=None, | ||
): | ||
super().__init__() | ||
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None | ||
self.feat_drop = nn.Dropout(feat_drop) | ||
self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True) | ||
self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False) | ||
self.fc_v = nn.Linear(input_dim, output_dim, bias=False) | ||
self.fc_e = nn.Linear(hidden_dim, 1, bias=False) | ||
self.activation = activation | ||
|
||
def forward(self, sg, feat): | ||
import dgl.ops as F | ||
|
||
if self.batch_norm is not None: | ||
feat = self.batch_norm(feat) | ||
feat = self.feat_drop(feat) | ||
q = self.fc_q(feat) | ||
k = self.fc_k(feat) | ||
v = self.fc_v(feat) | ||
e = F.u_add_v(sg, q, k) | ||
e = self.fc_e(torch.sigmoid(e)) | ||
a = F.edge_softmax(sg, e) | ||
rst = F.u_mul_e_sum(sg, v, a) | ||
if self.activation is not None: | ||
rst = self.activation(rst) | ||
return rst | ||
|
||
|
||
class AttnReadout(nn.Module): | ||
def __init__( | ||
self, | ||
input_dim, | ||
hidden_dim, | ||
output_dim, | ||
batch_norm=True, | ||
feat_drop=0.0, | ||
activation=None, | ||
): | ||
super().__init__() | ||
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None | ||
self.feat_drop = nn.Dropout(feat_drop) | ||
self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False) | ||
self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True) | ||
self.fc_e = nn.Linear(hidden_dim, 1, bias=False) | ||
self.fc_out = ( | ||
nn.Linear(input_dim, output_dim, bias=False) | ||
if output_dim != input_dim else None | ||
) | ||
self.activation = activation | ||
|
||
def forward(self, g, feat, last_nodes, batch): | ||
if self.batch_norm is not None: | ||
feat = self.batch_norm(feat) | ||
feat = self.feat_drop(feat) | ||
feat_u = self.fc_u(feat) | ||
feat_v = self.fc_v(feat[last_nodes]) | ||
feat_v = torch.index_select(feat_v, dim=0, index=batch) | ||
e = self.fc_e(torch.sigmoid(feat_u + feat_v)) | ||
alpha = softmax(e, batch) | ||
feat_norm = feat * alpha | ||
rst = global_add_pool(feat_norm, batch) | ||
if self.fc_out is not None: | ||
rst = self.fc_out(rst) | ||
if self.activation is not None: | ||
rst = self.activation(rst) | ||
return rst | ||
|
||
|
||
class LESSR(SequentialRecommender): | ||
r"""LESSR analyzes the information losses when constructing session graphs, | ||
and emphasises lossy session encoding problem and the ineffective long-range dependency capturing problem. | ||
To solve the first problem, authors propose a lossless encoding scheme and an edge-order preserving aggregation layer. | ||
To solve the second problem, authors propose a shortcut graph attention layer that effectively captures long-range dependencies. | ||
Note: | ||
We follow the original implementation, which requires DGL package. | ||
We find it difficult to implement these functions via PyG, so we remain them. | ||
If you would like to test this model, please install DGL. | ||
""" | ||
|
||
def __init__(self, config, dataset): | ||
super().__init__(config, dataset) | ||
|
||
embedding_dim = config['embedding_size'] | ||
self.num_layers = config['n_layers'] | ||
batch_norm = config['batch_norm'] | ||
feat_drop = config['feat_drop'] | ||
self.loss_type = config['loss_type'] | ||
|
||
self.item_embedding = nn.Embedding(self.n_items, embedding_dim, max_norm=1) | ||
self.layers = nn.ModuleList() | ||
input_dim = embedding_dim | ||
for i in range(self.num_layers): | ||
if i % 2 == 0: | ||
layer = EOPA( | ||
input_dim, | ||
embedding_dim, | ||
batch_norm=batch_norm, | ||
feat_drop=feat_drop, | ||
activation=nn.PReLU(embedding_dim), | ||
) | ||
else: | ||
layer = SGAT( | ||
input_dim, | ||
embedding_dim, | ||
embedding_dim, | ||
batch_norm=batch_norm, | ||
feat_drop=feat_drop, | ||
activation=nn.PReLU(embedding_dim), | ||
) | ||
input_dim += embedding_dim | ||
self.layers.append(layer) | ||
self.readout = AttnReadout( | ||
input_dim, | ||
embedding_dim, | ||
embedding_dim, | ||
batch_norm=batch_norm, | ||
feat_drop=feat_drop, | ||
activation=nn.PReLU(embedding_dim), | ||
) | ||
input_dim += embedding_dim | ||
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None | ||
self.feat_drop = nn.Dropout(feat_drop) | ||
self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False) | ||
|
||
if self.loss_type == 'CE': | ||
self.loss_fct = nn.CrossEntropyLoss() | ||
else: | ||
raise NotImplementedError("Make sure 'loss_type' in ['CE']!") | ||
|
||
def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_last): | ||
import dgl | ||
|
||
mg = dgl.graph((edge_index_EOP[0], edge_index_EOP[1]), num_nodes=batch.shape[0]) | ||
sg = dgl.graph((edge_index_shortcut[0], edge_index_shortcut[1]), num_nodes=batch.shape[0]) | ||
|
||
feat = self.item_embedding(x) | ||
for i, layer in enumerate(self.layers): | ||
if i % 2 == 0: | ||
out = layer(mg, feat) | ||
else: | ||
out = layer(sg, feat) | ||
feat = torch.cat([out, feat], dim=1) | ||
sr_g = self.readout(mg, feat, is_last, batch) | ||
sr_l = feat[is_last] | ||
sr = torch.cat([sr_l, sr_g], dim=1) | ||
if self.batch_norm is not None: | ||
sr = self.batch_norm(sr) | ||
sr = self.fc_sr(self.feat_drop(sr)) | ||
return sr | ||
|
||
def calculate_loss(self, interaction): | ||
x = interaction['x'] | ||
edge_index_EOP = interaction['edge_index_EOP'] | ||
edge_index_shortcut = interaction['edge_index_shortcut'] | ||
batch = interaction['batch'] | ||
is_last = interaction['is_last'] | ||
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) | ||
pos_items = interaction[self.POS_ITEM_ID] | ||
test_item_emb = self.item_embedding.weight | ||
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) | ||
loss = self.loss_fct(logits, pos_items) | ||
return loss | ||
|
||
def predict(self, interaction): | ||
test_item = interaction[self.ITEM_ID] | ||
x = interaction['x'] | ||
edge_index_EOP = interaction['edge_index_EOP'] | ||
edge_index_shortcut = interaction['edge_index_shortcut'] | ||
batch = interaction['batch'] | ||
is_last = interaction['is_last'] | ||
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) | ||
test_item_emb = self.item_embedding(test_item) | ||
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] | ||
return scores | ||
|
||
def full_sort_predict(self, interaction): | ||
x = interaction['x'] | ||
edge_index_EOP = interaction['edge_index_EOP'] | ||
edge_index_shortcut = interaction['edge_index_shortcut'] | ||
batch = interaction['batch'] | ||
is_last = interaction['is_last'] | ||
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last) | ||
test_items_emb = self.item_embedding.weight | ||
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items] | ||
return scores |
Oops, something went wrong.