Skip to content

Commit

Permalink
Remove QLKKNN model path flag
Browse files Browse the repository at this point in the history
This flag is not necessary, the model path can be set using the `TORAX_QLKNN_MODEL_PATH` environment variable instead.

This simplifies the logic to update the config and build the sim, and removes dependencies on qlknn from the main interface.

PiperOrigin-RevId: 713417605
  • Loading branch information
hamelphi authored and Torax team committed Jan 8, 2025
1 parent e536221 commit fd91404
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 84 deletions.
14 changes: 1 addition & 13 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TORAX_QLKNN_MODEL_PATH
^^^^^^^^^^^^^^^^^^^^^^^
Path to the QuaLiKiz-neural-network parameters. The path specified here
will be ignored if the ``model_path`` field in the ``qlknn_params`` section of
the run config file or the ``qlknn_model_path`` flag are set.
the run config file is set.

.. code-block:: console
Expand Down Expand Up @@ -105,18 +105,6 @@ Provide a reference run to compare against in post-simulation plotting.
--config='torax.examples.basic_config' \
--reference_run=<path_to_reference_run>
qlknn_model_path
^^^^^^^^^^^^^^^^
Provide a path to load the QLKNN model from. This flag supersedes
the path set in the config file and the ``TORAX_QLKNN_MODEL_PATH`` environment
variable.

.. code-block:: console
python3 run_simulation_main.py \
--config='torax.examples.basic_config' \
--qlknn_model_path=<path_to_qlknn_model>
output_dir
^^^^^^^^^^
Override the default output directory. If not provided, it will be set to
Expand Down
14 changes: 1 addition & 13 deletions docs/running.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TORAX_QLKNN_MODEL_PATH
^^^^^^^^^^^^^^^^^^^^^^^
Path to the QuaLiKiz-neural-network parameters. The path specified here
will be ignored if the ``model_path`` field in the ``qlknn_params`` section of
the run config file or the ``qlknn_model_path`` flag are set.
the run config file is set.

.. code-block:: console
Expand Down Expand Up @@ -104,18 +104,6 @@ Provide a reference run to compare against in post-simulation plotting.
--config='torax.examples.basic_config' \
--reference_run=<path_to_reference_run>
qlknn_model_path
^^^^^^^^^^^^^^^^
Provide a path to load the QLKNN model from. This flag supersedes
the path set in the config file and the ``TORAX_QLKNN_MODEL_PATH`` environment
variable.

.. code-block:: console
python3 run_simulation_main.py \
--config='torax.examples.basic_config' \
--qlknn_model_path=<path_to_qlknn_model>
output_dir
^^^^^^^^^^
Override the default output directory. If not provided, it will be set to
Expand Down
30 changes: 5 additions & 25 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from torax.config import config_loader
from torax.config import runtime_params
from torax.plotting import plotruns_lib
from torax.transport_model import qlknn_transport_model


# String used when prompting the user to make a choice of command
Expand Down Expand Up @@ -101,17 +100,6 @@
'If True, quits after the first operation (no interactive mode).',
)

_QLKNN_MODEL_PATH = flags.DEFINE_string(
'qlknn_model_path',
None,
'Path to the qlknn model network parameters (if using a QLKNN transport'
' model). If not set, then it will use the value from the config in the'
' "model_path" field in the qlknn_params. If that is not set, it will look'
f' for the "{qlknn_transport_model.MODEL_PATH_ENV_VAR}" env variable.'
' Finally, if this is also not set, it uses a hardcoded default path'
f' "{qlknn_transport_model.DEFAULT_MODEL_PATH}".',
)

_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
None,
Expand Down Expand Up @@ -204,7 +192,6 @@ def maybe_update_config_module(
def change_config(
sim: sim_lib.Sim,
config_module_str: str,
qlknn_model_path: str | None,
) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams] | None:
"""Returns a new Sim with the updated config but same SimulationStepFn.
Expand All @@ -219,7 +206,6 @@ def change_config(
Args:
sim: Sim object used in the previous run.
config_module_str: Config module being used.
qlknn_model_path: QLKNN model path set by flag.
Returns:
Tuple with:
Expand Down Expand Up @@ -257,9 +243,6 @@ def change_config(
if hasattr(config_module, 'CONFIG'):
# Assume that the config module uses the basic config dict to build Sim.
sim_config = config_module.CONFIG
config_loader.maybe_update_config_with_qlknn_model_path(
sim_config, qlknn_model_path
)
new_runtime_params = build_sim.build_runtime_params_from_config(
sim_config['runtime_params']
)
Expand Down Expand Up @@ -316,7 +299,7 @@ def change_config(


def change_sim_obj(
config_module_str: str, qlknn_model_path: str | None
config_module_str: str
) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams, str]:
"""Builds a new Sim from the config module.
Expand All @@ -327,7 +310,6 @@ def change_sim_obj(
Args:
config_module_str: Config module used previously. User will have the
opportunity to update which module to load.
qlknn_model_path: QLKNN model path set by flag.
Returns:
Tuple with:
Expand All @@ -344,7 +326,7 @@ def change_sim_obj(
input('Press Enter when done changing the module.')
sim, new_runtime_params = (
config_loader.build_sim_and_runtime_params_from_config_module(
config_module_str, qlknn_model_path, _PYTHON_CONFIG_PACKAGE.value
config_module_str, _PYTHON_CONFIG_PACKAGE.value
)
)
return sim, new_runtime_params, config_module_str
Expand Down Expand Up @@ -482,15 +464,14 @@ def main(_):
log_sim_progress = _LOG_SIM_PROGRESS.value
plot_sim_progress = _PLOT_SIM_PROGRESS.value
log_sim_output = _LOG_SIM_OUTPUT.value
qlknn_model_path = _QLKNN_MODEL_PATH.value
sim = None
new_runtime_params = None
output_files = []
try:
start_time = time.time()
sim, new_runtime_params = (
config_loader.build_sim_and_runtime_params_from_config_module(
config_module_str, qlknn_model_path, _PYTHON_CONFIG_PACKAGE.value
config_module_str, _PYTHON_CONFIG_PACKAGE.value
)
)
output_dir = (
Expand Down Expand Up @@ -573,8 +554,7 @@ def main(_):
try:
start_time = time.time()
sim_and_runtime_params_or_none = change_config(
sim, config_module_str, qlknn_model_path
)
sim, config_module_str)
if sim_and_runtime_params_or_none is not None:
sim, new_runtime_params = sim_and_runtime_params_or_none
config_change_time = time.time() - start_time
Expand All @@ -595,7 +575,7 @@ def main(_):
try:
start_time = time.time()
sim, new_runtime_params, config_module_str = change_sim_obj(
config_module_str, qlknn_model_path
config_module_str
)
sim_change_time = time.time() - start_time
simulation_app.log_to_stdout(
Expand Down
33 changes: 0 additions & 33 deletions torax/config/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import importlib
import logging
from typing import Any

from torax import sim
from torax.config import build_sim
Expand All @@ -28,16 +27,13 @@

def build_sim_and_runtime_params_from_config_module(
config_module_str: str,
qlknn_model_path: str | None,
config_package: str | None = None,
) -> tuple[sim.Sim, runtime_params.GeneralRuntimeParams]:
"""Returns a Sim and RuntimeParams from the config module.
Args:
config_module_str: Python package path to config module. E.g.
torax.examples.iterhybrid_predictor_corrector.
qlknn_model_path: QLKNN model path set by flag. See qlknn_model_path flag
docs.
config_package: Optional, base package config is imported from. See
config_package flag docs.
"""
Expand All @@ -46,7 +42,6 @@ def build_sim_and_runtime_params_from_config_module(
# The module likely uses the "basic" config setup which has a single CONFIG
# dictionary defining the full simulation.
config = config_module.CONFIG
maybe_update_config_with_qlknn_model_path(config, qlknn_model_path)
new_runtime_params = build_sim.build_runtime_params_from_config(
config['runtime_params']
)
Expand All @@ -56,8 +51,6 @@ def build_sim_and_runtime_params_from_config_module(
):
# The module is likely using the "advances", more Python-forward
# configuration setup.
if qlknn_model_path is not None:
logging.warning('Cannot override qlknn model for this type of config.')
new_runtime_params = config_module.get_runtime_params()
simulator = config_module.get_sim()
else:
Expand All @@ -68,32 +61,6 @@ def build_sim_and_runtime_params_from_config_module(
return simulator, new_runtime_params


def maybe_update_config_with_qlknn_model_path(
config: dict[str, Any], qlknn_model_path: str | None
) -> None:
"""Sets the qlknn_model_path in the config if needed."""
if qlknn_model_path is None:
return
if (
'transport' not in config
or 'transport_model' not in config['transport']
or config['transport']['transport_model'] != 'qlknn'
):
return
qlknn_params = config['transport'].get('qlknn_params', {})
config_model_path = qlknn_params.get('model_path', '')
if config_model_path:
logging.info(
'Overriding QLKNN model path from "%s" to "%s"',
config_model_path,
qlknn_model_path,
)
else:
logging.info('Setting QLKNN model path to "%s".', qlknn_model_path)
qlknn_params['model_path'] = qlknn_model_path
config['transport']['qlknn_params'] = qlknn_params


def import_module(module_name: str, config_package: str | None = None):
"""Imports a module."""
try:
Expand Down

0 comments on commit fd91404

Please sign in to comment.