Skip to content

Commit

Permalink
Merge branch 'feature/add_enable_hive_support_param' into 'main'
Browse files Browse the repository at this point in the history
remove hive support from get_spark_session function

See merge request ai-lab-pmo/mltools/recsys/RePlay!233
  • Loading branch information
monkey0head committed Nov 6, 2024
2 parents 5ee3c03 + 6f5c4fc commit 0016eea
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 46 deletions.
90 changes: 48 additions & 42 deletions examples/01_replay_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
"outputs": [],
"source": [
"import warnings\n",
"\n",
"from optuna.exceptions import ExperimentalWarning\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
"warnings.filterwarnings(\"ignore\", category=ExperimentalWarning)"
]
Expand All @@ -65,7 +67,6 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from pyspark.sql import SparkSession\n",
"\n",
"from replay.metrics import Coverage, HitRate, NDCG, MAP, Experiment, OfflineMetrics\n",
Expand All @@ -76,8 +77,12 @@
"\n",
"from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType\n",
"from replay.data.dataset_utils import DatasetLabelEncoder\n",
"\n",
"from replay.models import ALSWrap, ItemKNN, SLIM\n"
"from replay.metrics import MAP, NDCG, Coverage, Experiment, HitRate, OfflineMetrics\n",
"from replay.models import SLIM, ALSWrap, ItemKNN\n",
"from replay.splitters import TwoStageSplitter\n",
"from replay.utils.model_handler import load, load_encoder, save, save_encoder\n",
"from replay.utils.session_handler import State, get_spark_session\n",
"from replay.utils.spark_utils import convert2spark, get_log_info"
]
},
{
Expand All @@ -87,7 +92,7 @@
"outputs": [],
"source": [
"K = 5\n",
"SEED=42"
"SEED = 42"
]
},
{
Expand All @@ -105,7 +110,7 @@
"\n",
"- Option 1: use default RePlay `SparkSession`\n",
"- You can pass you own session to RePlay. Class `State` stores current session. Here you also have two options: \n",
" - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions \n",
" - Option 2: call `get_spark_session` to use default RePlay `SparkSession` with the custom driver memory and number of partitions\n",
" - Option 3: create `SparkSession` from scratch\n"
]
},
Expand All @@ -126,7 +131,7 @@
"outputs": [],
"source": [
"spark = State().session\n",
"spark.sparkContext.setLogLevel('ERROR')\n",
"spark.sparkContext.setLogLevel(\"ERROR\")\n",
"spark"
]
},
Expand All @@ -140,7 +145,7 @@
" # get current spark session configuration:\n",
" conf = session.sparkContext.getConf().getAll()\n",
" # get num partitions\n",
" print(f'{conf_name}: {dict(conf)[conf_name]}')"
" print(f\"{conf_name}: {dict(conf)[conf_name]}\")"
]
},
{
Expand All @@ -157,7 +162,7 @@
}
],
"source": [
"print_config_param(spark, 'spark.sql.shuffle.partitions')"
"print_config_param(spark, \"spark.sql.shuffle.partitions\")"
]
},
{
Expand Down Expand Up @@ -193,7 +198,7 @@
}
],
"source": [
"print_config_param(spark, 'spark.sql.shuffle.partitions')"
"print_config_param(spark, \"spark.sql.shuffle.partitions\")"
]
},
{
Expand All @@ -220,16 +225,16 @@
"source": [
"spark.stop()\n",
"session = (\n",
" SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n",
" .config(\"spark.sql.shuffle.partitions\", \"50\")\n",
" .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n",
" .config(\"spark.driver.host\", \"localhost\")\n",
" .master(\"local[*]\")\n",
" .enableHiveSupport()\n",
" .getOrCreate()\n",
" )\n",
" SparkSession.builder.config(\"spark.driver.memory\", \"8g\")\n",
" .config(\"spark.sql.shuffle.partitions\", \"50\")\n",
" .config(\"spark.driver.bindAddress\", \"127.0.0.1\")\n",
" .config(\"spark.driver.host\", \"localhost\")\n",
" .master(\"local[*]\")\n",
" .enableHiveSupport()\n",
" .getOrCreate()\n",
")\n",
"spark = State(session).session\n",
"print_config_param(spark, 'spark.sql.shuffle.partitions')"
"print_config_param(spark, \"spark.sql.shuffle.partitions\")"
]
},
{
Expand Down Expand Up @@ -281,7 +286,7 @@
"source": [
"spark.stop()\n",
"spark = State(get_spark_session()).session\n",
"spark.sparkContext.setLogLevel('ERROR')\n",
"spark.sparkContext.setLogLevel(\"ERROR\")\n",
"spark"
]
},
Expand Down Expand Up @@ -549,7 +554,7 @@
],
"source": [
"filtered_df = filter_out_low_ratings(df_spark, value=3, rating_column=\"rating\")\n",
"get_log_info(filtered_df, user_col='user_id', item_col='item_id')"
"get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")"
]
},
{
Expand All @@ -576,8 +581,8 @@
}
],
"source": [
"filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by='user_id')\n",
"get_log_info(filtered_df, user_col='user_id', item_col='item_id')"
"filtered_df = filter_by_min_count(filtered_df, num_entries=5, group_by=\"user_id\")\n",
"get_log_info(filtered_df, user_col=\"user_id\", item_col=\"item_id\")"
]
},
{
Expand Down Expand Up @@ -649,7 +654,7 @@
" second_divide_size=K,\n",
" first_divide_size=500,\n",
" seed=SEED,\n",
" shuffle=True\n",
" shuffle=True,\n",
")\n",
"train, test = splitter.split(filtered_df)\n",
"print(train.count(), test.count())"
Expand Down Expand Up @@ -935,7 +940,9 @@
"print(\n",
" NDCG(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
" MAP(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
" HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
" HitRate([1, K], query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(\n",
" recs, test_dataset.interactions\n",
" ),\n",
" Coverage(K, query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\")(recs, test_dataset.interactions),\n",
")"
]
Expand Down Expand Up @@ -985,7 +992,9 @@
"source": [
"offline_metrics = OfflineMetrics(\n",
" [NDCG(K), MAP(K), HitRate([1, K]), Coverage(K)],\n",
" query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n",
" query_column=\"user_id\",\n",
" item_column=\"item_id\",\n",
" rating_column=\"rating\",\n",
")\n",
"offline_metrics(recs, test_dataset.interactions, train_dataset.interactions)"
]
Expand All @@ -1005,7 +1014,9 @@
" ],\n",
" test_dataset.interactions,\n",
" train_dataset.interactions,\n",
" query_column=\"user_id\", item_column=\"item_id\", rating_column=\"rating\"\n",
" query_column=\"user_id\",\n",
" item_column=\"item_id\",\n",
" rating_column=\"rating\",\n",
")"
]
},
Expand Down Expand Up @@ -1149,12 +1160,7 @@
"def fit_predict_evaluate(model, experiment, name):\n",
" model.fit(train_dataset)\n",
"\n",
" recs = model.predict(\n",
" dataset=train_dataset,\n",
" k=K,\n",
" queries=test_dataset.query_ids,\n",
" filter_seen_items=False\n",
" )\n",
" recs = model.predict(dataset=train_dataset, k=K, queries=test_dataset.query_ids, filter_seen_items=False)\n",
"\n",
" experiment.add_result(name, recs)\n",
" return recs"
Expand All @@ -1167,9 +1173,9 @@
"outputs": [],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, 'SLIM_optimized')\n",
"recs = fit_predict_evaluate(SLIM(**best_params, seed=SEED), metrics, \"SLIM_optimized\")\n",
"recs.cache() #caching for further processing\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
"metrics.results.sort_values(\"NDCG@5\", ascending=False)"
]
},
{
Expand Down Expand Up @@ -1308,7 +1314,7 @@
],
"source": [
"%%time\n",
"recs.write.parquet(path='./slim_recs.parquet', mode='overwrite')"
"recs.write.parquet(path=\"./slim_recs.parquet\", mode=\"overwrite\")"
]
},
{
Expand Down Expand Up @@ -1369,8 +1375,8 @@
],
"source": [
"%%time\n",
"save(slim, path='./slim_best_params')\n",
"slim_loaded = load('./slim_best_params')"
"save(slim, path=\"./slim_best_params\")\n",
"slim_loaded = load(\"./slim_best_params\")"
]
},
{
Expand Down Expand Up @@ -1559,8 +1565,8 @@
],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, 'ALS')\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
"recs = fit_predict_evaluate(ALSWrap(rank=100, seed=SEED), metrics, \"ALS\")\n",
"metrics.results.sort_values(\"NDCG@5\", ascending=False)"
]
},
{
Expand Down Expand Up @@ -1662,8 +1668,8 @@
],
"source": [
"%%time\n",
"recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, 'ItemKNN')\n",
"metrics.results.sort_values('NDCG@5', ascending=False)"
"recs = fit_predict_evaluate(ItemKNN(num_neighbours=100), metrics, \"ItemKNN\")\n",
"metrics.results.sort_values(\"NDCG@5\", ascending=False)"
]
},
{
Expand Down Expand Up @@ -1795,7 +1801,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.19"
},
"name": "movielens_nmf.ipynb",
"pycharm": {
Expand Down
7 changes: 3 additions & 4 deletions replay/utils/session_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_spark_session(
shuffle_partitions = os.cpu_count() * 3
driver_memory = f"{spark_memory}g"
user_home = os.environ["HOME"]
spark = (
spark_session_builder = (
SparkSession.builder.config("spark.driver.memory", driver_memory)
.config(
"spark.driver.extraJavaOptions",
Expand All @@ -87,10 +87,9 @@ def get_spark_session(
.config("spark.kryoserializer.buffer.max", "256m")
.config("spark.files.overwrite", "true")
.master(f"local[{'*' if core_count == -1 else core_count}]")
.enableHiveSupport()
.getOrCreate()
)
return spark

return spark_session_builder.getOrCreate()


def logger_with_settings() -> logging.Logger:
Expand Down

0 comments on commit 0016eea

Please sign in to comment.