From 6f5c4fc4229cf7029cec3b72d4f6263223561585 Mon Sep 17 00:00:00 2001
From: Ruslan Izmaylov <lord.rik@yandex.ru>
Date: Wed, 6 Nov 2024 16:20:14 +0000
Subject: [PATCH] remove hive support from get_spark_session function

---
 examples/01_replay_basics.ipynb | 90 ++++++++++++++++++---------------
 replay/utils/session_handler.py |  7 ++-
 2 files changed, 51 insertions(+), 46 deletions(-)

diff --git a/examples/01_replay_basics.ipynb b/examples/01_replay_basics.ipynb
index 3f139d7e6..d99cf9f29 100644
--- a/examples/01_replay_basics.ipynb
+++ b/examples/01_replay_basics.ipynb
@@ -53,7 +53,9 @@
    "outputs": [],
    "source": [
     "import warnings\n",
+    "\n",
     "from optuna.exceptions import ExperimentalWarning\n",
+    "\n",
     "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
     "warnings.filterwarnings(\"ignore\", category=ExperimentalWarning)"
    ]
@@ -65,7 +67,6 @@
    "outputs": [],
    "source": [
     "import pandas as pd\n",
-    "\n",
     "from pyspark.sql import SparkSession\n",
     "\n",
     "from replay.metrics import Coverage, HitRate, NDCG, MAP, Experiment, OfflineMetrics\n",
@@ -76,8 +77,12 @@
     "\n",
     "from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType\n",
     "from replay.data.dataset_utils import DatasetLabelEncoder\n",
-    "\n",
-    "from replay.models import ALSWrap, ItemKNN, SLIM\n"
+    "from replay.metrics import MAP, NDCG, Coverage, Experiment, HitRate, OfflineMetrics\n",
+    "from replay.models import SLIM, ALSWrap, ItemKNN\n",
+    "from replay.splitters import TwoStageSplitter\n",
+    "from replay.utils.model_handler import load, load_encoder, save, save_encoder\n",
+    "from replay.utils.session_handler import State, get_spark_session\n",
+    "from replay.utils.spark_utils import convert2spark, get_log_info"
    ]
   },
   {
@@ -87,7 +92,7 @@
    "outputs": [],
    "source": [
     "K = 5\n",
-    "SEED=42"
+    "SEED = 42"
    ]
   },
   {
@@ -105,7 +110,7 @@
     "\n",
     "- Option 1: use default RePlay `SparkSession`\n",
     "- You can pass you own session to RePlay. Class `State` stores current session. Here you also have two options: \n",
-    "    - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions \n",
+    "    - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions\n",
     "    - Option 3: create `SparkSession` from scratch\n"
    ]
   },
@@ -126,7 +131,7 @@
    "outputs": [],
    "source": [
     "spark = State().session\n",
-    "spark.sparkContext.setLogLevel('ERROR')\n",
+    "spark.sparkContext.setLogLevel(\"ERROR\")\n",
     "spark"
    ]
   },
@@ -140,7 +145,7 @@
     "    # get current spark session configuration:\n",
     "    conf = session.sparkContext.getConf().getAll()\n",
     "    # get num partitions\n",
-    "    print(f'{conf_name}: {dict(conf)[conf_name]}')"
+    "    print(f\"{conf_name}: {dict(conf)[conf_name]}\")"
    ]
   },
   {
@@ -157,7 +162,7 @@
     }
    ],
    "source": [
-    "print_config_param(spark, 'spark.sql.shuffle.partitions')"
+    "print_config_param(spark, \"spark.sql.shuffle.partitions\")"
    ]
   },
   {
@@ -193,7 +198,7 @@
     }
    ],
    "source": [
-    "print_config_param(spark, 'spark.sql.shuffle.partitions')"
+    "print_config_param(spark, \"spark.sql.shuffle.partitions\")"
    ]
   },
   {
@@ -220,16 +225,16 @@
    "source": [
     "spark.stop()\n",
     "session = (\n",
-    "        SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n",
-    "        .config(\"spark.sql.shuffle.partitions\", \"50\")\n",
-    "        .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n",
-    "        .config(\"spark.driver.host\", \"localhost\")\n",
-    "        .master(\"local[*]\")\n",
-    "        .enableHiveSupport()\n",
-    "        .getOrCreate()\n",
-    "    )\n",
+    "    SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n",
+    "    .config(\"spark.sql.shuffle.partitions\", \"50\")\n",
+    "    .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n",
+    "    .config(\"spark.driver.host\", \"localhost\")\n",
+    "    .master(\"local[*]\")\n",
+    "    .enableHiveSupport()\n",
+    "    .getOrCreate()\n",
+    ")\n",
     "spark = State(session).session\n",
-    "print_config_param(spark, 'spark.sql.shuffle.partitions')"
+    "print_config_param(spark, \"spark.sql.shuffle.partitions\")"
    ]
   },
   {
@@ -281,7 +286,7 @@
    "source": [
     "spark.stop()\n",
     "spark = State(get_spark_session()).session\n",
-    "spark.sparkContext.setLogLevel('ERROR')\n",
+    "spark.sparkContext.setLogLevel(\"ERROR\")\n",
     "spark"
    ]
   },
@@ -549,7 +554,7 @@
    ],
    "source": [
     "filtered_df = filter_out_low_ratings(df_spark, value=3, rating_column=\"rating\")\n",
-    "get_log_info(filtered_df, user_col='user_id', item_col='item_id')"
+    "get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")"
    ]
   },
   {
@@ -576,8 +581,8 @@
     }
    ],
    "source": [
-    "filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by='user_id')\n",
-    "get_log_info(filtered_df, user_col='user_id', item_col='item_id')"
+    "filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by=\"user_id\")\n",
+    "get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")"
    ]
   },
   {
@@ -649,7 +654,7 @@
     "    second_divide_size=K,\n",
     "    first_divide_size=500,\n",
     "    seed=SEED,\n",
-    "    shuffle=True\n",
+    "    shuffle=True,\n",
     ")\n",
     "train, test = splitter.split(filtered_df)\n",
     "print(train.count(), test.count())"
@@ -935,7 +940,9 @@
     "print(\n",
     "    NDCG(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
     "    MAP(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
-    "    HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
+    "    HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(\n",
+    "        recs, test_dataset.interactions\n",
+    "    ),\n",
     "    Coverage(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
     ")"
    ]
@@ -985,7 +992,9 @@
    "source": [
     "offline_metrics = OfflineMetrics(\n",
     "    [NDCG(K), MAP(K), HitRate([1, K]), Coverage(K)],\n",
-    "    query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n",
+    "    query_column=\"user_id\",\n",
+    "    item_column=\"item_id\",\n",
+    "    rating_column=\"rating\",\n",
     ")\n",
     "offline_metrics(recs, test_dataset.interactions, train_dataset.interactions)"
    ]
@@ -1005,7 +1014,9 @@
     "    ],\n",
     "    test_dataset.interactions,\n",
     "    train_dataset.interactions,\n",
-    "    query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n",
+    "    query_column=\"user_id\",\n",
+    "    item_column=\"item_id\",\n",
+    "    rating_column=\"rating\",\n",
     ")"
    ]
   },
@@ -1149,12 +1160,7 @@
     "def fit_predict_evaluate(model, experiment, name):\n",
     "    model.fit(train_dataset)\n",
     "\n",
-    "    recs = model.predict(\n",
-    "        dataset=train_dataset,\n",
-    "        k=K,\n",
-    "        queries=test_dataset.query_ids,\n",
-    "        filter_seen_items=False\n",
-    "    )\n",
+    "    recs = model.predict(dataset=train_dataset, k=K, queries=test_dataset.query_ids, filter_seen_items=False)\n",
     "\n",
     "    experiment.add_result(name, recs)\n",
     "    return recs"
@@ -1167,9 +1173,9 @@
    "outputs": [],
    "source": [
     "%%time\n",
-    "recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, 'SLIM_optimized')\n",
+    "recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, \"SLIM_optimized\")\n",
     "recs.cache() #caching for further processing\n",
-    "metrics.results.sort_values('NDCG@5', ascending=False)"
+    "metrics.results.sort_values(\"NDCG@5\", ascending=False)"
    ]
   },
   {
@@ -1308,7 +1314,7 @@
    ],
    "source": [
     "%%time\n",
-    "recs.write.parquet(path='./slim_recs.parquet', mode='overwrite')"
+    "recs.write.parquet(path=\"./slim_recs.parquet\", mode=\"overwrite\")"
    ]
   },
   {
@@ -1369,8 +1375,8 @@
    ],
    "source": [
     "%%time\n",
-    "save(slim, path='./slim_best_params')\n",
-    "slim_loaded = load('./slim_best_params')"
+    "save(slim, path=\"./slim_best_params\")\n",
+    "slim_loaded = load(\"./slim_best_params\")"
    ]
   },
   {
@@ -1559,8 +1565,8 @@
    ],
    "source": [
     "%%time\n",
-    "recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, 'ALS')\n",
-    "metrics.results.sort_values('NDCG@5', ascending=False)"
+    "recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, \"ALS\")\n",
+    "metrics.results.sort_values(\"NDCG@5\", ascending=False)"
    ]
   },
   {
@@ -1662,8 +1668,8 @@
    ],
    "source": [
     "%%time\n",
-    "recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, 'ItemKNN')\n",
-    "metrics.results.sort_values('NDCG@5', ascending=False)"
+    "recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, \"ItemKNN\")\n",
+    "metrics.results.sort_values(\"NDCG@5\", ascending=False)"
    ]
   },
   {
@@ -1795,7 +1801,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.18"
+   "version": "3.9.19"
   },
   "name": "movielens_nmf.ipynb",
   "pycharm": {
diff --git a/replay/utils/session_handler.py b/replay/utils/session_handler.py
index b2be25e20..08226bf15 100644
--- a/replay/utils/session_handler.py
+++ b/replay/utils/session_handler.py
@@ -71,7 +71,7 @@ def get_spark_session(
         shuffle_partitions = os.cpu_count() * 3
     driver_memory = f"{spark_memory}g"
     user_home = os.environ["HOME"]
-    spark = (
+    spark_session_builder = (
         SparkSession.builder.config("spark.driver.memory", driver_memory)
         .config(
             "spark.driver.extraJavaOptions",
@@ -87,10 +87,9 @@ def get_spark_session(
         .config("spark.kryoserializer.buffer.max", "256m")
         .config("spark.files.overwrite", "true")
         .master(f"local[{'*' if core_count == -1 else core_count}]")
-        .enableHiveSupport()
-        .getOrCreate()
     )
-    return spark
+
+    return spark_session_builder.getOrCreate()
 
 
 def logger_with_settings() -> logging.Logger: