diff --git a/torax/config/plasma_composition.py b/torax/config/plasma_composition.py index ce133f85..4dfb1111 100644 --- a/torax/config/plasma_composition.py +++ b/torax/config/plasma_composition.py @@ -19,6 +19,7 @@ from collections.abc import Mapping import dataclasses import logging +import typing import chex import numpy as np @@ -61,6 +62,15 @@ def __post_init__(self): if not isinstance(self.species, Mapping): raise ValueError('species must be a Mapping') + # Iterate through species keys and check if they are in the allowed list. + allowed_symbols = typing.get_args(constants.ION_SYMBOLS) + for ion_symbol in self.species: + if ion_symbol not in allowed_symbols: + raise ValueError( + f'Invalid ion symbol: {ion_symbol}. Allowed symbols are:' + f' {allowed_symbols}' + ) + time_arrays = [] fraction_arrays = [] diff --git a/torax/config/tests/plasma_composition.py b/torax/config/tests/plasma_composition.py index 3a42014c..333c5c23 100644 --- a/torax/config/tests/plasma_composition.py +++ b/torax/config/tests/plasma_composition.py @@ -147,6 +147,7 @@ class IonMixtureTest(parameterized.TestCase): ('valid_tolerance', {'D': 0.49999999, 'T': 0.5}, False), ('invalid_tolerance', {'D': 0.4999, 'T': 0.5}, True), ('invalid_not_mapping', 'D', True), + ('invalid_ion_symbol', {'De': 0.5, 'Tr': 0.5}, True), ) def test_ion_mixture_constructor(self, input_species, should_raise): """Tests various cases of IonMixture construction."""