Skip to content

Commit d11565d

Browse files
authored
Merge pull request #10 from coord-e/closures
Support closures (with simple usage of traits)
2 parents b41df50 + 9b2c3ed commit d11565d

18 files changed

+370
-26
lines changed

src/analyze.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ impl<'tcx> Analyzer<'tcx> {
341341

342342
/// Computes the signature of the local function.
343343
///
344-
/// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
344+
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
345345
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
346346
/// reflects potential type instantiations happened after `optimized_mir`.
347347
pub fn local_fn_sig_with_body(
@@ -364,4 +364,14 @@ impl<'tcx> Analyzer<'tcx> {
364364
sig.abi,
365365
)
366366
}
367+
368+
/// Computes the signature of the local function.
369+
///
370+
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
371+
/// but extracts parameter and return types directly from [`mir::Body`] to obtain a signature that
372+
/// reflects the actual type of lifted closure functions.
373+
pub fn local_fn_sig(&self, local_def_id: LocalDefId) -> mir_ty::FnSig<'tcx> {
374+
let body = self.tcx.optimized_mir(local_def_id);
375+
self.local_fn_sig_with_body(local_def_id, body)
376+
}
367377
}

src/analyze/basic_block.rs

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,55 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
8484
) -> Vec<chc::Clause> {
8585
let mut clauses = Vec::new();
8686

87-
if expected_args.is_empty() {
88-
// elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl)
89-
expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous());
90-
}
9187
tracing::debug!(
9288
got = %got.display(),
9389
expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(),
9490
"fn_sub_type"
9591
);
9692

93+
match got.abi {
94+
rty::FunctionAbi::Rust => {
95+
if expected_args.is_empty() {
96+
// elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl)
97+
expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous());
98+
}
99+
}
100+
rty::FunctionAbi::RustCall => {
101+
// &Closure, { v: (own i32, own bool) | v = (<0>, <false>) }
102+
// =>
103+
// &Closure, { v: i32 | (<v>, _) = (<0>, <false>) }, { v: bool | (_, <v>) = (<0>, <false>) }
104+
105+
let rty::RefinedType { ty, mut refinement } =
106+
expected_args.pop().expect("rust-call last arg");
107+
let ty = ty.into_tuple().expect("rust-call last arg is tuple");
108+
let mut replacement_tuple = Vec::new(); // will be (<v>, _) or (_, <v>)
109+
for elem in &ty.elems {
110+
let existential = refinement.existentials.push(elem.ty.to_sort());
111+
replacement_tuple.push(chc::Term::var(rty::RefinedTypeVar::Existential(
112+
existential,
113+
)));
114+
}
115+
116+
for (i, elem) in ty.elems.into_iter().enumerate() {
117+
// all tuple elements are boxed during the translation to rty::Type
118+
let mut param_ty = elem.deref();
119+
param_ty
120+
.refinement
121+
.push_conj(refinement.clone().subst_value_var(|| {
122+
let mut value_elems = replacement_tuple.clone();
123+
value_elems[i] = chc::Term::var(rty::RefinedTypeVar::Value).boxed();
124+
chc::Term::tuple(value_elems)
125+
}));
126+
expected_args.push(param_ty);
127+
}
128+
129+
tracing::info!(
130+
expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(),
131+
"rust-call expanded",
132+
);
133+
}
134+
}
135+
97136
// TODO: check sty and length is equal
98137
let mut builder = self.env.build_clause();
99138
for (param_idx, param_rty) in got.params.iter_enumerated() {
@@ -175,6 +214,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
175214
chc::Term::bool(val.try_to_bool().unwrap()),
176215
)
177216
}
217+
(mir_ty::TyKind::Tuple(tys), _) if tys.is_empty() => {
218+
PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![]))
219+
}
220+
(mir_ty::TyKind::Closure(_, args), _) if args.as_closure().upvar_tys().is_empty() => {
221+
PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![]))
222+
}
178223
(
179224
mir_ty::TyKind::Ref(_, elem, Mutability::Not),
180225
ConstValue::Scalar(Scalar::Ptr(ptr, _)),
@@ -568,12 +613,25 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
568613
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
569614
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
570615
}
571-
Some((def_id, args)) => self
572-
.ctx
573-
.def_ty_with_args(def_id, args)
574-
.expect("unknown def")
575-
.ty
576-
.vacuous(),
616+
Some((def_id, args)) => {
617+
let param_env = self.tcx.param_env(self.local_def_id);
618+
let instance =
619+
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
620+
let resolved_def_id = if let Some(instance) = instance {
621+
instance.def_id()
622+
} else {
623+
def_id
624+
};
625+
if def_id != resolved_def_id {
626+
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
627+
}
628+
629+
self.ctx
630+
.def_ty_with_args(resolved_def_id, args)
631+
.expect("unknown def")
632+
.ty
633+
.vacuous()
634+
}
577635
_ => self.operand_type(func.clone()).ty,
578636
};
579637
let expected_args: IndexVec<_, _> = args

src/analyze/crate_.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
3939

4040
#[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(local_def_id)))]
4141
fn refine_fn_def(&mut self, local_def_id: LocalDefId) {
42+
let sig = self.ctx.local_fn_sig(local_def_id);
43+
4244
let mut analyzer = self.ctx.local_def_analyzer(local_def_id);
4345

4446
if analyzer.is_annotated_as_trusted() {
4547
assert!(analyzer.is_fully_annotated());
4648
self.trusted.insert(local_def_id.to_def_id());
4749
}
4850

49-
let sig = self
50-
.tcx
51-
.fn_sig(local_def_id)
52-
.instantiate_identity()
53-
.skip_binder();
5451
use mir_ty::TypeVisitableExt as _;
5552
if sig.has_param() && !analyzer.is_fully_annotated() {
5653
self.ctx.register_deferred_def(local_def_id.to_def_id());

src/chc.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,10 @@ impl<V> Term<V> {
623623
Term::Mut(Box::new(t1), Box::new(t2))
624624
}
625625

626+
pub fn boxed(self) -> Self {
627+
Term::Box(Box::new(self))
628+
}
629+
626630
pub fn box_current(self) -> Self {
627631
Term::BoxCurrent(Box::new(self))
628632
}

src/refine/template.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ impl<'tcx> TypeBuilder<'tcx> {
163163
unimplemented!("unsupported ADT: {:?}", ty);
164164
}
165165
}
166+
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
166167
kind => unimplemented!("unrefined_ty: {:?}", kind),
167168
}
168169
}
@@ -183,6 +184,11 @@ impl<'tcx> TypeBuilder<'tcx> {
183184
registry: &'a mut R,
184185
sig: mir_ty::FnSig<'tcx>,
185186
) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
187+
let abi = match sig.abi {
188+
rustc_target::spec::abi::Abi::Rust => rty::FunctionAbi::Rust,
189+
rustc_target::spec::abi::Abi::RustCall => rty::FunctionAbi::RustCall,
190+
_ => unimplemented!("unsupported function ABI: {:?}", sig.abi),
191+
};
186192
FunctionTemplateTypeBuilder {
187193
inner: self.clone(),
188194
registry,
@@ -198,6 +204,7 @@ impl<'tcx> TypeBuilder<'tcx> {
198204
param_rtys: Default::default(),
199205
param_refinement: None,
200206
ret_rty: None,
207+
abi,
201208
}
202209
}
203210
}
@@ -282,6 +289,7 @@ where
282289
unimplemented!("unsupported ADT: {:?}", ty);
283290
}
284291
}
292+
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
285293
kind => unimplemented!("ty: {:?}", kind),
286294
}
287295
}
@@ -301,9 +309,12 @@ where
301309
where
302310
I: IntoIterator<Item = (Local, mir_ty::TypeAndMut<'tcx>)>,
303311
{
312+
// this is necessary for local_def::Analyzer::elaborate_unused_args
313+
let mut live_locals: Vec<_> = live_locals.into_iter().collect();
314+
live_locals.sort_by_key(|(local, _)| *local);
315+
304316
let mut locals = IndexVec::<rty::FunctionParamIdx, _>::new();
305317
let mut tys = Vec::new();
306-
// TODO: avoid two iteration and assumption of FunctionParamIdx match between locals and ty
307318
for (local, ty) in live_locals {
308319
locals.push((local, ty.mutbl));
309320
tys.push(ty);
@@ -316,6 +327,7 @@ where
316327
param_rtys: Default::default(),
317328
param_refinement: None,
318329
ret_rty: None,
330+
abi: Default::default(),
319331
}
320332
.build();
321333
BasicBlockType { ty, locals }
@@ -331,6 +343,7 @@ pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> {
331343
param_refinement: Option<rty::Refinement<rty::FunctionParamIdx>>,
332344
param_rtys: HashMap<rty::FunctionParamIdx, rty::RefinedType<rty::FunctionParamIdx>>,
333345
ret_rty: Option<rty::RefinedType<rty::FunctionParamIdx>>,
346+
abi: rty::FunctionAbi,
334347
}
335348

336349
impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
@@ -439,6 +452,6 @@ where
439452
.with_scope(&builder)
440453
.build_refined(self.ret_ty)
441454
});
442-
rty::FunctionType::new(param_rtys, ret_rty)
455+
rty::FunctionType::new(param_rtys, ret_rty).with_abi(self.abi)
443456
}
444457
}

src/rty.rs

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,36 @@ where
8383
}
8484
}
8585

86+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
87+
pub enum FunctionAbi {
88+
#[default]
89+
Rust,
90+
RustCall,
91+
}
92+
93+
impl std::fmt::Display for FunctionAbi {
94+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
95+
f.write_str(self.name())
96+
}
97+
}
98+
99+
impl FunctionAbi {
100+
pub fn name(&self) -> &'static str {
101+
match self {
102+
FunctionAbi::Rust => "rust",
103+
FunctionAbi::RustCall => "rust-call",
104+
}
105+
}
106+
107+
pub fn is_rust(&self) -> bool {
108+
matches!(self, FunctionAbi::Rust)
109+
}
110+
111+
pub fn is_rust_call(&self) -> bool {
112+
matches!(self, FunctionAbi::RustCall)
113+
}
114+
}
115+
86116
/// A function type.
87117
///
88118
/// In Thrust, function types are closed. Because of that, function types, thus its parameters and
@@ -92,6 +122,7 @@ where
92122
pub struct FunctionType {
93123
pub params: IndexVec<FunctionParamIdx, RefinedType<FunctionParamIdx>>,
94124
pub ret: Box<RefinedType<FunctionParamIdx>>,
125+
pub abi: FunctionAbi,
95126
}
96127

97128
impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b FunctionType
@@ -100,15 +131,25 @@ where
100131
D::Doc: Clone,
101132
{
102133
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> {
134+
let abi = match self.abi {
135+
FunctionAbi::Rust => allocator.nil(),
136+
abi => allocator
137+
.text("extern")
138+
.append(allocator.space())
139+
.append(allocator.as_string(abi))
140+
.append(allocator.space()),
141+
};
103142
let separator = allocator.text(",").append(allocator.line());
104-
allocator
105-
.intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator)
106-
.parens()
107-
.append(allocator.space())
108-
.append(allocator.text("→"))
109-
.append(allocator.line())
110-
.append(self.ret.pretty(allocator))
111-
.group()
143+
abi.append(
144+
allocator
145+
.intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator)
146+
.parens(),
147+
)
148+
.append(allocator.space())
149+
.append(allocator.text("→"))
150+
.append(allocator.line())
151+
.append(self.ret.pretty(allocator))
152+
.group()
112153
}
113154
}
114155

@@ -120,9 +161,15 @@ impl FunctionType {
120161
FunctionType {
121162
params,
122163
ret: Box::new(ret),
164+
abi: FunctionAbi::Rust,
123165
}
124166
}
125167

168+
pub fn with_abi(mut self, abi: FunctionAbi) -> Self {
169+
self.abi = abi;
170+
self
171+
}
172+
126173
/// Because function types are always closed in Thrust, we can convert this into
127174
/// [`Type<Closed>`].
128175
pub fn into_closed_ty(self) -> Type<Closed> {
@@ -1304,6 +1351,32 @@ impl<FV> RefinedType<FV> {
13041351
RefinedType { ty, refinement }
13051352
}
13061353

1354+
/// Returns a dereferenced type of the immutable reference or owned pointer.
1355+
///
1356+
/// e.g. `{ v: Box<T> | φ } --> { v: T | φ[box v/v] }`
1357+
pub fn deref(self) -> Self {
1358+
let RefinedType {
1359+
ty,
1360+
refinement: outer_refinement,
1361+
} = self;
1362+
let inner_ty = ty.into_pointer().expect("invalid deref");
1363+
if inner_ty.is_mut() {
1364+
// losing info about proph
1365+
panic!("invalid deref");
1366+
}
1367+
let RefinedType {
1368+
ty: inner_ty,
1369+
refinement: mut inner_refinement,
1370+
} = *inner_ty.elem;
1371+
inner_refinement.push_conj(
1372+
outer_refinement.subst_value_var(|| chc::Term::var(RefinedTypeVar::Value).boxed()),
1373+
);
1374+
RefinedType {
1375+
ty: inner_ty,
1376+
refinement: inner_refinement,
1377+
}
1378+
}
1379+
13071380
pub fn subst_var<F, W>(self, mut f: F) -> RefinedType<W>
13081381
where
13091382
F: FnMut(FV) -> chc::Term<W>,

tests/ui/fail/closure_mut.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
fn main() {
5+
let mut x = 1;
6+
let mut incr = |by: i32| {
7+
x += by;
8+
};
9+
incr(5);
10+
incr(5);
11+
assert!(x == 10);
12+
}

tests/ui/fail/closure_mut_0.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
fn main() {
5+
let mut x = 1;
6+
x += 1;
7+
let mut incr = || {
8+
x += 1;
9+
};
10+
incr();
11+
assert!(x == 2);
12+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
fn main() {
5+
let incr = |x| {
6+
x + 1
7+
};
8+
assert!(incr(2) == 2);
9+
}

0 commit comments

Comments
 (0)