1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <atomic>
12  
#include <atomic>
13  
#include <coroutine>
13  
#include <coroutine>
14  
#include <mutex>
14  
#include <mutex>
15  
#include <thread>
15  
#include <thread>
16  
#include <utility>
16  
#include <utility>
17  

17  

18  
namespace boost {
18  
namespace boost {
19  
namespace capy {
19  
namespace capy {
20  
namespace detail {
20  
namespace detail {
21  

21  

22  
//----------------------------------------------------------
22  
//----------------------------------------------------------
23  

23  

24  
/** Implementation state for a strand.
24  
/** Implementation state for a strand.
25  

25  

26  
    Each strand_impl provides serialization for coroutines
26  
    Each strand_impl provides serialization for coroutines
27  
    dispatched through strands that share it.
27  
    dispatched through strands that share it.
28  
*/
28  
*/
29  
struct strand_impl
29  
struct strand_impl
30  
{
30  
{
31  
    std::mutex mutex_;
31  
    std::mutex mutex_;
32  
    strand_queue pending_;
32  
    strand_queue pending_;
33  
    bool locked_ = false;
33  
    bool locked_ = false;
34  
    std::atomic<std::thread::id> dispatch_thread_{};
34  
    std::atomic<std::thread::id> dispatch_thread_{};
35  
    void* cached_frame_ = nullptr;
35  
    void* cached_frame_ = nullptr;
36  
};
36  
};
37  

37  

38  
//----------------------------------------------------------
38  
//----------------------------------------------------------
39  

39  

40  
/** Invoker coroutine for strand dispatch.
40  
/** Invoker coroutine for strand dispatch.
41  

41  

42  
    Uses custom allocator to recycle frame - one allocation
42  
    Uses custom allocator to recycle frame - one allocation
43  
    per strand_impl lifetime, stored in trailer for recovery.
43  
    per strand_impl lifetime, stored in trailer for recovery.
44  
*/
44  
*/
45  
struct strand_invoker
45  
struct strand_invoker
46  
{
46  
{
47  
    struct promise_type
47  
    struct promise_type
48  
    {
48  
    {
49  
        void* operator new(std::size_t n, strand_impl& impl)
49  
        void* operator new(std::size_t n, strand_impl& impl)
50  
        {
50  
        {
51  
            constexpr auto A = alignof(strand_impl*);
51  
            constexpr auto A = alignof(strand_impl*);
52  
            std::size_t padded = (n + A - 1) & ~(A - 1);
52  
            std::size_t padded = (n + A - 1) & ~(A - 1);
53  
            std::size_t total = padded + sizeof(strand_impl*);
53  
            std::size_t total = padded + sizeof(strand_impl*);
54  

54  

55  
            void* p = impl.cached_frame_
55  
            void* p = impl.cached_frame_
56  
                ? std::exchange(impl.cached_frame_, nullptr)
56  
                ? std::exchange(impl.cached_frame_, nullptr)
57  
                : ::operator new(total);
57  
                : ::operator new(total);
58  

58  

59  
            // Trailer lets delete recover impl
59  
            // Trailer lets delete recover impl
60  
            *reinterpret_cast<strand_impl**>(
60  
            *reinterpret_cast<strand_impl**>(
61  
                static_cast<char*>(p) + padded) = &impl;
61  
                static_cast<char*>(p) + padded) = &impl;
62  
            return p;
62  
            return p;
63  
        }
63  
        }
64  

64  

65  
        void operator delete(void* p, std::size_t n) noexcept
65  
        void operator delete(void* p, std::size_t n) noexcept
66  
        {
66  
        {
67  
            constexpr auto A = alignof(strand_impl*);
67  
            constexpr auto A = alignof(strand_impl*);
68  
            std::size_t padded = (n + A - 1) & ~(A - 1);
68  
            std::size_t padded = (n + A - 1) & ~(A - 1);
69  

69  

70  
            auto* impl = *reinterpret_cast<strand_impl**>(
70  
            auto* impl = *reinterpret_cast<strand_impl**>(
71  
                static_cast<char*>(p) + padded);
71  
                static_cast<char*>(p) + padded);
72  

72  

73  
            if (!impl->cached_frame_)
73  
            if (!impl->cached_frame_)
74  
                impl->cached_frame_ = p;
74  
                impl->cached_frame_ = p;
75  
            else
75  
            else
76  
                ::operator delete(p);
76  
                ::operator delete(p);
77  
        }
77  
        }
78  

78  

79  
        strand_invoker get_return_object() noexcept
79  
        strand_invoker get_return_object() noexcept
80  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
80  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
81  

81  

82  
        std::suspend_always initial_suspend() noexcept { return {}; }
82  
        std::suspend_always initial_suspend() noexcept { return {}; }
83  
        std::suspend_never final_suspend() noexcept { return {}; }
83  
        std::suspend_never final_suspend() noexcept { return {}; }
84  
        void return_void() noexcept {}
84  
        void return_void() noexcept {}
85  
        void unhandled_exception() { std::terminate(); }
85  
        void unhandled_exception() { std::terminate(); }
86  
    };
86  
    };
87  

87  

88  
    std::coroutine_handle<promise_type> h_;
88  
    std::coroutine_handle<promise_type> h_;
89  
};
89  
};
90  

90  

91  
//----------------------------------------------------------
91  
//----------------------------------------------------------
92  

92  

93  
/** Concrete implementation of strand_service.
93  
/** Concrete implementation of strand_service.
94  

94  

95  
    Holds the fixed pool of strand_impl objects.
95  
    Holds the fixed pool of strand_impl objects.
96  
*/
96  
*/
97  
class strand_service_impl : public strand_service
97  
class strand_service_impl : public strand_service
98  
{
98  
{
99  
    static constexpr std::size_t num_impls = 211;
99  
    static constexpr std::size_t num_impls = 211;
100  

100  

101  
    strand_impl impls_[num_impls];
101  
    strand_impl impls_[num_impls];
102  
    std::size_t salt_ = 0;
102  
    std::size_t salt_ = 0;
103  
    std::mutex mutex_;
103  
    std::mutex mutex_;
104  

104  

105  
public:
105  
public:
106  
    explicit
106  
    explicit
107  
    strand_service_impl(execution_context&)
107  
    strand_service_impl(execution_context&)
108  
    {
108  
    {
109  
    }
109  
    }
110  

110  

111  
    strand_impl*
111  
    strand_impl*
112  
    get_implementation() override
112  
    get_implementation() override
113  
    {
113  
    {
114  
        std::lock_guard<std::mutex> lock(mutex_);
114  
        std::lock_guard<std::mutex> lock(mutex_);
115  
        std::size_t index = salt_++;
115  
        std::size_t index = salt_++;
116  
        index = index % num_impls;
116  
        index = index % num_impls;
117  
        return &impls_[index];
117  
        return &impls_[index];
118  
    }
118  
    }
119  

119  

120  
protected:
120  
protected:
121  
    void
121  
    void
122  
    shutdown() override
122  
    shutdown() override
123  
    {
123  
    {
124  
        for(std::size_t i = 0; i < num_impls; ++i)
124  
        for(std::size_t i = 0; i < num_impls; ++i)
125  
        {
125  
        {
126  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
126  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
127  
            impls_[i].locked_ = true;
127  
            impls_[i].locked_ = true;
128  

128  

129  
            if(impls_[i].cached_frame_)
129  
            if(impls_[i].cached_frame_)
130  
            {
130  
            {
131  
                ::operator delete(impls_[i].cached_frame_);
131  
                ::operator delete(impls_[i].cached_frame_);
132  
                impls_[i].cached_frame_ = nullptr;
132  
                impls_[i].cached_frame_ = nullptr;
133  
            }
133  
            }
134  
        }
134  
        }
135  
    }
135  
    }
136  

136  

137  
private:
137  
private:
138  
    static bool
138  
    static bool
139  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
139  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
140  
    {
140  
    {
141  
        std::lock_guard<std::mutex> lock(impl.mutex_);
141  
        std::lock_guard<std::mutex> lock(impl.mutex_);
142  
        impl.pending_.push(h);
142  
        impl.pending_.push(h);
143  
        if(!impl.locked_)
143  
        if(!impl.locked_)
144  
        {
144  
        {
145  
            impl.locked_ = true;
145  
            impl.locked_ = true;
146  
            return true;
146  
            return true;
147  
        }
147  
        }
148  
        return false;
148  
        return false;
149  
    }
149  
    }
150  

150  

151  
    static void
151  
    static void
152  
    dispatch_pending(strand_impl& impl)
152  
    dispatch_pending(strand_impl& impl)
153  
    {
153  
    {
154  
        strand_queue::taken_batch batch;
154  
        strand_queue::taken_batch batch;
155  
        {
155  
        {
156  
            std::lock_guard<std::mutex> lock(impl.mutex_);
156  
            std::lock_guard<std::mutex> lock(impl.mutex_);
157  
            batch = impl.pending_.take_all();
157  
            batch = impl.pending_.take_all();
158  
        }
158  
        }
159  
        impl.pending_.dispatch_batch(batch);
159  
        impl.pending_.dispatch_batch(batch);
160  
    }
160  
    }
161  

161  

162  
    static bool
162  
    static bool
163  
    try_unlock(strand_impl& impl)
163  
    try_unlock(strand_impl& impl)
164  
    {
164  
    {
165  
        std::lock_guard<std::mutex> lock(impl.mutex_);
165  
        std::lock_guard<std::mutex> lock(impl.mutex_);
166  
        if(impl.pending_.empty())
166  
        if(impl.pending_.empty())
167  
        {
167  
        {
168  
            impl.locked_ = false;
168  
            impl.locked_ = false;
169  
            return true;
169  
            return true;
170  
        }
170  
        }
171  
        return false;
171  
        return false;
172  
    }
172  
    }
173  

173  

174  
    static void
174  
    static void
175  
    set_dispatch_thread(strand_impl& impl) noexcept
175  
    set_dispatch_thread(strand_impl& impl) noexcept
176  
    {
176  
    {
177  
        impl.dispatch_thread_.store(std::this_thread::get_id());
177  
        impl.dispatch_thread_.store(std::this_thread::get_id());
178  
    }
178  
    }
179  

179  

180  
    static void
180  
    static void
181  
    clear_dispatch_thread(strand_impl& impl) noexcept
181  
    clear_dispatch_thread(strand_impl& impl) noexcept
182  
    {
182  
    {
183  
        impl.dispatch_thread_.store(std::thread::id{});
183  
        impl.dispatch_thread_.store(std::thread::id{});
184  
    }
184  
    }
185  

185  

186  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
186  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
187  
    // (repost after each batch to let other work run) - explore if starvation observed.
187  
    // (repost after each batch to let other work run) - explore if starvation observed.
188  
    static strand_invoker
188  
    static strand_invoker
189  
    make_invoker(strand_impl& impl)
189  
    make_invoker(strand_impl& impl)
190  
    {
190  
    {
191  
        strand_impl* p = &impl;
191  
        strand_impl* p = &impl;
192  
        for(;;)
192  
        for(;;)
193  
        {
193  
        {
194  
            set_dispatch_thread(*p);
194  
            set_dispatch_thread(*p);
195  
            dispatch_pending(*p);
195  
            dispatch_pending(*p);
196  
            if(try_unlock(*p))
196  
            if(try_unlock(*p))
197  
            {
197  
            {
198  
                clear_dispatch_thread(*p);
198  
                clear_dispatch_thread(*p);
199  
                co_return;
199  
                co_return;
200  
            }
200  
            }
201  
        }
201  
        }
202  
    }
202  
    }
203  

203  

204  
    friend class strand_service;
204  
    friend class strand_service;
205  
};
205  
};
206  

206  

207  
//----------------------------------------------------------
207  
//----------------------------------------------------------
208  

208  

209  
strand_service::
209  
strand_service::
210  
strand_service()
210  
strand_service()
211  
    : service()
211  
    : service()
212  
{
212  
{
213  
}
213  
}
214  

214  

215  
strand_service::
215  
strand_service::
216  
~strand_service() = default;
216  
~strand_service() = default;
217  

217  

218  
bool
218  
bool
219  
strand_service::
219  
strand_service::
220  
running_in_this_thread(strand_impl& impl) noexcept
220  
running_in_this_thread(strand_impl& impl) noexcept
221  
{
221  
{
222  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
222  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
223  
}
223  
}
224  

224  

225  
std::coroutine_handle<>
225  
std::coroutine_handle<>
226  
strand_service::
226  
strand_service::
227  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
227  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
228  
{
228  
{
229  
    if(running_in_this_thread(impl))
229  
    if(running_in_this_thread(impl))
230  
        return h;
230  
        return h;
231  

231  

232  
    if(strand_service_impl::enqueue(impl, h))
232  
    if(strand_service_impl::enqueue(impl, h))
233  
        ex.post(strand_service_impl::make_invoker(impl).h_);
233  
        ex.post(strand_service_impl::make_invoker(impl).h_);
234  
    return std::noop_coroutine();
234  
    return std::noop_coroutine();
235  
}
235  
}
236  

236  

237  
void
237  
void
238  
strand_service::
238  
strand_service::
239  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
239  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
240  
{
240  
{
241  
    if(strand_service_impl::enqueue(impl, h))
241  
    if(strand_service_impl::enqueue(impl, h))
242  
        ex.post(strand_service_impl::make_invoker(impl).h_);
242  
        ex.post(strand_service_impl::make_invoker(impl).h_);
243  
}
243  
}
244  

244  

245  
strand_service&
245  
strand_service&
246  
get_strand_service(execution_context& ctx)
246  
get_strand_service(execution_context& ctx)
247  
{
247  
{
248  
    return ctx.use_service<strand_service_impl>();
248  
    return ctx.use_service<strand_service_impl>();
249  
}
249  
}
250  

250  

251  
} // namespace detail
251  
} // namespace detail
252  
} // namespace capy
252  
} // namespace capy
253  
} // namespace boost
253  
} // namespace boost