1515# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
1616
1717import math
18- from typing import List , Optional , Tuple , Union
18+ from typing import List , Literal , Optional , Tuple , Union
1919
2020import numpy as np
2121import torch
@@ -36,27 +36,30 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
3636 methods the library implements for all schedulers such as loading and saving.
3737
3838 Args:
39- sigma_min (`float`, *optional*, defaults to 0.3):
39+ sigma_min (`float`, defaults to ` 0.3` ):
4040 Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
41- sigma_max (`float`, *optional*, defaults to 500):
41+ sigma_max (`float`, defaults to ` 500` ):
4242 Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
43- sigma_data (`float`, *optional*, defaults to 1.0):
43+ sigma_data (`float`, defaults to ` 1.0` ):
4444 The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
45- sigma_schedule (`str`, *optional*, defaults to `exponential`):
46- Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
47- (https://huggingface.co/papers/2206.00364 ). Other acceptable value is "exponential". The exponential
48- schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl .
49- num_train_timesteps (`int`, defaults to 1000):
45+ sigma_schedule (`str`, defaults to `" exponential" `):
46+ Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential
47+ schedule was incorporated in [stabilityai/cosxl] (https://huggingface.co/stabilityai/cosxl ). The Karras
48+ schedule is introduced in the [EDM]( https://huggingface.co/papers/2206.00364) paper .
49+ num_train_timesteps (`int`, defaults to ` 1000` ):
5050 The number of diffusion steps to train the model.
51- solver_order (`int`, defaults to 2 ):
51+ solver_order (`int`, defaults to `2` ):
5252 The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
53- prediction_type (`str`, defaults to `v_prediction`, *optional* ):
54- Prediction type of the scheduler function; can be ` epsilon` (predicts the noise of the diffusion process),
55- ` sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
53+ prediction_type (`str`, defaults to `" v_prediction"` ):
54+ Prediction type of the scheduler function. Must be one of `" epsilon" ` (predicts the noise of the diffusion
55+ process), `" sample" ` (directly predicts the noisy sample), or `" v_prediction" ` (see section 2.4 of [Imagen
5656 Video](https://huggingface.co/papers/2210.02303) paper).
57- solver_type (`str`, defaults to `midpoint`):
58- Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
59- sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
57+ rho (`float`, defaults to `7.0`):
58+ The parameter for calculating the Karras sigma schedule from the EDM
59+ [paper](https://huggingface.co/papers/2206.00364).
60+ solver_type (`str`, defaults to `"midpoint"`):
61+ Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly
62+ affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`.
6063 lower_order_final (`bool`, defaults to `True`):
6164 Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
6265 stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
@@ -65,8 +68,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
6568 richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
6669 steps, but sometimes may result in blurring.
6770 final_sigmas_type (`str`, defaults to `"zero"`):
68- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
69- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71+ The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or
72+ `"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If
73+ `"zero"`, the final sigma is set to 0.
7074 """
7175
7276 _compatibles = []
@@ -78,16 +82,16 @@ def __init__(
7882 sigma_min : float = 0.3 ,
7983 sigma_max : float = 500 ,
8084 sigma_data : float = 1.0 ,
81- sigma_schedule : str = "exponential" ,
85+ sigma_schedule : Literal [ "exponential" , "karras" ] = "exponential" ,
8286 num_train_timesteps : int = 1000 ,
8387 solver_order : int = 2 ,
84- prediction_type : str = "v_prediction" ,
88+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "v_prediction" ,
8589 rho : float = 7.0 ,
86- solver_type : str = "midpoint" ,
90+ solver_type : Literal [ "midpoint" , "heun" ] = "midpoint" ,
8791 lower_order_final : bool = True ,
8892 euler_at_final : bool = False ,
89- final_sigmas_type : Optional [ str ] = "zero" , # "zero", "sigma_min"
90- ):
93+ final_sigmas_type : Literal [ "zero" , "sigma_min" ] = "zero" ,
94+ ) -> None :
9195 if solver_type not in ["midpoint" , "heun" ]:
9296 if solver_type in ["logrho" , "bh1" , "bh2" ]:
9397 self .register_to_config (solver_type = "midpoint" )
@@ -113,26 +117,40 @@ def __init__(
113117 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
114118
115119 @property
116- def init_noise_sigma (self ):
117- # standard deviation of the initial noise distribution
120+ def init_noise_sigma (self ) -> float :
121+ """
122+ The standard deviation of the initial noise distribution.
123+
124+ Returns:
125+ `float`:
126+ The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`.
127+ """
118128 return (self .config .sigma_max ** 2 + 1 ) ** 0.5
119129
120130 @property
121- def step_index (self ):
131+ def step_index (self ) -> Optional [ int ] :
122132 """
123133 The index counter for current timestep. It will increase 1 after each scheduler step.
134+
135+ Returns:
136+ `int` or `None`:
137+ The current step index, or `None` if not yet initialized.
124138 """
125139 return self ._step_index
126140
127141 @property
128- def begin_index (self ):
142+ def begin_index (self ) -> Optional [ int ] :
129143 """
130144 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
145+
146+ Returns:
147+ `int` or `None`:
148+ The begin index, or `None` if not yet set.
131149 """
132150 return self ._begin_index
133151
134152 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
135- def set_begin_index (self , begin_index : int = 0 ):
153+ def set_begin_index (self , begin_index : int = 0 ) -> None :
136154 """
137155 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
138156
@@ -161,7 +179,18 @@ def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Te
161179 scaled_sample = sample * c_in
162180 return scaled_sample
163181
164- def precondition_noise (self , sigma ):
182+ def precondition_noise (self , sigma : Union [float , torch .Tensor ]) -> torch .Tensor :
183+ """
184+ Precondition the noise level by computing a normalized timestep representation.
185+
186+ Args:
187+ sigma (`float` or `torch.Tensor`):
188+ The sigma (noise level) value to precondition.
189+
190+ Returns:
191+ `torch.Tensor`:
192+ The preconditioned noise value computed as `atan(sigma) / pi * 2`.
193+ """
165194 if not isinstance (sigma , torch .Tensor ):
166195 sigma = torch .tensor ([sigma ])
167196
@@ -228,12 +257,14 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
228257 self .is_scale_input_called = True
229258 return sample
230259
231- def set_timesteps (self , num_inference_steps : int = None , device : Union [str , torch .device ] = None ):
260+ def set_timesteps (
261+ self , num_inference_steps : Optional [int ] = None , device : Optional [Union [str , torch .device ]] = None
262+ ) -> None :
232263 """
233264 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
234265
235266 Args:
236- num_inference_steps (`int`):
267+ num_inference_steps (`int`, *optional* ):
237268 The number of diffusion steps used when generating samples with a pre-trained model.
238269 device (`str` or `torch.device`, *optional*):
239270 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -334,7 +365,7 @@ def _compute_exponential_sigmas(
334365 return sigmas
335366
336367 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
337- def _sigma_to_t (self , sigma , log_sigmas ) :
368+ def _sigma_to_t (self , sigma : np . ndarray , log_sigmas : np . ndarray ) -> np . ndarray :
338369 """
339370 Convert sigma values to corresponding timestep values through interpolation.
340371
@@ -370,7 +401,19 @@ def _sigma_to_t(self, sigma, log_sigmas):
370401 t = t .reshape (sigma .shape )
371402 return t
372403
373- def _sigma_to_alpha_sigma_t (self , sigma ):
404+ def _sigma_to_alpha_sigma_t (self , sigma : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
405+ """
406+ Convert sigma to alpha and sigma_t values for the diffusion process.
407+
408+ Args:
409+ sigma (`torch.Tensor`):
410+ The sigma (noise level) value.
411+
412+ Returns:
413+ `Tuple[torch.Tensor, torch.Tensor]`:
414+ A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input
415+ sigma).
416+ """
374417 alpha_t = torch .tensor (1 ) # Inputs are pre-scaled before going into unet, so alpha_t = 1
375418 sigma_t = sigma
376419
@@ -536,7 +579,7 @@ def index_for_timestep(
536579 return step_index
537580
538581 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
539- def _init_step_index (self , timestep ) :
582+ def _init_step_index (self , timestep : Union [ int , torch . Tensor ]) -> None :
540583 """
541584 Initialize the step_index counter for the scheduler.
542585
@@ -557,7 +600,7 @@ def step(
557600 model_output : torch .Tensor ,
558601 timestep : Union [int , torch .Tensor ],
559602 sample : torch .Tensor ,
560- generator = None ,
603+ generator : Optional [ torch . Generator ] = None ,
561604 return_dict : bool = True ,
562605 ) -> Union [SchedulerOutput , Tuple ]:
563606 """
@@ -567,20 +610,19 @@ def step(
567610 Args:
568611 model_output (`torch.Tensor`):
569612 The direct output from learned diffusion model.
570- timestep (`int`):
613+ timestep (`int` or `torch.Tensor` ):
571614 The current discrete timestep in the diffusion chain.
572615 sample (`torch.Tensor`):
573616 A current instance of a sample created by the diffusion process.
574617 generator (`torch.Generator`, *optional*):
575618 A random number generator.
576- return_dict (`bool`):
619+ return_dict (`bool`, defaults to `True` ):
577620 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
578621
579622 Returns:
580623 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
581624 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
582625 tuple is returned where the first element is the sample tensor.
583-
584626 """
585627 if self .num_inference_steps is None :
586628 raise ValueError (
@@ -702,5 +744,12 @@ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[flo
702744 c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
703745 return c_in
704746
705- def __len__ (self ):
747+ def __len__ (self ) -> int :
748+ """
749+ Returns the number of training timesteps.
750+
751+ Returns:
752+ `int`:
753+ The number of training timesteps configured for the scheduler.
754+ """
706755 return self .config .num_train_timesteps
0 commit comments