Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712973014
  • Loading branch information
jcitrin authored and Torax team committed Jan 7, 2025
1 parent afbb82a commit 9aa2340
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions torax/transport_model/qlknn_10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torax.transport_model import base_qlknn_model
from torax.transport_model import qualikiz_based_transport_model

# Internal import.
# Internal import.

# Move this to common lib.
Expand Down Expand Up @@ -95,9 +96,7 @@ def __init__(
self._model = MLP(hidden_sizes=hidden_sizes, activations=activations)

def _load_prescale(self, key: str, names: list[str]) -> np.ndarray:
return np.array([self._model_config[key][k] for k in names])[
np.newaxis, :
]
return np.array([self._model_config[key][k] for k in names])[np.newaxis, :]

def __call__(
self,
Expand Down Expand Up @@ -161,16 +160,12 @@ def predict(

model_output = {}
model_output['qi_itg'] = self.net_itgleading(inputs).clip(0)
model_output['qe_itg'] = (
self.net_itgqediv(inputs) * model_output['qi_itg']
)
model_output['qe_itg'] = self.net_itgqediv(inputs) * model_output['qi_itg']
model_output['pfe_itg'] = (
self.net_itgpfediv(inputs) * model_output['qi_itg']
)
model_output['qe_tem'] = self.net_temleading(inputs).clip(0)
model_output['qi_tem'] = (
self.net_temqidiv(inputs) * model_output['qe_tem']
)
model_output['qi_tem'] = self.net_temqidiv(inputs) * model_output['qe_tem']
model_output['pfe_tem'] = (
self.net_tempfediv(inputs) * model_output['qe_tem']
)
Expand Down

0 comments on commit 9aa2340

Please sign in to comment.