diff --git a/src/Std/Sync.lean b/src/Std/Sync.lean index c79ce06a5c8e..5c42254e64f2 100644 --- a/src/Std/Sync.lean +++ b/src/Std/Sync.lean @@ -17,5 +17,6 @@ public import Std.Sync.Broadcast public import Std.Sync.StreamMap public import Std.Sync.CancellationToken public import Std.Sync.CancellationContext +public import Std.Sync.Future @[expose] public section diff --git a/src/Std/Sync/Future.lean b/src/Std/Sync/Future.lean new file mode 100644 index 000000000000..6af9f11b83ac --- /dev/null +++ b/src/Std/Sync/Future.lean @@ -0,0 +1,158 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Std.Sync.Mutex +public import Std.Internal.Async.IO + +public section + +/-! +This module contains the implementation of `Std.Future` that is a write-once container for a value of type `α`. +Once resolved with a value, it cannot be changed or resolved again. It's similar to an `IO.Promise` but it exists +in order to make `Seletor` work correctly. +-/ + +namespace Std + +open Internal.IO.Async + +private inductive Consumer (α : Type) where + | normal (promise : IO.Promise α) + | select (finished : Waiter α) + +private def Consumer.resolve (c : Consumer α) (x : α) : BaseIO Bool := do + match c with + | .normal promise => + promise.resolve x + return true + | .select waiter => + let lose := return false + let win promise := do + promise.resolve (.ok x) + return true + waiter.race lose win + +/-- +A `Future` is a write-once container for a value of type `α`. Once resolved with a value, it cannot be +changed or resolved again. +-/ +structure Future (α : Type) where + private mk :: + private state : Mutex (Option α) + private consumers : Mutex (Array (Consumer α)) + private nonEmpty : Nonempty α + +namespace Future + +/-- +Create a new unresolved `Future`. +-/ +def new [h : Nonempty α] : BaseIO (Future α) := do + return { + state := ← Mutex.new none + consumers := ← Mutex.new #[] + nonEmpty := h + } + +/-- +Attempt to resolve the future with the given value. Returns `true` if the future was successfully resolved +(was not already resolved). Returns `false` if the future was already resolved. When resolved, all +waiting consumers will be notified. +-/ +def resolve (p : Future α) (value : α) : BaseIO Bool := do + let consumersToNotify ← p.state.atomically do + let current ← get + match current with + | some _ => + return none + | none => + set (some value) + let cs ← p.consumers.atomically do + let cs ← get + MonadState.set #[] + return some cs + return cs + + match consumersToNotify with + | none => + return false + + | some consumers => + if consumers.isEmpty then + return true + + for consumer in consumers do + discard <| consumer.resolve value + + return true + +/-- +Check if the future has been resolved. +-/ +def isResolved (p : Future α) : BaseIO Bool := do + p.state.atomically do + return (← get).isSome + +/-- +Get the value if the future is resolved, otherwise return `none`. +-/ +def tryGet (p : Future α) : BaseIO (Option α) := do + p.state.atomically do + return (← get) + +/-- +Wait for the future to be resolved and return its value. Returns a task that will complete once the +future is resolved. +-/ +def get [Inhabited α] (p : Future α) : BaseIO (Task α) := do + p.state.atomically do + match ← MonadState.get with + | some value => + return .pure value + | none => + let promise ← IO.Promise.new + p.consumers.atomically do + modify (·.push (.normal promise)) + + BaseIO.bindTask promise.result? fun res => + match res with + | some res => pure (Task.pure res) + | none => unreachable! + +/-- +Creates a `Selector` that resolves once the future is resolved. +-/ +def selector (p : Future α) : Selector α where + tryFn := p.tryGet + + registerFn waiter := do + p.state.atomically do + match ← MonadState.get with + | some value => + let lose := return () + let win promise := promise.resolve (.ok value) + waiter.race lose win + | none => + p.consumers.atomically do + modify (·.push (.select waiter)) + + unregisterFn := do + p.consumers.atomically do + let cs ← MonadState.get + let filtered ← cs.filterM fun + | .normal .. => return true + | .select waiter => return !(← waiter.checkFinished) + set filtered + +def ofPromise (promise : IO.Promise α) : BaseIO (Std.Future (Option α)) := do + let stdFuture ← Std.Future.new + BaseIO.chainTask promise.result? (fun x => discard <| stdFuture.resolve x) + return stdFuture + +end Future +end Std diff --git a/tests/lean/run/sync_future.lean b/tests/lean/run/sync_future.lean new file mode 100644 index 000000000000..c7186b5a08f3 --- /dev/null +++ b/tests/lean/run/sync_future.lean @@ -0,0 +1,76 @@ +import Std.Sync + +open Std + +def assertBEq [BEq α] [ToString α] (is should : α) : IO Unit := do + if is != should then + throw <| .userError s!"{is} should be {should}" + +def resolveOnce (f : Future Nat) : IO Unit := do + assertBEq (← f.isResolved) false + assertBEq (← f.tryGet) none + assertBEq (← f.resolve 42) true + assertBEq (← f.isResolved) true + assertBEq (← f.tryGet) (some 42) + assertBEq (← f.resolve 43) false + assertBEq (← f.tryGet) (some 42) + +def getAfterResolve (f : Future Nat) : IO Unit := do + assertBEq (← f.resolve 37) true + let task ← f.get + assertBEq (← IO.wait task) 37 + +def getBeforeResolve (f : Future Nat) : IO Unit := do + let task ← f.get + assertBEq (← f.resolve 37) true + assertBEq (← IO.wait task) 37 + +def multipleGets (f : Future Nat) : IO Unit := do + let task1 ← f.get + let task2 ← f.get + let task3 ← f.get + assertBEq (← f.resolve 99) true + assertBEq (← IO.wait task1) 99 + assertBEq (← IO.wait task2) 99 + assertBEq (← IO.wait task3) 99 + +def concurrentResolve (f : Future Nat) : IO Unit := do + let resolveTask1 ← IO.asTask (f.resolve 10) + let resolveTask2 ← IO.asTask (f.resolve 20) + let resolveTask3 ← IO.asTask (f.resolve 30) + + let result1 ← IO.ofExcept =<< IO.wait resolveTask1 + let result2 ← IO.ofExcept =<< IO.wait resolveTask2 + let result3 ← IO.ofExcept =<< IO.wait resolveTask3 + + let successCount := [result1, result2, result3].filter id |>.length + assertBEq successCount 1 + + let value ← f.tryGet + assertBEq (value.isSome) true + assertBEq ([10, 20, 30].contains value.get!) true + +def concurrentGetResolve (f : Future Nat) : IO Unit := do + let getTask1 ← f.get + let getTask2 ← f.get + let resolveTask ← f.resolve 55 + let getTask3 ← f.get + + let value1 ← IO.wait getTask1 + let value2 ← IO.wait getTask2 + let value3 ← IO.wait getTask3 + + assertBEq resolveTask true + assertBEq value1 55 + assertBEq value2 55 + assertBEq value3 55 + +def suite : IO Unit := do + resolveOnce (← Future.new) + getAfterResolve (← Future.new) + getBeforeResolve (← Future.new) + multipleGets (← Future.new) + concurrentResolve (← Future.new) + concurrentGetResolve (← Future.new) + +#eval suite