From 6f5c4fc4229cf7029cec3b72d4f6263223561585 Mon Sep 17 00:00:00 2001 From: Ruslan Izmaylov <lord.rik@yandex.ru> Date: Wed, 6 Nov 2024 16:20:14 +0000 Subject: [PATCH] remove hive support from get_spark_session function --- examples/01_replay_basics.ipynb | 90 ++++++++++++++++++--------------- replay/utils/session_handler.py | 7 ++- 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/examples/01_replay_basics.ipynb b/examples/01_replay_basics.ipynb index 3f139d7e6..d99cf9f29 100644 --- a/examples/01_replay_basics.ipynb +++ b/examples/01_replay_basics.ipynb @@ -53,7 +53,9 @@ "outputs": [], "source": [ "import warnings\n", + "\n", "from optuna.exceptions import ExperimentalWarning\n", + "\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "warnings.filterwarnings(\"ignore\", category=ExperimentalWarning)" ] @@ -65,7 +67,6 @@ "outputs": [], "source": [ "import pandas as pd\n", - "\n", "from pyspark.sql import SparkSession\n", "\n", "from replay.metrics import Coverage, HitRate, NDCG, MAP, Experiment, OfflineMetrics\n", @@ -76,8 +77,12 @@ "\n", "from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType\n", "from replay.data.dataset_utils import DatasetLabelEncoder\n", - "\n", - "from replay.models import ALSWrap, ItemKNN, SLIM\n" + "from replay.metrics import MAP, NDCG, Coverage, Experiment, HitRate, OfflineMetrics\n", + "from replay.models import SLIM, ALSWrap, ItemKNN\n", + "from replay.splitters import TwoStageSplitter\n", + "from replay.utils.model_handler import load, load_encoder, save, save_encoder\n", + "from replay.utils.session_handler import State, get_spark_session\n", + "from replay.utils.spark_utils import convert2spark, get_log_info" ] }, { @@ -87,7 +92,7 @@ "outputs": [], "source": [ "K = 5\n", - "SEED=42" + "SEED = 42" ] }, { @@ -105,7 +110,7 @@ "\n", "- Option 1: use default RePlay `SparkSession`\n", "- You can pass you own session to RePlay. Class `State` stores current session. Here you also have two options: \n", - " - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions \n", + " - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions\n", " - Option 3: create `SparkSession` from scratch\n" ] }, @@ -126,7 +131,7 @@ "outputs": [], "source": [ "spark = State().session\n", - "spark.sparkContext.setLogLevel('ERROR')\n", + "spark.sparkContext.setLogLevel(\"ERROR\")\n", "spark" ] }, @@ -140,7 +145,7 @@ " # get current spark session configuration:\n", " conf = session.sparkContext.getConf().getAll()\n", " # get num partitions\n", - " print(f'{conf_name}: {dict(conf)[conf_name]}')" + " print(f\"{conf_name}: {dict(conf)[conf_name]}\")" ] }, { @@ -157,7 +162,7 @@ } ], "source": [ - "print_config_param(spark, 'spark.sql.shuffle.partitions')" + "print_config_param(spark, \"spark.sql.shuffle.partitions\")" ] }, { @@ -193,7 +198,7 @@ } ], "source": [ - "print_config_param(spark, 'spark.sql.shuffle.partitions')" + "print_config_param(spark, \"spark.sql.shuffle.partitions\")" ] }, { @@ -220,16 +225,16 @@ "source": [ "spark.stop()\n", "session = (\n", - " SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n", - " .config(\"spark.sql.shuffle.partitions\", \"50\")\n", - " .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n", - " .config(\"spark.driver.host\", \"localhost\")\n", - " .master(\"local[*]\")\n", - " .enableHiveSupport()\n", - " .getOrCreate()\n", - " )\n", + " SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n", + " .config(\"spark.sql.shuffle.partitions\", \"50\")\n", + " .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n", + " .config(\"spark.driver.host\", \"localhost\")\n", + " .master(\"local[*]\")\n", + " .enableHiveSupport()\n", + " .getOrCreate()\n", + ")\n", "spark = State(session).session\n", - "print_config_param(spark, 'spark.sql.shuffle.partitions')" + "print_config_param(spark, \"spark.sql.shuffle.partitions\")" ] }, { @@ -281,7 +286,7 @@ "source": [ "spark.stop()\n", "spark = State(get_spark_session()).session\n", - "spark.sparkContext.setLogLevel('ERROR')\n", + "spark.sparkContext.setLogLevel(\"ERROR\")\n", "spark" ] }, @@ -549,7 +554,7 @@ ], "source": [ "filtered_df = filter_out_low_ratings(df_spark, value=3, rating_column=\"rating\")\n", - "get_log_info(filtered_df, user_col='user_id', item_col='item_id')" + "get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")" ] }, { @@ -576,8 +581,8 @@ } ], "source": [ - "filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by='user_id')\n", - "get_log_info(filtered_df, user_col='user_id', item_col='item_id')" + "filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by=\"user_id\")\n", + "get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")" ] }, { @@ -649,7 +654,7 @@ " second_divide_size=K,\n", " first_divide_size=500,\n", " seed=SEED,\n", - " shuffle=True\n", + " shuffle=True,\n", ")\n", "train, test = splitter.split(filtered_df)\n", "print(train.count(), test.count())" @@ -935,7 +940,9 @@ "print(\n", " NDCG(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n", " MAP(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n", - " HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n", + " HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(\n", + " recs, test_dataset.interactions\n", + " ),\n", " Coverage(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n", ")" ] @@ -985,7 +992,9 @@ "source": [ "offline_metrics = OfflineMetrics(\n", " [NDCG(K), MAP(K), HitRate([1, K]), Coverage(K)],\n", - " query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n", + " query_column=\"user_id\",\n", + " item_column=\"item_id\",\n", + " rating_column=\"rating\",\n", ")\n", "offline_metrics(recs, test_dataset.interactions, train_dataset.interactions)" ] @@ -1005,7 +1014,9 @@ " ],\n", " test_dataset.interactions,\n", " train_dataset.interactions,\n", - " query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n", + " query_column=\"user_id\",\n", + " item_column=\"item_id\",\n", + " rating_column=\"rating\",\n", ")" ] }, @@ -1149,12 +1160,7 @@ "def fit_predict_evaluate(model, experiment, name):\n", " model.fit(train_dataset)\n", "\n", - " recs = model.predict(\n", - " dataset=train_dataset,\n", - " k=K,\n", - " queries=test_dataset.query_ids,\n", - " filter_seen_items=False\n", - " )\n", + " recs = model.predict(dataset=train_dataset, k=K, queries=test_dataset.query_ids, filter_seen_items=False)\n", "\n", " experiment.add_result(name, recs)\n", " return recs" @@ -1167,9 +1173,9 @@ "outputs": [], "source": [ "%%time\n", - "recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, 'SLIM_optimized')\n", + "recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, \"SLIM_optimized\")\n", "recs.cache() #caching for further processing\n", - "metrics.results.sort_values('NDCG@5', ascending=False)" + "metrics.results.sort_values(\"NDCG@5\", ascending=False)" ] }, { @@ -1308,7 +1314,7 @@ ], "source": [ "%%time\n", - "recs.write.parquet(path='./slim_recs.parquet', mode='overwrite')" + "recs.write.parquet(path=\"./slim_recs.parquet\", mode=\"overwrite\")" ] }, { @@ -1369,8 +1375,8 @@ ], "source": [ "%%time\n", - "save(slim, path='./slim_best_params')\n", - "slim_loaded = load('./slim_best_params')" + "save(slim, path=\"./slim_best_params\")\n", + "slim_loaded = load(\"./slim_best_params\")" ] }, { @@ -1559,8 +1565,8 @@ ], "source": [ "%%time\n", - "recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, 'ALS')\n", - "metrics.results.sort_values('NDCG@5', ascending=False)" + "recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, \"ALS\")\n", + "metrics.results.sort_values(\"NDCG@5\", ascending=False)" ] }, { @@ -1662,8 +1668,8 @@ ], "source": [ "%%time\n", - "recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, 'ItemKNN')\n", - "metrics.results.sort_values('NDCG@5', ascending=False)" + "recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, \"ItemKNN\")\n", + "metrics.results.sort_values(\"NDCG@5\", ascending=False)" ] }, { @@ -1795,7 +1801,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.9.19" }, "name": "movielens_nmf.ipynb", "pycharm": { diff --git a/replay/utils/session_handler.py b/replay/utils/session_handler.py index b2be25e20..08226bf15 100644 --- a/replay/utils/session_handler.py +++ b/replay/utils/session_handler.py @@ -71,7 +71,7 @@ def get_spark_session( shuffle_partitions = os.cpu_count() * 3 driver_memory = f"{spark_memory}g" user_home = os.environ["HOME"] - spark = ( + spark_session_builder = ( SparkSession.builder.config("spark.driver.memory", driver_memory) .config( "spark.driver.extraJavaOptions", @@ -87,10 +87,9 @@ def get_spark_session( .config("spark.kryoserializer.buffer.max", "256m") .config("spark.files.overwrite", "true") .master(f"local[{'*' if core_count == -1 else core_count}]") - .enableHiveSupport() - .getOrCreate() ) - return spark + + return spark_session_builder.getOrCreate() def logger_with_settings() -> logging.Logger: