1414
1515"""Calculates Block1DCoeffs for a time step."""
1616
17+ import dataclasses
1718import functools
1819import jax
1920import jax .numpy as jnp
2829from torax ._src .fvm import cell_variable
2930from torax ._src .geometry import geometry
3031from torax ._src .internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
32+ from torax ._src .pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
3133from torax ._src .pedestal_model import runtime_params as pedestal_runtime_params_lib
3234from torax ._src .sources import source_profile_builders
3335from torax ._src .sources import source_profiles as source_profiles_lib
@@ -72,6 +74,9 @@ def __call__(
7274 # Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
7375 # should be called
7476 explicit_call : bool = False ,
77+ pedestal_transition_state : (
78+ pedestal_transition_state_lib .PedestalTransitionState | None
79+ ) = None ,
7580 ) -> block_1d_coeffs .Block1DCoeffs :
7681 """Returns coefficients given a state x.
7782
@@ -84,8 +89,8 @@ def __call__(
8489 state x.
8590 geo: The geometry of the system at this time step.
8691 core_profiles: The core profiles of the system at this time step.
87- prev_core_profiles: The core profiles of the system at the previous
88- time step.
92+ prev_core_profiles: The core profiles of the system at the previous time
93+ step.
8994 dt: The time step size.
9095 x: The state with cell-grid values of the evolving variables.
9196 explicit_source_profiles: Precomputed explicit source profiles. These
@@ -104,6 +109,9 @@ def __call__(
104109 explicit_call: If True, then if theta_implicit=1, only a reduced
105110 Block1DCoeffs is calculated since most explicit coefficients will not be
106111 used.
112+ pedestal_transition_state: State for tracking pedestal L-H and H-L
113+ transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
114+ use_formation_model_with_adaptive_source=True. None otherwise.
107115
108116 Returns:
109117 coeffs: The diffusion, convection, etc. coefficients for this state.
@@ -133,6 +141,7 @@ def __call__(
133141 evolving_names = self .evolving_names ,
134142 use_pereverzev = use_pereverzev ,
135143 explicit_call = explicit_call ,
144+ pedestal_transition_state = pedestal_transition_state ,
136145 )
137146
138147
@@ -145,6 +154,9 @@ def calc_coeffs(
145154 evolving_names : tuple [str , ...],
146155 use_pereverzev : bool = False ,
147156 explicit_call : bool = False ,
157+ pedestal_transition_state : (
158+ pedestal_transition_state_lib .PedestalTransitionState | None
159+ ) = None ,
148160) -> block_1d_coeffs .Block1DCoeffs :
149161 """Calculates Block1DCoeffs for the time step described by `core_profiles`.
150162
@@ -170,6 +182,9 @@ def calc_coeffs(
170182 explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
171183 theta_implicit=1. This saves computation for the default fully implicit
172184 implementation.
185+ pedestal_transition_state: State for tracking pedestal L-H and H-L
186+ transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
187+ use_formation_model_with_adaptive_source=True. None otherwise.
173188
174189 Returns:
175190 coeffs: Block1DCoeffs containing the coefficients at this time step.
@@ -192,6 +207,7 @@ def calc_coeffs(
192207 physics_models = physics_models ,
193208 evolving_names = evolving_names ,
194209 use_pereverzev = use_pereverzev ,
210+ pedestal_transition_state = pedestal_transition_state ,
195211 )
196212
197213
@@ -210,6 +226,9 @@ def _calc_coeffs_full(
210226 physics_models : physics_models_lib .PhysicsModels ,
211227 evolving_names : tuple [str , ...],
212228 use_pereverzev : bool = False ,
229+ pedestal_transition_state : (
230+ pedestal_transition_state_lib .PedestalTransitionState | None
231+ ) = None ,
213232) -> block_1d_coeffs .Block1DCoeffs :
214233 """See `calc_coeffs` for details."""
215234
@@ -268,6 +287,7 @@ def _calc_coeffs_full(
268287 core_profiles ,
269288 merged_source_profiles ,
270289 use_pereverzev ,
290+ pedestal_transition_state = pedestal_transition_state ,
271291 )
272292 )
273293
@@ -415,26 +435,75 @@ def _calc_coeffs_full(
415435 runtime_params .pedestal .mode
416436 == pedestal_runtime_params_lib .Mode .ADAPTIVE_SOURCE
417437 ):
438+ # Get the pedestal-top target values from the pedestal model.
439+ pedestal_top_values = (
440+ pedestal_model_output .to_internal_boundary_conditions (geo )
441+ )
442+
443+ # Apply ramp scaling if use_formation_model_with_adaptive_source is
444+ # enabled.
445+ if runtime_params .pedestal .use_formation_model_with_adaptive_source :
446+ assert pedestal_transition_state is not None , (
447+ 'pedestal_transition_state must not be None when'
448+ ' use_formation_model_with_adaptive_source is True.'
449+ )
450+ # Scale the pedestal-top values from the pedestal model by the ramp
451+ # fraction. Will be a no-op in H-mode following the transition_time_width.
452+ internal_boundary_conditions = _apply_transition_ramp_scaling (
453+ pedestal_top_values = pedestal_top_values ,
454+ pedestal_transition_state = pedestal_transition_state ,
455+ runtime_params = runtime_params ,
456+ )
457+ # If in L-mode and the H->L ramp has completed (fraction >= 1.0), skip
458+ # the adaptive source entirely to revert to standard L-mode modeling.
459+ # ramp_fraction will be 1.0 if simulation initialized in L-mode and has
460+ # remained in L-mode, since initial transition_start_time is -inf.
461+ ramp_fraction = _compute_ramp_fraction (
462+ pedestal_transition_state = pedestal_transition_state ,
463+ transition_time_width = runtime_params .pedestal .transition_time_width ,
464+ t = runtime_params .t ,
465+ )
466+ # Skip adaptive source if in L-mode and the H->L ramp has completed.
467+ skip_adaptive_source = ~ pedestal_transition_state .in_H_mode & (
468+ ramp_fraction >= 1.0
469+ )
470+ else :
471+ internal_boundary_conditions = pedestal_top_values
472+ skip_adaptive_source = jnp .bool_ (False )
473+
474+ def _apply_source ():
475+ return internal_boundary_conditions_lib .apply_adaptive_source (
476+ source_T_i = source_i ,
477+ source_T_e = source_e ,
478+ source_n_e = source_n_e ,
479+ source_mat_ii = source_mat_ii ,
480+ source_mat_ee = source_mat_ee ,
481+ source_mat_nn = source_mat_nn ,
482+ runtime_params = runtime_params ,
483+ internal_boundary_conditions = internal_boundary_conditions ,
484+ )
485+
486+ def _skip_source ():
487+ return (
488+ source_i ,
489+ source_e ,
490+ source_n_e ,
491+ source_mat_ii ,
492+ source_mat_ee ,
493+ source_mat_nn ,
494+ )
495+
418496 (
419497 source_i ,
420498 source_e ,
421499 source_n_e ,
422500 source_mat_ii ,
423501 source_mat_ee ,
424502 source_mat_nn ,
425- ) = internal_boundary_conditions_lib .apply_adaptive_source (
426- source_T_i = source_i ,
427- source_T_e = source_e ,
428- source_n_e = source_n_e ,
429- source_mat_ii = source_mat_ii ,
430- source_mat_ee = source_mat_ee ,
431- source_mat_nn = source_mat_nn ,
432- runtime_params = runtime_params ,
433- # Pedestal contributes an internal boundary condition to the source
434- # terms at the pedestal top.
435- internal_boundary_conditions = pedestal_model_output .to_internal_boundary_conditions (
436- geo
437- ),
503+ ) = jax .lax .cond (
504+ skip_adaptive_source ,
505+ _skip_source ,
506+ _apply_source ,
438507 )
439508
440509 # --- Build arguments to solver --- #
@@ -539,3 +608,93 @@ def _calc_coeffs_reduced(
539608 transient_in_cell = transient_in_cell ,
540609 )
541610 return coeffs
611+
612+
613+ def _compute_ramp_fraction (
614+ pedestal_transition_state : pedestal_transition_state_lib .PedestalTransitionState ,
615+ transition_time_width : array_typing .FloatScalar ,
616+ t : array_typing .FloatScalar ,
617+ ) -> array_typing .FloatScalar :
618+ """Computes the ramp fraction for a pedestal transition.
619+
620+ Returns a value in [0, 1] representing the progress of the current
621+ transition. 0 means the transition just started, 1 means it is complete.
622+
623+ Args:
624+ pedestal_transition_state: Current transition state.
625+ transition_time_width: Duration of the transition ramp.
626+ t: Current simulation time (i.e. t + dt when called from the solver).
627+
628+ Returns:
629+ Ramp fraction clipped to [0, 1].
630+ """
631+ elapsed = t - pedestal_transition_state .transition_start_time
632+ fraction = elapsed / transition_time_width
633+ return jnp .clip (fraction , 0.0 , 1.0 )
634+
635+
636+ def _apply_transition_ramp_scaling (
637+ pedestal_top_values : internal_boundary_conditions_lib .InternalBoundaryConditions ,
638+ pedestal_transition_state : pedestal_transition_state_lib .PedestalTransitionState ,
639+ runtime_params : runtime_params_lib .RuntimeParams ,
640+ ) -> internal_boundary_conditions_lib .InternalBoundaryConditions :
641+ """Applies ramp scaling to internal boundary conditions during transitions.
642+
643+ During an L-H transition, linearly ramps from L-mode values to the H-mode
644+ targets. During an H-L transition, ramps from the H-mode targets back to
645+ the L-mode values.
646+
647+ The L-mode values are stored in the pedestal_transition_state (captured
648+ at the start of an L->H transition). The H-mode targets are the full
649+ pedestal model output.
650+
651+ Args:
652+ pedestal_top_values: Pedestal-top target internal boundary conditions from
653+ the pedestal model.
654+ pedestal_transition_state: Current transition state containing L-mode
655+ baseline values.
656+ runtime_params: Runtime parameters (provides time t and pedestal config).
657+
658+ Returns:
659+ Scaled internal boundary conditions.
660+ """
661+ ramp_fraction = _compute_ramp_fraction (
662+ pedestal_transition_state = pedestal_transition_state ,
663+ transition_time_width = runtime_params .pedestal .transition_time_width ,
664+ t = runtime_params .t ,
665+ )
666+
667+ # Extract the nonzero pedestal-top values from the IBC. The IBC arrays are
668+ # cell-grid sized with a single nonzero element at the pedestal top. We use
669+ # jnp.max to extract the nonzero value.
670+ h_mode_T_i_ped = jnp .max (pedestal_top_values .T_i )
671+ h_mode_T_e_ped = jnp .max (pedestal_top_values .T_e )
672+ h_mode_n_e_ped = jnp .max (pedestal_top_values .n_e )
673+
674+ l_mode_T_i_ped = pedestal_transition_state .T_i_ped_L_mode
675+ l_mode_T_e_ped = pedestal_transition_state .T_e_ped_L_mode
676+ l_mode_n_e_ped = pedestal_transition_state .n_e_ped_L_mode
677+
678+ # In H-mode: ramp from L-mode to H-mode (L + fraction * (H - L))
679+ # In L-mode (H->L ramp): ramp from H-mode to L-mode (H + fraction * (L - H))
680+ def _lerp (l_val , h_val , frac , in_h_mode ):
681+ return jnp .where (
682+ in_h_mode ,
683+ l_val + frac * (h_val - l_val ), # L->H ramp
684+ h_val + frac * (l_val - h_val ), # H->L ramp
685+ )
686+
687+ in_h_mode = pedestal_transition_state .in_H_mode
688+ scaled_T_i = _lerp (l_mode_T_i_ped , h_mode_T_i_ped , ramp_fraction , in_h_mode )
689+ scaled_T_e = _lerp (l_mode_T_e_ped , h_mode_T_e_ped , ramp_fraction , in_h_mode )
690+ scaled_n_e = _lerp (l_mode_n_e_ped , h_mode_n_e_ped , ramp_fraction , in_h_mode )
691+
692+ # Reconstruct IBC with scaled values at the same pedestal-top location.
693+ # The nonzero mask from the original pedestal_top_values gives us the
694+ # location.
695+ return dataclasses .replace (
696+ pedestal_top_values ,
697+ T_i = jnp .where (pedestal_top_values .T_i != 0.0 , scaled_T_i , 0.0 ),
698+ T_e = jnp .where (pedestal_top_values .T_e != 0.0 , scaled_T_e , 0.0 ),
699+ n_e = jnp .where (pedestal_top_values .n_e != 0.0 , scaled_n_e , 0.0 ),
700+ )
0 commit comments