Skip to content

Commit a407f01

Browse files
committed
Implement #[thrust::extern_spec_fn]
1 parent 827c750 commit a407f01

File tree

5 files changed

+95
-2
lines changed

5 files changed

+95
-2
lines changed

src/analyze/annot.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] {
3333
[Symbol::intern("thrust"), Symbol::intern("callable")]
3434
}
3535

36+
pub fn extern_spec_fn_path() -> [Symbol; 2] {
37+
[Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")]
38+
}
39+
3640
/// A [`annot::Resolver`] implementation for resolving function parameters.
3741
///
3842
/// The parameter names and their sorts needs to be configured via

src/analyze/crate_.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
4848
self.trusted.insert(local_def_id.to_def_id());
4949
}
5050

51+
if analyzer.is_annotated_as_extern_spec_fn() {
52+
assert!(analyzer.is_fully_annotated());
53+
self.trusted.insert(local_def_id.to_def_id());
54+
}
55+
5156
use mir_ty::TypeVisitableExt as _;
5257
if sig.has_param() && !analyzer.is_fully_annotated() {
5358
self.ctx.register_deferred_def(local_def_id.to_def_id());
5459
} else {
5560
let expected = analyzer.expected_ty();
56-
self.ctx.register_def(local_def_id.to_def_id(), expected);
61+
let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() {
62+
analyzer.extern_spec_fn_target_def_id()
63+
} else {
64+
local_def_id.to_def_id()
65+
};
66+
self.ctx.register_def(target_def_id, expected);
5767
}
5868
}
5969

src/analyze/local_def.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet;
66
use rustc_index::IndexVec;
77
use rustc_middle::mir::{self, BasicBlock, Body, Local};
88
use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut};
9-
use rustc_span::def_id::LocalDefId;
9+
use rustc_span::def_id::{DefId, LocalDefId};
1010
use rustc_span::symbol::Ident;
1111

1212
use crate::analyze;
@@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
126126
.is_some()
127127
}
128128

129+
pub fn is_annotated_as_extern_spec_fn(&self) -> bool {
130+
self.tcx
131+
.get_attrs_by_path(
132+
self.local_def_id.to_def_id(),
133+
&analyze::annot::extern_spec_fn_path(),
134+
)
135+
.next()
136+
.is_some()
137+
}
138+
129139
// TODO: unify this logic with extraction functions above
130140
pub fn is_fully_annotated(&self) -> bool {
131141
let has_require = self
@@ -240,6 +250,46 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
240250
rty::RefinedType::unrefined(builder.build().into())
241251
}
242252

253+
pub fn extern_spec_fn_target_def_id(&self) -> DefId {
254+
struct ExtractDefId<'tcx> {
255+
tcx: TyCtxt<'tcx>,
256+
outer_def_id: LocalDefId,
257+
inner_def_id: Option<DefId>,
258+
}
259+
260+
impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> {
261+
type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies;
262+
263+
fn nested_visit_map(&mut self) -> Self::Map {
264+
self.tcx.hir()
265+
}
266+
267+
fn visit_qpath(
268+
&mut self,
269+
qpath: &rustc_hir::QPath<'tcx>,
270+
hir_id: rustc_hir::HirId,
271+
_span: rustc_span::Span,
272+
) {
273+
let typeck_result = self.tcx.typeck(self.outer_def_id);
274+
if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id)
275+
{
276+
self.inner_def_id = Some(def_id);
277+
}
278+
}
279+
}
280+
281+
use rustc_hir::intravisit::Visitor as _;
282+
let mut visitor = ExtractDefId {
283+
tcx: self.tcx,
284+
outer_def_id: self.local_def_id,
285+
inner_def_id: None,
286+
};
287+
if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) {
288+
visitor.visit_item(item);
289+
}
290+
visitor.inner_def_id.expect("invalid extern_spec_fn")
291+
}
292+
243293
fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool {
244294
let param_local = analyze::local_of_function_param(param_idx);
245295
self.body.local_decls[param_local].mutability.is_mut()

tests/ui/fail/extern_spec_take.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[thrust::extern_spec_fn]
4+
#[thrust::requires(true)]
5+
#[thrust::ensures(result == *dest && ^dest == 0)]
6+
fn _extern_spec_take(dest: &mut i32) -> i32 {
7+
std::mem::take(dest)
8+
}
9+
10+
fn main() {
11+
let mut x = 42;
12+
let old = std::mem::take(&mut x);
13+
assert!(x == 42);
14+
}

tests/ui/pass/extern_spec_take.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//@check-pass
2+
3+
#[thrust::extern_spec_fn]
4+
#[thrust::requires(true)]
5+
#[thrust::ensures(result == *dest && ^dest == 0)]
6+
fn _extern_spec_take(dest: &mut i32) -> i32 {
7+
std::mem::take(dest)
8+
}
9+
10+
fn main() {
11+
let mut x = 42;
12+
let old = std::mem::take(&mut x);
13+
assert!(old == 42);
14+
assert!(x == 0);
15+
}

0 commit comments

Comments
 (0)