Skip to content

Commit cdcac7c

Browse files
Pass ctx to _py2expr when it is available (#113)
1 parent 3cc7467 commit cdcac7c

File tree

1 file changed

+51
-44
lines changed

1 file changed

+51
-44
lines changed

cvc5_pythonic_api/cvc5_pythonic.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,16 @@ def _get_ctx(ctx):
213213
return ctx
214214

215215

216+
def _get_ctx2(a, b, ctx=None):
217+
if is_expr(a):
218+
return a.ctx
219+
if is_expr(b):
220+
return b.ctx
221+
if ctx is None:
222+
ctx = main_ctx()
223+
return ctx
224+
225+
216226
def get_ctx(ctx):
217227
"""
218228
Returns `ctx` if it is not `None`, and the default context otherwise.
@@ -2128,9 +2138,8 @@ def Length(s, ctx=None):
21282138
>>> simplify(l)
21292139
1
21302140
"""
2131-
s = _py2expr(s)
2132-
ctx = _get_ctx(ctx)
2133-
return ArithRef(ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), ctx)
2141+
s = _py2expr(s, ctx)
2142+
return ArithRef(s.ctx.tm.mkTerm(Kind.SEQ_LENGTH, s.ast), s.ctx)
21342143

21352144

21362145
def SubString(s, offset, length, ctx=None):
@@ -2141,16 +2150,15 @@ def SubString(s, offset, length, ctx=None):
21412150
>>> simplify(SubString(StringVal('hello'),3,2))
21422151
"lo"
21432152
"""
2144-
ctx = _get_ctx(ctx)
2145-
s = _py2expr(s)
2146-
offset = _py2expr(offset)
2147-
length = _py2expr(length)
2153+
s = _py2expr(s, ctx)
2154+
offset = _py2expr(offset, s.ctx)
2155+
length = _py2expr(length, s.ctx)
21482156
return StringRef(
2149-
ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), ctx
2157+
s.ctx.tm.mkTerm(Kind.STRING_SUBSTR, s.ast, offset.ast, length.ast), s.ctx
21502158
)
21512159

21522160

2153-
def SubSeq(s, offset, length):
2161+
def SubSeq(s, offset, length, ctx=None):
21542162
"""Extract subsequence starting at offset
21552163
21562164
>>> seq = Concat(Unit(IntVal(1)),Unit(IntVal(2)))
@@ -2159,9 +2167,9 @@ def SubSeq(s, offset, length):
21592167
>>> simplify(SubSeq(seq,1,0))
21602168
(as seq.empty (Seq Int))()
21612169
"""
2162-
s = _py2expr(s)
2163-
offset = _py2expr(offset)
2164-
length = _py2expr(length)
2170+
s = _py2expr(s, ctx)
2171+
offset = _py2expr(offset, s.ctx)
2172+
length = _py2expr(length, s.ctx)
21652173
return SeqRef(
21662174
s.ctx.tm.mkTerm(Kind.SEQ_EXTRACT, s.ast, offset.ast, length.ast), s.ctx
21672175
)
@@ -2178,7 +2186,7 @@ def SeqUpdate(s, t, i):
21782186
>>> simplify(SeqUpdate(lst,Unit(IntVal(1)),4))
21792187
(seq.++ (seq.unit 1) (seq.unit 2) (seq.unit 3))()
21802188
"""
2181-
i = _py2expr(i)
2189+
i = _py2expr(i, t.ctx)
21822190
return SeqRef(t.ctx.tm.mkTerm(Kind.SEQ_UPDATE, s.ast, i.ast, t.ast), t.ctx)
21832191

21842192

@@ -2206,9 +2214,9 @@ def Contains(a, b, ctx=None):
22062214
>>> simplify(s)
22072215
True
22082216
"""
2209-
ctx = _get_ctx(ctx)
2210-
a = _py2expr(a)
2211-
b = _py2expr(b)
2217+
ctx = _get_ctx2(a, b, ctx)
2218+
a = _py2expr(a, ctx)
2219+
b = _py2expr(b, ctx)
22122220
if is_string(a) and is_string(b):
22132221
return BoolRef(ctx.tm.mkTerm(Kind.STRING_CONTAINS, a.ast, b.ast), ctx)
22142222
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_CONTAINS, a.ast, b.ast), ctx)
@@ -2224,9 +2232,9 @@ def PrefixOf(a, b, ctx=None):
22242232
>>> simplify(s2)
22252233
False
22262234
"""
2227-
ctx = _get_ctx(ctx)
2228-
a = _py2expr(a)
2229-
b = _py2expr(b)
2235+
ctx = _get_ctx2(a, b, ctx)
2236+
a = _py2expr(a, ctx)
2237+
b = _py2expr(b, ctx)
22302238
if is_string(a) and is_string(b):
22312239
return BoolRef(ctx.tm.mkTerm(Kind.STRING_PREFIX, a.ast, b.ast), ctx)
22322240
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_PREFIX, a.ast, b.ast), ctx)
@@ -2242,9 +2250,9 @@ def SuffixOf(a, b, ctx=None):
22422250
>>> simplify(s2)
22432251
True
22442252
"""
2245-
ctx = _get_ctx(ctx)
2246-
a = _py2expr(a)
2247-
b = _py2expr(b)
2253+
ctx = _get_ctx2(a, b, ctx)
2254+
a = _py2expr(a, ctx)
2255+
b = _py2expr(b, ctx)
22482256
if is_string(a) and is_string(b):
22492257
return BoolRef(ctx.tm.mkTerm(Kind.STRING_SUFFIX, a.ast, b.ast), ctx)
22502258
return BoolRef(ctx.tm.mkTerm(Kind.SEQ_SUFFIX, a.ast, b.ast), ctx)
@@ -2260,11 +2268,11 @@ def IndexOf(s, substr, offset=None):
22602268
>>> simplify(IndexOf("abcabc", "bc", 2))
22612269
4
22622270
"""
2271+
ctx = _get_ctx2(s, substr)
22632272
if offset is None:
2264-
offset = IntVal(0)
2265-
s = _py2expr(s)
2266-
substr = _py2expr(substr)
2267-
ctx = _get_ctx(None)
2273+
offset = IntVal(0, ctx)
2274+
s = _py2expr(s, ctx)
2275+
substr = _py2expr(substr, ctx)
22682276
if _is_int(offset):
22692277
offset = IntVal(offset, ctx)
22702278
if is_string(s) and is_string(substr):
@@ -2285,10 +2293,12 @@ def Replace(s, src, dst):
22852293
>>> simplify(Replace(seq,Unit(IntVal(1)),Unit(IntVal(5))))
22862294
(seq.++ (seq.unit 5) (seq.unit 2))()
22872295
"""
2288-
s = _py2expr(s)
2289-
src = _py2expr(src)
2290-
dst = _py2expr(dst)
2291-
ctx = _get_ctx(None)
2296+
ctx = _get_ctx2(dst, s)
2297+
if ctx is None and is_expr(src):
2298+
ctx = src.ctx
2299+
s = _py2expr(s, ctx)
2300+
src = _py2expr(src, ctx)
2301+
dst = _py2expr(dst, ctx)
22922302
if is_string(s) and is_string(src) and is_string(dst):
22932303
return StringRef(
22942304
ctx.tm.mkTerm(Kind.STRING_REPLACE, s.ast, src.ast, dst.ast), ctx
@@ -2307,8 +2317,7 @@ def StrToInt(s):
23072317
123
23082318
"""
23092319
s = _py2expr(s)
2310-
ctx = _get_ctx(s.ctx)
2311-
return ArithRef(ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), ctx)
2320+
return ArithRef(s.ctx.tm.mkTerm(Kind.STRING_TO_INT, s.ast), s.ctx)
23122321

23132322

23142323
def IntToStr(s):
@@ -2399,7 +2408,7 @@ def Re(s, ctx=None):
23992408
>>> simplify(InRe('b',re))
24002409
False
24012410
"""
2402-
s = _py2expr(s)
2411+
s = _py2expr(s, ctx)
24032412
return ReRef(s.ctx.tm.mkTerm(Kind.STRING_TO_REGEXP, s.ast), s.ctx)
24042413

24052414

@@ -2426,7 +2435,9 @@ def InRe(s, re):
24262435
>>> print (simplify(InRe("c", re)))
24272436
False
24282437
"""
2429-
s = _py2expr(s)
2438+
ctx = _get_ctx2(s, re)
2439+
s = _py2expr(s, ctx)
2440+
re = _py2expr(re, ctx)
24302441
return BoolRef(s.ctx.tm.mkTerm(Kind.STRING_IN_REGEXP, s.ast, re.ast), s.ctx)
24312442

24322443

@@ -2556,6 +2567,7 @@ def Range(lo, hi, ctx=None):
25562567
>>> print(simplify(InRe("bb", range)))
25572568
False
25582569
"""
2570+
ctx = _get_ctx2(lo, hi)
25592571
lo = _py2expr(lo, ctx)
25602572
hi = _py2expr(hi, ctx)
25612573
return ReRef(lo.ctx.tm.mkTerm(Kind.REGEXP_RANGE, lo.ast, hi.ast), lo.ctx)
@@ -4766,12 +4778,8 @@ def Concat(*args):
47664778
sz = len(args)
47674779
if debugging():
47684780
_assert(sz >= 2, "At least two arguments expected.")
4769-
args = [_py2expr(s) for s in args]
4770-
ctx = _get_ctx(None)
4771-
for a in args:
4772-
if is_expr(a):
4773-
ctx = a.ctx
4774-
break
4781+
ctx = _get_ctx(_ctx_from_ast_arg_list(args))
4782+
args = [_py2expr(s, ctx) for s in args]
47754783
if debugging():
47764784
_assert(
47774785
all([is_bv(a) or is_string(a) or is_seq(a) or is_re(a) for a in args]),
@@ -5860,15 +5868,14 @@ def SetComplement(s):
58605868
return ArrayRef(ctx.tm.mkTerm(Kind.SET_COMPLEMENT, s.ast), ctx)
58615869

58625870

5863-
def Singleton(s):
5871+
def Singleton(s, ctx=None):
58645872
"""The single element set of just e
58655873
58665874
>>> Singleton(IntVal(1))
58675875
Singleton(1)
58685876
"""
5869-
s = _py2expr(s)
5870-
ctx = s.ctx
5871-
return SetRef(ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), ctx)
5877+
s = _py2expr(s, ctx)
5878+
return SetRef(s.ctx.tm.mkTerm(Kind.SET_SINGLETON, s.ast), s.ctx)
58725879

58735880

58745881
def SetDifference(a, b):

0 commit comments

Comments
 (0)