Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.env.drop_local(local);
}

/// Schedules `local` to be implicitly dropped after this block's terminator,
/// in addition to the liveness-derived drop points.
fn drop_after_terminator(&mut self, local: Local) {
self.drop_points.insert_after_terminator(local);
}

fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) {
let ty = self.type_builder.build(ty);
let temp_var = self.env.push_temp_var(ty.vacuous());
Expand Down Expand Up @@ -1149,7 +1155,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
targets.clone(),
outer_fn_param_vars,
|a, target| {
for local in a.drop_points.after_terminator(&target).iter() {
for local in a.drop_points.after_terminator(&target) {
tracing::info!(?local, ?target, "implicitly dropped for target");
a.drop_local(local);
}
Expand All @@ -1158,15 +1164,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
TerminatorKind::Call { target, .. } => {
if let Some(target) = target {
for local in self.drop_points.after_terminator(target).iter() {
for local in self.drop_points.after_terminator(target) {
tracing::info!(?local, "implicitly dropped after call");
self.drop_local(local);
}
self.type_goto(*target, outer_fn_param_vars);
}
}
TerminatorKind::Drop { target, .. } => {
for local in self.drop_points.after_terminator(target).iter() {
for local in self.drop_points.after_terminator(target) {
tracing::info!(?local, "dropped");
self.drop_local(local);
}
Expand All @@ -1178,7 +1184,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
target,
..
} => {
for local in self.drop_points.after_terminator(target).iter() {
for local in self.drop_points.after_terminator(target) {
tracing::info!(?local, "dropped");
self.drop_local(local);
}
Expand Down
26 changes: 20 additions & 6 deletions src/analyze/basic_block/drop_point.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};

use rustc_index::bit_set::DenseBitSet;
use rustc_middle::mir::{self, BasicBlock, Body, Local};
Expand All @@ -10,6 +10,10 @@ pub struct DropPoints {
pub before_statements: Vec<Local>,
after_statements: Vec<DenseBitSet<Local>>,
after_terminator: HashMap<BasicBlock, DenseBitSet<Local>>,
/// Locals dropped after the terminator regardless of the target, in
/// addition to the liveness-derived sets above. A set, since the same local
/// must not be dropped twice; ordered by index to keep drops deterministic.
after_terminator_extra: BTreeSet<Local>,
}

impl DropPoints {
Expand All @@ -25,13 +29,16 @@ impl DropPoints {
.iter()
.position(|s| s.contains(local))
.or_else(|| {
self.after_terminator
.values()
.any(|s| s.contains(local))
self.is_after_terminator(local)
.then_some(self.after_statements.len())
})
}

fn is_after_terminator(&self, local: Local) -> bool {
self.after_terminator.values().any(|s| s.contains(local))
|| self.after_terminator_extra.contains(&local)
}

pub fn remove_after_statement(&mut self, statement_index: usize, local: Local) -> bool {
self.after_statements[statement_index].remove(local)
}
Expand All @@ -44,10 +51,16 @@ impl DropPoints {
self.after_statements[statement_index].clone()
}

pub fn after_terminator(&self, target: &BasicBlock) -> DenseBitSet<Local> {
pub fn insert_after_terminator(&mut self, local: Local) {
self.after_terminator_extra.insert(local);
}

pub fn after_terminator(&self, target: &BasicBlock) -> Vec<Local> {
let mut t = self.after_terminator[target].clone();
t.union(self.after_statements.last().unwrap());
t
t.iter()
.chain(self.after_terminator_extra.iter().copied())
.collect()
}
}

Expand Down Expand Up @@ -197,6 +210,7 @@ impl<'mir, 'tcx> DropPointsBuilder<'mir, 'tcx> {
before_statements: Default::default(),
after_statements,
after_terminator,
after_terminator_extra: Default::default(),
}
}
}
30 changes: 29 additions & 1 deletion src/analyze/basic_block/visitor/rust_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ pub struct RustCallVisitor<'a, 'tcx, 'ctx> {

// TODO: consolidate logic with ReborrowVisitor
impl<'tcx> RustCallVisitor<'_, 'tcx, '_> {
fn insert_mut_borrow(&mut self, place: mir::Place<'tcx>, inner_ty: mir_ty::Ty<'tcx>) -> Local {
let r = mir_ty::Region::new_from_kind(self.tcx, mir_ty::RegionKind::ReErased);
let ty = mir_ty::Ty::new_mut_ref(self.tcx, r, inner_ty);
let decl = mir::LocalDecl::new(ty, Default::default()).immutable();
let new_local = self.analyzer.local_decls.push(decl);
let new_local_ty = self.analyzer.borrow_place_(place, inner_ty);
self.analyzer.bind_local(new_local, new_local_ty);
tracing::info!(old_place = ?place, ?new_local, "implicitly (mut-)borrowed");
new_local
}

fn insert_immut_borrow(
&mut self,
place: mir::Place<'tcx>,
Expand Down Expand Up @@ -105,7 +116,24 @@ impl<'a, 'tcx, 'ctx> mir::visit::MutVisitor<'tcx> for RustCallVisitor<'a, 'tcx,
if arg_closure_ty.ref_mutability().is_none()
) {
// case 3: {closure} -> &mut {closure}
unimplemented!();
let borrowed_closure_local =
self.insert_mut_borrow(arg_closure_place, arg_closure_ty);
args[0].node = mir::Operand::Move(borrowed_closure_local.into());
// FnOnce::call_once consumes the closure, but the resolved function
// only borrows it: drop the borrow and the environment after the
// call to resolve the prophecies of the captured mutable borrows.
self.analyzer.drop_after_terminator(borrowed_closure_local);
// The original MIR moves the closure into the call, so `moved_locals`
// dropped its drop obligation, expecting the callee to consume it; we
// must restore it. `moved_locals` only steals whole-local moves, so we
// only restore those: with a projection the obligation was never stolen
// and the normal drop machinery still handles it (re-adding it would
// double-drop). In practice a non-`Copy` closure (the only kind reaching
// this case) is always moved through a projection-less temporary.
if arg_closure_place.projection.is_empty() {
self.analyzer.drop_after_terminator(arg_closure_place.local);
}
tracing::debug!("applied mut-borrow for closure argument");
Comment thread
coord-e marked this conversation as resolved.
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,37 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self.super_operand(operand, location);
}

fn visit_terminator(
&mut self,
terminator: &mir::Terminator<'tcx>,
location: mir::Location,
) {
// calling an FnMut closure via FnOnce::call_once resolves to the closure
// function taking &mut {closure}, so RustCallVisitor mutably borrows the
// closure argument; see analyze::basic_block::visitor
if let mir::TerminatorKind::Call { func, args, .. } = &terminator.kind {
if let Some((def_id, generic_args)) = func.const_fn_def() {
let trait_did = self
.tcx
.opt_associated_item(def_id)
.and_then(|item| item.trait_container(self.tcx));
if trait_did.is_some() && trait_did == self.tcx.lang_items().fn_once_trait()
{
if let mir_ty::TyKind::Closure(_, closure_args) =
generic_args.type_at(0).kind()
{
if closure_args.as_closure().kind() == mir_ty::ClosureKind::FnMut {
if let Some(place) = args[0].node.place() {
self.locals.insert(place.local);
}
}
}
}
}
}
self.super_terminator(terminator, location);
}

fn visit_assign(
&mut self,
place: &mir::Place<'tcx>,
Expand Down
19 changes: 19 additions & 0 deletions tests/ui/fail/closure_param_weaken_3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

fn apply<F>(f: F) -> i32
where
F: FnOnce(i32) -> i32,
{
f(1)
}

fn main() {
let mut x = 1;
let closure = |y: i32| {
x += 1;
y + x
};
let result = apply(closure);
assert!(result == 3 && x == 1);
}
19 changes: 19 additions & 0 deletions tests/ui/pass/closure_param_weaken_3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

fn apply<F>(f: F) -> i32
where
F: FnOnce(i32) -> i32,
{
f(1)
}

fn main() {
let mut x = 1;
let closure = |y: i32| {
x += 1;
y + x
};
let result = apply(closure);
assert!(result == 3 && x == 2);
}