Skip to content

Commit de0d635

Browse files
committed
fixed
1 parent 4ed9c3f commit de0d635

4 files changed

Lines changed: 50 additions & 55 deletions

File tree

src/mixtures/abstract_mixture.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ class AbstractMixtures(metaclass=ABCMeta):
1919
def __init__(
2020
self,
2121
mixture_form: str,
22-
integrator_cls: Type[Integrator] = RQMCIntegrator,
23-
integrator_params: Dict[str, Any] = None,
2422
**kwargs: Any
2523
) -> None:
2624
"""
@@ -31,8 +29,7 @@ def __init__(
3129
**kwargs: Parameters of Mixture (alpha, gamma, etc.)
3230
"""
3331
self.mixture_form = mixture_form
34-
self.integrator_cls = integrator_cls
35-
self.integrator_params = integrator_params or {}
32+
3633

3734
if mixture_form == "classical":
3835
self.params = self._params_validation(self._classical_collector, kwargs)
@@ -42,70 +39,74 @@ def __init__(
4239
raise AssertionError(f"Unknown mixture form: {mixture_form}")
4340

4441
@abstractmethod
45-
def _compute_moment(self, n: int) -> Tuple[float, float]:
42+
def _compute_moment(self, n: int, integrator: Integrator) -> Tuple[float, float]:
4643
...
4744

4845
def compute_moment(
4946
self,
50-
x: Union[List[int], int, NDArray[np.float64]]
47+
x: Union[List[int], int, NDArray[np.float64]],
48+
integrator: Integrator
5149
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
5250
if isinstance(x, np.ndarray):
53-
return np.array([self._compute_moment(xp) for xp in x], dtype=object)
51+
return np.array([self._compute_moment(xp, integrator) for xp in x], dtype=object)
5452
elif isinstance(x, list):
55-
return [self._compute_moment(xp) for xp in x]
53+
return [self._compute_moment(xp, integrator) for xp in x]
5654
elif isinstance(x, int):
57-
return self._compute_moment(x)
55+
return self._compute_moment(x, integrator)
5856
else:
5957
raise TypeError(f"Unsupported type for x: {type(x)}")
6058

6159
@abstractmethod
62-
def _compute_pdf(self, x: float) -> Tuple[float, float]:
60+
def _compute_pdf(self, x: float, integrator: Integrator) -> Tuple[float, float]:
6361
...
6462

6563
def compute_pdf(
6664
self,
67-
x: Union[List[float], float, NDArray[np.float64]]
65+
x: Union[List[float], float, NDArray[np.float64]],
66+
integrator: Integrator
6867
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
6968
if isinstance(x, np.ndarray):
70-
return np.array([self._compute_pdf(xp) for xp in x], dtype=object)
69+
return np.array([self._compute_pdf(xp, integrator) for xp in x], dtype=object)
7170
elif isinstance(x, list):
72-
return [self._compute_pdf(xp) for xp in x]
71+
return [self._compute_pdf(xp, integrator) for xp in x]
7372
elif isinstance(x, float):
74-
return self._compute_pdf(x)
73+
return self._compute_pdf(x, integrator)
7574
else:
7675
raise TypeError(f"Unsupported type for x: {type(x)}")
7776

7877
@abstractmethod
79-
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
78+
def _compute_logpdf(self, x: float, integrator: Integrator) -> Tuple[float, float]:
8079
...
8180

8281
def compute_logpdf(
8382
self,
84-
x: Union[List[float], float, NDArray[np.float64]]
83+
x: Union[List[float], float, NDArray[np.float64]],
84+
integrator: Integrator
8585
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
8686
if isinstance(x, np.ndarray):
87-
return np.array([self._compute_logpdf(xp) for xp in x], dtype=object)
87+
return np.array([self._compute_logpdf(xp, integrator) for xp in x], dtype=object)
8888
elif isinstance(x, list):
89-
return [self._compute_logpdf(xp) for xp in x]
89+
return [self._compute_logpdf(xp, integrator) for xp in x]
9090
elif isinstance(x, float):
91-
return self._compute_logpdf(x)
91+
return self._compute_logpdf(x, integrator)
9292
else:
9393
raise TypeError(f"Unsupported type for x: {type(x)}")
9494

9595
@abstractmethod
96-
def _compute_cdf(self, x: float) -> Tuple[float, float]:
96+
def _compute_cdf(self, x: float, integrator: Integrator) -> Tuple[float, float]:
9797
...
9898

9999
def compute_cdf(
100100
self,
101-
x: Union[List[float], float, NDArray[np.float64]]
101+
x: Union[List[float], float, NDArray[np.float64]],
102+
integrator: Integrator
102103
) -> Union[List[Tuple[float, float]], Tuple[float, float], NDArray[Any]]:
103104
if isinstance(x, np.ndarray):
104-
return np.array([self._compute_cdf(xp) for xp in x], dtype=object)
105+
return np.array([self._compute_cdf(xp, integrator) for xp in x], dtype=object)
105106
elif isinstance(x, list):
106-
return [self._compute_cdf(xp) for xp in x]
107+
return [self._compute_cdf(xp, integrator) for xp in x]
107108
elif isinstance(x, float):
108-
return self._compute_cdf(x)
109+
return self._compute_cdf(x, integrator)
109110
else:
110111
raise TypeError(f"Unsupported type for x: {type(x)}")
111112

src/mixtures/nm_mixture.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from src.procedures.support.integrator import Integrator
1010
from src.procedures.support.rqmc import RQMCIntegrator
1111
from src.procedures.support.log_rqmc import LogRQMC
12+
from src.procedures.support.quad_integrator import QuadIntegrator
1213
from src.mixtures.abstract_mixture import AbstractMixtures
1314

1415
@dataclass
@@ -30,11 +31,9 @@ class NormalMeanMixtures(AbstractMixtures):
3031
def __init__(
3132
self,
3233
mixture_form: str,
33-
integrator_cls: Type[Integrator] = RQMCIntegrator,
34-
integrator_params: Dict[str, Any] = None,
3534
**kwargs: Any
3635
) -> None:
37-
super().__init__(mixture_form, integrator_cls=integrator_cls, integrator_params=integrator_params, **kwargs)
36+
super().__init__(mixture_form, **kwargs)
3837

3938
def _params_validation(self, data_collector: Any, params: dict[str, float | rv_continuous | rv_frozen]) -> Any:
4039
data_class = super()._params_validation(data_collector, params)
@@ -44,7 +43,7 @@ def _params_validation(self, data_collector: Any, params: dict[str, float | rv_c
4443
raise ValueError("Gamma can't be zero")
4544
return data_class
4645

47-
def _compute_moment(self, n: int) -> Tuple[float, float]:
46+
def _compute_moment(self, n: int, integrator: Integrator=QuadIntegrator) -> Tuple[float, float]:
4847
mixture_moment = 0.0
4948
error = 0.0
5049
if self.mixture_form == "classical":
@@ -55,20 +54,21 @@ def mix(u: float) -> float:
5554
return (
5655
self.params.distribution.ppf(u) ** (k - l)
5756
)
58-
res = self.integrator_cls(**(self.integrator_params or {})).compute(mix)
57+
58+
res = integrator.compute(mix)
5959
mixture_moment += coeff * (self.params.beta ** (k - l)) * (self.params.gamma ** l) * (self.params.alpha ** (n - k)) * res.value * norm.moment(l)
6060
error += coeff * (self.params.beta ** (k - l)) * (self.params.gamma ** l) * (self.params.alpha ** (n - k)) * res.error * norm.moment(l)
6161
else:
6262
for k in range(n + 1):
6363
coeff = binom(n, n - k)
6464
def mix(u: float) -> float:
6565
return self.params.distribution.ppf(u) ** (n - k)
66-
res = self.integrator_cls(**(self.integrator_params or {})).compute(mix)
66+
res = integrator.compute(mix)
6767
mixture_moment += coeff * (self.params.sigma ** k) * res.value * norm.moment(k)
6868
error += coeff * (self.params.sigma ** k) * res.error * norm.moment(k)
6969
return mixture_moment, error
7070

71-
def _compute_cdf(self, x: float) -> Tuple[float, float]:
71+
def _compute_cdf(self, x: float, integrator: Integrator=RQMCIntegrator) -> Tuple[float, float]:
7272
if self.mixture_form == "classical":
7373
def fn(u: float) -> float:
7474
p = self.params.distribution.ppf(u)
@@ -77,10 +77,10 @@ def fn(u: float) -> float:
7777
def fn(u: float) -> float:
7878
p = self.params.distribution.ppf(u)
7979
return norm.cdf((x - p) / abs(self.params.sigma))
80-
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
80+
res = integrator.compute(fn)
8181
return res.value, res.error
8282

83-
def _compute_pdf(self, x: float) -> Tuple[float, float]:
83+
def _compute_pdf(self, x: float, integrator: Integrator=RQMCIntegrator) -> Tuple[float, float]:
8484
if self.mixture_form == "classical":
8585
def fn(u: float) -> float:
8686
p = self.params.distribution.ppf(u)
@@ -89,10 +89,10 @@ def fn(u: float) -> float:
8989
def fn(u: float) -> float:
9090
p = self.params.distribution.ppf(u)
9191
return (1 / abs(self.params.sigma)) * norm.pdf((x - p) / abs(self.params.sigma))
92-
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
92+
res = integrator.compute(fn)
9393
return res.value, res.error
9494

95-
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
95+
def _compute_logpdf(self, x: float, integrator: Integrator=LogRQMC) -> Tuple[float, float]:
9696
if self.mixture_form == "classical":
9797
def fn(u: float) -> float:
9898
p = self.params.distribution.ppf(u)
@@ -101,5 +101,5 @@ def fn(u: float) -> float:
101101
def fn(u: float) -> float:
102102
p = self.params.distribution.ppf(u)
103103
return np.log(1 / abs(self.params.sigma)) + norm.logpdf((x - p) / abs(self.params.sigma))
104-
res = self.integrator_cls(**(self.integrator_params or {})).compute(fn)
104+
res = integrator.compute(fn)
105105
return res.value, res.error

src/mixtures/nmv_mixture.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@ class NormalMeanVarianceMixtures(AbstractMixtures):
3232
def __init__(
3333
self,
3434
mixture_form: str,
35-
integrator_cls: Type[Integrator] = RQMCIntegrator,
36-
integrator_params: Dict[str, Any] = None,
3735
**kwargs: Any
3836
) -> None:
39-
super().__init__(mixture_form, integrator_cls=integrator_cls, integrator_params=integrator_params, **kwargs)
37+
super().__init__(mixture_form, **kwargs)
4038

41-
def _compute_moment(self, n: int) -> Tuple[float, float]:
39+
def _compute_moment(self, n: int, integrator: Integrator=RQMCIntegrator) -> Tuple[float, float]:
4240
gamma = getattr(self.params, 'gamma', None)
4341

4442
def integrand(u: float) -> float:
@@ -65,20 +63,20 @@ def integrand(u: float) -> float:
6563
s += term
6664
return s
6765

68-
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
66+
res = integrator.compute(integrand)
6967
return res.value, res.error
7068

71-
def _compute_cdf(self, x: float) -> Tuple[float, float]:
69+
def _compute_cdf(self, x: float, integrator: Integrator=RQMCIntegrator) -> Tuple[float, float]:
7270
def integrand(u: float) -> float:
7371
p = self.params.distribution.ppf(u)
7472
if self.mixture_form == 'classical':
7573
return norm.cdf((x - self.params.alpha) / (np.sqrt(p) * self.params.gamma))
7674
return norm.cdf((x - self.params.alpha) / np.sqrt(p) - self.params.mu * np.sqrt(p))
7775

78-
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
76+
res = integrator.compute(integrand)
7977
return res.value, res.error
8078

81-
def _compute_pdf(self, x: float) -> Tuple[float, float]:
79+
def _compute_pdf(self, x: float, integrator: Integrator=RQMCIntegrator) -> Tuple[float, float]:
8280
def integrand(u: float) -> float:
8381
p = self.params.distribution.ppf(u)
8482
if self.mixture_form == 'classical':
@@ -91,21 +89,21 @@ def integrand(u: float) -> float:
9189
* np.exp(-((x - self.params.alpha) ** 2 + self.params.mu ** 2 * p ** 2) / (2 * p))
9290
)
9391

94-
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
92+
res = integrator.compute(integrand)
9593
if self.mixture_form == 'classical':
9694
val = np.exp(self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2) * res.value
9795
else:
9896
val = np.exp(self.params.mu * (x - self.params.alpha)) * res.value
9997
return val, res.error
10098

101-
def _compute_logpdf(self, x: float) -> Tuple[float, float]:
99+
def _compute_logpdf(self, x: float, integrator: Integrator=LogRQMC) -> Tuple[float, float]:
102100
def integrand(u: float) -> float:
103101
p = self.params.distribution.ppf(u)
104102
if self.mixture_form == 'classical':
105103
return -((x - self.params.alpha) ** 2 + p ** 2 * self.params.beta ** 2 + p * self.params.gamma ** 2 * np.log(2 * np.pi * p * self.params.gamma ** 2)) / (2 * p * self.params.gamma ** 2)
106104
return -((x - self.params.alpha) ** 2 + p ** 2 * self.params.mu ** 2 + p * np.log(2 * np.pi * p)) / (2 * p)
107105

108-
res = self.integrator_cls(**(self.integrator_params or {})).compute(integrand)
106+
res = integrator.compute(integrand)
109107
if self.mixture_form == 'classical':
110108
val = self.params.beta * (x - self.params.alpha) / self.params.gamma ** 2 + res.value
111109
else:

src/mixtures/nv_mixture.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self.integrator_cls = integrator_cls
4141
self.integrator_params = integrator_params or {}
4242

43-
def _compute_moment(self, n: int) -> tuple[float, float]:
43+
def _compute_moment(self, n: int, integrator: Integrator=QuadIntegrator) -> tuple[float, float]:
4444
gamma = getattr(self.params, 'gamma', 1)
4545

4646
def integrand(u: float) -> float:
@@ -53,40 +53,36 @@ def integrand(u: float) -> float:
5353
for k in range(n + 1)
5454
)
5555

56-
integrator = self.integrator_cls(**self.integrator_params)
5756
result = integrator.compute(integrand)
5857
return result.value, result.error
5958

60-
def _compute_cdf(self, x: float) -> tuple[float, float]:
59+
def _compute_cdf(self, x: float, integrator: Integrator=QuadIntegrator) -> tuple[float, float]:
6160
gamma = getattr(self.params, 'gamma', 1)
6261
param_norm = norm(0, gamma)
6362

6463
def integrand(u: float) -> float:
6564
return param_norm.cdf((x - self.params.alpha) / np.sqrt(self.params.distribution.ppf(u)))
6665

67-
integrator = self.integrator_cls(**self.integrator_params)
6866
result = integrator.compute(integrand)
6967
return result.value, result.error
7068

71-
def _compute_pdf(self, x: float) -> tuple[float, float]:
69+
def _compute_pdf(self, x: float, integrator: Integrator=QuadIntegrator) -> tuple[float, float]:
7270
gamma = getattr(self.params, 'gamma', 1)
7371
d = (x - self.params.alpha) ** 2 / gamma ** 2
7472

7573
def integrand(u: float) -> float:
7674
return self._integrand_func(u, d, gamma)
7775

78-
integrator = self.integrator_cls(**self.integrator_params)
7976
result = integrator.compute(integrand)
8077
return result.value, result.error
8178

82-
def _compute_logpdf(self, x: float) -> tuple[float, float]:
79+
def _compute_logpdf(self, x: float, integrator: Integrator=LogRQMC) -> tuple[float, float]:
8380
gamma = getattr(self.params, 'gamma', 1)
8481
d = (x - self.params.alpha) ** 2 / gamma ** 2
8582

8683
def integrand(u: float) -> float:
8784
return self._log_integrand_func(u, d, gamma)
8885

89-
integrator = self.integrator_cls(**self.integrator_params)
9086
result = integrator.compute(integrand)
9187
return result.value, result.error
9288

0 commit comments

Comments
 (0)