Skip to content

Commit

Permalink
Feat: add recbole_gnn base files
Browse files Browse the repository at this point in the history
SeungahP committed Feb 22, 2024
1 parent e1044cb commit 2ca9622
Showing 53 changed files with 4,551 additions and 0 deletions.
80 changes: 80 additions & 0 deletions Recbole_GNN/recbole_gnn/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import recbole
from recbole.config.configurator import Config as RecBole_Config
from recbole.utils import ModelType as RecBoleModelType

from recbole_gnn.utils import get_model, ModelType


class Config(RecBole_Config):
def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=None):
"""
Args:
model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
will search the parameter 'model' from the external input as the model name or model class.
dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
from the external input as the dataset name.
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
config_dict (dict): the external parameter dictionaries, default is None.
"""
if recbole.__version__ == "1.1.1":
self.compatibility_settings()
super(Config, self).__init__(model, dataset, config_file_list, config_dict)

def compatibility_settings(self):
import numpy as np
np.bool = np.bool_
np.int = np.int_
np.float = np.float_
np.complex = np.complex_
np.object = np.object_
np.str = np.str_
np.long = np.int_
np.unicode = np.unicode_

def _get_model_and_dataset(self, model, dataset):

if model is None:
try:
model = self.external_config_dict['model']
except KeyError:
raise KeyError(
'model need to be specified in at least one of the these ways: '
'[model variable, config file, config dict, command line] '
)
if not isinstance(model, str):
final_model_class = model
final_model = model.__name__
else:
final_model = model
final_model_class = get_model(final_model)

if dataset is None:
try:
final_dataset = self.external_config_dict['dataset']
except KeyError:
raise KeyError(
'dataset need to be specified in at least one of the these ways: '
'[dataset variable, config file, config dict, command line] '
)
else:
final_dataset = dataset

return final_model, final_model_class, final_dataset

def _load_internal_config_dict(self, model, model_class, dataset):
super()._load_internal_config_dict(model, model_class, dataset)
current_path = os.path.dirname(os.path.realpath(__file__))
model_init_file = os.path.join(current_path, './properties/model/' + model + '.yaml')
quick_start_config_path = os.path.join(current_path, './properties/quick_start_config/')
sequential_base_init = os.path.join(quick_start_config_path, 'sequential_base.yaml')
social_base_init = os.path.join(quick_start_config_path, 'social_base.yaml')

if os.path.isfile(model_init_file):
config_dict = self._update_internal_config_dict(model_init_file)

self.internal_config_dict['MODEL_TYPE'] = model_class.type
if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL:
self._update_internal_config_dict(sequential_base_init)
if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL:
self._update_internal_config_dict(social_base_init)
30 changes: 30 additions & 0 deletions Recbole_GNN/recbole_gnn/model/abstract_recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.utils import ModelType as RecBoleModelType

from recbole_gnn.utils import ModelType


class GeneralGraphRecommender(GeneralRecommender):
"""This is an abstract general graph recommender. All the general graph models should implement in this class.
The base general graph recommender class provide the basic U-I graph dataset and parameters information.
"""
type = RecBoleModelType.GENERAL

def __init__(self, config, dataset):
super(GeneralGraphRecommender, self).__init__(config, dataset)
self.edge_index, self.edge_weight = dataset.get_norm_adj_mat(enable_sparse=config["enable_sparse"])
self.use_sparse = config["enable_sparse"] and dataset.is_sparse
if self.use_sparse:
self.edge_index, self.edge_weight = self.edge_index.to(self.device), None
else:
self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.to(self.device)


class SocialRecommender(GeneralRecommender):
"""This is an abstract social recommender. All the social graph model should implement this class.
The base social recommender class provide the basic social graph dataset and parameters information.
"""
type = ModelType.SOCIAL

def __init__(self, config, dataset):
super(SocialRecommender, self).__init__(config, dataset)
10 changes: 10 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from recbole_gnn.model.general_recommender.lightgcn import LightGCN
from recbole_gnn.model.general_recommender.hmlet import HMLET
from recbole_gnn.model.general_recommender.ncl import NCL
from recbole_gnn.model.general_recommender.ngcf import NGCF
from recbole_gnn.model.general_recommender.sgl import SGL
from recbole_gnn.model.general_recommender.lightgcl import LightGCL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
from recbole_gnn.model.general_recommender.directau import DirectAU
from recbole_gnn.model.general_recommender.ssl4rec import SSL4REC
120 changes: 120 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/directau.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# r"""
# DiretAU
# ################################################
# Reference:
# Chenyang Wang et al. "Towards Representation Alignment and Uniformity in Collaborative Filtering." in KDD 2022.

# Reference code:
# https://github.com/THUwangcy/DirectAU
# """

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.general_recommender import BPR
from recbole_gnn.model.general_recommender import LightGCN

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class DirectAU(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(DirectAU, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.gamma = config['gamma']
self.encoder_name = config['encoder']

# define encoder
if self.encoder_name == 'MF':
self.encoder = MFEncoder(config, dataset)
elif self.encoder_name == 'LightGCN':
self.encoder = LGCNEncoder(config, dataset)
else:
raise ValueError('Non-implemented Encoder.')

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_normal_initialization)

def forward(self, user, item):
user_e, item_e = self.encoder(user, item)
return F.normalize(user_e, dim=-1), F.normalize(item_e, dim=-1)

@staticmethod
def alignment(x, y, alpha=2):
return (x - y).norm(p=2, dim=1).pow(alpha).mean()

@staticmethod
def uniformity(x, t=2):
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_e, item_e = self.forward(user, item)
align = self.alignment(user_e, item_e)
uniform = self.gamma * (self.uniformity(user_e) + self.uniformity(item_e)) / 2

return align, uniform

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e = self.user_embedding(user)
item_e = self.item_embedding(item)
return torch.mul(user_e, item_e).sum(dim=1)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.encoder_name == 'LightGCN':
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.encoder.get_all_embeddings()
user_e = self.restore_user_e[user]
all_item_e = self.restore_item_e
else:
user_e = self.encoder.user_embedding(user)
all_item_e = self.encoder.item_embedding.weight
score = torch.matmul(user_e, all_item_e.transpose(0, 1))
return score.view(-1)


class MFEncoder(BPR):
def __init__(self, config, dataset):
super(MFEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
return super().forward(user_id, item_id)

def get_all_embeddings(self):
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
return user_embeddings, item_embeddings


class LGCNEncoder(LightGCN):
def __init__(self, config, dataset):
super(LGCNEncoder, self).__init__(config, dataset)

def forward(self, user_id, item_id):
user_all_embeddings, item_all_embeddings = self.get_all_embeddings()
u_embed = user_all_embeddings[user_id]
i_embed = item_all_embeddings[item_id]
return u_embed, i_embed

def get_all_embeddings(self):
return super().forward()
229 changes: 229 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/hmlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# @Time : 2022/3/21
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

r"""
HMLET
################################################
Reference:
Taeyong Kong et al. "Linear, or Non-Linear, That is the Question!." in WSDM 2022.
Reference code:
https://github.com/qbxlvnf11/HMLET
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.model.layers import activation_layer
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class Gating_Net(nn.Module):
def __init__(self, embedding_dim, mlp_dims, dropout_p):
super(Gating_Net, self).__init__()
self.embedding_dim = embedding_dim

fc_layers = []
for i in range(len(mlp_dims)):
if i == 0:
fc = nn.Linear(embedding_dim*2, mlp_dims[i])
fc_layers.append(fc)
else:
fc = nn.Linear(mlp_dims[i-1], mlp_dims[i])
fc_layers.append(fc)
if i != len(mlp_dims) - 1:
fc_layers.append(nn.BatchNorm1d(mlp_dims[i]))
fc_layers.append(nn.Dropout(p=dropout_p))
fc_layers.append(nn.ReLU(inplace=True))
self.mlp = nn.Sequential(*fc_layers)

def gumbel_softmax(self, logits, temperature, hard):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = self.gumbel_softmax_sample(logits, temperature) ## (0.6, 0.2, 0.1,..., 0.11)
if hard:
k = logits.size(1) # k is numb of classes
# y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) ## (1, 0, 0, ..., 0)
y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y)
y = (y_hard - y).detach() + y
return y

def gumbel_softmax_sample(self, logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
noise = self.sample_gumbel(logits)
y = (logits + noise) / temperature
return F.softmax(y, dim=1)

def sample_gumbel(self, logits):
"""Sample from Gumbel(0, 1)"""
noise = torch.rand(logits.size())
eps = 1e-20
noise.add_(eps).log_().neg_()
noise.add_(eps).log_().neg_()
return torch.Tensor(noise.float()).to(logits.device)

def forward(self, feature, temperature, hard):
x = self.mlp(feature)
out = self.gumbel_softmax(x, temperature, hard)
out_value = out.unsqueeze(2)
gating_out = out_value.repeat(1, 1, self.embedding_dim)
return gating_out


class HMLET(GeneralGraphRecommender):
r"""HMLET combines both linear and non-linear propagation layers for general recommendation and yields better performance.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(HMLET, self).__init__(config, dataset)

# load parameters info
self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN
self.n_layers = config['n_layers'] # int type:the layer num of lightGCN
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization
self.gate_layer_ids = config['gate_layer_ids'] # list type: layer ids for non-linear gating
self.gating_mlp_dims = config['gating_mlp_dims'] # list type: list of mlp dimensions in gating module
self.dropout_ratio = config['dropout_ratio'] # dropout ratio for mlp in gating module
self.gum_temp = config['ori_temp']
self.logger.info(f'Model initialization, gumbel softmax temperature: {self.gum_temp}')

# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.activation = nn.ELU() if config['activation_function'] == 'elu' else activation_layer(config['activation_function'])
self.gating_nets = nn.ModuleList([
Gating_Net(self.latent_dim, self.gating_mlp_dims, self.dropout_ratio) for _ in range(len(self.gate_layer_ids))
])

self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e', 'gum_temp']

for gating in self.gating_nets:
self._gating_freeze(gating, False)

def _gating_freeze(self, model, freeze_flag):
for name, child in model.named_children():
for param in child.parameters():
param.requires_grad = freeze_flag

def __choosing_one(self, features, gumbel_out):
feature = torch.sum(torch.mul(features, gumbel_out), dim=1) # batch x embedding_dim (or batch x embedding_dim x layer_num)
return feature

def __where(self, idx, lst):
for i in range(len(lst)):
if lst[i] == idx:
return i
raise ValueError(f'{idx} not in {lst}.')

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.
Returns:
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
non_lin_emb_list = [all_embeddings]

for layer_idx in range(self.n_layers):
linear_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
if layer_idx not in self.gate_layer_ids:
all_embeddings = linear_embeddings
else:
non_lin_id = self.__where(layer_idx, self.gate_layer_ids)
last_non_lin_emb = non_lin_emb_list[non_lin_id]
non_lin_embeddings = self.activation(self.gcn_conv(last_non_lin_emb, self.edge_index, self.edge_weight))
stack_embeddings = torch.stack([linear_embeddings, non_lin_embeddings], dim=1)
concat_embeddings = torch.cat((linear_embeddings, non_lin_embeddings), dim=-1)
gumbel_out = self.gating_nets[non_lin_id](concat_embeddings, self.gum_temp, not self.training)
all_embeddings = self.__choosing_one(stack_embeddings, gumbel_out)
non_lin_emb_list.append(all_embeddings)
embeddings_list.append(all_embeddings)
hmlet_all_embeddings = torch.stack(embeddings_list, dim=1)
hmlet_all_embeddings = torch.mean(hmlet_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(hmlet_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
loss = mf_loss + self.reg_weight * reg_loss

return loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
226 changes: 226 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/lightgcl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# -*- coding: utf-8 -*-
# @Time : 2023/04/12
# @Author : Wanli Yang
# @Email : 2013774@mail.nankai.edu.cn

r"""
LightGCL
################################################
Reference:
Xuheng Cai et al. "LightGCL: Simple Yet Effective Graph Contrastive Learning for Recommendation" in ICLR 2023.
Reference code:
https://github.com/HKUDS/LightGCL
"""

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
import torch.nn.functional as F


class LightGCL(GeneralRecommender):
r"""LightGCL is a GCN-based recommender model.
LightGCL guides graph augmentation by singular value decomposition (SVD) to not only
distill the useful information of user-item interactions but also inject the global
collaborative context into the representation alignment of contrastive learning.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(LightGCL, self).__init__(config, dataset)
self._user = dataset.inter_feat[dataset.uid_field]
self._item = dataset.inter_feat[dataset.iid_field]

# load parameters info
self.embed_dim = config["embedding_size"]
self.n_layers = config["n_layers"]
self.dropout = config["dropout"]
self.temp = config["temp"]
self.lambda_1 = config["lambda1"]
self.lambda_2 = config["lambda2"]
self.q = config["q"]
self.act = nn.LeakyReLU(0.5)
self.reg_loss = EmbLoss()

# get the normalized adjust matrix
self.adj_norm = self.coo2tensor(self.create_adjust_matrix())

# perform svd reconstruction
svd_u, s, svd_v = torch.svd_lowrank(self.adj_norm, q=self.q)
self.u_mul_s = svd_u @ (torch.diag(s))
self.v_mul_s = svd_v @ (torch.diag(s))
del s
self.ut = svd_u.T
self.vt = svd_v.T

self.E_u_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_users, self.embed_dim)))
self.E_i_0 = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.n_items, self.embed_dim)))
self.E_u_list = [None] * (self.n_layers + 1)
self.E_i_list = [None] * (self.n_layers + 1)
self.E_u_list[0] = self.E_u_0
self.E_i_list[0] = self.E_i_0
self.Z_u_list = [None] * (self.n_layers + 1)
self.Z_i_list = [None] * (self.n_layers + 1)
self.G_u_list = [None] * (self.n_layers + 1)
self.G_i_list = [None] * (self.n_layers + 1)
self.G_u_list[0] = self.E_u_0
self.G_i_list[0] = self.E_i_0

self.E_u = None
self.E_i = None
self.restore_user_e = None
self.restore_item_e = None

self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def create_adjust_matrix(self):
r"""Get the normalized interaction matrix of users and items.
Returns:
coo_matrix of the normalized interaction matrix.
"""
ratings = np.ones_like(self._user, dtype=np.float32)
matrix = sp.csr_matrix(
(ratings, (self._user, self._item)),
shape=(self.n_users, self.n_items),
).tocoo()
rowD = np.squeeze(np.array(matrix.sum(1)), axis=1)
colD = np.squeeze(np.array(matrix.sum(0)), axis=0)
for i in range(len(matrix.data)):
matrix.data[i] = matrix.data[i] / pow(rowD[matrix.row[i]] * colD[matrix.col[i]], 0.5)
return matrix

def coo2tensor(self, matrix: sp.coo_matrix):
r"""Convert coo_matrix to tensor.
Args:
matrix (scipy.coo_matrix): Sparse matrix to be converted.
Returns:
torch.sparse.FloatTensor: Transformed sparse matrix.
"""
indices = torch.from_numpy(
np.vstack((matrix.row, matrix.col)).astype(np.int64))
values = torch.from_numpy(matrix.data)
shape = torch.Size(matrix.shape)
x = torch.sparse.FloatTensor(indices, values, shape).coalesce().to(self.device)
return x

def sparse_dropout(self, matrix, dropout):
if dropout == 0.0:
return matrix
indices = matrix.indices()
values = F.dropout(matrix.values(), p=dropout)
size = matrix.size()
return torch.sparse.FloatTensor(indices, values, size)

def forward(self):
for layer in range(1, self.n_layers + 1):
# GNN propagation
self.Z_u_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout),
self.E_i_list[layer - 1])
self.Z_i_list[layer] = torch.spmm(self.sparse_dropout(self.adj_norm, self.dropout).transpose(0, 1),
self.E_u_list[layer - 1])
# aggregate
self.E_u_list[layer] = self.Z_u_list[layer]
self.E_i_list[layer] = self.Z_i_list[layer]

# aggregate across layer
self.E_u = sum(self.E_u_list)
self.E_i = sum(self.E_i_list)

return self.E_u, self.E_i

def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user_list = interaction[self.USER_ID]
pos_item_list = interaction[self.ITEM_ID]
neg_item_list = interaction[self.NEG_ITEM_ID]
E_u_norm, E_i_norm = self.forward()
bpr_loss = self.calc_bpr_loss(E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list)
ssl_loss = self.calc_ssl_loss(E_u_norm, E_i_norm, user_list, pos_item_list)
total_loss = bpr_loss + ssl_loss
return total_loss

def calc_bpr_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list, neg_item_list):
r"""Calculate the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.
Args:
E_u_norm (torch.Tensor): Ego embedding of all users after forwarding.
E_i_norm (torch.Tensor): Ego embedding of all items after forwarding.
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
neg_item_list (torch.Tensor): List of negative examples.
Returns:
torch.Tensor: Loss of BPR tasks and parameter regularization.
"""
u_e = E_u_norm[user_list]
pi_e = E_i_norm[pos_item_list]
ni_e = E_i_norm[neg_item_list]
pos_scores = torch.mul(u_e, pi_e).sum(dim=1)
neg_scores = torch.mul(u_e, ni_e).sum(dim=1)
loss1 = -(pos_scores - neg_scores).sigmoid().log().mean()

# reg loss
loss_reg = 0
for param in self.parameters():
loss_reg += param.norm(2).square()
loss_reg *= self.lambda_2
return loss1 + loss_reg

def calc_ssl_loss(self, E_u_norm, E_i_norm, user_list, pos_item_list):
r"""Calculate the loss of self-supervised tasks.
Args:
E_u_norm (torch.Tensor): Ego embedding of all users in the original graph after forwarding.
E_i_norm (torch.Tensor): Ego embedding of all items in the original graph after forwarding.
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
Returns:
torch.Tensor: Loss of self-supervised tasks.
"""
# calculate G_u_norm&G_i_norm
for layer in range(1, self.n_layers + 1):
# svd_adj propagation
vt_ei = self.vt @ self.E_i_list[layer - 1]
self.G_u_list[layer] = self.u_mul_s @ vt_ei
ut_eu = self.ut @ self.E_u_list[layer - 1]
self.G_i_list[layer] = self.v_mul_s @ ut_eu

# aggregate across layer
G_u_norm = sum(self.G_u_list)
G_i_norm = sum(self.G_i_list)

neg_score = torch.log(torch.exp(G_u_norm[user_list] @ E_u_norm.T / self.temp).sum(1) + 1e-8).mean()
neg_score += torch.log(torch.exp(G_i_norm[pos_item_list] @ E_i_norm.T / self.temp).sum(1) + 1e-8).mean()
pos_score = (torch.clamp((G_u_norm[user_list] * E_u_norm[user_list]).sum(1) / self.temp, -5.0, 5.0)).mean() + (
torch.clamp((G_i_norm[pos_item_list] * E_i_norm[pos_item_list]).sum(1) / self.temp, -5.0, 5.0)).mean()
ssl_loss = -pos_score + neg_score
return self.lambda_1 * ssl_loss

def predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
user = self.restore_user_e[interaction[self.USER_ID]]
item = self.restore_item_e[interaction[self.ITEM_ID]]
return torch.sum(user * item, dim=1)

def full_sort_predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
user = self.restore_user_e[interaction[self.USER_ID]]
return user.matmul(self.restore_item_e.T)
133 changes: 133 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/lightgcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# @Time : 2022/3/8
# @Author : Lanling Xu
# @Email : xulanling_sherry@163.com

r"""
LightGCN
################################################
Reference:
Xiangnan He et al. "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation." in SIGIR 2020.
Reference code:
https://github.com/kuandeng/LightGCN
"""

import numpy as np
import torch

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class LightGCN(GeneralGraphRecommender):
r"""LightGCN is a GCN-based recommender model, implemented via PyG.
LightGCN includes only the most essential component in GCN — neighborhood aggregation — for
collaborative filtering. Specifically, LightGCN learns user and item embeddings by linearly
propagating them on the user-item interaction graph, and uses the weighted sum of the embeddings
learned at all layers as the final embedding.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(LightGCN, self).__init__(config, dataset)

# load parameters info
self.latent_dim = config['embedding_size'] # int type:the embedding size of lightGCN
self.n_layers = config['n_layers'] # int type:the layer num of lightGCN
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization
self.require_pow = config['require_pow'] # bool type: whether to require pow when regularization

# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.
Returns:
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]

for layer_idx in range(self.n_layers):
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
embeddings_list.append(all_embeddings)
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
loss = mf_loss + self.reg_weight * reg_loss

return loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
222 changes: 222 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/ncl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
r"""
NCL
################################################
Reference:
Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022.
"""

import torch
import torch.nn.functional as F

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class NCL(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(NCL, self).__init__(config, dataset)

# load parameters info
self.latent_dim = config['embedding_size'] # int type: the embedding size of the base model
self.n_layers = config['n_layers'] # int type: the layer num of the base model
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization

self.ssl_temp = config['ssl_temp']
self.ssl_reg = config['ssl_reg']
self.hyper_layers = config['hyper_layers']

self.alpha = config['alpha']

self.proto_reg = config['proto_reg']
self.k = config['num_clusters']

# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

self.user_centroids = None
self.user_2cluster = None
self.item_centroids = None
self.item_2cluster = None

def e_step(self):
user_embeddings = self.user_embedding.weight.detach().cpu().numpy()
item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)

def run_kmeans(self, x):
"""Run K-means algorithm to get k clusters of the input tensor x
"""
import faiss
kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
kmeans.train(x)
cluster_cents = kmeans.centroids

_, I = kmeans.index.search(x, 1)

# convert to cuda Tensors for broadcast
centroids = torch.Tensor(cluster_cents).to(self.device)
centroids = F.normalize(centroids, p=2, dim=1)

node2cluster = torch.LongTensor(I).squeeze().to(self.device)
return centroids, node2cluster

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.
Returns:
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)):
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
embeddings_list.append(all_embeddings)

lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings, embeddings_list

def ProtoNCE_loss(self, node_embedding, user, item):
user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items])

user_embeddings = user_embeddings_all[user] # [B, e]
norm_user_embeddings = F.normalize(user_embeddings)

user2cluster = self.user_2cluster[user] # [B,]
user2centroids = self.user_centroids[user2cluster] # [B, e]
pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1))
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

item_embeddings = item_embeddings_all[item]
norm_item_embeddings = F.normalize(item_embeddings)

item2cluster = self.item_2cluster[item] # [B, ]
item2centroids = self.item_centroids[item2cluster] # [B, e]
pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1))
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
return proto_nce_loss

def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items])
previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items])

current_user_embeddings = current_user_embeddings[user]
previous_user_embeddings = previous_user_embeddings_all[user]
norm_user_emb1 = F.normalize(current_user_embeddings)
norm_user_emb2 = F.normalize(previous_user_embeddings)
norm_all_user_emb = F.normalize(previous_user_embeddings_all)
pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

current_item_embeddings = current_item_embeddings[item]
previous_item_embeddings = previous_item_embeddings_all[item]
norm_item_emb1 = F.normalize(current_item_embeddings)
norm_item_emb2 = F.normalize(previous_item_embeddings)
norm_all_item_emb = F.normalize(previous_item_embeddings_all)
pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)

ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
return ssl_loss

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

center_embedding = embeddings_list[0]
context_embedding = embeddings_list[self.hyper_layers * 2]

ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item)
proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item)

u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)

mf_loss = self.mf_loss(pos_scores, neg_scores)

u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)

return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e, embedding_list = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
149 changes: 149 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# @Time : 2022/3/8
# @Author : Changxin Tian
# @Email : cx.tian@outlook.com
r"""
NGCF
################################################
Reference:
Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019.
Reference code:
https://github.com/xiangwang1223/neural_graph_collaborative_filtering
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj

from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import BiGNNConv


class NGCF(GeneralGraphRecommender):
r"""NGCF is a model that incorporate GNN for recommendation.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(NGCF, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.hidden_size_list = config['hidden_size_list']
self.hidden_size_list = [self.embedding_size] + self.hidden_size_list
self.node_dropout = config['node_dropout']
self.message_dropout = config['message_dropout']
self.reg_weight = config['reg_weight']

# define layers and loss
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
self.GNNlayers = torch.nn.ModuleList()
for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]):
self.GNNlayers.append(BiGNNConv(input_size, output_size))
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_normal_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.
Returns:
Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim)
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
if self.node_dropout == 0:
edge_index, edge_weight = self.edge_index, self.edge_weight
else:
edge_index, edge_weight = self.edge_index, self.edge_weight
if self.use_sparse:
row, col, edge_weight = edge_index.t().coo()
edge_index = torch.stack([row, col], 0)
edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
p=self.node_dropout, training=self.training)
from torch_sparse import SparseTensor
edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight,
sparse_sizes=(self.n_users + self.n_items, self.n_users + self.n_items))
edge_index = edge_index.t()
edge_weight = None
else:
edge_index, edge_weight = dropout_adj(edge_index=edge_index, edge_attr=edge_weight,
p=self.node_dropout, training=self.training)

all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
for gnn in self.GNNlayers:
all_embeddings = gnn(all_embeddings, edge_index, edge_weight)
all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings)
all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings)
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
embeddings_list += [all_embeddings] # storage output embedding of each layer
ngcf_all_embeddings = torch.cat(embeddings_list, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items])

return user_all_embeddings, item_all_embeddings

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores) # calculate BPR Loss

reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) # L2 regularization of embeddings

return mf_loss + self.reg_weight * reg_loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
240 changes: 240 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/sgl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
# @Time : 2022/3/8
# @Author : Changxin Tian
# @Email : cx.tian@outlook.com
r"""
SGL
################################################
Reference:
Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021.
Reference code:
https://github.com/wujcan/SGL
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType

from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender
from recbole_gnn.model.layers import LightGCNConv


class SGL(GeneralGraphRecommender):
r"""SGL is a GCN-based recommender model.
SGL supplements the classical supervised task of recommendation with an auxiliary
self supervised task, which reinforces node representation learning via self-
discrimination.Specifically,SGL generates multiple views of a node, maximizing the
agreement between different views of the same node compared to that of other nodes.
SGL devises three operators to generate the views — node dropout, edge dropout, and
random walk — that change the graph structure in different manners.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(SGL, self).__init__(config, dataset)

# load parameters info
self.latent_dim = config["embedding_size"]
self.n_layers = int(config["n_layers"])
self.aug_type = config["type"]
self.drop_ratio = config["drop_ratio"]
self.ssl_tau = config["ssl_tau"]
self.reg_weight = config["reg_weight"]
self.ssl_weight = config["ssl_weight"]

self._user = dataset.inter_feat[dataset.uid_field]
self._item = dataset.inter_feat[dataset.iid_field]
self.dataset = dataset

# define layers and loss
self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def train(self, mode: bool = True):
r"""Override train method of base class. The subgraph is reconstructed each time it is called.
"""
T = super().train(mode=mode)
if mode:
self.graph_construction()
return T

def graph_construction(self):
r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node.
"""
if self.aug_type == "ND" or self.aug_type == "ED":
self.sub_graph1 = [self.random_graph_augment()] * self.n_layers
self.sub_graph2 = [self.random_graph_augment()] * self.n_layers
elif self.aug_type == "RW":
self.sub_graph1 = [self.random_graph_augment() for _ in range(self.n_layers)]
self.sub_graph2 = [self.random_graph_augment() for _ in range(self.n_layers)]

def random_graph_augment(self):
def rand_sample(high, size=None, replace=True):
return np.random.choice(np.arange(high), size=size, replace=replace)

if self.aug_type == "ND":
drop_user = rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False)
drop_item = rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False)

mask = np.isin(self._user.numpy(), drop_user)
mask |= np.isin(self._item.numpy(), drop_item)
keep = np.where(~mask)

row = self._user[keep]
col = self._item[keep] + self.n_users

elif self.aug_type == "ED" or self.aug_type == "RW":
keep = rand_sample(len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False)
row = self._user[keep]
col = self._item[keep] + self.n_users

edge_index1 = torch.stack([row, col])
edge_index2 = torch.stack([col, row])
edge_index = torch.cat([edge_index1, edge_index2], dim=1)
edge_weight = torch.ones(edge_index.size(1))
num_nodes = self.n_users + self.n_items

if self.use_sparse:
adj_t = self.dataset.edge_index_to_adj_t(edge_index, edge_weight, num_nodes, num_nodes)
adj_t = gcn_norm(adj_t, None, num_nodes, add_self_loops=False)
return adj_t.to(self.device), None

edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, add_self_loops=False)

return edge_index.to(self.device), edge_weight.to(self.device)

def forward(self, graph=None):
all_embeddings = torch.cat([self.user_embedding.weight, self.item_embedding.weight])
embeddings_list = [all_embeddings]

if graph is None: # for the original graph
for _ in range(self.n_layers):
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
embeddings_list.append(all_embeddings)
else: # for the augmented graph
for graph_edge_index, graph_edge_weight in graph:
all_embeddings = self.gcn_conv(all_embeddings, graph_edge_index, graph_edge_weight)
embeddings_list.append(all_embeddings)

embeddings_list = torch.stack(embeddings_list, dim=1)
embeddings_list = torch.mean(embeddings_list, dim=1, keepdim=False)
user_all_embeddings, item_all_embeddings = torch.split(embeddings_list, [self.n_users, self.n_items], dim=0)

return user_all_embeddings, item_all_embeddings

def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list):
r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.
Args:
user_emd (torch.Tensor): Ego embedding of all users after forwarding.
item_emd (torch.Tensor): Ego embedding of all items after forwarding.
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
neg_item_list (torch.Tensor): List of negative examples.
Returns:
torch.Tensor: Loss of BPR tasks and parameter regularization.
"""
u_e = user_emd[user_list]
pi_e = item_emd[pos_item_list]
ni_e = item_emd[neg_item_list]
p_scores = torch.mul(u_e, pi_e).sum(dim=1)
n_scores = torch.mul(u_e, ni_e).sum(dim=1)

l1 = torch.sum(-F.logsigmoid(p_scores - n_scores))

u_e_p = self.user_embedding(user_list)
pi_e_p = self.item_embedding(pos_item_list)
ni_e_p = self.item_embedding(neg_item_list)

l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p)

return l1 + l2 * self.reg_weight

def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2):
r"""Calculate the loss of self-supervised tasks.
Args:
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding.
user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding.
item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding.
item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding.
Returns:
torch.Tensor: Loss of self-supervised tasks.
"""

u_emd1 = F.normalize(user_sub1[user_list], dim=1)
u_emd2 = F.normalize(user_sub2[user_list], dim=1)
all_user2 = F.normalize(user_sub2, dim=1)
v1 = torch.sum(u_emd1 * u_emd2, dim=1)
v2 = u_emd1.matmul(all_user2.T)
v1 = torch.exp(v1 / self.ssl_tau)
v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1)
ssl_user = -torch.sum(torch.log(v1 / v2))

i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1)
i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1)
all_item2 = F.normalize(item_sub2, dim=1)
v3 = torch.sum(i_emd1 * i_emd2, dim=1)
v4 = i_emd1.matmul(all_item2.T)
v3 = torch.exp(v3 / self.ssl_tau)
v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1)
ssl_item = -torch.sum(torch.log(v3 / v4))

return (ssl_item + ssl_user) * self.ssl_weight

def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user_list = interaction[self.USER_ID]
pos_item_list = interaction[self.ITEM_ID]
neg_item_list = interaction[self.NEG_ITEM_ID]

user_emd, item_emd = self.forward()
user_sub1, item_sub1 = self.forward(self.sub_graph1)
user_sub2, item_sub2 = self.forward(self.sub_graph2)

total_loss = self.calc_bpr_loss(user_emd, item_emd, user_list, pos_item_list, neg_item_list) + \
self.calc_ssl_loss(user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2)
return total_loss

def predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()

user = self.restore_user_e[interaction[self.USER_ID]]
item = self.restore_item_e[interaction[self.ITEM_ID]]
return torch.sum(user * item, dim=1)

def full_sort_predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()

user = self.restore_user_e[interaction[self.USER_ID]]
return user.matmul(self.restore_item_e.T)
60 changes: 60 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/simgcl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
r"""
SimGCL
################################################
Reference:
Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022.
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class SimGCL(LightGCN):
def __init__(self, config, dataset):
super(SimGCL, self).__init__(config, dataset)

self.cl_rate = config['lambda']
self.eps = config['eps']
self.temperature = config['temperature']

def forward(self, perturbed=False):
all_embs = self.get_ego_embeddings()
embeddings_list = []

for layer_idx in range(self.n_layers):
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
if perturbed:
random_noise = torch.rand_like(all_embs, device=all_embs.device)
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
embeddings_list.append(all_embs)
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings

def calculate_cl_loss(self, x1, x2):
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temperature)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
return -torch.log(pos_score / ttl_score).sum()

def calculate_loss(self, interaction):
loss = super().calculate_loss(interaction)

user = torch.unique(interaction[self.USER_ID])
pos_item = torch.unique(interaction[self.ITEM_ID])

perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True)
perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True)

user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user])
item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item])

return loss + self.cl_rate * (user_cl_loss + item_cl_loss)
163 changes: 163 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/ssl4rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
r"""
SSL4REC
################################################
Reference:
Tiansheng Yao et al. "Self-supervised Learning for Large-scale Item Recommendations." in CIKM 2021.
Reference code:
https://github.com/Coder-Yu/SELFRec/model/graph/SSL4Rec.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from recbole.model.loss import EmbLoss
from recbole.utils import InputType

from recbole.model.init import xavier_uniform_initialization
from recbole_gnn.model.abstract_recommender import GeneralGraphRecommender


class SSL4REC(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(SSL4REC, self).__init__(config, dataset)

# load parameters info
self.tau = config["tau"]
self.reg_weight = config["reg_weight"]
self.cl_rate = config["ssl_weight"]
self.require_pow = config["require_pow"]

self.reg_loss = EmbLoss()

self.encoder = DNN_Encoder(config, dataset)

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def forward(self, user, item):
user_e, item_e = self.encoder(user, item)
return user_e, item_e

def calculate_batch_softmax_loss(self, user_emb, item_emb, temperature):
user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
pos_score = (user_emb * item_emb).sum(dim=-1)
pos_score = torch.exp(pos_score / temperature)
ttl_score = torch.matmul(user_emb, item_emb.transpose(0, 1))
ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
loss = -torch.log(pos_score / ttl_score + 10e-6)
return torch.mean(loss)

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]

user_embeddings, item_embeddings = self.forward(user, pos_item)

rec_loss = self.calculate_batch_softmax_loss(user_embeddings, item_embeddings, self.tau)
cl_loss = self.encoder.calculate_cl_loss(pos_item)
reg_loss = self.reg_loss(user_embeddings, item_embeddings, require_pow=self.require_pow)

loss = rec_loss + self.cl_rate * cl_loss + self.reg_weight * reg_loss

return loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_embeddings, item_embeddings = self.forward(user, item)

u_embeddings = user_embeddings[user]
i_embeddings = item_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward(torch.arange(
self.n_users, device=self.device), torch.arange(self.n_items, device=self.device))
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)


class DNN_Encoder(nn.Module):
def __init__(self, config, dataset):
super(DNN_Encoder, self).__init__()

self.emb_size = config["embedding_size"]
self.drop_ratio = config["drop_ratio"]
self.tau = config["tau"]

self.USER_ID = config["USER_ID_FIELD"]
self.ITEM_ID = config["ITEM_ID_FIELD"]
self.n_users = dataset.num(self.USER_ID)
self.n_items = dataset.num(self.ITEM_ID)

self.user_tower = nn.Sequential(
nn.Linear(self.emb_size, 1024),
nn.ReLU(True),
nn.Linear(1024, 128),
nn.Tanh()
)
self.item_tower = nn.Sequential(
nn.Linear(self.emb_size, 1024),
nn.ReLU(True),
nn.Linear(1024, 128),
nn.Tanh()
)
self.dropout = nn.Dropout(self.drop_ratio)

self.initial_user_emb = nn.Embedding(self.n_users, self.emb_size)
self.initial_item_emb = nn.Embedding(self.n_items, self.emb_size)
self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.initial_user_emb.weight)
nn.init.xavier_uniform_(self.initial_item_emb.weight)

def forward(self, q, x):
q_emb = self.initial_user_emb(q)
i_emb = self.initial_item_emb(x)

q_emb = self.user_tower(q_emb)
i_emb = self.item_tower(i_emb)

return q_emb, i_emb

def item_encoding(self, x):
i_emb = self.initial_item_emb(x)
i1_emb = self.dropout(i_emb)
i2_emb = self.dropout(i_emb)

i1_emb = self.item_tower(i1_emb)
i2_emb = self.item_tower(i2_emb)

return i1_emb, i2_emb

def calculate_cl_loss(self, idx):
x1, x2 = self.item_encoding(idx)
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.tau)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.tau).sum(dim=1)
return -torch.log(pos_score / ttl_score).mean()
90 changes: 90 additions & 0 deletions Recbole_GNN/recbole_gnn/model/general_recommender/xsimgcl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
r"""
XSimGCL
################################################
Reference:
Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023.
Reference code:
https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class XSimGCL(LightGCN):
def __init__(self, config, dataset):
super(XSimGCL, self).__init__(config, dataset)

self.cl_rate = config['lambda']
self.eps = config['eps']
self.temperature = config['temperature']
self.layer_cl = config['layer_cl']

def forward(self, perturbed=False):
all_embs = self.get_ego_embeddings()
all_embs_cl = all_embs
embeddings_list = []

for layer_idx in range(self.n_layers):
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
if perturbed:
random_noise = torch.rand_like(all_embs, device=all_embs.device)
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
embeddings_list.append(all_embs)
if layer_idx == self.layer_cl - 1:
all_embs_cl = all_embs
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items])
if perturbed:
return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl
return user_all_embeddings, item_all_embeddings

def calculate_cl_loss(self, x1, x2):
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temperature)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
return -torch.log(pos_score / ttl_score).mean()

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True)
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)

user = torch.unique(interaction[self.USER_ID])
pos_item = torch.unique(interaction[self.ITEM_ID])

# calculate CL Loss
user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user])
item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item])

return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss)
114 changes: 114 additions & 0 deletions Recbole_GNN/recbole_gnn/model/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_sparse import matmul


class LightGCNConv(MessagePassing):
def __init__(self, dim):
super(LightGCNConv, self).__init__(aggr='add')
self.dim = dim

def forward(self, x, edge_index, edge_weight):
return self.propagate(edge_index, x=x, edge_weight=edge_weight)

def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.dim)


class BipartiteGCNConv(MessagePassing):
def __init__(self, dim):
super(BipartiteGCNConv, self).__init__(aggr='add')
self.dim = dim

def forward(self, x, edge_index, edge_weight, size):
return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)

def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j

def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.dim)


class BiGNNConv(MessagePassing):
r"""Propagate a layer of Bi-interaction GNN
.. math::
output = (L+I)EW_1 + LE \otimes EW_2
"""

def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.in_channels, self.out_channels = in_channels, out_channels
self.lin1 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)
self.lin2 = torch.nn.Linear(in_features=in_channels, out_features=out_channels)

def forward(self, x, edge_index, edge_weight):
x_prop = self.propagate(edge_index, x=x, edge_weight=edge_weight)
x_trans = self.lin1(x_prop + x)
x_inter = self.lin2(torch.mul(x_prop, x))
return x_trans + x_inter

def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t, x):
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


class SRGNNConv(MessagePassing):
def __init__(self, dim):
# mean aggregation to incorporate weight naturally
super(SRGNNConv, self).__init__(aggr='mean')

self.lin = torch.nn.Linear(dim, dim)

def forward(self, x, edge_index):
x = self.lin(x)
return self.propagate(edge_index, x=x)


class SRGNNCell(nn.Module):
def __init__(self, dim):
super(SRGNNCell, self).__init__()

self.dim = dim
self.incomming_conv = SRGNNConv(dim)
self.outcomming_conv = SRGNNConv(dim)

self.lin_ih = nn.Linear(2 * dim, 3 * dim)
self.lin_hh = nn.Linear(dim, 3 * dim)

self._reset_parameters()

def forward(self, hidden, edge_index):
input_in = self.incomming_conv(hidden, edge_index)
reversed_edge_index = torch.flip(edge_index, dims=[0])
input_out = self.outcomming_conv(hidden, reversed_edge_index)
inputs = torch.cat([input_in, input_out], dim=-1)

gi = self.lin_ih(inputs)
gh = self.lin_hh(hidden)
i_r, i_i, i_n = gi.chunk(3, -1)
h_r, h_i, h_n = gh.chunk(3, -1)
reset_gate = torch.sigmoid(i_r + h_r)
input_gate = torch.sigmoid(i_i + h_i)
new_gate = torch.tanh(i_n + reset_gate * h_n)
hy = (1 - input_gate) * hidden + input_gate * new_gate
return hy

def _reset_parameters(self):
stdv = 1.0 / np.sqrt(self.dim)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from recbole_gnn.model.sequential_recommender.gcegnn import GCEGNN
from recbole_gnn.model.sequential_recommender.gcsan import GCSAN
from recbole_gnn.model.sequential_recommender.lessr import LESSR
from recbole_gnn.model.sequential_recommender.niser import NISER
from recbole_gnn.model.sequential_recommender.sgnnhn import SGNNHN
from recbole_gnn.model.sequential_recommender.srgnn import SRGNN
from recbole_gnn.model.sequential_recommender.tagnn import TAGNN
277 changes: 277 additions & 0 deletions Recbole_GNN/recbole_gnn/model/sequential_recommender/gcegnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# @Time : 2022/3/22
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

r"""
GCE-GNN
################################################
Reference:
Ziyang Wang et al. "Global Context Enhanced Graph Neural Networks for Session-based Recommendation." in SIGIR 2020.
Reference code:
https://github.com/CCIIPLab/GCE-GNN
"""

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from recbole.model.loss import BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender


class LocalAggregator(MessagePassing):
def __init__(self, dim, alpha):
super().__init__(aggr='add')
self.edge_emb = nn.Embedding(4, dim)
self.leakyrelu = nn.LeakyReLU(alpha)

def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_j, x_i, edge_attr, index, ptr, size_i):
x = x_j * x_i
a = self.edge_emb(edge_attr)
e = (x * a).sum(dim=-1)
e = self.leakyrelu(e)
e = softmax(e, index, ptr, size_i)
return e.unsqueeze(-1) * x_j


class GlobalAggregator(nn.Module):
def __init__(self, dim, dropout, act=torch.relu):
super(GlobalAggregator, self).__init__()
self.dropout = dropout
self.act = act
self.dim = dim

self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim))
self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1))
self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim))
self.bias = nn.Parameter(torch.Tensor(self.dim))

def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None):
if extra_vector is not None:
alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1)
alpha = F.leaky_relu(alpha, negative_slope=0.2)
alpha = torch.matmul(alpha, self.w_2).squeeze(-1)
alpha = torch.softmax(alpha, -1).unsqueeze(-1)
neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2)
else:
neighbor_vector = torch.mean(neighbor_vector, dim=2)
# self_vectors = F.dropout(self_vectors, 0.5, training=self.training)
output = torch.cat([self_vectors, neighbor_vector], -1)
output = F.dropout(output, self.dropout, training=self.training)
output = torch.matmul(output, self.w_3)
output = output.view(batch_size, -1, self.dim)
output = self.act(output)
return output


class GCEGNN(SequentialRecommender):
def __init__(self, config, dataset):
super(GCEGNN, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.leakyrelu_alpha = config['leakyrelu_alpha']
self.dropout_local = config['dropout_local']
self.dropout_global = config['dropout_global']
self.dropout_gcn = config['dropout_gcn']
self.device = config['device']
self.loss_type = config['loss_type']
self.build_global_graph = config['build_global_graph']
self.sample_num = config['sample_num']
self.hop = config['hop']
self.max_seq_length = dataset.field2seqlen[self.ITEM_SEQ]

# global graph construction
self.global_graph = None
if self.build_global_graph:
self.global_adj, self.global_weight = self.construct_global_graph(dataset)

# item embedding
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.pos_embedding = nn.Embedding(self.max_seq_length, self.embedding_size)

# define layers and loss
# Aggregator
self.local_agg = LocalAggregator(self.embedding_size, self.leakyrelu_alpha)
global_agg_list = []
for i in range(self.hop):
global_agg_list.append(GlobalAggregator(self.embedding_size, self.dropout_gcn))
self.global_agg = nn.ModuleList(global_agg_list)

self.w_1 = nn.Linear(2 * self.embedding_size, self.embedding_size, bias=False)
self.w_2 = nn.Linear(self.embedding_size, 1, bias=False)
self.glu1 = nn.Linear(self.embedding_size, self.embedding_size)
self.glu2 = nn.Linear(self.embedding_size, self.embedding_size, bias=False)
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

self.reset_parameters()
self.other_parameter_name = ['global_adj', 'global_weight']

def reset_parameters(self):
stdv = 1.0 / np.sqrt(self.embedding_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)

def _add_edge(self, graph, sid, tid):
if tid not in graph[sid]:
graph[sid][tid] = 0
graph[sid][tid] += 1

def construct_global_graph(self, dataset):
self.logger.info('Constructing global graphs.')
item_id_list = dataset.inter_feat['item_id_list']
src_item_ids = item_id_list[:,:4].tolist()
tgt_itme_id = dataset.inter_feat['item_id'].tolist()
global_graph = [{} for _ in range(self.n_items)]
for i in tqdm(range(len(tgt_itme_id)), desc='Converting: '):
tid = tgt_itme_id[i]
for sid in src_item_ids[i]:
if sid > 0:
self._add_edge(global_graph, tid, sid)
self._add_edge(global_graph, sid, tid)
global_adj = [[] for _ in range(self.n_items)]
global_weight = [[] for _ in range(self.n_items)]
for i in tqdm(range(self.n_items), desc='Sorting: '):
sorted_out_edges = [v for v in sorted(global_graph[i].items(), reverse=True, key=lambda x: x[1])]
global_adj[i] = [v[0] for v in sorted_out_edges[:self.sample_num]]
global_weight[i] = [v[1] for v in sorted_out_edges[:self.sample_num]]
if len(global_adj[i]) < self.sample_num:
for j in range(self.sample_num - len(global_adj[i])):
global_adj[i].append(0)
global_weight[i].append(0)
return torch.LongTensor(global_adj).to(self.device), torch.FloatTensor(global_weight).to(self.device)

def fusion(self, hidden, mask):
batch_size = hidden.shape[0]
length = hidden.shape[1]
pos_emb = self.pos_embedding.weight[:length]
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)

hs = torch.sum(hidden * mask, -2) / torch.sum(mask, 1)
hs = hs.unsqueeze(-2).expand(-1, length, -1)
nh = self.w_1(torch.cat([pos_emb, hidden], -1))
nh = torch.tanh(nh)
nh = torch.sigmoid(self.glu1(nh) + self.glu2(hs))
beta = self.w_2(nh)
beta = beta * mask
final_h = torch.sum(beta * hidden, 1)
return final_h

def forward(self, x, edge_index, edge_attr, alias_inputs, item_seq_len):
batch_size = alias_inputs.shape[0]
mask = alias_inputs.gt(0).unsqueeze(-1)
h = self.item_embedding(x)

# local
h_local = self.local_agg(h, edge_index, edge_attr)

# global
item_neighbors = [F.pad(x[alias_inputs], (0, self.max_seq_length - x[alias_inputs].shape[1]), "constant", 0)]
weight_neighbors = []
support_size = self.max_seq_length

for i in range(self.hop):
item_sample_i, weight_sample_i = self.global_adj[item_neighbors[-1].view(-1)], self.global_weight[item_neighbors[-1].view(-1)]
support_size *= self.sample_num
item_neighbors.append(item_sample_i.view(batch_size, support_size))
weight_neighbors.append(weight_sample_i.view(batch_size, support_size))

entity_vectors = [self.item_embedding(i) for i in item_neighbors]
weight_vectors = weight_neighbors

session_info = []
item_emb = h[alias_inputs] * mask

# mean
sum_item_emb = torch.sum(item_emb, 1) / torch.sum(mask.float(), 1)

# sum
# sum_item_emb = torch.sum(item_emb, 1)

sum_item_emb = sum_item_emb.unsqueeze(-2)
for i in range(self.hop):
session_info.append(sum_item_emb.repeat(1, entity_vectors[i].shape[1], 1))

for n_hop in range(self.hop):
entity_vectors_next_iter = []
shape = [batch_size, -1, self.sample_num, self.embedding_size]
for hop in range(self.hop - n_hop):
aggregator = self.global_agg[n_hop]
vector = aggregator(self_vectors=entity_vectors[hop],
neighbor_vector=entity_vectors[hop + 1].view(shape),
masks=None,
batch_size=batch_size,
neighbor_weight=weight_vectors[hop].view(batch_size, -1, self.sample_num),
extra_vector=session_info[hop])
entity_vectors_next_iter.append(vector)
entity_vectors = entity_vectors_next_iter

h_global = entity_vectors[0].view(batch_size, self.max_seq_length, self.embedding_size)
h_global = h_global[:,:alias_inputs.shape[1],:]

h_local = F.dropout(h_local, self.dropout_local, training=self.training)
h_global = F.dropout(h_global, self.dropout_global, training=self.training)
h_local = h_local[alias_inputs]

h_session = h_local + h_global
h_session = self.fusion(h_session, mask)
return h_session

def calculate_loss(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
edge_attr = interaction['edge_attr']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss

def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
x = interaction['x']
edge_index = interaction['edge_index']
edge_attr = interaction['edge_attr']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores

def full_sort_predict(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
edge_attr = interaction['edge_attr']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, edge_attr, alias_inputs, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores
165 changes: 165 additions & 0 deletions Recbole_GNN/recbole_gnn/model/sequential_recommender/gcsan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# @Time : 2022/3/7
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

r"""
GCSAN
################################################
Reference:
Chengfeng Xu et al. "Graph Contextualized Self-Attention Network for Session-based Recommendation." in IJCAI 2019.
"""

import torch
from torch import nn
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import EmbLoss, BPRLoss
from recbole.model.abstract_recommender import SequentialRecommender

from recbole_gnn.model.layers import SRGNNCell


class GCSAN(SequentialRecommender):
r"""GCSAN captures rich local dependencies via graph neural network,
and learns long-range dependencies by applying the self-attention mechanism.
Note:
In the original paper, the attention mechanism in the self-attention layer is a single head,
for the reusability of the project code, we use a unified transformer component.
According to the experimental results, we only applied regularization to embedding.
"""

def __init__(self, config, dataset):
super(GCSAN, self).__init__(config, dataset)

# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size'] # same as embedding_size
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']

self.step = config['step']
self.device = config['device']
self.weight = config['weight']
self.reg_weight = config['reg_weight']
self.loss_type = config['loss_type']
self.initializer_range = config['initializer_range']

# item embedding
self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0)

# define layers and loss
self.gnncell = SRGNNCell(self.hidden_size)
self.self_attention = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
)
self.reg_loss = EmbLoss()
if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

# parameters initialization
self.apply(self._init_weights)

def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
# mask for left-to-right unidirectional
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
subsequent_mask = subsequent_mask.long().to(item_seq.device)

extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

def forward(self, x, edge_index, alias_inputs, item_seq_len):
hidden = self.item_embedding(x)
for i in range(self.step):
hidden = self.gnncell(hidden, edge_index)

seq_hidden = hidden[alias_inputs]
# fetch the last hidden state of last timestamp
ht = self.gather_indexes(seq_hidden, item_seq_len - 1)

attention_mask = self.get_attention_mask(alias_inputs)
outputs = self.self_attention(seq_hidden, attention_mask, output_all_encoded_layers=True)
output = outputs[-1]
at = self.gather_indexes(output, item_seq_len - 1)
seq_output = self.weight * at + (1 - self.weight) * ht
return seq_output

def calculate_loss(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
reg_loss = self.reg_loss(self.item_embedding.weight)
total_loss = loss + self.reg_weight * reg_loss
return total_loss

def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores

def full_sort_predict(self, interaction):
x = interaction['x']
edge_index = interaction['edge_index']
alias_inputs = interaction['alias_inputs']
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(x, edge_index, alias_inputs, item_seq_len)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores
257 changes: 257 additions & 0 deletions Recbole_GNN/recbole_gnn/model/sequential_recommender/lessr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# @Time : 2022/3/11
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

r"""
LESSR
################################################
Reference:
Tianwen Chen and Raymond Chi-Wing Wong. "Handling Information Loss of Graph Neural Networks for Session-based Recommendation." in KDD 2020.
Reference code:
https://github.com/twchen/lessr
"""

import torch
from torch import nn
from torch_geometric.utils import softmax
from torch_geometric.nn import global_add_pool
from recbole.model.abstract_recommender import SequentialRecommender


class EOPA(nn.Module):
def __init__(
self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None
):
super().__init__()
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
self.feat_drop = nn.Dropout(feat_drop)
self.gru = nn.GRU(input_dim, input_dim, batch_first=True)
self.fc_self = nn.Linear(input_dim, output_dim, bias=False)
self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False)
self.activation = activation

def reducer(self, nodes):
m = nodes.mailbox['m'] # (num_nodes, deg, d)
# m[i]: the messages passed to the i-th node with in-degree equal to 'deg'
# the order of messages follows the order of incoming edges
# since the edges are sorted by occurrence time when the EOP multigraph is built
# the messages are in the order required by EOPA
_, hn = self.gru(m) # hn: (1, num_nodes, d)
return {'neigh': hn.squeeze(0)}

def forward(self, mg, feat):
import dgl.function as fn

with mg.local_scope():
if self.batch_norm is not None:
feat = self.batch_norm(feat)
mg.ndata['ft'] = self.feat_drop(feat)
if mg.number_of_edges() > 0:
mg.update_all(fn.copy_u('ft', 'm'), self.reducer)
neigh = mg.ndata['neigh']
rst = self.fc_self(feat) + self.fc_neigh(neigh)
else:
rst = self.fc_self(feat)
if self.activation is not None:
rst = self.activation(rst)
return rst


class SGAT(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
batch_norm=True,
feat_drop=0.0,
activation=None,
):
super().__init__()
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
self.feat_drop = nn.Dropout(feat_drop)
self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True)
self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False)
self.fc_v = nn.Linear(input_dim, output_dim, bias=False)
self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
self.activation = activation

def forward(self, sg, feat):
import dgl.ops as F

if self.batch_norm is not None:
feat = self.batch_norm(feat)
feat = self.feat_drop(feat)
q = self.fc_q(feat)
k = self.fc_k(feat)
v = self.fc_v(feat)
e = F.u_add_v(sg, q, k)
e = self.fc_e(torch.sigmoid(e))
a = F.edge_softmax(sg, e)
rst = F.u_mul_e_sum(sg, v, a)
if self.activation is not None:
rst = self.activation(rst)
return rst


class AttnReadout(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
batch_norm=True,
feat_drop=0.0,
activation=None,
):
super().__init__()
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
self.feat_drop = nn.Dropout(feat_drop)
self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False)
self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True)
self.fc_e = nn.Linear(hidden_dim, 1, bias=False)
self.fc_out = (
nn.Linear(input_dim, output_dim, bias=False)
if output_dim != input_dim else None
)
self.activation = activation

def forward(self, g, feat, last_nodes, batch):
if self.batch_norm is not None:
feat = self.batch_norm(feat)
feat = self.feat_drop(feat)
feat_u = self.fc_u(feat)
feat_v = self.fc_v(feat[last_nodes])
feat_v = torch.index_select(feat_v, dim=0, index=batch)
e = self.fc_e(torch.sigmoid(feat_u + feat_v))
alpha = softmax(e, batch)
feat_norm = feat * alpha
rst = global_add_pool(feat_norm, batch)
if self.fc_out is not None:
rst = self.fc_out(rst)
if self.activation is not None:
rst = self.activation(rst)
return rst


class LESSR(SequentialRecommender):
r"""LESSR analyzes the information losses when constructing session graphs,
and emphasises lossy session encoding problem and the ineffective long-range dependency capturing problem.
To solve the first problem, authors propose a lossless encoding scheme and an edge-order preserving aggregation layer.
To solve the second problem, authors propose a shortcut graph attention layer that effectively captures long-range dependencies.
Note:
We follow the original implementation, which requires DGL package.
We find it difficult to implement these functions via PyG, so we remain them.
If you would like to test this model, please install DGL.
"""

def __init__(self, config, dataset):
super().__init__(config, dataset)

embedding_dim = config['embedding_size']
self.num_layers = config['n_layers']
batch_norm = config['batch_norm']
feat_drop = config['feat_drop']
self.loss_type = config['loss_type']

self.item_embedding = nn.Embedding(self.n_items, embedding_dim, max_norm=1)
self.layers = nn.ModuleList()
input_dim = embedding_dim
for i in range(self.num_layers):
if i % 2 == 0:
layer = EOPA(
input_dim,
embedding_dim,
batch_norm=batch_norm,
feat_drop=feat_drop,
activation=nn.PReLU(embedding_dim),
)
else:
layer = SGAT(
input_dim,
embedding_dim,
embedding_dim,
batch_norm=batch_norm,
feat_drop=feat_drop,
activation=nn.PReLU(embedding_dim),
)
input_dim += embedding_dim
self.layers.append(layer)
self.readout = AttnReadout(
input_dim,
embedding_dim,
embedding_dim,
batch_norm=batch_norm,
feat_drop=feat_drop,
activation=nn.PReLU(embedding_dim),
)
input_dim += embedding_dim
self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None
self.feat_drop = nn.Dropout(feat_drop)
self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False)

if self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['CE']!")

def forward(self, x, edge_index_EOP, edge_index_shortcut, batch, is_last):
import dgl

mg = dgl.graph((edge_index_EOP[0], edge_index_EOP[1]), num_nodes=batch.shape[0])
sg = dgl.graph((edge_index_shortcut[0], edge_index_shortcut[1]), num_nodes=batch.shape[0])

feat = self.item_embedding(x)
for i, layer in enumerate(self.layers):
if i % 2 == 0:
out = layer(mg, feat)
else:
out = layer(sg, feat)
feat = torch.cat([out, feat], dim=1)
sr_g = self.readout(mg, feat, is_last, batch)
sr_l = feat[is_last]
sr = torch.cat([sr_l, sr_g], dim=1)
if self.batch_norm is not None:
sr = self.batch_norm(sr)
sr = self.fc_sr(self.feat_drop(sr))
return sr

def calculate_loss(self, interaction):
x = interaction['x']
edge_index_EOP = interaction['edge_index_EOP']
edge_index_shortcut = interaction['edge_index_shortcut']
batch = interaction['batch']
is_last = interaction['is_last']
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
pos_items = interaction[self.POS_ITEM_ID]
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
return loss

def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
x = interaction['x']
edge_index_EOP = interaction['edge_index_EOP']
edge_index_shortcut = interaction['edge_index_shortcut']
batch = interaction['batch']
is_last = interaction['is_last']
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores

def full_sort_predict(self, interaction):
x = interaction['x']
edge_index_EOP = interaction['edge_index_EOP']
edge_index_shortcut = interaction['edge_index_shortcut']
batch = interaction['batch']
is_last = interaction['is_last']
seq_output = self.forward(x, edge_index_EOP, edge_index_shortcut, batch, is_last)
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B, n_items]
return scores
Loading

0 comments on commit 2ca9622

Please sign in to comment.