In my last post I implemented a tiny coroutine class that is fully functional, but it doesn’t support arguments and return values. This time I provide an implementation which adds those features as well as exception safety.
I also used this as an example to compare D and C++; as an excuse to learn C++’s variadic template syntax and as a reason to see how you can improve on boost’s implementation by using variadic templates.
Let’s start off with the punchline. Here is my coroutine implementation in C++:
#pragma once
#include <ucontext.h>
#include <cstdint>
#include <exception>
#include <functional>
#include <memory>
#include <tuple>
#ifndef CORO_DEFAULT_STACK_SIZE
#define CORO_DEFAULT_STACK_SIZE SIGSTKSZ
#endif
namespace coro
{
namespace detail
{
/**
* the coroutine_context holds the stack context and the current state
* of the coroutine. it also starts the coroutine
*/
struct coroutine_context
{
coroutine_context(size_t stack_size, void (*coroutine_call)())
: stack{new unsigned char[stack_size]}
{
// create a context for the callee
getcontext(&callee);
callee.uc_link = &caller; // tell it to return to me when done
callee.uc_stack.ss_size = stack_size;
callee.uc_stack.ss_sp = stack.get();
// pass the this pointer as two arguments, because officially
// makecontext will pass the arguments as ints (even though
// inofficially it would work with a single argument)
makecontext(&callee, coroutine_call, 2, reinterpret_cast<size_t>(this) >> 32, this);
}
void operator()()
{
if (returned) throw "This coroutine has already finished";
swapcontext(&caller, &callee); // continue here if yielded or returned
if (exception)
{
std::rethrow_exception(std::move(exception));
}
}
protected:
ucontext_t caller;
ucontext_t callee;
std::unique_ptr<unsigned char[]> stack;
std::exception_ptr exception;
bool returned = false;
};
// a type that can store both objects and references
template<typename T>
struct any_storage
{
any_storage() = default;
any_storage(T to_store)
: stored{std::move(to_store)}
{
}
any_storage & operator=(T to_store)
{
stored = std::move(to_store);
return *this;
}
operator T &&()
{
return std::move(stored);
}
private:
T stored = T{};
};
// specialization for void
template<>
struct any_storage<void>
{
};
// specialization for lvalue references
template<typename T>
struct any_storage<T &>
{
any_storage() = default;
any_storage(T & to_store)
: stored{&to_store}
{
}
any_storage & operator=(T & to_store)
{
stored = &to_store;
return *this;
}
operator T &()
{
return *stored;
}
private:
T * stored = nullptr;
};
// specialization for rvalue references
template<typename T>
struct any_storage<T &&>
{
any_storage() = default;
any_storage(T && to_store)
: stored{&to_store}
{
}
any_storage & operator=(T && to_store)
{
stored = &to_store;
return *this;
}
operator T &&()
{
return *stored;
}
private:
T * stored = nullptr;
};
// implements the shared code among all specializations of coroutine_yielder
template<typename Result, typename... Arguments>
struct coroutine_yielder_base
: protected coroutine_context
{
coroutine_yielder_base(size_t stack_size, void (*coroutine_call)())
: coroutine_context{stack_size, coroutine_call}
{
}
std::tuple<Arguments...> yield()
{
swapcontext(&this->callee, &this->caller);
return std::move(this->arguments);
}
protected:
any_storage<Result> result;
std::tuple<any_storage<Arguments>...> arguments;
};
/**
* The coroutine_yielder is responsible for providing a yield
* function
*/
template<typename Result, typename... Arguments>
struct coroutine_yielder
: protected coroutine_yielder_base<Result, Arguments...>
{
coroutine_yielder(size_t stack_size, void (*coroutine_call)())
: coroutine_yielder_base<Result, Arguments...>{stack_size, coroutine_call}
{
}
std::tuple<Arguments...> yield(Result && result)
{
this->result = std::forward<Result>(result);
return coroutine_yielder_base<Result, Arguments...>::yield();
}
// copying yield
std::tuple<Arguments...> yield(const Result & result)
{
this->result = result;
return coroutine_yielder_base<Result, Arguments...>::yield();
}
};
// specialization for void
template<typename... Arguments>
struct coroutine_yielder<void, Arguments...>
: coroutine_yielder_base<void, Arguments...>
{
coroutine_yielder(size_t stack_size, void (*coroutine_call)())
: coroutine_yielder_base<void, Arguments...>{stack_size, coroutine_call}
{
}
};
// specialization for lvalue reference
template<typename Result, typename... Arguments>
struct coroutine_yielder<Result &, Arguments...>
: protected coroutine_yielder_base<Result &, Arguments...>
{
coroutine_yielder(size_t stack_size, void (*coroutine_call)())
: coroutine_yielder_base<Result &, Arguments...>{stack_size, coroutine_call}
{
}
std::tuple<Arguments...> yield(Result & result)
{
this->result = result;
return coroutine_yielder_base<Result &, Arguments...>::yield();
}
};
// specialization for rvalue reference
template<typename Result, typename... Arguments>
struct coroutine_yielder<Result &&, Arguments...>
: protected coroutine_yielder_base<Result &&, Arguments...>
{
coroutine_yielder(size_t stack_size, void (*coroutine_call)())
: coroutine_yielder_base<Result &&, Arguments...>{stack_size, coroutine_call}
{
}
std::tuple<Arguments...> yield(Result && result)
{
this->result = std::move(result);
return coroutine_yielder_base<Result &&, Arguments...>::yield();
}
};
}
template<typename Result, typename... Arguments>
struct coroutine;
template<typename Result, typename... Arguments>
struct coroutine<Result (Arguments...)>
: detail::coroutine_yielder<Result, Arguments...>
{
private:
template<int N, int... S>
struct starter;
typedef starter<sizeof...(Arguments)> Starter;
public:
typedef typename detail::coroutine_yielder<Result, Arguments...> self;
coroutine(std::function<Result (self &, Arguments...)> func, size_t stack_size = CORO_DEFAULT_STACK_SIZE)
: detail::coroutine_yielder<Result, Arguments...>{stack_size, reinterpret_cast<void (*)()>(&Starter::coroutine_start)}
, func{std::move(func)}
{
}
// I don't need to specify these. the default behavior would result
// in the same constructors and assignment operators being generated
// but you get better error messages if I'm explicit about it
coroutine(const coroutine &) = delete;
coroutine & operator=(const coroutine &) = delete;
coroutine(coroutine &&) = default;
coroutine & operator=(coroutine &&) = default;
operator bool() const
{
return !this->returned;
}
Result operator()(Arguments... args)
{
this->arguments = std::make_tuple(detail::any_storage<Arguments>{std::forward<Arguments>(args)}...);
detail::coroutine_yielder<Result, Arguments...>::operator()();
return Starter::Finisher::return_result(*this);
}
private:
// returning a value needs to be specialized for void return
// this struct handles that
template<typename R, int... S>
struct finisher
{
static R return_result(coroutine & this_)
{
return std::forward<Result>(this_.result);
}
static void start_and_store_result(coroutine & caller)
{
caller.result = caller.coroutine_start<S...>();
}
};
template<int... S>
struct finisher<void, S...>
{
static void return_result(coroutine &)
{
}
static void start_and_store_result(coroutine & caller)
{
caller.coroutine_start<S...>();
}
};
template<int ...ArgCountList>
Result coroutine_start()
{
return func(*this, std::forward<Arguments>(std::get<ArgCountList>(this->arguments))...);
}
// the calling of the coroutine needs to be specialized by
// the number of arguments. this struct handles that
template<int N, int... S>
struct starter : starter<N - 1, N - 1, S...>
{
};
template<int... S>
struct starter<0, S...>
{
typedef finisher<Result, S...> Finisher;
static void coroutine_start(uint32_t this_pointer_left_half, uint32_t this_pointer_right_half)
{
coroutine & this_ = *reinterpret_cast<coroutine *>((static_cast<size_t>(this_pointer_left_half) << 32) + this_pointer_right_half);
try
{
Finisher::start_and_store_result(this_);
}
catch(...)
{
this_.exception = std::current_exception();
}
this_.returned = true;
}
};
std::function<Result (self &, Arguments...)> func;
};
}
And here it is in D:
module coro.coroutine;
import coro.ucontext;
import std.algorithm;
import std.conv;
import std.typecons;
template VariadicForward(string before, string toforward, string after, size_t i)
{
static if (i == 0)
{
immutable string VariadicForward = before ~ after;
}
else static if (i == 1)
{
immutable string VariadicForward = VariadicForward!(before, toforward, "move(" ~ toforward ~ "[" ~ to!string(i - 1) ~ "])" ~ after, i - 1);
}
else
{
immutable string VariadicForward = VariadicForward!(before, toforward, ", move(" ~ toforward ~ "[" ~ to!string(i - 1) ~ "])" ~ after, i - 1);
}
}
class coroutine(Result, Arguments...)
{
private alias Tuple!Arguments.Types Types;
this(Result delegate(coroutine!(Result, Arguments), Types) func, size_t stack_size = 8192)
{
this.func = func;
stack = new ubyte[stack_size];
getcontext(&callee);
callee.uc_link = &caller;
callee.uc_stack.ss_size = stack_size;
callee.uc_stack.ss_sp = stack.ptr;
extern(C) void function() call = cast(void function())&coroutine_start;
makecontext(&callee, call, 2, cast(size_t)(cast(void *)this) >> 32, this);
}
Result opCall(Types arguments)
{
if (returned) throw new Throwable("This coroutine has already finished");
static if (arguments.length)
{
foreach(i, ref argument; arguments)
this.arguments[i] = move(argument);
}
swapcontext(&caller, &callee);
if (exception)
{
throw exception;
}
else static if (!is(Result == void))
{
return move(result);
}
}
@property bool callable() const
{
return !returned;
}
static if (is(Result == void))
{
alias yield_common yield;
}
else
{
Tuple!Arguments yield(ref Result result)
{
this.result = result;
return yield_common();
}
Tuple!Arguments yield(Result result)
{
this.result = move(result);
return yield_common();
}
}
private Tuple!Arguments yield_common()
{
swapcontext(&callee, &caller);
mixin(VariadicForward!("return Tuple!Arguments(", "arguments", ");", Types.length));
}
private:
Result delegate(coroutine!(Result, Arguments), Types) func;
ubyte[] stack;
ucontext_t caller;
ucontext_t callee;
Throwable exception;
bool returned = false;
static if (!is(Result == void))
{
Result result;
}
Types arguments;
extern(C) static void coroutine_start(uint this_pointer_left_half, uint this_pointer_right_half)
{
coroutine this_ = cast(coroutine)cast(void *)(((cast(size_t)this_pointer_left_half) << 32) + this_pointer_right_half);
try
{
static if (is(Result == void))
{
mixin(VariadicForward!("this_.func(this_, ", "this_.arguments", ");", Types.length));
}
else
{
mixin(VariadicForward!("this_.result = this_.func(this_, ", "this_.arguments", ");", Types.length));
}
}
catch(Throwable ex)
{
this_.exception = ex;
}
this_.returned = true;
}
}
The most obvious point first: D is much cleaner.
Read the rest of this entry »