-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from boostcampaitech6/feat/32-VASP_baseline
Feat:Implement VASP_baseline #32
- Loading branch information
Showing
2 changed files
with
783 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"from scipy.sparse import csr_matrix\n", | ||
"from datetime import datetime" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# 실험용" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 데이터 로드\n", | ||
"data_file = \"/opt/ml/input/data/train/train_ratings.csv\"\n", | ||
"rating_df = pd.read_csv(data_file)\n", | ||
"\n", | ||
"# user와 item에 대한 고유한 인덱스 생성\n", | ||
"user_list = rating_df['user'].unique()\n", | ||
"item_list = rating_df['item'].unique()\n", | ||
"\n", | ||
"user_to_index = {user: index for index, user in enumerate(user_list)}\n", | ||
"item_to_index = {item: index for index, item in enumerate(item_list)}\n", | ||
"\n", | ||
"# 데이터셋에서 user, item을 인덱스로 변환\n", | ||
"rating_df['user'] = rating_df['user'].map(user_to_index)\n", | ||
"rating_df['item'] = rating_df['item'].map(item_to_index)\n", | ||
"\n", | ||
"num_users = len(user_to_index)\n", | ||
"num_items = len(item_to_index)\n", | ||
"\n", | ||
"# csr_matrix 생성\n", | ||
"rows = rating_df['user'].values\n", | ||
"cols = rating_df['item'].values\n", | ||
"\n", | ||
"rating_score = 0.9\n", | ||
"data = np.full(len(rating_df), rating_score)\n", | ||
"rating_matrix = csr_matrix((data, (rows, cols)), shape=(num_users, num_items))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class EASE:\n", | ||
" def __init__(self, _lambda):\n", | ||
" self.B = None\n", | ||
" self._lambda = _lambda\n", | ||
"\n", | ||
" def train(self, X):\n", | ||
" try:\n", | ||
" # csr_matrix --> dense\n", | ||
" X_dense = X.toarray()\n", | ||
" print(\"X shape:\", X_dense.shape)\n", | ||
" \n", | ||
" G = X_dense.T @ X_dense\n", | ||
" print(\"G shape:\", G.shape)\n", | ||
" \n", | ||
" diag_indices = np.diag_indices_from(G)\n", | ||
" G[diag_indices] += self._lambda\n", | ||
" print(\"G after adding lambda on diag:\", G.shape)\n", | ||
" \n", | ||
" # dense --> pseudo-inverse 계산\n", | ||
" P = np.linalg.pinv(G)\n", | ||
" print(\"P shape:\", P.shape)\n", | ||
" \n", | ||
" self.B = P / -np.diag(P)\n", | ||
" self.B[diag_indices] = 0\n", | ||
" print(\"Final B shape:\", self.B.shape)\n", | ||
" except Exception as e:\n", | ||
" print(\"Error occurred:\", e)\n", | ||
"\n", | ||
" def predict(self, X):\n", | ||
" # csr_matrix --> dense\n", | ||
" X_dense = X.toarray() if isinstance(X, csr_matrix) else X\n", | ||
" return X_dense @ self.B" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"X shape: (31360, 6807)\n", | ||
"G shape: (6807, 6807)\n", | ||
"G after adding lambda on diag: (6807, 6807)\n", | ||
"P shape: (6807, 6807)\n", | ||
"Final B shape: (6807, 6807)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# EASE 모델 학습\n", | ||
"_lambda = 750\n", | ||
"ease = EASE(_lambda)\n", | ||
"ease.train(rating_matrix)\n", | ||
"\n", | ||
"# 예측\n", | ||
"predictions = ease.predict(rating_matrix)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 상위 N개 아이템 추출\n", | ||
"\n", | ||
"N = 10\n", | ||
"top_n_items_per_user = []\n", | ||
"\n", | ||
"# 이미 평가한 아이템은 제외\n", | ||
"predictions[rating_matrix.nonzero()] = -np.inf\n", | ||
"\n", | ||
"for user_idx in range(predictions.shape[0]):\n", | ||
" user_predictions = predictions[user_idx, :]\n", | ||
" top_n_indices = np.argpartition(user_predictions, -N)[-N:]\n", | ||
" top_n_indices_sorted = top_n_indices[np.argsort(user_predictions[top_n_indices])[::-1]]\n", | ||
" top_n_items_per_user.append(top_n_indices_sorted)\n", | ||
"\n", | ||
"# 인덱스를 실제 아이템 ID로 변환\n", | ||
"index_to_item = {index: item for item, index in item_to_index.items()}\n", | ||
"top_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in top_n_items_per_user]\n", | ||
"\n", | ||
"# 제출 파일 생성\n", | ||
"result = []\n", | ||
"for user_id, items in zip(user_list, top_n_items_per_user_ids):\n", | ||
" for item_id in items:\n", | ||
" result.append((user_id, item_id))\n", | ||
"\n", | ||
"# 저장\n", | ||
"current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n", | ||
"args_str = f\"submission_EASE-lam{_lambda}_tm{current_time}\"\n", | ||
"submission_df = pd.DataFrame(result, columns=['user', 'item'])\n", | ||
"submission_df.to_csv(f\"{args_str}.csv\", index=False)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 하위 N개 아이템 추출\n", | ||
"\n", | ||
"N = 80\n", | ||
"bottom_n_items_per_user = [] # 각 사용자별로 하위 N개 아이템의 인덱스를 저장할 리스트\n", | ||
"\n", | ||
"for user_idx in range(predictions.shape[0]):\n", | ||
" user_predictions = predictions[user_idx, :]\n", | ||
" bottom_n_indices = np.argpartition(user_predictions, N-1)[:N]\n", | ||
" bottom_n_indices_sorted = bottom_n_indices[np.argsort(user_predictions[bottom_n_indices])]\n", | ||
" bottom_n_items_per_user.append(bottom_n_indices_sorted)\n", | ||
"\n", | ||
"\n", | ||
"bottom_n_items_per_user = np.array(bottom_n_items_per_user)\n", | ||
"\n", | ||
"# 인덱스를 실제 아이템 ID로 변환\n", | ||
"index_to_item = {index: item for item, index in item_to_index.items()}\n", | ||
"bottom_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in bottom_n_items_per_user]\n", | ||
"\n", | ||
"# 제출 파일 생성\n", | ||
"result = []\n", | ||
"for user_id, items in zip(user_list, bottom_n_items_per_user_ids):\n", | ||
" for item_id in items:\n", | ||
" result.append((user_id, item_id))\n", | ||
"\n", | ||
"# 저장\n", | ||
"current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n", | ||
"args_str = f\"Negative_EASE-lam{_lambda}_tm{current_time}\"\n", | ||
"submission_df = pd.DataFrame(result, columns=['user', 'item'])\n", | ||
"submission_df.to_csv(f\"{args_str}.csv\", index=False)\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.8.18 ('EASE': conda)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.18" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "78251029becf62da8b58a8ef92d6c9c1752dac293e5a0befc81af987f28f6887" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.