Skip to content

Commit

Permalink
Refactor: dataset & evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
ChangZero committed Feb 14, 2024
1 parent 64d0d34 commit 6da717f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
12 changes: 8 additions & 4 deletions sequential/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import torch
# import wandb
import torch.nn as nn
import wandb

import argparse

from module.trainer import trainer
from module.utils import set_seed, parse_args_boolean
from module.utils import set_seed, parse_args_boolean, logging_conf, get_logger

logger = get_logger(logger_conf=logging_conf)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -41,4 +44,5 @@ def main():
trainer(config=args)

if __name__ == "__main__":
main()
main()

36 changes: 36 additions & 0 deletions sequential/module/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,42 @@ def __getitem__(self, user):

return torch.LongTensor(tokens), torch.LongTensor(labels)

def random_neg_sampling(self, rated_item : list, num_item_sample : int):
nge_samples = random.sample(list(self._all_items - set(rated_item)), num_item_sample)
return nge_samples


class SASRecDataSet(Dataset):
def __init__(self, user_train, max_len, num_user, num_item):
self.user_train = user_train
self.max_len = max_len
self.num_user = num_user
self.num_item = num_item
self._all_items = set([i for i in range(1, self.num_item + 1)])

def __len__(self):
# 총 user의 수 = 학습에 사용할 sequence의 수
return self.num_user

def __getitem__(self, user):

user_seq = self.user_train[user]
user_seq_len = len(user_seq)

seq = user_seq[-(user_seq_len) : -1]
seq = seq[-self.max_len :]

pos = user_seq[-(user_seq_len - 1) : ]
pos = pos[-self.max_len :]

neg = random.sample(list(self._all_items - set(user_seq)), len(pos))

seq = [0] * (self.max_len - len(seq)) + seq
pos = [0] * (self.max_len - len(pos)) + pos
neg = [0] * (self.max_len - len(neg)) + neg

return np.array(seq, dtype=np.int32), np.array(pos, dtype=np.int32), np.array(neg, dtype=np.int32)

def random_neg_sampling(self, rated_item : list, num_item_sample : int):
nge_samples = random.sample(list(self._all_items - set(rated_item)), num_item_sample)
return nge_samples
5 changes: 0 additions & 5 deletions sequential/module/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
def inference(model, user_train, user_valid, max_len, make_sequence_dataset, exp_name):
model.eval()

NDCG = 0.0 # NDCG@10
HIT = 0.0 # HIT@10


users = [user for user in range(make_sequence_dataset.num_user)]
result = []

Expand All @@ -34,4 +30,3 @@ def inference(model, user_train, user_valid, max_len, make_sequence_dataset, exp

submit_df = pd.DataFrame(data={'user': submit_user, 'item': submit_item}, columns=['user', 'item'])
return submit_df

8 changes: 4 additions & 4 deletions sequential/module/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train(model, criterion, optimizer, data_loader, device):
return loss_val


def evaluate(model, user_train, user_valid, max_len, bert4rec_dataset, make_sequence_dataset):
def evaluate(model, user_train, user_valid, max_len, dataset, make_sequence_dataset):
model.eval()

ndcg = 0.0 # NDCG@10
Expand All @@ -49,7 +49,7 @@ def evaluate(model, user_train, user_valid, max_len, bert4rec_dataset, make_sequ
for user in users:
seq = (user_train[user] + [make_sequence_dataset.num_item + 1])[-max_len:]
rated = user_train[user] + user_valid[user]
items = user_valid[user] + bert4rec_dataset.random_neg_sampling(rated_item = rated, num_item_sample = num_item_sample)
items = user_valid[user] + dataset.random_neg_sampling(rated_item = rated, num_item_sample = num_item_sample)

with torch.no_grad():
predictions = -model(np.array([seq]))
Expand Down Expand Up @@ -134,7 +134,7 @@ def trainer(config):
user_train = user_train,
user_valid = user_valid,
max_len = config.max_len,
bert4rec_dataset = bert4rec_dataset,
dataset = dataset,
make_sequence_dataset = make_sequence_dataset,
)

Expand Down Expand Up @@ -163,4 +163,4 @@ def trainer(config):
submission_artifact = wandb.Artifact(f'{exp_name}_submission', type='output')
submission_artifact.add_file(local_path=f"outputs/{exp_name}_submission.csv")
wandb.log_artifact(submission_artifact)
wandb.finish()
wandb.finish()

0 comments on commit 6da717f

Please sign in to comment.