From 37012c8c1f5a4ddc35f42689f9c1a10999cfd275 Mon Sep 17 00:00:00 2001 From: guridon Date: Mon, 19 Feb 2024 17:30:17 +0000 Subject: [PATCH] EASE_baseline --- general/EASE/EASE.ipynb | 223 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 general/EASE/EASE.ipynb diff --git a/general/EASE/EASE.ipynb b/general/EASE/EASE.ipynb new file mode 100644 index 0000000..84952af --- /dev/null +++ b/general/EASE/EASE.ipynb @@ -0,0 +1,223 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from scipy.sparse import csr_matrix\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 실험용" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# 데이터 로드\n", + "data_file = \"/opt/ml/input/data/train/train_ratings.csv\"\n", + "rating_df = pd.read_csv(data_file)\n", + "\n", + "# user와 item에 대한 고유한 인덱스 생성\n", + "user_list = rating_df['user'].unique()\n", + "item_list = rating_df['item'].unique()\n", + "\n", + "user_to_index = {user: index for index, user in enumerate(user_list)}\n", + "item_to_index = {item: index for index, item in enumerate(item_list)}\n", + "\n", + "# 데이터셋에서 user, item을 인덱스로 변환\n", + "rating_df['user'] = rating_df['user'].map(user_to_index)\n", + "rating_df['item'] = rating_df['item'].map(item_to_index)\n", + "\n", + "num_users = len(user_to_index)\n", + "num_items = len(item_to_index)\n", + "\n", + "# csr_matrix 생성\n", + "rows = rating_df['user'].values\n", + "cols = rating_df['item'].values\n", + "# data = np.ones(len(rating_df)) --> 0.1599\n", + "data = np.full(len(rating_df), 0.9)\n", + "rating_matrix = csr_matrix((data, (rows, cols)), shape=(num_users, num_items))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "class EASE:\n", + " def __init__(self, _lambda):\n", + " self.B = None\n", + " self._lambda = _lambda\n", + "\n", + " def train(self, X):\n", + " try:\n", + " # csr_matrix --> dense\n", + " X_dense = X.toarray()\n", + " print(\"X shape:\", X_dense.shape)\n", + " \n", + " G = X_dense.T @ X_dense\n", + " print(\"G shape:\", G.shape)\n", + " \n", + " diag_indices = np.diag_indices_from(G)\n", + " G[diag_indices] += self._lambda\n", + " print(\"G after adding lambda on diag:\", G.shape)\n", + " \n", + " # dense --> pseudo-inverse 계산\n", + " P = np.linalg.pinv(G)\n", + " print(\"P shape:\", P.shape)\n", + " \n", + " self.B = P / -np.diag(P)\n", + " self.B[diag_indices] = 0\n", + " print(\"Final B shape:\", self.B.shape)\n", + " except Exception as e:\n", + " print(\"Error occurred:\", e)\n", + "\n", + " def predict(self, X):\n", + " # csr_matrix --> dense\n", + " X_dense = X.toarray() if isinstance(X, csr_matrix) else X\n", + " return X_dense @ self.B" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X shape: (31360, 6807)\n", + "G shape: (6807, 6807)\n", + "G after adding lambda on diag: (6807, 6807)\n", + "P shape: (6807, 6807)\n", + "Final B shape: (6807, 6807)\n" + ] + } + ], + "source": [ + "# EASE 모델 학습\n", + "_lambda = 500\n", + "ease = EASE(_lambda)\n", + "ease.train(rating_matrix)\n", + "\n", + "# 예측\n", + "predictions = ease.predict(rating_matrix)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# 상위 N개 아이템 추출\n", + "\n", + "N = 10\n", + "top_n_items_per_user = []\n", + "\n", + "# 이미 평가한 아이템은 제외\n", + "predictions[rating_matrix.nonzero()] = -np.inf\n", + "\n", + "for user_idx in range(predictions.shape[0]):\n", + " user_predictions = predictions[user_idx, :]\n", + " top_n_indices = np.argpartition(user_predictions, -N)[-N:]\n", + " top_n_indices_sorted = top_n_indices[np.argsort(user_predictions[top_n_indices])[::-1]]\n", + " top_n_items_per_user.append(top_n_indices_sorted)\n", + "\n", + "# 인덱스를 실제 아이템 ID로 변환\n", + "index_to_item = {index: item for item, index in item_to_index.items()}\n", + "top_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in top_n_items_per_user]\n", + "\n", + "# 제출 파일 생성\n", + "result = []\n", + "for user_id, items in zip(user_list, top_n_items_per_user_ids):\n", + " for item_id in items:\n", + " result.append((user_id, item_id))\n", + "\n", + "# 저장\n", + "current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n", + "args_str = f\"submission_EASE-lam{_lambda}_tm{current_time}\"\n", + "submission_df = pd.DataFrame(result, columns=['user', 'item'])\n", + "submission_df.to_csv(f\"{args_str}.csv\", index=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# 하위 N개 아이템 추출\n", + "\n", + "N = 80\n", + "bottom_n_items_per_user = [] # 각 사용자별로 하위 N개 아이템의 인덱스를 저장할 리스트\n", + "\n", + "for user_idx in range(predictions.shape[0]):\n", + " user_predictions = predictions[user_idx, :]\n", + " bottom_n_indices = np.argpartition(user_predictions, N-1)[:N]\n", + " bottom_n_indices_sorted = bottom_n_indices[np.argsort(user_predictions[bottom_n_indices])]\n", + " bottom_n_items_per_user.append(bottom_n_indices_sorted)\n", + "\n", + "\n", + "bottom_n_items_per_user = np.array(bottom_n_items_per_user)\n", + "\n", + "# 인덱스를 실제 아이템 ID로 변환\n", + "index_to_item = {index: item for item, index in item_to_index.items()}\n", + "bottom_n_items_per_user_ids = [[index_to_item[idx] for idx in user_items] for user_items in bottom_n_items_per_user]\n", + "\n", + "# 제출 파일 생성\n", + "result = []\n", + "for user_id, items in zip(user_list, bottom_n_items_per_user_ids):\n", + " for item_id in items:\n", + " result.append((user_id, item_id))\n", + "\n", + "# 저장\n", + "current_time = datetime.now().strftime('%Y%m%d-%H%M%S')\n", + "args_str = f\"Negative_EASE-lam{_lambda}_tm{current_time}\"\n", + "submission_df = pd.DataFrame(result, columns=['user', 'item'])\n", + "submission_df.to_csv(f\"{args_str}.csv\", index=False)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.18 ('EASE': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "78251029becf62da8b58a8ef92d6c9c1752dac293e5a0befc81af987f28f6887" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}