Skip to content

Commit

Permalink
Merge pull request #33 from boostcampaitech6/feat/32-VASP_baseline
Browse files Browse the repository at this point in the history
Feat:Implement VASP_baseline #32
  • Loading branch information
guridon authored Apr 7, 2024
2 parents 5edac62 + 3843430 commit 260638b
Show file tree
Hide file tree
Showing 2 changed files with 783 additions and 0 deletions.
224 changes: 224 additions & 0 deletions notebooks/ksj/EASE.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.sparse import csr_matrix\n",
"from datetime import datetime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 실험용"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# 데이터 로드\n",
"data_file = \"/opt/ml/input/data/train/train_ratings.csv\"\n",
"rating_df = pd.read_csv(data_file)\n",
"\n",
"# user와 item에 대한 고유한 인덱스 생성\n",
"user_list = rating_df['user'].unique()\n",
"item_list = rating_df['item'].unique()\n",
"\n",
"user_to_index = {user: index for index, user in enumerate(user_list)}\n",
"item_to_index = {item: index for index, item in enumerate(item_list)}\n",
"\n",
"# 데이터셋에서 user, item을 인덱스로 변환\n",
"rating_df['user'] = rating_df['user'].map(user_to_index)\n",
"rating_df['item'] = rating_df['item'].map(item_to_index)\n",
"\n",
"num_users = len(user_to_index)\n",
"num_items = len(item_to_index)\n",
"\n",
"# csr_matrix 생성\n",
"rows = rating_df['user'].values\n",
"cols = rating_df['item'].values\n",
"\n",
"rating_score = 0.9\n",
"data = np.full(len(rating_df), rating_score)\n",
"rating_matrix = csr_matrix((data, (rows, cols)), shape=(num_users, num_items))\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class EASE:\n",
" def __init__(self, _lambda):\n",
" self.B = None\n",
" self._lambda = _lambda\n",
"\n",
" def train(self, X):\n",
" try:\n",
" # csr_matrix --> dense\n",
" X_dense = X.toarray()\n",
" print(\"X shape:\", X_dense.shape)\n",
" \n",
" G = X_dense.T @ X_dense\n",
" print(\"G shape:\", G.shape)\n",
" \n",
" diag_indices = np.diag_indices_from(G)\n",
" G[diag_indices] += self._lambda\n",
" print(\"G after adding lambda on diag:\", G.shape)\n",
" \n",
" # dense --> pseudo-inverse 계산\n",
" P = np.linalg.pinv(G)\n",
" print(\"P shape:\", P.shape)\n",
" \n",
" self.B = P / -np.diag(P)\n",
" self.B[diag_indices] = 0\n",
" print(\"Final B shape:\", self.B.shape)\n",
" except Exception as e:\n",
" print(\"Error occurred:\", e)\n",
"\n",
" def predict(self, X):\n",
" # csr_matrix --> dense\n",
" X_dense = X.toarray() if isinstance(X, csr_matrix) else X\n",
" return X_dense @ self.B"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X shape: (31360, 6807)\n",
"G shape: (6807, 6807)\n",
"G after adding lambda on diag: (6807, 6807)\n",
"P shape: (6807, 6807)\n",
"Final B shape: (6807, 6807)\n"
]
}
],
"source": [
"# EASE 모델 학습\n",
"_lambda = 750\n",
"ease = EASE(_lambda)\n",
"ease.train(rating_matrix)\n",
"\n",
"# 예측\n",
"predictions = ease.predict(rating_matrix)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# 상위 N개 아이템 추출\n",
"\n",
"N = 10\n",
"top_n_items_per_user = []\n",
"\n",
"# 이미 평가한 아이템은 제외\n",
"predictions[rating_matrix.nonzero()] = -np.inf\n",
"\n",
"for user_idx in range(predictions.shape[0]):\n",
" user_predictions = predictions[user_idx, :]\n",
" top_n_indices = np.argpartition(user_predictions, -N)[-N:]\n",
" top_n_indices_sorted = top_n_indices[np.argsort(user_predictions[top_n_indices])[::-1]]\n",
" top_n_items_per_user.append(top_n_indices_sorted)\n",
"\n",
"# 인덱스를 실제 아이템 ID로 변환\n",
"index_to_item = {index: item for item, index in item_to_index.items()}\n",
"top_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in top_n_items_per_user]\n",
"\n",
"# 제출 파일 생성\n",
"result = []\n",
"for user_id, items in zip(user_list, top_n_items_per_user_ids):\n",
" for item_id in items:\n",
" result.append((user_id, item_id))\n",
"\n",
"# 저장\n",
"current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n",
"args_str = f\"submission_EASE-lam{_lambda}_tm{current_time}\"\n",
"submission_df = pd.DataFrame(result, columns=['user', 'item'])\n",
"submission_df.to_csv(f\"{args_str}.csv\", index=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 하위 N개 아이템 추출\n",
"\n",
"N = 80\n",
"bottom_n_items_per_user = [] # 각 사용자별로 하위 N개 아이템의 인덱스를 저장할 리스트\n",
"\n",
"for user_idx in range(predictions.shape[0]):\n",
" user_predictions = predictions[user_idx, :]\n",
" bottom_n_indices = np.argpartition(user_predictions, N-1)[:N]\n",
" bottom_n_indices_sorted = bottom_n_indices[np.argsort(user_predictions[bottom_n_indices])]\n",
" bottom_n_items_per_user.append(bottom_n_indices_sorted)\n",
"\n",
"\n",
"bottom_n_items_per_user = np.array(bottom_n_items_per_user)\n",
"\n",
"# 인덱스를 실제 아이템 ID로 변환\n",
"index_to_item = {index: item for item, index in item_to_index.items()}\n",
"bottom_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in bottom_n_items_per_user]\n",
"\n",
"# 제출 파일 생성\n",
"result = []\n",
"for user_id, items in zip(user_list, bottom_n_items_per_user_ids):\n",
" for item_id in items:\n",
" result.append((user_id, item_id))\n",
"\n",
"# 저장\n",
"current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n",
"args_str = f\"Negative_EASE-lam{_lambda}_tm{current_time}\"\n",
"submission_df = pd.DataFrame(result, columns=['user', 'item'])\n",
"submission_df.to_csv(f\"{args_str}.csv\", index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.18 ('EASE': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "78251029becf62da8b58a8ef92d6c9c1752dac293e5a0befc81af987f28f6887"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 260638b

Please sign in to comment.