diff --git a/StatTools/generators/kasdin_generator.py b/StatTools/generators/kasdin_generator.py index 4787d1b..f0b5bf4 100644 --- a/StatTools/generators/kasdin_generator.py +++ b/StatTools/generators/kasdin_generator.py @@ -15,32 +15,44 @@ class KasdinGenerator: doi:10.1109/5.381848 Args: - h (float): Hurst exponent (0.5 < H < 1.5) # TODO: update docs - length (int): Maximum length of the sequence. - random_generator (Iterator[float], optional): Iterator providing random values. - Defaults is iter(np.random.randn(), None). + h (float): Hurst exponent, 0.5 <= H <= 1.5. + length (int): Length of the generated sequence (must be >= 1). + random_generator (Iterator[float], optional): Iterator providing i.i.d. normal + random values. If None, one is created from ``seed`` via + ``np.random.default_rng``. + normalize (bool): If True, the output is zero-mean unit-variance. + filter_coefficients_length (int, optional): Number of filter coefficients to use. + Defaults to ``length``. + seed (int | None): Seed for the internal RNG. Ignored when + ``random_generator`` is provided explicitly. Raises: - ValueError: If length is less than 1 - StopIteration('Sequence exhausted') : If maximum sequence length has been reached. + ValueError: If ``length`` is less than 1 or ``h`` is out of range. + StopIteration('Sequence exhausted'): If the iterator is advanced past the end. Example usage: >>> generator = KasdinGenerator(h, length) >>> trj = list(generator) + >>> generator = KasdinGenerator(h, length, seed=42) + >>> trj = list(generator) """ def __init__( self, h: float, length: int, - random_generator: Optional[Iterator[float]] = iter(np.random.randn, None), + random_generator: Optional[Iterator[float]] = None, normalize=True, filter_coefficients_length=None, + seed: Optional[int] = None, ) -> None: if length is not None and length < 1: raise ValueError("Length must be more than 1") self.validate_h(h) self._h = h self.length = length + if random_generator is None: + rng = np.random.default_rng(seed) + random_generator = iter(rng.standard_normal, None) self.random_generator = random_generator self.filter_coefficients_length = filter_coefficients_length @@ -104,15 +116,32 @@ def get_h(self): class ERKasdinGenerator(KasdinGenerator): - """Extended range version of Kasdin generator, which can be used for H < 0.5 and H > 1.5""" + """ + Extended-range Kasdin generator supporting H < 0.5 and H > 1.5. + + Values outside [0.5, 1.5] are reached by repeatedly differencing (H > 1.5) + or integrating (H < 0.5) a standard KasdinGenerator sequence. + + Args: + h (float): Hurst exponent. Any real value is accepted. + length (int): Length of the generated sequence. + random_generator (Iterator[float], optional): Iterator providing i.i.d. normal + random values. If None, one is created from ``seed`` via + ``np.random.default_rng``. + normalize (bool): If True, the output is zero-mean unit-variance. + filter_coefficients_length (int, optional): Number of filter coefficients. + seed (int | None): Seed for the internal RNG. Ignored when + ``random_generator`` is provided explicitly. + """ def __init__( self, h: float, length: int, - random_generator: Optional[Iterator[float]] = iter(np.random.randn, None), + random_generator: Optional[Iterator[float]] = None, normalize=True, filter_coefficients_length=None, + seed: Optional[int] = None, ) -> None: self._effective_h = h self.steps_count = 0 @@ -136,6 +165,7 @@ def __init__( random_generator, normalize, filter_coefficients_length, + seed, ) if self.steps_count > 0: @@ -164,16 +194,21 @@ def create_kasdin_generator( random_generator: Optional[Iterator[float]] = None, normalize=True, filter_coefficients_length=None, - seed: int | None = None, + seed: Optional[int] = None, ) -> KasdinGenerator | ERKasdinGenerator: - """Fabric for creating a Kasdin generator.""" - if random_generator is None: - rng = np.random.default_rng(seed) - random_generator = iter(rng.standard_normal, None) - if 0.5 <= h <= 1.5: - return KasdinGenerator( - h, length, random_generator, normalize, filter_coefficients_length - ) - return ERKasdinGenerator( - h, length, random_generator, normalize, filter_coefficients_length - ) + """Factory for creating a Kasdin generator. + + Args: + h (float): Hurst exponent. + length (int): Length of the generated sequence. + random_generator (Iterator[float], optional): Custom RNG iterator. + If None, one is built from ``seed``. + normalize (bool): Zero-mean unit-variance normalisation. + filter_coefficients_length (int, optional): Filter order. + seed (int | None): RNG seed. Ignored when ``random_generator`` is given. + + Returns: + KasdinGenerator for 0.5 <= h <= 1.5, ERKasdinGenerator otherwise. + """ + cls = KasdinGenerator if 0.5 <= h <= 1.5 else ERKasdinGenerator + return cls(h, length, random_generator, normalize, filter_coefficients_length, seed) diff --git a/tests/test_dpcca.py b/tests/test_dpcca.py index 197cecd..4cc941f 100644 --- a/tests/test_dpcca.py +++ b/tests/test_dpcca.py @@ -224,11 +224,11 @@ def test_dpcca_chol2d_correlation(hurst, des_r0): Three independent fBn tracks are generated with the given Hurst exponent and then correlated by multiplying with the Cholesky factor of R0. """ - length = 2**14 + length = 2**15 s_list = [512, 1024, 2048] - sig_1 = generate_fbn(hurst=hurst, length=length) - sig_2 = generate_fbn(hurst=hurst, length=length) - sig_3 = generate_fbn(hurst=hurst, length=length) + sig_1 = generate_fbn(hurst=hurst, length=length, seed=4) + sig_2 = generate_fbn(hurst=hurst, length=length, seed=44) + sig_3 = generate_fbn(hurst=hurst, length=length, seed=14) np.random.seed(42) signal_triplet = np.vstack((sig_1, sig_2, sig_3)).T diff --git a/tests/test_kalman_filter.py b/tests/test_kalman_filter.py index b163045..7165874 100644 --- a/tests/test_kalman_filter.py +++ b/tests/test_kalman_filter.py @@ -175,11 +175,6 @@ def test_adjust_covariance_stays_symmetric(self, kf_2x1): kf_2x1.adjust(np.array([[np.random.randn()]])) assert np.allclose(kf_2x1._P, kf_2x1._P.T, atol=1e-10) - def test_adjust_none_raises(self, kf_2x1): - """ValueError raised when None is passed as measurement.""" - with pytest.raises(ValueError, match="Do not pass None as a measurement"): - kf_2x1.adjust(None) - def test_adjust_wrong_shape_raises(self, kf_2x1): """ValueError raised when measurement has wrong shape.""" with pytest.raises(ValueError, match="Expected z shape \\(1, 1\\), got"):