Variadic Coroutines in C++ and D
In my last post I implemented a tiny coroutine class that is fully functional, but it doesn’t support arguments and return values. This time I provide an implementation which adds those features as well as exception safety.
I also used this as an example to compare D and C++; as an excuse to learn C++’s variadic template syntax and as a reason to see how you can improve on boost’s implementation by using variadic templates.
Let’s start off with the punchline. Here is my coroutine implementation in C++:
#pragma once #include <ucontext.h> #include <cstdint> #include <exception> #include <functional> #include <memory> #include <tuple> #ifndef CORO_DEFAULT_STACK_SIZE #define CORO_DEFAULT_STACK_SIZE SIGSTKSZ #endif namespace coro { namespace detail { /** * the coroutine_context holds the stack context and the current state * of the coroutine. it also starts the coroutine */ struct coroutine_context { coroutine_context(size_t stack_size, void (*coroutine_call)()) : stack{new unsigned char[stack_size]} { // create a context for the callee getcontext(&callee); callee.uc_link = &caller; // tell it to return to me when done callee.uc_stack.ss_size = stack_size; callee.uc_stack.ss_sp = stack.get(); // pass the this pointer as two arguments, because officially // makecontext will pass the arguments as ints (even though // inofficially it would work with a single argument) makecontext(&callee, coroutine_call, 2, reinterpret_cast<size_t>(this) >> 32, this); } void operator()() { if (returned) throw "This coroutine has already finished"; swapcontext(&caller, &callee); // continue here if yielded or returned if (exception) { std::rethrow_exception(std::move(exception)); } } protected: ucontext_t caller; ucontext_t callee; std::unique_ptr<unsigned char[]> stack; std::exception_ptr exception; bool returned = false; }; // a type that can store both objects and references template<typename T> struct any_storage { any_storage() = default; any_storage(T to_store) : stored{std::move(to_store)} { } any_storage & operator=(T to_store) { stored = std::move(to_store); return *this; } operator T &&() { return std::move(stored); } private: T stored = T{}; }; // specialization for void template<> struct any_storage<void> { }; // specialization for lvalue references template<typename T> struct any_storage<T &> { any_storage() = default; any_storage(T & to_store) : stored{&to_store} { } any_storage & operator=(T & to_store) { stored = &to_store; return *this; } operator T &() { return *stored; } private: T * stored = nullptr; }; // specialization for rvalue references template<typename T> struct any_storage<T &&> { any_storage() = default; any_storage(T && to_store) : stored{&to_store} { } any_storage & operator=(T && to_store) { stored = &to_store; return *this; } operator T &&() { return *stored; } private: T * stored = nullptr; }; // implements the shared code among all specializations of coroutine_yielder template<typename Result, typename... Arguments> struct coroutine_yielder_base : protected coroutine_context { coroutine_yielder_base(size_t stack_size, void (*coroutine_call)()) : coroutine_context{stack_size, coroutine_call} { } std::tuple<Arguments...> yield() { swapcontext(&this->callee, &this->caller); return std::move(this->arguments); } protected: any_storage<Result> result; std::tuple<any_storage<Arguments>...> arguments; }; /** * The coroutine_yielder is responsible for providing a yield * function */ template<typename Result, typename... Arguments> struct coroutine_yielder : protected coroutine_yielder_base<Result, Arguments...> { coroutine_yielder(size_t stack_size, void (*coroutine_call)()) : coroutine_yielder_base<Result, Arguments...>{stack_size, coroutine_call} { } std::tuple<Arguments...> yield(Result && result) { this->result = std::forward<Result>(result); return coroutine_yielder_base<Result, Arguments...>::yield(); } // copying yield std::tuple<Arguments...> yield(const Result & result) { this->result = result; return coroutine_yielder_base<Result, Arguments...>::yield(); } }; // specialization for void template<typename... Arguments> struct coroutine_yielder<void, Arguments...> : coroutine_yielder_base<void, Arguments...> { coroutine_yielder(size_t stack_size, void (*coroutine_call)()) : coroutine_yielder_base<void, Arguments...>{stack_size, coroutine_call} { } }; // specialization for lvalue reference template<typename Result, typename... Arguments> struct coroutine_yielder<Result &, Arguments...> : protected coroutine_yielder_base<Result &, Arguments...> { coroutine_yielder(size_t stack_size, void (*coroutine_call)()) : coroutine_yielder_base<Result &, Arguments...>{stack_size, coroutine_call} { } std::tuple<Arguments...> yield(Result & result) { this->result = result; return coroutine_yielder_base<Result &, Arguments...>::yield(); } }; // specialization for rvalue reference template<typename Result, typename... Arguments> struct coroutine_yielder<Result &&, Arguments...> : protected coroutine_yielder_base<Result &&, Arguments...> { coroutine_yielder(size_t stack_size, void (*coroutine_call)()) : coroutine_yielder_base<Result &&, Arguments...>{stack_size, coroutine_call} { } std::tuple<Arguments...> yield(Result && result) { this->result = std::move(result); return coroutine_yielder_base<Result &&, Arguments...>::yield(); } }; } template<typename Result, typename... Arguments> struct coroutine; template<typename Result, typename... Arguments> struct coroutine<Result (Arguments...)> : detail::coroutine_yielder<Result, Arguments...> { private: template<int N, int... S> struct starter; typedef starter<sizeof...(Arguments)> Starter; public: typedef typename detail::coroutine_yielder<Result, Arguments...> self; coroutine(std::function<Result (self &, Arguments...)> func, size_t stack_size = CORO_DEFAULT_STACK_SIZE) : detail::coroutine_yielder<Result, Arguments...>{stack_size, reinterpret_cast<void (*)()>(&Starter::coroutine_start)} , func{std::move(func)} { } // I don't need to specify these. the default behavior would result // in the same constructors and assignment operators being generated // but you get better error messages if I'm explicit about it coroutine(const coroutine &) = delete; coroutine & operator=(const coroutine &) = delete; coroutine(coroutine &&) = default; coroutine & operator=(coroutine &&) = default; operator bool() const { return !this->returned; } Result operator()(Arguments... args) { this->arguments = std::make_tuple(detail::any_storage<Arguments>{std::forward<Arguments>(args)}...); detail::coroutine_yielder<Result, Arguments...>::operator()(); return Starter::Finisher::return_result(*this); } private: // returning a value needs to be specialized for void return // this struct handles that template<typename R, int... S> struct finisher { static R return_result(coroutine & this_) { return std::forward<Result>(this_.result); } static void start_and_store_result(coroutine & caller) { caller.result = caller.coroutine_start<S...>(); } }; template<int... S> struct finisher<void, S...> { static void return_result(coroutine &) { } static void start_and_store_result(coroutine & caller) { caller.coroutine_start<S...>(); } }; template<int ...ArgCountList> Result coroutine_start() { return func(*this, std::forward<Arguments>(std::get<ArgCountList>(this->arguments))...); } // the calling of the coroutine needs to be specialized by // the number of arguments. this struct handles that template<int N, int... S> struct starter : starter<N - 1, N - 1, S...> { }; template<int... S> struct starter<0, S...> { typedef finisher<Result, S...> Finisher; static void coroutine_start(uint32_t this_pointer_left_half, uint32_t this_pointer_right_half) { coroutine & this_ = *reinterpret_cast<coroutine *>((static_cast<size_t>(this_pointer_left_half) << 32) + this_pointer_right_half); try { Finisher::start_and_store_result(this_); } catch(...) { this_.exception = std::current_exception(); } this_.returned = true; } }; std::function<Result (self &, Arguments...)> func; }; }
And here it is in D:
module coro.coroutine; import coro.ucontext; import std.algorithm; import std.conv; import std.typecons; template VariadicForward(string before, string toforward, string after, size_t i) { static if (i == 0) { immutable string VariadicForward = before ~ after; } else static if (i == 1) { immutable string VariadicForward = VariadicForward!(before, toforward, "move(" ~ toforward ~ "[" ~ to!string(i - 1) ~ "])" ~ after, i - 1); } else { immutable string VariadicForward = VariadicForward!(before, toforward, ", move(" ~ toforward ~ "[" ~ to!string(i - 1) ~ "])" ~ after, i - 1); } } class coroutine(Result, Arguments...) { private alias Tuple!Arguments.Types Types; this(Result delegate(coroutine!(Result, Arguments), Types) func, size_t stack_size = 8192) { this.func = func; stack = new ubyte[stack_size]; getcontext(&callee); callee.uc_link = &caller; callee.uc_stack.ss_size = stack_size; callee.uc_stack.ss_sp = stack.ptr; extern(C) void function() call = cast(void function())&coroutine_start; makecontext(&callee, call, 2, cast(size_t)(cast(void *)this) >> 32, this); } Result opCall(Types arguments) { if (returned) throw new Throwable("This coroutine has already finished"); static if (arguments.length) { foreach(i, ref argument; arguments) this.arguments[i] = move(argument); } swapcontext(&caller, &callee); if (exception) { throw exception; } else static if (!is(Result == void)) { return move(result); } } @property bool callable() const { return !returned; } static if (is(Result == void)) { alias yield_common yield; } else { Tuple!Arguments yield(ref Result result) { this.result = result; return yield_common(); } Tuple!Arguments yield(Result result) { this.result = move(result); return yield_common(); } } private Tuple!Arguments yield_common() { swapcontext(&callee, &caller); mixin(VariadicForward!("return Tuple!Arguments(", "arguments", ");", Types.length)); } private: Result delegate(coroutine!(Result, Arguments), Types) func; ubyte[] stack; ucontext_t caller; ucontext_t callee; Throwable exception; bool returned = false; static if (!is(Result == void)) { Result result; } Types arguments; extern(C) static void coroutine_start(uint this_pointer_left_half, uint this_pointer_right_half) { coroutine this_ = cast(coroutine)cast(void *)(((cast(size_t)this_pointer_left_half) << 32) + this_pointer_right_half); try { static if (is(Result == void)) { mixin(VariadicForward!("this_.func(this_, ", "this_.arguments", ");", Types.length)); } else { mixin(VariadicForward!("this_.result = this_.func(this_, ", "this_.arguments", ");", Types.length)); } } catch(Throwable ex) { this_.exception = ex; } this_.returned = true; } }
The most obvious point first: D is much cleaner.