Skip to content

Commit 97826b9

Browse files
authored
Add dependency missing from confuse (#184)
* [Add] testing for Expr and Orient behavior * [Add] dependency missing for confuse == 2.2.0
1 parent ecb689d commit 97826b9

File tree

13 files changed

+1574
-534
lines changed

13 files changed

+1574
-534
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ dependencies = [
1313
'antlr4-python3-runtime>=4.13.2',
1414
'numpy>=2',
1515
'pooch>=1.7.0',
16-
'confuse>=2.0.1',
16+
'confuse>=2.2.0',
17+
'typing-extensions>=4.15.0',
1718
'loguru>=0.7.2',
1819
'gitpython>=3.1.43',
1920
'msgspec>=0.19.0',

src/mccode_antlr/common/expression.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,16 @@ def simplify(self):
324324
f, s, t = [[x.simplify() for x in y] for y in (self.first, self.second, self.third)]
325325
if self.op == '__trinary__' and len(f) == 1:
326326
if f[0].is_value(True) and not f[0].is_value(False):
327-
return Expr(s)
327+
return s
328328
if f[0].is_value(False) and not f[0].is_value(True):
329-
return Expr(t)
329+
return t
330330
return TrinaryOp(self.data_type, self.style, self.op, f, s, t)
331331

332332
def evaluate(self, known: dict):
333-
first, second, third = [[x.evaluate(known) for x in y] for y in (self.first, self.second, self.third)]
333+
def evaluate_to_single(node):
334+
v = node.evaluate(known)
335+
return v[0] if hasattr(v, '__len__') and len(v) == 1 else v
336+
first, second, third = [[evaluate_to_single(x) for x in y] for y in (self.first, self.second, self.third)]
334337
return TrinaryOp(self.data_type, self.style, self.op, first, second, third).simplify()
335338

336339
def depends_on(self, name: str):
@@ -451,14 +454,13 @@ def simplify(self):
451454
left = [x.simplify() for x in self.left]
452455
right = [x.simplify() for x in self.right]
453456
if len(left) == 1 and ((left[0].is_zero and self.op == '+') or (left[0].is_value(1) and self.op == '*')):
454-
return Expr(right)
457+
return right
455458
if len(right) == 1 and (
456459
(right[0].is_zero and any(x == self.op for x in '+-')) or
457460
(right[0].is_value(1) and any(x == self.op for x in '*/'))
458461
):
459-
return Expr(left)
460-
if len(left) == 1 and len(right) == 1 and left[0].is_constant and right[0].is_constant\
461-
and self.op in ('+', '-', '*', '/'):
462+
return left
463+
if len(left) == 1 and len(right) == 1 and left[0].is_constant and right[0].is_constant:
462464
if self.op == '+':
463465
return left[0] + right[0]
464466
if self.op == '-':
@@ -467,11 +469,16 @@ def simplify(self):
467469
return left[0] * right[0]
468470
if self.op == '/':
469471
return left[0] / right[0]
472+
if self.op == '__pow__':
473+
return left[0] ** right[0]
470474
# punt!
471475
return BinaryOp(self.data_type, self.style, self.op, left, right)
472476

473477
def evaluate(self, known: dict):
474-
left, right = [[x.evaluate(known) for x in y] for y in (self.left, self.right)]
478+
def evaluate_to_single(node):
479+
v = node.evaluate(known)
480+
return v[0] if hasattr(v, '__len__') and len(v) == 1 else v
481+
left, right = [[evaluate_to_single(x) for x in y] for y in (self.left, self.right)]
475482
return BinaryOp(self.data_type, self.style, self.op, left, right).simplify()
476483

477484
def depends_on(self, name: str):
@@ -564,10 +571,17 @@ def simplify(self):
564571
value = [v.simplify() for v in self.value]
565572
if self.op == '__group__' and len(value) == 1 and isinstance(value[0], Value):
566573
return value[0] # Expr(value)
574+
elif self.op == '-' and len(value) == 1 and isinstance(value[0], Value):
575+
return -value[0]
576+
elif self.op == '+' and len(value) == 1 and isinstance(value[0], Value):
577+
return value[0]
567578
return UnaryOp(self.data_type, self.style, self.op, value)
568579

569580
def evaluate(self, known: dict):
570-
value = [x.evaluate(known) for x in self.value]
581+
def evaluate_to_single(node):
582+
v = node.evaluate(known)
583+
return v[0] if hasattr(v, '__len__') and len(v) == 1 else v
584+
value = [evaluate_to_single(x) for x in self.value]
571585
return UnaryOp(self.data_type, self.style, self.op, value).simplify()
572586

573587
def depends_on(self, name: str):
@@ -839,6 +853,18 @@ def __mul__(self, other):
839853
return BinaryOp(self.data_type, OpStyle.C, '*', [self], [other])
840854
return BinaryOp(self.data_type, OpStyle.C, '*', [self], [other]) if pdt.is_str else Value(self.value * other.value, pdt)
841855

856+
def __mod__(self, other):
857+
other = other if isinstance(other, (Value, Op)) else Value.best(other)
858+
pdt = self.data_type
859+
if other.is_op or self.is_id or other.is_id or pdt.is_str:
860+
return BinaryOp(self.data_type, OpStyle.C, '%', [self], [other])
861+
# This is neither consistent with the Python or C definition
862+
# Python -- the result takes the sign of the divisor
863+
# C -- the sign is always positive for positive numerator
864+
# but compiler-dependent for negative numerator
865+
# return Value(self.value - (self.value // other.value) * other.value, pdt)
866+
return Value(self.value % other.value, pdt)
867+
842868
def __truediv__(self, other):
843869
other = other if isinstance(other, (Value, Op)) else Value.best(other)
844870
pdt = self.data_type / other.data_type
@@ -919,6 +945,8 @@ def __pow__(self, power):
919945
return self
920946
if power.is_zero:
921947
return Value(1, self.data_type)
948+
if self.is_constant and power.is_constant:
949+
return Value(_value=self.value ** power.value, _data=self.data_type)
922950
return BinaryOp(self.data_type, OpStyle.C, '__pow__', [self], [power])
923951

924952
@property
@@ -991,6 +1019,9 @@ def from_dict(cls, args: dict):
9911019
def __post_init__(self):
9921020
if not isinstance(self.expr, list):
9931021
self.expr = [self.expr]
1022+
if any(not isinstance(node, ExprNodeSingular) for node in self.expr):
1023+
types = list(dict.fromkeys(type(node) for node in self.expr).keys())
1024+
raise ValueError(f"An Expr can not be a list of {types}")
9941025

9951026
def __str__(self):
9961027
return ','.join(str(x) for x in self.expr)
@@ -1157,6 +1188,9 @@ def __sub__(self, other):
11571188
def __mul__(self, other):
11581189
return Expr(self.expr[0] * self._prep_numeric_operation('multiply', other))
11591190

1191+
def __mod__(self, other):
1192+
return Expr(self.expr[0] % self._prep_numeric_operation('mod', other))
1193+
11601194
def __truediv__(self, other):
11611195
return Expr(self.expr[0] / self._prep_numeric_operation('divide', other))
11621196

@@ -1293,10 +1327,16 @@ def shape_type(self, st):
12931327

12941328
def simplify(self):
12951329
"""Perform a very basic analysis to reduce the expression complexity"""
1296-
return Expr([x.simplify() for x in self.expr])
1330+
def simplify_to_single_or_list(node):
1331+
s = node.simplify()
1332+
return s[0] if hasattr(s, '__len__') and len(s) == 1 else s
1333+
return Expr([simplify_to_single_or_list(x) for x in self.expr])
12971334

12981335
def evaluate(self, known: dict):
1299-
return Expr([x.evaluate(known) for x in self.expr]).simplify()
1336+
def evaluate_to_single_or_list(node):
1337+
s = node.evaluate(known)
1338+
return s[0] if hasattr(s, '__len__') and len(s) == 1 else s
1339+
return Expr([evaluate_to_single_or_list(x) for x in self.expr]).simplify()
13001340

13011341
def depends_on(self, name: str):
13021342
return any(x.depends_on(name) for x in self.expr)

src/mccode_antlr/common/textwrap.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,8 @@ def quoted_line(self, name: str, items: list[str], sep: str = ' ') -> str:
250250
def lines(self, items: list[str]) -> str:
251251
return '<br>'.join('<br>'.join(self.wrap(line)) for line in items)
252252

253-
@staticmethod
254-
def metadata_group(name: str, mimetype: str, item: str, value: str) -> str:
255-
return f'<b>{name}</b> <code>"{mimetype}"</code>" <var>{item}</var> %{{<pre>{value}</pre>%}}'
253+
def metadata_group(self, name: str, mimetype: str, item: str, value: str) -> str:
254+
return f'<b>{name}</b> <code>{mimetype}</code> <var>{item}</var>' + ' %{' + self.hide(f'<pre>{value}</pre>') +'%}'
256255

257256
@staticmethod
258257
def datatype(data_type: str) -> str:

src/mccode_antlr/common/visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def visitExpressionIdentifier(obj, ctx):
3737
# check if this identifier is an InstrumentParameter name:
3838
name = str(ctx.Identifier())
3939
inst_par = obj.state.get_parameter(name, None)
40-
obj = ObjectType.parameter if inst_par is not None else ObjectType.identifier
40+
obj_type = ObjectType.parameter if inst_par is not None else ObjectType.identifier
4141
dat = inst_par.value.data_type if inst_par is not None else DataType.undefined
42-
return Expr(Value(name, _data=dat, _object=obj))
42+
return Expr(Value(name, _data=dat, _object=obj_type))
4343

4444
def visitExpressionInteger(obj, ctx):
4545
return Expr.int(str(ctx.IntegerLiteral()))

0 commit comments

Comments
 (0)