Skip to content

Commit 2ea4b0c

Browse files
authored
fix TypeMustMatch is not respected for PyArrayLike1 (#520)
1 parent 78d5e8d commit 2ea4b0c

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# Changelog
2+
- v0.28.0
3+
- Fix mismatched behavior between `PyArrayLike1` and `PyArrayLike2` when used with floats ([#520](https://github.com/PyO3/rust-numpy/pull/520))
4+
25
- v0.27.1
36
- Bump ndarray dependency to v0.17. ([#516](https://github.com/PyO3/rust-numpy/pull/516))
47

src/array_like.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ use pyo3::{
1010
};
1111

1212
use crate::array::PyArrayMethods;
13-
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
13+
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
1414

1515
pub trait Coerce: Sealed {
16-
const VAL: bool;
16+
const ALLOW_TYPE_CHANGE: bool;
1717
}
1818

1919
mod sealed {
@@ -29,7 +29,7 @@ pub struct TypeMustMatch;
2929
impl Sealed for TypeMustMatch {}
3030

3131
impl Coerce for TypeMustMatch {
32-
const VAL: bool = false;
32+
const ALLOW_TYPE_CHANGE: bool = false;
3333
}
3434

3535
/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
@@ -39,7 +39,7 @@ pub struct AllowTypeChange;
3939
impl Sealed for AllowTypeChange {}
4040

4141
impl Coerce for AllowTypeChange {
42-
const VAL: bool = true;
42+
const ALLOW_TYPE_CHANGE: bool = true;
4343
}
4444

4545
/// Receiver for arrays or array-like types.
@@ -151,7 +151,11 @@ where
151151

152152
let py = ob.py();
153153

154-
if matches!(D::NDIM, None | Some(1)) {
154+
// If the input is already an ndarray and `TypeMustMatch` is used then no type conversion
155+
// should be performed.
156+
if (C::ALLOW_TYPE_CHANGE || ob.cast::<PyUntypedArray>().is_err())
157+
&& matches!(D::NDIM, None | Some(1))
158+
{
155159
if let Ok(vec) = ob.extract::<Vec<T>>() {
156160
let array = Array1::from(vec)
157161
.into_dimensionality()
@@ -170,7 +174,7 @@ where
170174
})?
171175
.bind(py);
172176

173-
let kwargs = if C::VAL {
177+
let kwargs = if C::ALLOW_TYPE_CHANGE {
174178
let kwargs = PyDict::new(py);
175179
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
176180
Some(kwargs)

tests/array_like.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,50 @@ fn unsafe_cast_shall_fail() {
132132
});
133133
}
134134

135+
#[test]
136+
fn extract_1d_array_of_different_float_types_fail() {
137+
Python::attach(|py| {
138+
let locals = get_np_locals(py);
139+
let py_list = py
140+
.eval(
141+
c_str!("np.array([1, 2, 3, 4], dtype='float64')"),
142+
Some(&locals),
143+
None,
144+
)
145+
.unwrap();
146+
let extracted_array_f32 = py_list.extract::<PyArrayLike1<'_, f32>>();
147+
let extracted_array_f64 = py_list.extract::<PyArrayLike1<'_, f64>>().unwrap();
148+
149+
assert!(extracted_array_f32.is_err());
150+
assert_eq!(
151+
array![1_f64, 2_f64, 3_f64, 4_f64],
152+
extracted_array_f64.as_array()
153+
);
154+
});
155+
}
156+
157+
#[test]
158+
fn extract_2d_array_of_different_float_types_fail() {
159+
Python::attach(|py| {
160+
let locals = get_np_locals(py);
161+
let py_list = py
162+
.eval(
163+
c_str!("np.array([[1, 2], [3, 4]], dtype='float64')"),
164+
Some(&locals),
165+
None,
166+
)
167+
.unwrap();
168+
let extracted_array_f32 = py_list.extract::<PyArrayLike2<'_, f32>>();
169+
let extracted_array_f64 = py_list.extract::<PyArrayLike2<'_, f64>>().unwrap();
170+
171+
assert!(extracted_array_f32.is_err());
172+
assert_eq!(
173+
array![[1_f64, 2_f64], [3_f64, 4_f64]],
174+
extracted_array_f64.as_array()
175+
);
176+
});
177+
}
178+
135179
#[test]
136180
fn unsafe_cast_with_coerce_works() {
137181
Python::attach(|py| {

0 commit comments

Comments
 (0)