2626import operator
2727import functools
2828import itertools as it
29- from .deprecate import deprecated
3029from .types import (
3130 ArrayLike , MatrixLike , EmptyShape , VectorShape , MatrixShape , TensorShape , ArrayShape , VectorLike ,
3231 TensorLike , Array , Matrix , Tensor , Vector , VectorBool , MatrixBool , TensorBool , MatrixInt , ArrayType , VectorInt , # noqa: F401
3332 Shape , DimHints , SupportsFloatOrInt
3433)
35- from typing import Callable , Sequence , Iterator , Any , Iterable , overload
34+ from typing import Callable , Sequence , Iterator , Any , Iterable , overload , cast
3635
3736EPS = sys .float_info .epsilon
3837RTOL = 4 * EPS
4342MIN_FLOAT = sys .float_info .min
4443
4544# Keeping for backwards compatibility
46- prod = math .prod
4745_all = builtins .all
4846_any = builtins .any
4947
8684################################
8785# General math
8886################################
89- def sign (x : float ) -> float :
87+ def sgn (x : SupportsFloatOrInt ) -> SupportsFloatOrInt :
9088 """Return the sign of a given value."""
9189
92- if x and x == x :
93- return x / abs ( x )
94- return x
90+ if isinstance ( x , int ) :
91+ return 1 if x > 0 else - 1 if x < 0 else 0
92+ return 1.0 if x > 0.0 else - 1.0 if x < 0 else x
9593
9694
9795def order (x : float ) -> int :
@@ -1118,7 +1116,7 @@ def _cross_pad(a: ArrayLike, s: ArrayShape) -> Array:
11181116 m = acopy (a )
11191117
11201118 # Initialize indexes so we can properly write our data
1121- total = prod (s [:- 1 ])
1119+ total = math . prod (s [:- 1 ])
11221120 idx = [0 ] * (len (s ) - 1 )
11231121
11241122 for c in range (total ):
@@ -1783,8 +1781,8 @@ def multi_dot(arrays: Sequence[ArrayLike]) -> Any:
17831781 # We can easily calculate three with less complexity and in less time. Anything
17841782 # greater than three becomes a headache.
17851783 if count == 3 :
1786- pa = prod (shapes [0 ])
1787- pc = prod (shapes [2 ])
1784+ pa = math . prod (shapes [0 ])
1785+ pc = math . prod (shapes [2 ])
17881786 cost1 = pa * shapes [2 ][0 ] + pc * shapes [0 ][0 ]
17891787 cost2 = pc * shapes [0 ][1 ] + pa * shapes [2 ][1 ] # type: ignore[misc]
17901788 if cost1 < cost2 :
@@ -1845,7 +1843,7 @@ def __init__(self, array: ArrayLike | float, old: Shape, new: Shape) -> None:
18451843 elif self .different :
18461844 # Calculate the shape of the data.
18471845 if len (old ) > 1 :
1848- self .amount = prod (old [:- 1 ])
1846+ self .amount = math . prod (old [:- 1 ])
18491847 self .length = old [- 1 ]
18501848 else :
18511849 # Vectors have to be handled a bit special as they only have 1-D
@@ -1856,7 +1854,7 @@ def __init__(self, array: ArrayLike | float, old: Shape, new: Shape) -> None:
18561854 # We need to flip them based on whether the original shape has an even or odd number of
18571855 # dimensions.
18581856 diff = [int (x / y ) if y else y for x , y in zip (new , old )]
1859- repeat = prod (diff [:- 1 ]) if len (old ) > 1 else 1
1857+ repeat = math . prod (diff [:- 1 ]) if len (old ) > 1 else 1
18601858 expand = diff [- 1 ]
18611859 if len (diff ) > 1 and diff [- 2 ] > 1 :
18621860 self .repeat = expand
@@ -2110,7 +2108,7 @@ def __init__(self, *arrays: ArrayLike | float) -> None:
21102108 # But shouldn't matter for what we do.
21112109 self .shape = common
21122110 self .ndims = max_dims
2113- self .size = prod (common )
2111+ self .size = math . prod (common )
21142112 self ._init ()
21152113
21162114 def _init (self ) -> None :
@@ -2378,7 +2376,7 @@ def __call__(
23782376 # Apply math to two N-D matrices
23792377 if dims_a == dims_b :
23802378 empty = (not shape_a or 0 in shape_a ) and (not shape_b or 0 in shape_b )
2381- if not empty and prod (shape_a ) != prod (shape_b ): # pragma: no cover
2379+ if not empty and math . prod (shape_a ) != math . prod (shape_b ): # pragma: no cover
23822380 raise ValueError (f'Shape { shape_a } does not match the data total of { shape_b } ' )
23832381 with ArrayBuilder (m , shape_a ) as build :
23842382 for x , y in zip (flatiter (a ), flatiter (b )):
@@ -2658,13 +2656,6 @@ def vectorize2(
26582656 raise ValueError ("'vectorize2' does not support dimensions greater than 2 or less than 1" )
26592657
26602658
2661- @deprecated ("'vectorize1' is deprecated, use 'vectorize2(func, doc, params=1)' for the equivalent" )
2662- def vectorize1 (pyfunc : Callable [..., Any ], doc : str | None = None ) -> Callable [..., Any ]: # pragma: no cover
2663- """An optimized version of vectorize that is hard coded to broadcast only the first input."""
2664-
2665- return vectorize2 (pyfunc , doc , params = 1 )
2666-
2667-
26682659@overload
26692660def linspace (start : float , stop : float , num : int = ..., endpoint : bool = ...) -> Vector :
26702661 ...
@@ -2805,6 +2796,38 @@ def isnan(a: TensorLike, *, dims: DimHints = ..., **kwargs: Any) -> TensorBool:
28052796isnan = vectorize2 (math .isnan , doc = "Test if a value or values in an array are NaN." , params = 1 )
28062797
28072798
2799+ @overload # type: ignore[no-overload-impl]
2800+ def sign (a : float , * , dims : DimHints = ..., ** kwargs : Any ) -> float :
2801+ ...
2802+
2803+
2804+ @overload
2805+ def sign (a : VectorLike , * , dims : DimHints = ..., ** kwargs : Any ) -> Vector :
2806+ ...
2807+
2808+
2809+ @overload
2810+ def sign (a : MatrixLike , * , dims : DimHints = ..., ** kwargs : Any ) -> Matrix :
2811+ ...
2812+
2813+
2814+ @overload
2815+ def sign (a : TensorLike , * , dims : DimHints = ..., ** kwargs : Any ) -> Tensor :
2816+ ...
2817+
2818+
2819+ sign = vectorize2 (sgn , doc = "Return the sign of a number." , params = 1 )
2820+
2821+
2822+ def prod (a : ArrayLike | float ) -> float :
2823+ """Return the product."""
2824+
2825+ l = len (shape (a ))
2826+ if l == 0 :
2827+ return float (math .prod ([a ])) # type: ignore[list-item]
2828+ return float (math .prod (flatiter (a ) if l > 1 else a )) # type: ignore[arg-type]
2829+
2830+
28082831def allclose (a : ArrayType , b : ArrayType , ** kwargs : Any ) -> bool :
28092832 """Test if all are close."""
28102833
@@ -3120,14 +3143,14 @@ def full(array_shape: int | Shape, fill_value: float | ArrayLike) -> Array | flo
31203143 if not isinstance (fill_value , Sequence ):
31213144 return fill_value
31223145 _s = shape (fill_value )
3123- if prod (_s ) == 1 :
3146+ if math . prod (_s ) == 1 :
31243147 return ravel (fill_value )[0 ]
31253148
31263149 # Normalize `fill_value` to be an array.
31273150 elif not isinstance (fill_value , Sequence ):
31283151 m = [] # type: Array
31293152 with ArrayBuilder (m , s ) as build :
3130- for v in [fill_value ] * prod (s ):
3153+ for v in [fill_value ] * math . prod (s ):
31313154 next (build ).append (v )
31323155 return m
31333156
@@ -3358,9 +3381,9 @@ def transpose(array: ArrayLike | float) -> float | Array:
33583381 # N x M matrix
33593382 if s and s [0 ] == 0 :
33603383 s = s [1 :] + (0 ,)
3361- total = prod (s [:- 1 ])
3384+ total = math . prod (s [:- 1 ])
33623385 else :
3363- total = prod (s )
3386+ total = math . prod (s )
33643387
33653388 # Create the array
33663389 m = [] # type: Array
@@ -3455,8 +3478,8 @@ def reshape(array: ArrayLike | float, new_shape: int | Shape) -> float | Array:
34553478 empty = (not new_shape or 0 in new_shape ) and (not current_shape or 0 in current_shape )
34563479
34573480 # Make sure we can actually reshape.
3458- total = prod (new_shape ) if not empty else prod ( new_shape [:- 1 ])
3459- if not empty and total != prod (current_shape ):
3481+ total = math . prod (new_shape if not empty else new_shape [:- 1 ])
3482+ if not empty and total != math . prod (current_shape ):
34603483 raise ValueError (f'Shape { new_shape } does not match the data total of { shape (array )} ' )
34613484
34623485 # Create the array
@@ -4290,7 +4313,7 @@ def _qr(a: Matrix, m: int, n: int, mode: str = 'reduced') -> Any:
42904313 for k in range (0 , m - 1 if not tall else n ):
42914314 # Calculate the householder reflections
42924315 norm = math .sqrt (sum ([r [i ][k ] ** 2 for i in range (k , m )]))
4293- sig = - sign (r [k ][k ])
4316+ sig = - sgn (r [k ][k ])
42944317 u0 = r [k ][k ] - sig * norm
42954318 w = [[(r [i ][k ] / u0 ) if u0 else 1 ] for i in range (k , m )]
42964319 w [0 ][0 ] = 1
@@ -4479,7 +4502,7 @@ def solve(a: MatrixLike | TensorLike, b: ArrayLike) -> Array:
44794502 p , l , u = lu (a , p_indices = True , _shape = s )
44804503
44814504 # If determinant is zero, we can't solve. Really small determinant may give bad results.
4482- if prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
4505+ if math . prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
44834506 raise ValueError ('Matrix is singular' )
44844507
44854508 # Solve for x using forward substitution on U and back substitution on L
@@ -4521,7 +4544,7 @@ def solve(a: MatrixLike | TensorLike, b: ArrayLike) -> Array:
45214544
45224545 p , l , u = lu (ma , p_indices = True , _shape = m_shape )
45234546
4524- if prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 : # pragma: no cover
4547+ if math . prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 : # pragma: no cover
45254548 raise ValueError ('Matrix is singular' )
45264549
45274550 next (build ).append (_back_sub_vector (u , _forward_sub_vector (l , [b [i ] for i in p ], size ), size )) # type: ignore[misc]
@@ -4543,7 +4566,7 @@ def solve(a: MatrixLike | TensorLike, b: ArrayLike) -> Array:
45434566
45444567 p , l , u = lu (ma , p_indices = True , _shape = s [- 2 :]) # type: ignore[misc]
45454568
4546- if prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
4569+ if math . prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
45474570 raise ValueError ('Matrix is singular' )
45484571
45494572 bi = [[* mb [i ]] for i in p ]
@@ -4578,8 +4601,8 @@ def det(array: MatrixLike | TensorLike) -> float | Vector:
45784601 size = s [0 ]
45794602 p , l , u = lu (array , _shape = s )
45804603 swaps = size - trace (p )
4581- sign = (- 1 ) ** (swaps - 1 ) if swaps else 1
4582- dt = sign * prod (l [i ][i ] * u [i ][i ] for i in range (size ))
4604+ _sign = (- 1 ) ** (swaps - 1 ) if swaps else 1
4605+ dt = _sign * math . prod (l [i ][i ] * u [i ][i ] for i in range (size ))
45834606 return 0.0 if not dt else dt
45844607 else :
45854608 last = s [- 2 :] # type: ignore[misc]
@@ -4625,7 +4648,7 @@ def inv(matrix: MatrixLike | TensorLike) -> Matrix | Tensor:
46254648 # Floating point math will produce very small, non-zero determinants for singular matrices.
46264649 # This occurs with Numpy as well.
46274650 # Don't bother calculating sign as we only care about how close to zero we are.
4628- if prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
4651+ if math . prod (l [i ][i ] * u [i ][i ] for i in range (size )) == 0.0 :
46294652 raise ValueError ('Matrix is singular' )
46304653
46314654 # Solve for the identity matrix (will give us inverse)
@@ -4722,7 +4745,7 @@ def vstack(arrays: Sequence[ArrayLike | float]) -> Matrix | Tensor:
47224745 raise ValueError ('All the input array dimensions except for the concatenation axis must match exactly' )
47234746
47244747 # Stack the arrays
4725- m .extend (reshape (a , (prod (s [:1 - dims ]),) + s [1 - dims :- 1 ] + s [- 1 :])) # type: ignore[arg-type, misc]
4748+ m .extend (reshape (a , (math . prod (s [:1 - dims ]),) + s [1 - dims :- 1 ] + s [- 1 :])) # type: ignore[arg-type, misc]
47264749
47274750 # Update the last array tracker
47284751 if not last or len (last ) > len (s ):
@@ -4740,7 +4763,7 @@ def _hstack_extract(a: ArrayLike | float, s: ArrayShape) -> Iterator[Array]:
47404763 """Extract data from the second axis."""
47414764
47424765 data = flatiter (a )
4743- length = prod (s [1 :])
4766+ length = math . prod (s [1 :])
47444767
47454768 for _ in range (s [0 ]):
47464769 yield [next (data ) for _ in range (length )]
@@ -5203,9 +5226,9 @@ def roll(
52035226 if axis is None :
52045227 if not isinstance (shift , int ):
52055228 shift = sum (shift )
5206- p = prod (s )
5207- sgn = sign (shift )
5208- shift = int ( shift % (p * sgn )) if p and sgn else 0
5229+ p = math . prod (s )
5230+ _sign = sgn (shift )
5231+ shift = shift % (p * _sign ) if p and _sign else 0
52095232 flat = ravel (a ) if len (s ) != 1 else [* a ] # type: ignore[misc]
52105233 sh = - shift
52115234 flat [:] = flat [sh :] + flat [:sh ]
@@ -5221,11 +5244,12 @@ def roll(
52215244 new_shift = [] # type: VectorInt
52225245 new_axes = [] # type: VectorInt
52235246 for i , j in broadcast (shift , axes ):
5247+ i , j = cast (int , i ), cast (int , j )
52245248 if j < 0 :
52255249 j = l + j
5226- sgn = sign (i )
5227- new_shift .append (int (i % (s [j ] * sgn )) if s [j ] and sgn else 0 ) # type: ignore[call-overload]
5228- new_axes .append (j ) # type: ignore[arg-type]
5250+ _sign = sgn (i )
5251+ new_shift .append ((i % (s [j ] * _sign )) if s [j ] and _sign else 0 )
5252+ new_axes .append (j )
52295253
52305254 # Perform the roll across the specified axes
52315255 for idx in ndindex (s [:- 1 ] + (1 ,)): # type: ignore[misc]
0 commit comments