Skip to content

Commit 3ea75f5

Browse files
jcitrinTorax team
authored andcommitted
Move calculate_psidot_from_psi_sources to new psidot_calculations
Could not move it to psi_calculations due to circular dependency. Nevertheless it should move to the physics package and not remain in ohmic_heat_source PiperOrigin-RevId: 734464091
1 parent 8d3eda4 commit 3ea75f5

File tree

6 files changed

+269
-93
lines changed

6 files changed

+269
-93
lines changed

torax/core_profiles/initialization.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from torax.geometry import geometry
2929
from torax.geometry import standard_geometry
3030
from torax.physics import psi_calculations
31-
from torax.sources import ohmic_heat_source
3231
from torax.sources import source_models as source_models_lib
32+
from torax.sources import source_operations
3333
from torax.sources import source_profile_builders
3434
from torax.sources import source_profiles as source_profiles_lib
3535

@@ -500,8 +500,13 @@ def _init_psi_psidot_vloop_and_current(
500500
# psidot calculated here with phibdot=0 in geo, since this is initial
501501
# conditions and we don't yet have information on geo_t_plus_dt for the
502502
# phibdot calculation.
503-
psidot = ohmic_heat_source.calculate_psidot_from_psi_sources(
504-
source_profiles=source_profiles,
503+
psi_sources = source_operations.sum_sources_psi(geo, source_profiles)
504+
sigma = source_profiles.j_bootstrap.sigma
505+
sigma_face = source_profiles.j_bootstrap.sigma_face
506+
psidot = psi_calculations.calculate_psidot_from_psi_sources(
507+
psi_sources=psi_sources,
508+
sigma=sigma,
509+
sigma_face=sigma_face,
505510
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
506511
psi=psi,
507512
geo=geo,

torax/orchestration/step_function.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torax.geometry import geometry_provider as geometry_provider_lib
3232
from torax.pedestal_model import pedestal_model as pedestal_model_lib
3333
from torax.physics import psi_calculations
34-
from torax.sources import ohmic_heat_source
34+
from torax.sources import source_operations
3535
from torax.sources import source_profile_builders
3636
from torax.sources import source_profiles as source_profiles_lib
3737
from torax.stepper import stepper as stepper_lib
@@ -543,7 +543,7 @@ def finalize_output(
543543
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt,
544544
geo=geo_t_plus_dt,
545545
core_profiles=output_state.core_profiles,
546-
core_sources=output_state.core_sources,
546+
source_profiles=output_state.core_sources,
547547
)
548548
output_state = post_processing.make_outputs(
549549
sim_state=output_state,
@@ -691,13 +691,18 @@ def _update_psidot(
691691
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
692692
geo: geometry.Geometry,
693693
core_profiles: state.CoreProfiles,
694-
core_sources: source_profiles_lib.SourceProfiles,
694+
source_profiles: source_profiles_lib.SourceProfiles,
695695
) -> state.CoreProfiles:
696696
"""Update psidot based on new core_profiles."""
697+
psi_sources = source_operations.sum_sources_psi(geo, source_profiles)
698+
sigma = source_profiles.j_bootstrap.sigma
699+
sigma_face = source_profiles.j_bootstrap.sigma_face
697700
psidot = dataclasses.replace(
698701
core_profiles.psidot,
699-
value=ohmic_heat_source.calculate_psidot_from_psi_sources(
700-
source_profiles=core_sources,
702+
value=psi_calculations.calculate_psidot_from_psi_sources(
703+
psi_sources=psi_sources,
704+
sigma=sigma,
705+
sigma_face=sigma_face,
701706
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
702707
psi=core_profiles.psi,
703708
geo=geo,

torax/physics/psi_calculations.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from torax import jax_utils
4141
from torax import state
4242
from torax.fvm import cell_variable
43+
from torax.fvm import convection_terms
44+
from torax.fvm import diffusion_terms
4345
from torax.geometry import geometry
4446

4547
_trapz = jax.scipy.integrate.trapezoid
@@ -318,3 +320,79 @@ def calculate_psi_grad_constraint_from_Ip_tot(
318320
* (16 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.Phib)
319321
/ (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1])
320322
)
323+
324+
325+
def calculate_psidot_from_psi_sources(
326+
*,
327+
psi_sources: array_typing.ArrayFloat,
328+
sigma: array_typing.ArrayFloat,
329+
sigma_face: array_typing.ArrayFloat,
330+
resistivity_multiplier: float,
331+
psi: cell_variable.CellVariable,
332+
geo: geometry.Geometry,
333+
) -> jax.Array:
334+
"""Calculates psidot (loop voltage) from the sum of the psi sources."""
335+
336+
# Calculate transient term
337+
consts = constants.CONSTANTS
338+
toc_psi = (
339+
1.0
340+
/ resistivity_multiplier
341+
* geo.rho_norm
342+
* sigma
343+
* consts.mu0
344+
* 16
345+
* jnp.pi**2
346+
* geo.Phib**2
347+
/ geo.F**2
348+
)
349+
# Calculate diffusion term coefficient
350+
d_face_psi = geo.g2g3_over_rhon_face
351+
# Add phibdot terms to poloidal flux convection
352+
v_face_psi = (
353+
-8.0
354+
* jnp.pi**2
355+
* consts.mu0
356+
* geo.Phibdot
357+
* geo.Phib
358+
* sigma_face
359+
* geo.rho_face_norm**2
360+
/ geo.F_face**2
361+
)
362+
363+
# Add effective phibdot poloidal flux source term
364+
ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient(
365+
sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
366+
)
367+
368+
psi_sources += (
369+
-8.0
370+
* jnp.pi**2
371+
* consts.mu0
372+
* geo.Phibdot
373+
* geo.Phib
374+
* ddrnorm_sigma_rnorm2_over_f2
375+
)
376+
377+
diffusion_mat, diffusion_vec = diffusion_terms.make_diffusion_terms(
378+
d_face_psi, psi
379+
)
380+
381+
# Set the psi convection term for psidot used in ohmic power, always with
382+
# the default 'ghost' mode. Impact of different modes would mildly impact
383+
# Ohmic power at the LCFS which has negligible impact on simulations.
384+
# Allowing it to be configurable introduces more complexity in the code by
385+
# needing to pass in the mode from the static_runtime_params across multiple
386+
# functions.
387+
conv_mat, conv_vec = convection_terms.make_convection_terms(
388+
v_face_psi, d_face_psi, psi
389+
)
390+
391+
c_mat = diffusion_mat + conv_mat
392+
c = diffusion_vec + conv_vec
393+
394+
c += psi_sources
395+
396+
psidot = (jnp.dot(c_mat, psi.value) + c) / toc_psi
397+
398+
return psidot

torax/physics/tests/psi_calculations_test.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
import jax
1919
import numpy as np
2020
from torax import constants
21+
from torax.config import build_runtime_params
2122
from torax.core_profiles import initialization
2223
from torax.geometry import pydantic_model as geometry_pydantic_model
2324
from torax.geometry import standard_geometry
2425
from torax.physics import psi_calculations
26+
from torax.sources import runtime_params as source_runtime_params
27+
from torax.sources import source_models as source_models_lib
28+
from torax.sources import source_profile_builders
29+
from torax.sources import source_profiles as source_profiles_lib
2530
from torax.tests.test_lib import torax_refs
2631

27-
2832
_trapz = jax.scipy.integrate.trapezoid
2933

3034

@@ -111,6 +115,79 @@ def test_calc_s(self, references_getter: Callable[[], torax_refs.References]):
111115

112116
np.testing.assert_allclose(s, references.s, rtol=1e-5)
113117

118+
@parameterized.parameters([
119+
dict(references_getter=torax_refs.circular_references),
120+
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
121+
dict(
122+
references_getter=torax_refs.chease_references_Ip_from_runtime_params
123+
),
124+
])
125+
def test_calc_psidot(
126+
self, references_getter: Callable[[], torax_refs.References]
127+
):
128+
references = references_getter()
129+
130+
runtime_params = references.runtime_params
131+
source_models_builder = source_models_lib.SourceModelsBuilder()
132+
source_models_builder.runtime_params['generic_current_source'].mode = (
133+
source_runtime_params.Mode.MODEL_BASED
134+
)
135+
source_models = source_models_builder()
136+
dynamic_runtime_params_slice, geo = (
137+
torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry(
138+
runtime_params,
139+
references.geometry_provider,
140+
sources=source_models_builder.runtime_params,
141+
)
142+
)
143+
source_profiles = source_profiles_lib.SourceProfiles(
144+
j_bootstrap=source_profiles_lib.BootstrapCurrentProfile.zero_profile(
145+
geo
146+
),
147+
qei=source_profiles_lib.QeiInfo.zeros(geo),
148+
)
149+
static_slice = build_runtime_params.build_static_runtime_params_slice(
150+
runtime_params=runtime_params,
151+
source_runtime_params=source_models_builder.runtime_params,
152+
torax_mesh=geo.torax_mesh,
153+
)
154+
initial_core_profiles = initialization.initial_core_profiles(
155+
static_slice,
156+
dynamic_runtime_params_slice,
157+
geo,
158+
source_models=source_models,
159+
)
160+
# Updates the calculated source profiles with the standard source profiles.
161+
source_profile_builders.build_standard_source_profiles(
162+
static_runtime_params_slice=static_slice,
163+
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
164+
geo=geo,
165+
core_profiles=initial_core_profiles,
166+
source_models=source_models,
167+
psi_only=True,
168+
calculate_anyway=True,
169+
calculated_source_profiles=source_profiles,
170+
)
171+
bootstrap_profiles = source_models.j_bootstrap.get_bootstrap(
172+
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
173+
static_runtime_params_slice=static_slice,
174+
geo=geo,
175+
core_profiles=initial_core_profiles,
176+
)
177+
178+
psidot_calculated = psi_calculations.calculate_psidot_from_psi_sources(
179+
psi_sources=sum(source_profiles.psi.values()),
180+
sigma=bootstrap_profiles.sigma,
181+
sigma_face=bootstrap_profiles.sigma_face,
182+
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
183+
psi=references.psi,
184+
geo=geo,
185+
)
186+
187+
psidot_expected = references.psidot
188+
189+
np.testing.assert_allclose(psidot_calculated, psidot_expected, rtol=1e-5)
190+
114191
# pylint: disable=invalid-name
115192
def test_calc_Wpol(self):
116193
# Small inverse aspect ratio limit of circular geometry, such that we

torax/sources/ohmic_heat_source.py

Lines changed: 10 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,9 @@
1919
from typing import ClassVar, Literal
2020

2121
import chex
22-
import jax
2322
import jax.numpy as jnp
24-
from torax import constants
2523
from torax import state
2624
from torax.config import runtime_params_slice
27-
from torax.fvm import cell_variable
28-
from torax.fvm import convection_terms
29-
from torax.fvm import diffusion_terms
3025
from torax.geometry import geometry
3126
from torax.physics import psi_calculations
3227
from torax.sources import runtime_params as runtime_params_lib
@@ -35,83 +30,6 @@
3530
from torax.sources import source_profiles as source_profiles_lib
3631

3732

38-
def calculate_psidot_from_psi_sources(
39-
*,
40-
source_profiles: source_profiles_lib.SourceProfiles,
41-
resistivity_multiplier: float,
42-
psi: cell_variable.CellVariable,
43-
geo: geometry.Geometry,
44-
) -> jax.Array:
45-
"""Calculates psidot (loop voltage) from precalculated sources."""
46-
psi_sources = source_operations.sum_sources_psi(geo, source_profiles)
47-
sigma = source_profiles.j_bootstrap.sigma
48-
sigma_face = source_profiles.j_bootstrap.sigma_face
49-
50-
# Calculate transient term
51-
consts = constants.CONSTANTS
52-
toc_psi = (
53-
1.0
54-
/ resistivity_multiplier
55-
* geo.rho_norm
56-
* sigma
57-
* consts.mu0
58-
* 16
59-
* jnp.pi**2
60-
* geo.Phib**2
61-
/ geo.F**2
62-
)
63-
# Calculate diffusion term coefficient
64-
d_face_psi = geo.g2g3_over_rhon_face
65-
# Add phibdot terms to poloidal flux convection
66-
v_face_psi = (
67-
-8.0
68-
* jnp.pi**2
69-
* consts.mu0
70-
* geo.Phibdot
71-
* geo.Phib
72-
* sigma_face
73-
* geo.rho_face_norm**2
74-
/ geo.F_face**2
75-
)
76-
77-
# Add effective phibdot poloidal flux source term
78-
ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient(
79-
sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm
80-
)
81-
82-
psi_sources += (
83-
-8.0
84-
* jnp.pi**2
85-
* consts.mu0
86-
* geo.Phibdot
87-
* geo.Phib
88-
* ddrnorm_sigma_rnorm2_over_f2
89-
)
90-
91-
diffusion_mat, diffusion_vec = diffusion_terms.make_diffusion_terms(
92-
d_face_psi, psi
93-
)
94-
95-
# Set the psi convection term for psidot used in ohmic power, always with
96-
# the default 'ghost' mode. Impact of different modes would mildly impact
97-
# Ohmic power at the LCFS which has negligible impact on simulations.
98-
# Allowing it to be configurable introduces more complexity in the code by
99-
# needing to pass in the mode from the static_runtime_params across multiple
100-
# functions.
101-
conv_mat, conv_vec = convection_terms.make_convection_terms(
102-
v_face_psi, d_face_psi, psi
103-
)
104-
105-
c_mat = diffusion_mat + conv_mat
106-
c = diffusion_vec + conv_vec
107-
108-
c += psi_sources
109-
110-
psidot = (jnp.dot(c_mat, psi.value) + c) / toc_psi
111-
112-
return psidot
113-
114-
11533
def ohmic_model_func(
11634
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
11735
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
@@ -132,8 +50,15 @@ def ohmic_model_func(
13250
geo,
13351
core_profiles.psi,
13452
)
135-
psidot = calculate_psidot_from_psi_sources(
136-
source_profiles=calculated_source_profiles,
53+
psi_sources = source_operations.sum_sources_psi(
54+
geo, calculated_source_profiles
55+
)
56+
sigma = calculated_source_profiles.j_bootstrap.sigma
57+
sigma_face = calculated_source_profiles.j_bootstrap.sigma_face
58+
psidot = psi_calculations.calculate_psidot_from_psi_sources(
59+
psi_sources=psi_sources,
60+
sigma=sigma,
61+
sigma_face=sigma_face,
13762
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
13863
psi=core_profiles.psi,
13964
geo=geo,
@@ -144,6 +69,7 @@ def ohmic_model_func(
14469

14570
class OhmicHeatSourceConfig(runtime_params_lib.SourceModelBase):
14671
"""Configuration for the OhmicHeatSource."""
72+
14773
source_name: Literal['ohmic_heat_source'] = 'ohmic_heat_source'
14874
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
14975

0 commit comments

Comments
 (0)