Skip to content

Commit 353f614

Browse files
committed
implement type-erased allocator for parallel_scheduler_backend
1 parent 15ab912 commit 353f614

File tree

7 files changed

+344
-259
lines changed

7 files changed

+344
-259
lines changed

include/exec/any_sender_of.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,15 @@ namespace exec {
342342
static_assert(sizeof(_Tp) <= __buffer_size && alignof(_Tp) <= __alignment);
343343
_Tp* __pointer = reinterpret_cast<_Tp*>(&__buffer_[0]);
344344
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
345-
_Alloc __alloc{__allocator_};
345+
_Alloc __alloc{__alloc_};
346346
std::allocator_traits<_Alloc>::construct(__alloc, __pointer, static_cast<_As&&>(__args)...);
347347
__object_pointer_ = __pointer;
348348
}
349349

350350
template <class _Tp, class... _As>
351351
void __construct_large(_As&&... __args) {
352352
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
353-
_Alloc __alloc{__allocator_};
353+
_Alloc __alloc{__alloc_};
354354
_Tp* __pointer = std::allocator_traits<_Alloc>::allocate(__alloc, 1);
355355
STDEXEC_TRY {
356356
std::allocator_traits<_Alloc>::construct(
@@ -369,7 +369,7 @@ namespace exec {
369369
return;
370370
}
371371
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
372-
_Alloc __alloc{__allocator_};
372+
_Alloc __alloc{__alloc_};
373373
_Tp* __pointer = static_cast<_Tp*>(std::exchange(__object_pointer_, nullptr));
374374
std::allocator_traits<_Alloc>::destroy(__alloc, __pointer);
375375
if constexpr (!__is_small<_Tp>) {
@@ -382,7 +382,7 @@ namespace exec {
382382
void* __object_pointer_{nullptr};
383383
alignas(__alignment) std::byte __buffer_[__buffer_size]{};
384384
STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS
385-
_Allocator __allocator_{};
385+
_Allocator __alloc_{};
386386
};
387387

388388
template <
@@ -510,15 +510,15 @@ namespace exec {
510510
static_assert(sizeof(_Tp) <= __buffer_size && alignof(_Tp) <= __alignment);
511511
_Tp* __pointer = reinterpret_cast<_Tp*>(&__buffer_[0]);
512512
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
513-
_Alloc __alloc{__allocator_};
513+
_Alloc __alloc{__alloc_};
514514
std::allocator_traits<_Alloc>::construct(__alloc, __pointer, static_cast<_As&&>(__args)...);
515515
__object_pointer_ = __pointer;
516516
}
517517

518518
template <class _Tp, class... _As>
519519
void __construct_large(_As&&... __args) {
520520
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
521-
_Alloc __alloc{__allocator_};
521+
_Alloc __alloc{__alloc_};
522522
_Tp* __pointer = std::allocator_traits<_Alloc>::allocate(__alloc, 1);
523523
STDEXEC_TRY {
524524
std::allocator_traits<_Alloc>::construct(
@@ -537,7 +537,7 @@ namespace exec {
537537
return;
538538
}
539539
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
540-
_Alloc __alloc{__allocator_};
540+
_Alloc __alloc{__alloc_};
541541
_Tp* __pointer = static_cast<_Tp*>(std::exchange(__object_pointer_, nullptr));
542542
std::allocator_traits<_Alloc>::destroy(__alloc, __pointer);
543543
if constexpr (!__is_small<_Tp>) {
@@ -555,7 +555,7 @@ namespace exec {
555555
_Tp& __other_object = *__pointer;
556556
this->template __construct_small<_Tp>(static_cast<_Tp&&>(__other_object));
557557
using _Alloc = std::allocator_traits<_Allocator>::template rebind_alloc<_Tp>;
558-
_Alloc __alloc{__allocator_};
558+
_Alloc __alloc{__alloc_};
559559
std::allocator_traits<_Alloc>::destroy(__alloc, __pointer);
560560
} else {
561561
__object_pointer_ = __pointer;
@@ -582,7 +582,7 @@ namespace exec {
582582
const __vtable_t* __vtable_{__default_storage_vtable(static_cast<__vtable_t*>(nullptr))};
583583
void* __object_pointer_{nullptr};
584584
alignas(__alignment) std::byte __buffer_[__buffer_size]{};
585-
STDEXEC_ATTRIBUTE(no_unique_address) _Allocator __allocator_ { };
585+
STDEXEC_ATTRIBUTE(no_unique_address) _Allocator __alloc_ { };
586586
};
587587

588588
struct __empty_vtable {

include/stdexec/__detail/__any.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,10 @@ namespace STDEXEC::__any {
12101210
// __bad_any_cast
12111211
struct __bad_any_cast : std::exception {
12121212
[[nodiscard]]
1213-
constexpr char const *what() const noexcept override {
1213+
#if __cpp_lib_constexpr_exceptions >= 2025'02L // constexpr support for std::exception
1214+
constexpr
1215+
#endif
1216+
char const *what() const noexcept override {
12141217
return "__bad_any_cast";
12151218
}
12161219
};
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) 2026 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "__execution_fwd.hpp"
20+
21+
#include "__any.hpp"
22+
#include "__concepts.hpp"
23+
#include "__memory.hpp"
24+
#include "__typeinfo.hpp"
25+
26+
namespace STDEXEC {
27+
namespace __detail {
28+
template <class _Base>
29+
struct __byte_allocator;
30+
31+
template <class _Base>
32+
using __byte_allocator_interface_t = __any::interface<
33+
__byte_allocator,
34+
_Base,
35+
__any::__extends<__any::__icopyable, __any::__iequality_comparable>
36+
>;
37+
38+
// NOLINTBEGIN(modernize-use-override)
39+
template <class _Base>
40+
struct __byte_allocator : __byte_allocator_interface_t<_Base> {
41+
using __byte_allocator_interface_t<_Base>::interface::interface;
42+
43+
[[nodiscard]]
44+
constexpr virtual auto allocate(size_t __n) -> std::byte* {
45+
return __any::__value(*this).allocate(__n);
46+
}
47+
48+
constexpr virtual void deallocate(std::byte* __byte, size_t __n) noexcept {
49+
__any::__value(*this).deallocate(__byte, __n);
50+
}
51+
};
52+
// NOLINTEND(modernize-use-override)
53+
} // namespace __detail
54+
55+
template <class _Ty>
56+
struct __any_allocator {
57+
using value_type = _Ty;
58+
59+
__any_allocator() = default;
60+
61+
template <__not_same_as<__any_allocator> _Alloc>
62+
requires __is_not_instance_of<_Alloc, __any_allocator> && __simple_allocator<_Alloc>
63+
__any_allocator(_Alloc __alloc) noexcept {
64+
using __value_t = std::allocator_traits<_Alloc>::value_type;
65+
static_assert(
66+
__same_as<_Ty, __value_t>,
67+
"__any_allocator<T> must be constructed with an allocator of the same value type");
68+
__alloc_.emplace(STDEXEC::__rebind_allocator<std::byte>(__alloc));
69+
}
70+
71+
template <__not_same_as<_Ty> _Uy>
72+
__any_allocator(__any_allocator<_Uy> __other) noexcept
73+
: __alloc_(std::move(__other.__alloc_)) {
74+
}
75+
76+
[[nodiscard]]
77+
constexpr bool has_value() const noexcept {
78+
return !__any::__empty(__alloc_);
79+
}
80+
81+
[[nodiscard]]
82+
constexpr auto type() const noexcept -> __type_index const & {
83+
return __any::__type(__alloc_);
84+
}
85+
86+
[[nodiscard]]
87+
constexpr auto allocate(size_t __n) -> _Ty* {
88+
void* __void_ptr = __alloc_.allocate(__n * sizeof(_Ty));
89+
return static_cast<_Ty*>(__void_ptr);
90+
}
91+
92+
constexpr virtual void deallocate(_Ty* __ptr, size_t __n) noexcept {
93+
void* __void_ptr = static_cast<void*>(__ptr);
94+
__alloc_.deallocate(static_cast<std::byte*>(__void_ptr), __n * sizeof(_Ty));
95+
}
96+
97+
private:
98+
__any::__any<__detail::__byte_allocator> __alloc_{};
99+
};
100+
101+
template <class _Alloc>
102+
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE
103+
__any_allocator(_Alloc) -> __any_allocator<typename _Alloc::value_type>;
104+
105+
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE
106+
__any_allocator(std::allocator<void>) -> __any_allocator<std::byte>;
107+
} // namespace STDEXEC

include/stdexec/__detail/__concepts.hpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -294,22 +294,13 @@ namespace STDEXEC {
294294
template <class _Ty, class _Up>
295295
concept __decays_to_derived_from = __std::derived_from<__decay_t<_Ty>, _Up>;
296296

297-
namespace __detail {
298-
template <class _Alloc>
299-
constexpr auto __test_alloc_pointer(int) -> _Alloc::pointer;
300-
template <class _Alloc>
301-
constexpr auto __test_alloc_pointer(long) -> _Alloc::value_type*;
302-
303-
template <class _Alloc>
304-
using __alloc_pointer_t = decltype(__detail::__test_alloc_pointer<__decay_t<_Alloc>>(0));
305-
} // namespace __detail
306-
297+
// See [allocator.requirements.general]/p99 (https://eel.is/c++draft/allocator.requirements.general#99)
307298
template <class _Alloc>
308-
concept __allocator_ = //
309-
requires(__decay_t<_Alloc>& __alloc, std::size_t __bytes) {
310-
{ __alloc.allocate(__bytes) } -> __std::same_as<__detail::__alloc_pointer_t<_Alloc>>;
311-
__alloc.deallocate(__alloc.allocate(__bytes), __bytes);
299+
concept __simple_allocator = //
300+
requires(_Alloc __alloc, std::size_t __count) {
301+
{ *__alloc.allocate(__count) } -> __std::same_as<typename _Alloc::value_type&>;
302+
__alloc.deallocate(__alloc.allocate(__count), __count);
312303
} //
313-
&& __std::copy_constructible<__decay_t<_Alloc>> //
314-
&& __std::equality_comparable<__decay_t<_Alloc>>;
304+
&& __std::copy_constructible<_Alloc> //
305+
&& __std::equality_comparable<_Alloc>;
315306
} // namespace STDEXEC

include/stdexec/__detail/__parallel_scheduler_backend.hpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
// include these after __execution_fwd.hpp
2323
// #include "any_allocator.cuh"
2424
#include "../functional.hpp" // IWYU pragma: keep for __with_default
25-
#include "../stop_token.hpp"
25+
#include "../stop_token.hpp" // IWYU pragma: keep for get_stop_token_t
26+
#include "__any_allocator.hpp"
2627
#include "__queries.hpp"
2728
#include "__typeinfo.hpp"
2829

@@ -34,34 +35,14 @@ STDEXEC_PRAGMA_PUSH()
3435
STDEXEC_PRAGMA_IGNORE_MSVC(4702) // warning C4702: unreachable code
3536

3637
namespace STDEXEC {
37-
template <class _Ty>
38-
class any_allocator : public std::allocator<_Ty> {
39-
public:
40-
template <class _OtherTy>
41-
struct rebind {
42-
using other = any_allocator<_OtherTy>;
43-
};
44-
45-
template <__not_same_as<any_allocator> _Alloc>
46-
any_allocator(const _Alloc&) noexcept {
47-
}
48-
};
49-
50-
template <class _Alloc>
51-
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE
52-
any_allocator(_Alloc) -> any_allocator<typename _Alloc::value_type>;
53-
54-
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE
55-
any_allocator(std::allocator<void>) -> any_allocator<std::byte>;
56-
5738
class task_scheduler;
5839

5940
// namespace __detail {
6041
// struct __env_proxy : __immovable {
6142
// [[nodiscard]]
6243
// virtual auto query(const get_stop_token_t&) const noexcept -> inplace_stop_token = 0;
6344
// [[nodiscard]]
64-
// virtual auto query(const get_allocator_t&) const noexcept -> any_allocator<std::byte> = 0;
45+
// virtual auto query(const get_allocator_t&) const noexcept -> __any_allocator<std::byte> = 0;
6546
// [[nodiscard]]
6647
// virtual auto query(const get_scheduler_t&) const noexcept -> task_scheduler = 0;
6748
// };
@@ -168,10 +149,12 @@ namespace STDEXEC {
168149
}
169150

170151
void __query(get_allocator_t, __type_index __value_type, void* __dest) const noexcept {
171-
if (__value_type == __mtypeid<any_allocator<std::byte>>) {
172-
using __dest_t = std::optional<any_allocator<std::byte>>;
173-
*static_cast<__dest_t*>(__dest) = any_allocator{
174-
__with_default(get_allocator, std::allocator<std::byte>())(STDEXEC::get_env(__rcvr_))};
152+
if (__value_type == __mtypeid<__any_allocator<std::byte>>) {
153+
using __dest_t = std::optional<__any_allocator<std::byte>>;
154+
constexpr auto __get_alloc = __with_default(get_allocator, std::allocator<std::byte>());
155+
auto __alloc = STDEXEC::__rebind_allocator<std::byte>(
156+
__get_alloc(STDEXEC::get_env(__rcvr_)));
157+
*static_cast<__dest_t*>(__dest) = __any_allocator{std::move(__alloc)};
175158
}
176159
}
177160

@@ -212,9 +195,10 @@ namespace STDEXEC {
212195
}
213196
};
214197

215-
// A receiver type that forwards its completion operations to a _RcvrProxy member held by
216-
// reference (where _RcvrProxy is one of receiver_proxy or bulk_item_receiver_proxy). It
217-
// is also responsible to destroying and, if necessary, deallocating the operation state.
198+
// A receiver type that forwards its completion operations to a _RcvrProxy member held
199+
// by reference (where _RcvrProxy is one of receiver_proxy or
200+
// bulk_item_receiver_proxy). It is also responsible for destroying and, if necessary,
201+
// deallocating the operation state.
218202
template <class _RcvrProxy>
219203
struct __proxy_receiver {
220204
using receiver_concept = receiver_t;

include/stdexec/__detail/__queries.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ struct get_allocator_t : STDEXEC::__query<get_allocator_t> {
129129
STDEXEC_ATTRIBUTE(always_inline, host, device)
130130
static constexpr void __validate() noexcept {
131131
static_assert(STDEXEC::__nothrow_callable<get_allocator_t, const _Env&>);
132-
static_assert(STDEXEC::__allocator_<STDEXEC::__call_result_t<get_allocator_t, const _Env&>>);
132+
using __alloc_t = STDEXEC::__call_result_t<get_allocator_t, const _Env&>;
133+
static_assert(STDEXEC::__simple_allocator<STDEXEC::__decay_t<__alloc_t>>);
133134
}
134135

135136
STDEXEC_ATTRIBUTE(nodiscard, always_inline, host, device)

0 commit comments

Comments
 (0)