1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #if !defined(_XOPEN_SOURCE)
16 // This must come before other #includes, otherwise we'll end up with ucontext_t
17 // definition mismatches, leading to memory corruption hilarity.
18 #define _XOPEN_SOURCE
19 #endif // !defined(_XOPEN_SOURCE)
20
21 #include "marl/debug.h"
22 #include "marl/memory.h"
23
24 #include <functional>
25 #include <memory>
26
27 #include <ucontext.h>
28
29 #if defined(__clang__)
30 #pragma clang diagnostic push
31 #pragma clang diagnostic ignored "-Wdeprecated-declarations"
32 #endif // defined(__clang__)
33
34 namespace marl {
35
36 class OSFiber {
37 public:
38 inline OSFiber(Allocator*);
39 inline ~OSFiber();
40
41 // createFiberFromCurrentThread() returns a fiber created from the current
42 // thread.
43 static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
44 Allocator* allocator);
45
46 // createFiber() returns a new fiber with the given stack size that will
47 // call func when switched to. func() must end by switching back to another
48 // fiber, and must not return.
49 static inline Allocator::unique_ptr<OSFiber> createFiber(
50 Allocator* allocator,
51 size_t stackSize,
52 const std::function<void()>& func);
53
54 // switchTo() immediately switches execution to the given fiber.
55 // switchTo() must be called on the currently executing fiber.
56 inline void switchTo(OSFiber*);
57
58 private:
59 Allocator* allocator;
60 ucontext_t context;
61 std::function<void()> target;
62 Allocation stack;
63 };
64
OSFiber(Allocator * allocator)65 OSFiber::OSFiber(Allocator* allocator) : allocator(allocator) {}
66
~OSFiber()67 OSFiber::~OSFiber() {
68 if (stack.ptr != nullptr) {
69 allocator->free(stack);
70 }
71 }
72
createFiberFromCurrentThread(Allocator * allocator)73 Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
74 Allocator* allocator) {
75 auto out = allocator->make_unique<OSFiber>(allocator);
76 out->context = {};
77 getcontext(&out->context);
78 return out;
79 }
80
createFiber(Allocator * allocator,size_t stackSize,const std::function<void ()> & func)81 Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
82 Allocator* allocator,
83 size_t stackSize,
84 const std::function<void()>& func) {
85 union Args {
86 OSFiber* self;
87 struct {
88 int a;
89 int b;
90 };
91 };
92
93 struct Target {
94 static void Main(int a, int b) {
95 Args u;
96 u.a = a;
97 u.b = b;
98 u.self->target();
99 }
100 };
101
102 Allocation::Request request;
103 request.size = stackSize;
104 request.alignment = 16;
105 request.usage = Allocation::Usage::Stack;
106 #if MARL_USE_FIBER_STACK_GUARDS
107 request.useGuards = true;
108 #endif
109
110 auto out = allocator->make_unique<OSFiber>(allocator);
111 out->context = {};
112 out->stack = allocator->allocate(request);
113 out->target = func;
114
115 auto res = getcontext(&out->context);
116 (void)res;
117 MARL_ASSERT(res == 0, "getcontext() returned %d", int(res));
118 out->context.uc_stack.ss_sp = out->stack.ptr;
119 out->context.uc_stack.ss_size = stackSize;
120 out->context.uc_link = nullptr;
121
122 Args args{};
123 args.self = out.get();
124 makecontext(&out->context, reinterpret_cast<void (*)()>(&Target::Main), 2,
125 args.a, args.b);
126
127 return out;
128 }
129
switchTo(OSFiber * fiber)130 void OSFiber::switchTo(OSFiber* fiber) {
131 auto res = swapcontext(&context, &fiber->context);
132 (void)res;
133 MARL_ASSERT(res == 0, "swapcontext() returned %d", int(res));
134 }
135
136 } // namespace marl
137
138 #if defined(__clang__)
139 #pragma clang diagnostic pop
140 #endif // defined(__clang__)
141