Skip to content

Commit 7120012

Browse files
author
Michael Neumann
committed
Stabilize ML doctests for parallel pytest (Fixes #13919)
1 parent 453d105 commit 7120012

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

machine_learning/k_means_clust.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@ def revise_centroids(
8787
"""Recompute centroids as the mean of the assigned samples.
8888
8989
>>> data = np.array([[0.0, 0.0], [0.0, 1.0], [5.0, 5.0]])
90-
>>> revise_centroids(data, 2, np.array([0, 0, 1]))
91-
array([[0. , 0.5],\n [5. , 5. ]])
90+
>>> np.allclose(
91+
... revise_centroids(data, 2, np.array([0, 0, 1])),
92+
... np.array([[0.0, 0.5], [5.0, 5.0]]),
93+
... )
94+
True
9295
"""
9396
new_centroids: list[NDArray[np.floating]] = []
9497
for i in range(k):
@@ -110,7 +113,7 @@ def compute_heterogeneity(
110113
111114
>>> data = np.array([[0.0, 0.0], [0.0, 1.0], [5.0, 5.0]])
112115
>>> centroids = np.array([[0.0, 0.5], [5.0, 5.0]])
113-
>>> compute_heterogeneity(data, 2, centroids, np.array([0, 0, 1]))
116+
>>> float(compute_heterogeneity(data, 2, centroids, np.array([0, 0, 1])))
114117
0.5
115118
"""
116119
heterogeneity = 0.0
@@ -184,7 +187,7 @@ def kmeans(
184187
... )
185188
>>> labels.tolist()
186189
[0, 0, 1]
187-
>>> [round(value, 3) for value in heterogeneity]
190+
>>> [round(float(value), 3) for value in heterogeneity]
188191
[0.5]
189192
>>> np.allclose(centroids, np.array([[0.0, 0.5], [5.0, 5.0]]))
190193
True
@@ -264,7 +267,7 @@ def report_generator(
264267
... {'spend': [0.0, 50.0, 100.0], 'Cluster': [0, 0, 1]}
265268
... )
266269
>>> report = report_generator(predicted, clustering_variables=['spend'])
267-
>>> report.loc[report['Features'] == '# of Customers', 0].iloc[0]
270+
>>> float(report.loc[report['Features'] == '# of Customers', 0].iloc[0])
268271
2.0
269272
>>> float(report.loc[report['Features'] == '% of Customers', 1])
270273
0.3333333333333333

machine_learning/linear_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def sum_of_square_error(
7272
Example:
7373
>>> vc_x = np.array([[1.1], [2.1], [3.1]])
7474
>>> vc_y = np.array([1.2, 2.2, 3.2])
75-
>>> round(sum_of_square_error(vc_x, vc_y, 3, np.array([1])), 3)
75+
>>> float(round(sum_of_square_error(vc_x, vc_y, 3, np.array([1])), 3))
7676
0.005
7777
"""
7878
prod = np.dot(theta, data_x.transpose())

0 commit comments

Comments
 (0)