Skip to content

Commit

Permalink
Make generic_current_source on cell grid like other sources.
Browse files Browse the repository at this point in the history
Very minor changes to sim integration tests on order of O(1e-3)

Also reinstate generic_current_source in output. Fixes broken plotting.

PiperOrigin-RevId: 714474563
  • Loading branch information
jcitrin authored and Torax team committed Jan 11, 2025
1 parent e5c3741 commit 3cf5f1b
Show file tree
Hide file tree
Showing 62 changed files with 136 additions and 156 deletions.
1 change: 1 addition & 0 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ToraxSimOutputs:
# Add `core_profiles` prefix here to avoid name clash with
# core_sources.generic_current.
CORE_PROFILES_EXTERNAL_CURRENT = "external_current_source"
GENERIC_CURRENT_SOURCE = "generic_current_source"
J_BOOTSTRAP = "j_bootstrap"
J_BOOTSTRAP_FACE = "j_bootstrap_face"
I_BOOTSTRAP = "I_bootstrap"
Expand Down
5 changes: 5 additions & 0 deletions torax/plotting/plotruns_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PlotData:
johm: np.ndarray # [MA/m^2]
j_bootstrap: np.ndarray # [MA/m^2]
j_ecrh: np.ndarray # [MA/m^2]
generic_current_source: np.ndarray # [MA/m^2]
external_current_source: np.ndarray # [MA/m^2]
q: np.ndarray # Dimensionless
s: np.ndarray # Dimensionless
Expand Down Expand Up @@ -170,6 +171,7 @@ def _transform_data(ds: xr.Dataset):
output.JOHM: 1e6, # A/m^2 to MA/m^2
output.J_BOOTSTRAP: 1e6, # A/m^2 to MA/m^2
output.CORE_PROFILES_EXTERNAL_CURRENT: 1e6, # A/m^2 to MA/m^2
output.GENERIC_CURRENT_SOURCE: 1e6, # A/m^2 to MA/m^2
output.I_BOOTSTRAP: 1e6, # A to MA
output.IP_PROFILE_FACE: 1e6, # A to MA
'electron_cyclotron_source_j': 1e6, # A/m^2 to MA/m^2
Expand Down Expand Up @@ -229,6 +231,9 @@ def _transform_data(ds: xr.Dataset):
j_ecrh=get_optional_data(
core_sources_dataset, 'electron_cyclotron_source_j', 'cell'
),
generic_current_source=get_optional_data(
core_sources_dataset, output.GENERIC_CURRENT_SOURCE, 'cell'
),
q=core_profiles_dataset[output.Q_FACE].to_numpy(),
s=core_profiles_dataset[output.S_FACE].to_numpy(),
chi_i=core_transport_dataset[output.CHI_FACE_ION].to_numpy(),
Expand Down
2 changes: 0 additions & 2 deletions torax/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ def _calculate_integrated_sources(
# index 1 corresponds to the current source profile.
if key == 'electron_cyclotron_source':
profile = core_sources.profiles[key][1, :]
elif key == 'generic_current_source':
profile = geometry.face_to_cell(core_sources.profiles[key])
else:
profile = core_sources.profiles[key]
integrated[f'{value}'] = math_utils.cell_integration(
Expand Down
49 changes: 15 additions & 34 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@
import chex
import jax
from jax import numpy as jnp
from jax.scipy import integrate
import jaxtyping as jt
from torax import array_typing
from torax import interpolated_param
from torax import jax_utils
from torax import math_utils
from torax import state
from torax.config import base
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from typing_extensions import override
# pylint: disable=invalid-name


Expand All @@ -56,7 +54,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams):

@property
def grid_type(self) -> base.GridType:
return base.GridType.FACE
return base.GridType.CELL

def make_provider(
self,
Expand Down Expand Up @@ -100,12 +98,9 @@ def __post_init__(self):
self.sanity_check()


_trapz = integrate.trapezoid


# pytype bug: does not treat 'source_models.SourceModels' as a forward reference
# pytype: disable=name-error
def calculate_generic_current_face(
def calculate_generic_current(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand All @@ -126,7 +121,7 @@ def calculate_generic_current_face(
is present to adhere to the source API.
Returns:
External current density profile along the face grid.
External current density profile along the cell grid.
"""
del (
static_runtime_params_slice,
Expand All @@ -142,22 +137,22 @@ def calculate_generic_current_face(
dynamic_runtime_params_slice,
dynamic_source_runtime_params,
)
# form of external current on face grid
generic_current_form_face = jnp.exp(
-((geo.rho_face_norm - dynamic_source_runtime_params.rext) ** 2)
# form of external current on cell grid
generic_current_form = jnp.exp(
-((geo.rho_norm - dynamic_source_runtime_params.rext) ** 2)
/ (2 * dynamic_source_runtime_params.wext**2)
)

Cext = (
Iext
* 1e6
/ _trapz(generic_current_form_face * geo.spr_face, geo.rho_face_norm)
/ math_utils.cell_integration(generic_current_form * geo.spr_cell, geo)
)

generic_current_face = (
Cext * generic_current_form_face
) # external current profile
return generic_current_face
generic_current_profile = (
Cext * generic_current_form
)
return generic_current_profile


def _calculate_Iext(
Expand All @@ -180,8 +175,8 @@ class GenericCurrentSource(source.Source):
"""A generic current density source profile."""

SOURCE_NAME: ClassVar[str] = 'generic_current_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face'
model_func: source.SourceProfileFunction = calculate_generic_current_face
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current'
model_func: source.SourceProfileFunction = calculate_generic_current

@property
def source_name(self) -> str:
Expand All @@ -193,18 +188,4 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:

@property
def output_shape_getter(self) -> source.SourceOutputShapeFunction:
return source.ProfileType.FACE.get_profile_shape

@override
def get_source_profile_for_affected_core_profile(
self,
profile: jt.Float[jt.Array, 'rhon_face'],
affected_core_profile: int,
geo: geometry.Geometry,
) -> jt.Float[jt.Array, 'rhon']:
return jnp.where(
affected_core_profile in self.affected_core_profiles_ints,
# Source profiles are always on cell grid so cast to cell grid.
geometry.face_to_cell(profile),
jnp.zeros_like(geo.rho),
)
return source.ProfileType.CELL.get_profile_shape
2 changes: 1 addition & 1 deletion torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class SupportedSource:
source_class=generic_current_source.GenericCurrentSource,
model_functions={
generic_current_source.GenericCurrentSource.DEFAULT_MODEL_FUNCTION_NAME: ModelFunction(
source_profile_function=generic_current_source.calculate_generic_current_face,
source_profile_function=generic_current_source.calculate_generic_current,
runtime_params_class=generic_current_source.RuntimeParams,
)
},
Expand Down
8 changes: 3 additions & 5 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class Source(abc.ABC):
source if another isn't specified.
runtime_params: Input dataclass containing all the source-specific runtime
parameters. At runtime, the parameters here are interpolated to a specific
time t and then passed to the model_func or formula, depending on the mode
this source is running in.
time t and then passed to the model_func, depending on the mode this
source is running in.
affected_core_profiles: Core profiles affected by this source's profile(s).
This attribute defines which equations the source profiles are terms for.
By default, the number of affected core profiles should equal the rank of
Expand All @@ -123,8 +123,6 @@ class Source(abc.ABC):
by this source.
model_func: The function used when the the runtime type is set to
"MODEL_BASED". If not provided, then it defaults to returning zeros.
formula: The prescribed formula used when the runtime type is set to
"FORMULA_BASED". If not provided, then it defaults to returning zeros.
affected_core_profiles_ints: Derived property from the
affected_core_profiles. Integer values of those enums.
"""
Expand Down Expand Up @@ -282,7 +280,7 @@ def get_source_profiles(
) -> chex.ArrayTree:
"""Returns source profiles requested by the runtime_params_lib.
This function handles MODEL_BASED, FORMULA_BASED, PRESCRIBED and ZERO sources.
This function handles MODEL_BASED, PRESCRIBED and ZERO sources.
All other source types will be ignored.
This function exists to simplify the creation of the profile to a set of
jnp.where calls.
Expand Down
2 changes: 1 addition & 1 deletion torax/sources/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def __init__(
] = source_lib.make_source_builder(
generic_current_source.GenericCurrentSource,
runtime_params_type=generic_current_source.RuntimeParams,
model_func=generic_current_source.calculate_generic_current_face,
model_func=generic_current_source.calculate_generic_current,
)()
source_builders[
generic_current_source.GenericCurrentSource.SOURCE_NAME
Expand Down
14 changes: 7 additions & 7 deletions torax/sources/tests/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ def setUpClass(cls):
source_class=generic_current_source.GenericCurrentSource,
runtime_params_class=generic_current_source.RuntimeParams,
source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME,
model_func=generic_current_source.calculate_generic_current_face,
model_func=generic_current_source.calculate_generic_current,
)

def test_profile_is_on_face_grid(self):
"""Tests that the profile is given on the face grid."""
def test_profile_is_on_cell_grid(self):
"""Tests that the profile is given on the cell grid."""
geo = geometry.build_circular_geometry()
source_builder = self._source_class_builder()
source = source_builder()
self.assertEqual(
source.output_shape_getter(geo),
source_lib.ProfileType.FACE.get_profile_shape(geo),
source_lib.ProfileType.CELL.get_profile_shape(geo),
)
runtime_params = general_runtime_params.GeneralRuntimeParams()
dynamic_runtime_params_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider(
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_profile_is_on_face_grid(self):
geo,
core_profiles=None,
).shape,
source_lib.ProfileType.FACE.get_profile_shape(geo),
source_lib.ProfileType.CELL.get_profile_shape(geo),
)

@parameterized.named_parameters(
Expand All @@ -100,11 +100,11 @@ def test_get_source_profile_for_affected_core_profile_with(

# Build a face profile with 3 values on a 2-cell grid.
geo = geometry.build_circular_geometry(n_rho=2)
face_profile = np.array([1, 2, 3])
cell_profile = np.array([1.5, 2.5])

np.testing.assert_allclose(
source.get_source_profile_for_affected_core_profile(
face_profile,
cell_profile,
affected_core_profile.value,
geo,
),
Expand Down
5 changes: 1 addition & 4 deletions torax/sources/tests/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,14 @@ def source_name(self) -> str:
source_models=source_models,
static_runtime_params_slice=static_runtime_params_slice,
)
expected_generic_current_source_face = source_models.psi_sources[
expected_generic_current_source = source_models.psi_sources[
generic_current_source.GenericCurrentSource.SOURCE_NAME
].get_value(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
)
expected_generic_current_source = geometry.face_to_cell(
expected_generic_current_source_face
)

external_current_source = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
3 changes: 1 addition & 2 deletions torax/tests/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def setUp(self):
)
# Make some dummy source profiles.
ones = np.ones_like(geo.rho)
ones_face = np.ones_like(geo.rho_face)
self.source_profiles = source_profiles_lib.SourceProfiles(
j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
Expand All @@ -64,7 +63,7 @@ def setUp(self):
profiles={
'bremsstrahlung_heat_sink': -ones,
'ohmic_heat_source': ones * 5,
'generic_current_source': ones_face * 2,
'generic_current_source': ones * 2,
'fusion_heat_source': np.stack([ones, ones]),
'generic_ion_el_heat_source': np.stack([2 * ones, 3 * ones]),
'electron_cyclotron_source': np.stack([7 * ones, 2 * ones]),
Expand Down
2 changes: 1 addition & 1 deletion torax/tests/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class SimTest(sim_test_case.SimTestCase):
'test_iterhybrid_predictor_corrector_zi2',
'test_iterhybrid_predictor_corrector_zi2.py',
_ALL_PROFILES,
1e-5,
5e-5,
),
# Predictor-corrector solver with ECCD Lin Liu model.
(
Expand Down
Binary file modified torax/tests/test_data/test_all_transport_crank_nicolson.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_all_transport_fusion_qlknn.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_bohmgyrobohm_all.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_bootstrap.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_bremsstrahlung.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_bremsstrahlung_time_dependent_Zimp.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_cgmheat.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_changing_config_after.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_changing_config_before.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_chease.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_eqdsk.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_explicit.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_fixed_dt.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_fusion_power.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_implicit.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_implicit_short_optimizer.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterbaseline_mockup.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterhybrid_mockup.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterhybrid_newton.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterhybrid_predictor_corrector.nc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterhybrid_predictor_corrector_zi2.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_iterhybrid_rampup.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_ne_qlknn_deff_veff.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_ne_qlknn_defromchie.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_ohmic_power.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_particle_sources_cgm.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_particle_sources_constant.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_pc_method_ne.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_pedestal.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_prescribed_generic_current_source.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_prescribed_timedependent_ne.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psi_and_heat.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psi_heat_dens.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psichease_ip_chease.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psichease_ip_parameters.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psichease_prescribed_johm.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psichease_prescribed_jtot.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_psiequation.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_qei.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_qei_chease_highdens.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_qlknnheat.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_semiimplicit_convection.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_time_dependent_circular_geo.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_timedependence.nc
Binary file not shown.
Loading

0 comments on commit 3cf5f1b

Please sign in to comment.