Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions StatTools/generators/kasdin_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,44 @@ class KasdinGenerator:
doi:10.1109/5.381848

Args:
h (float): Hurst exponent (0.5 < H < 1.5) # TODO: update docs
length (int): Maximum length of the sequence.
random_generator (Iterator[float], optional): Iterator providing random values.
Defaults is iter(np.random.randn(), None).
h (float): Hurst exponent, 0.5 <= H <= 1.5.
length (int): Length of the generated sequence (must be >= 1).
random_generator (Iterator[float], optional): Iterator providing i.i.d. normal
random values. If None, one is created from ``seed`` via
``np.random.default_rng``.
normalize (bool): If True, the output is zero-mean unit-variance.
filter_coefficients_length (int, optional): Number of filter coefficients to use.
Defaults to ``length``.
seed (int | None): Seed for the internal RNG. Ignored when
``random_generator`` is provided explicitly.
Raises:
ValueError: If length is less than 1
StopIteration('Sequence exhausted') : If maximum sequence length has been reached.
ValueError: If ``length`` is less than 1 or ``h`` is out of range.
StopIteration('Sequence exhausted'): If the iterator is advanced past the end.

Example usage:
>>> generator = KasdinGenerator(h, length)
>>> trj = list(generator)
>>> generator = KasdinGenerator(h, length, seed=42)
>>> trj = list(generator)
"""

def __init__(
self,
h: float,
length: int,
random_generator: Optional[Iterator[float]] = iter(np.random.randn, None),
random_generator: Optional[Iterator[float]] = None,
normalize=True,
filter_coefficients_length=None,
seed: Optional[int] = None,
) -> None:
if length is not None and length < 1:
raise ValueError("Length must be more than 1")
self.validate_h(h)
self._h = h
self.length = length
if random_generator is None:
rng = np.random.default_rng(seed)
random_generator = iter(rng.standard_normal, None)
self.random_generator = random_generator
self.filter_coefficients_length = filter_coefficients_length

Expand Down Expand Up @@ -104,15 +116,32 @@ def get_h(self):


class ERKasdinGenerator(KasdinGenerator):
"""Extended range version of Kasdin generator, which can be used for H < 0.5 and H > 1.5"""
"""
Extended-range Kasdin generator supporting H < 0.5 and H > 1.5.

Values outside [0.5, 1.5] are reached by repeatedly differencing (H > 1.5)
or integrating (H < 0.5) a standard KasdinGenerator sequence.

Args:
h (float): Hurst exponent. Any real value is accepted.
length (int): Length of the generated sequence.
random_generator (Iterator[float], optional): Iterator providing i.i.d. normal
random values. If None, one is created from ``seed`` via
``np.random.default_rng``.
normalize (bool): If True, the output is zero-mean unit-variance.
filter_coefficients_length (int, optional): Number of filter coefficients.
seed (int | None): Seed for the internal RNG. Ignored when
``random_generator`` is provided explicitly.
"""

def __init__(
self,
h: float,
length: int,
random_generator: Optional[Iterator[float]] = iter(np.random.randn, None),
random_generator: Optional[Iterator[float]] = None,
normalize=True,
filter_coefficients_length=None,
seed: Optional[int] = None,
) -> None:
self._effective_h = h
self.steps_count = 0
Expand All @@ -136,6 +165,7 @@ def __init__(
random_generator,
normalize,
filter_coefficients_length,
seed,
)

if self.steps_count > 0:
Expand Down Expand Up @@ -164,16 +194,21 @@ def create_kasdin_generator(
random_generator: Optional[Iterator[float]] = None,
normalize=True,
filter_coefficients_length=None,
seed: int | None = None,
seed: Optional[int] = None,
) -> KasdinGenerator | ERKasdinGenerator:
"""Fabric for creating a Kasdin generator."""
if random_generator is None:
rng = np.random.default_rng(seed)
random_generator = iter(rng.standard_normal, None)
if 0.5 <= h <= 1.5:
return KasdinGenerator(
h, length, random_generator, normalize, filter_coefficients_length
)
return ERKasdinGenerator(
h, length, random_generator, normalize, filter_coefficients_length
)
"""Factory for creating a Kasdin generator.

Args:
h (float): Hurst exponent.
length (int): Length of the generated sequence.
random_generator (Iterator[float], optional): Custom RNG iterator.
If None, one is built from ``seed``.
normalize (bool): Zero-mean unit-variance normalisation.
filter_coefficients_length (int, optional): Filter order.
seed (int | None): RNG seed. Ignored when ``random_generator`` is given.

Returns:
KasdinGenerator for 0.5 <= h <= 1.5, ERKasdinGenerator otherwise.
"""
cls = KasdinGenerator if 0.5 <= h <= 1.5 else ERKasdinGenerator
return cls(h, length, random_generator, normalize, filter_coefficients_length, seed)
8 changes: 4 additions & 4 deletions tests/test_dpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ def test_dpcca_chol2d_correlation(hurst, des_r0):
Three independent fBn tracks are generated with the given Hurst exponent
and then correlated by multiplying with the Cholesky factor of R0.
"""
length = 2**14
length = 2**15
s_list = [512, 1024, 2048]
sig_1 = generate_fbn(hurst=hurst, length=length)
sig_2 = generate_fbn(hurst=hurst, length=length)
sig_3 = generate_fbn(hurst=hurst, length=length)
sig_1 = generate_fbn(hurst=hurst, length=length, seed=4)
sig_2 = generate_fbn(hurst=hurst, length=length, seed=44)
sig_3 = generate_fbn(hurst=hurst, length=length, seed=14)

np.random.seed(42)
signal_triplet = np.vstack((sig_1, sig_2, sig_3)).T
Expand Down
5 changes: 0 additions & 5 deletions tests/test_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,6 @@ def test_adjust_covariance_stays_symmetric(self, kf_2x1):
kf_2x1.adjust(np.array([[np.random.randn()]]))
assert np.allclose(kf_2x1._P, kf_2x1._P.T, atol=1e-10)

def test_adjust_none_raises(self, kf_2x1):
"""ValueError raised when None is passed as measurement."""
with pytest.raises(ValueError, match="Do not pass None as a measurement"):
kf_2x1.adjust(None)

def test_adjust_wrong_shape_raises(self, kf_2x1):
"""ValueError raised when measurement has wrong shape."""
with pytest.raises(ValueError, match="Expected z shape \\(1, 1\\), got"):
Expand Down
Loading