@@ -324,13 +324,16 @@ def simplify(self):
324324 f , s , t = [[x .simplify () for x in y ] for y in (self .first , self .second , self .third )]
325325 if self .op == '__trinary__' and len (f ) == 1 :
326326 if f [0 ].is_value (True ) and not f [0 ].is_value (False ):
327- return Expr ( s )
327+ return s
328328 if f [0 ].is_value (False ) and not f [0 ].is_value (True ):
329- return Expr ( t )
329+ return t
330330 return TrinaryOp (self .data_type , self .style , self .op , f , s , t )
331331
332332 def evaluate (self , known : dict ):
333- first , second , third = [[x .evaluate (known ) for x in y ] for y in (self .first , self .second , self .third )]
333+ def evaluate_to_single (node ):
334+ v = node .evaluate (known )
335+ return v [0 ] if hasattr (v , '__len__' ) and len (v ) == 1 else v
336+ first , second , third = [[evaluate_to_single (x ) for x in y ] for y in (self .first , self .second , self .third )]
334337 return TrinaryOp (self .data_type , self .style , self .op , first , second , third ).simplify ()
335338
336339 def depends_on (self , name : str ):
@@ -451,14 +454,13 @@ def simplify(self):
451454 left = [x .simplify () for x in self .left ]
452455 right = [x .simplify () for x in self .right ]
453456 if len (left ) == 1 and ((left [0 ].is_zero and self .op == '+' ) or (left [0 ].is_value (1 ) and self .op == '*' )):
454- return Expr ( right )
457+ return right
455458 if len (right ) == 1 and (
456459 (right [0 ].is_zero and any (x == self .op for x in '+-' )) or
457460 (right [0 ].is_value (1 ) and any (x == self .op for x in '*/' ))
458461 ):
459- return Expr (left )
460- if len (left ) == 1 and len (right ) == 1 and left [0 ].is_constant and right [0 ].is_constant \
461- and self .op in ('+' , '-' , '*' , '/' ):
462+ return left
463+ if len (left ) == 1 and len (right ) == 1 and left [0 ].is_constant and right [0 ].is_constant :
462464 if self .op == '+' :
463465 return left [0 ] + right [0 ]
464466 if self .op == '-' :
@@ -467,11 +469,16 @@ def simplify(self):
467469 return left [0 ] * right [0 ]
468470 if self .op == '/' :
469471 return left [0 ] / right [0 ]
472+ if self .op == '__pow__' :
473+ return left [0 ] ** right [0 ]
470474 # punt!
471475 return BinaryOp (self .data_type , self .style , self .op , left , right )
472476
473477 def evaluate (self , known : dict ):
474- left , right = [[x .evaluate (known ) for x in y ] for y in (self .left , self .right )]
478+ def evaluate_to_single (node ):
479+ v = node .evaluate (known )
480+ return v [0 ] if hasattr (v , '__len__' ) and len (v ) == 1 else v
481+ left , right = [[evaluate_to_single (x ) for x in y ] for y in (self .left , self .right )]
475482 return BinaryOp (self .data_type , self .style , self .op , left , right ).simplify ()
476483
477484 def depends_on (self , name : str ):
@@ -564,10 +571,17 @@ def simplify(self):
564571 value = [v .simplify () for v in self .value ]
565572 if self .op == '__group__' and len (value ) == 1 and isinstance (value [0 ], Value ):
566573 return value [0 ] # Expr(value)
574+ elif self .op == '-' and len (value ) == 1 and isinstance (value [0 ], Value ):
575+ return - value [0 ]
576+ elif self .op == '+' and len (value ) == 1 and isinstance (value [0 ], Value ):
577+ return value [0 ]
567578 return UnaryOp (self .data_type , self .style , self .op , value )
568579
569580 def evaluate (self , known : dict ):
570- value = [x .evaluate (known ) for x in self .value ]
581+ def evaluate_to_single (node ):
582+ v = node .evaluate (known )
583+ return v [0 ] if hasattr (v , '__len__' ) and len (v ) == 1 else v
584+ value = [evaluate_to_single (x ) for x in self .value ]
571585 return UnaryOp (self .data_type , self .style , self .op , value ).simplify ()
572586
573587 def depends_on (self , name : str ):
@@ -839,6 +853,18 @@ def __mul__(self, other):
839853 return BinaryOp (self .data_type , OpStyle .C , '*' , [self ], [other ])
840854 return BinaryOp (self .data_type , OpStyle .C , '*' , [self ], [other ]) if pdt .is_str else Value (self .value * other .value , pdt )
841855
856+ def __mod__ (self , other ):
857+ other = other if isinstance (other , (Value , Op )) else Value .best (other )
858+ pdt = self .data_type
859+ if other .is_op or self .is_id or other .is_id or pdt .is_str :
860+ return BinaryOp (self .data_type , OpStyle .C , '%' , [self ], [other ])
861+ # This is neither consistent with the Python or C definition
862+ # Python -- the result takes the sign of the divisor
863+ # C -- the sign is always positive for positive numerator
864+ # but compiler-dependent for negative numerator
865+ # return Value(self.value - (self.value // other.value) * other.value, pdt)
866+ return Value (self .value % other .value , pdt )
867+
842868 def __truediv__ (self , other ):
843869 other = other if isinstance (other , (Value , Op )) else Value .best (other )
844870 pdt = self .data_type / other .data_type
@@ -919,6 +945,8 @@ def __pow__(self, power):
919945 return self
920946 if power .is_zero :
921947 return Value (1 , self .data_type )
948+ if self .is_constant and power .is_constant :
949+ return Value (_value = self .value ** power .value , _data = self .data_type )
922950 return BinaryOp (self .data_type , OpStyle .C , '__pow__' , [self ], [power ])
923951
924952 @property
@@ -991,6 +1019,9 @@ def from_dict(cls, args: dict):
9911019 def __post_init__ (self ):
9921020 if not isinstance (self .expr , list ):
9931021 self .expr = [self .expr ]
1022+ if any (not isinstance (node , ExprNodeSingular ) for node in self .expr ):
1023+ types = list (dict .fromkeys (type (node ) for node in self .expr ).keys ())
1024+ raise ValueError (f"An Expr can not be a list of { types } " )
9941025
9951026 def __str__ (self ):
9961027 return ',' .join (str (x ) for x in self .expr )
@@ -1157,6 +1188,9 @@ def __sub__(self, other):
11571188 def __mul__ (self , other ):
11581189 return Expr (self .expr [0 ] * self ._prep_numeric_operation ('multiply' , other ))
11591190
1191+ def __mod__ (self , other ):
1192+ return Expr (self .expr [0 ] % self ._prep_numeric_operation ('mod' , other ))
1193+
11601194 def __truediv__ (self , other ):
11611195 return Expr (self .expr [0 ] / self ._prep_numeric_operation ('divide' , other ))
11621196
@@ -1293,10 +1327,16 @@ def shape_type(self, st):
12931327
12941328 def simplify (self ):
12951329 """Perform a very basic analysis to reduce the expression complexity"""
1296- return Expr ([x .simplify () for x in self .expr ])
1330+ def simplify_to_single_or_list (node ):
1331+ s = node .simplify ()
1332+ return s [0 ] if hasattr (s , '__len__' ) and len (s ) == 1 else s
1333+ return Expr ([simplify_to_single_or_list (x ) for x in self .expr ])
12971334
12981335 def evaluate (self , known : dict ):
1299- return Expr ([x .evaluate (known ) for x in self .expr ]).simplify ()
1336+ def evaluate_to_single_or_list (node ):
1337+ s = node .evaluate (known )
1338+ return s [0 ] if hasattr (s , '__len__' ) and len (s ) == 1 else s
1339+ return Expr ([evaluate_to_single_or_list (x ) for x in self .expr ]).simplify ()
13001340
13011341 def depends_on (self , name : str ):
13021342 return any (x .depends_on (name ) for x in self .expr )
0 commit comments