coroutine.zig (4097B)
1 const std = @import("std"); 2 3 const Self = @This(); 4 5 tid_counter: usize, 6 tasks: std.DoublyLinkedList, 7 scheduler_task: *Task, 8 allocator: std.mem.Allocator, 9 10 pub const Task = struct { 11 tid: usize, 12 node: std.DoublyLinkedList.Node, 13 14 func: *const fn (*Self, ?*anyopaque) callconv(.c) void, 15 arg: ?*anyopaque, 16 17 saved_sp: *usize, 18 // r0-r3 are caller-saved and so we don't need to store them 19 // r4-r11 are callee-saved 20 // r12 is scratch 21 // r13 is sp 22 // r14 is lr 23 // r15 is pc 24 stack: [1024]usize align(8), 25 }; 26 27 pub fn new(gpa: std.mem.Allocator) !*Self { 28 const scheduler = try gpa.create(Self); 29 scheduler.* = .{ .tid_counter = 1, .tasks = .{}, .scheduler_task = undefined, .allocator = gpa }; 30 31 return scheduler; 32 } 33 34 pub fn destroy(self: *Self) void { 35 var task_node = self.tasks.first; 36 37 while (task_node) |n| { 38 task_node = n.next; 39 self.allocator.destroy(self.get_current()); 40 } 41 42 self.allocator.destroy(self); 43 } 44 45 fn push_stack(sp: *usize, value: usize) *usize { 46 const sp_new: *usize = @ptrFromInt(@intFromPtr(sp) - @sizeOf(usize)); 47 sp_new.* = value; 48 return sp_new; 49 } 50 51 pub fn fork(self: *Self, f: *const fn (*Self, ?*anyopaque) callconv(.c) void, arg: ?*anyopaque) !void { 52 var task = try self.allocator.create(Task); 53 54 task.tid = self.tid_counter; 55 self.tid_counter += 1; 56 57 task.func = f; 58 task.arg = arg; 59 60 const stack_top = @intFromPtr(&task.stack) + task.stack.len; 61 var sp: *usize = @ptrFromInt(stack_top); 62 std.debug.assertAligned(sp, .@"8"); 63 64 // lr 65 sp = push_stack(sp, @intFromPtr(&coroutine_trampoline)); 66 67 // r4-r11 68 for (4..12) |_| { 69 sp = push_stack(sp, 0); 70 } 71 // r3 (for trampoline) 72 sp = push_stack(sp, @intFromPtr(self)); 73 74 task.saved_sp = sp; 75 76 self.tasks.append(&task.node); 77 } 78 79 fn get_current(self: *Self) *Task { 80 return @alignCast(@fieldParentPtr("node", self.tasks.first.?)); 81 } 82 83 pub fn join(self: *Self) !void { 84 if (self.tasks.first == null) { 85 // Nothing to do 86 return; 87 } 88 89 self.scheduler_task = try self.allocator.create(Task); 90 self.scheduler_task.tid = 0; 91 self.scheduler_task.saved_sp = @ptrFromInt(@frameAddress()); 92 93 coroutine_context_switch(&self.scheduler_task.saved_sp, self.get_current().saved_sp); 94 } 95 96 pub fn yield(self: *Self) void { 97 // Single task 98 if (self.tasks.first == self.tasks.last) { 99 return; 100 } 101 102 const node_task = self.tasks.popFirst().?; 103 self.tasks.append(node_task); 104 105 const current: *Task = @alignCast(@fieldParentPtr("node", node_task)); 106 coroutine_context_switch(¤t.saved_sp, self.get_current().saved_sp); 107 } 108 109 pub fn get_tid(self: *Self) usize { 110 return self.get_current().tid; 111 } 112 113 export fn coroutine_trampoline_inner(self: *Self) callconv(.c) void { 114 { 115 const current: *Task = self.get_current(); 116 current.func(self, current.arg); 117 } 118 119 // We aren't preserving sp as the function ended 120 const current: *Task = @alignCast(@fieldParentPtr("node", self.tasks.popFirst().?)); 121 var saved_sp = current.saved_sp; 122 // Free the task 123 self.allocator.destroy(current); 124 125 if (self.tasks.first) |node_task| { 126 const next: *Task = @alignCast(@fieldParentPtr("node", node_task)); 127 coroutine_context_switch(&saved_sp, next.saved_sp); 128 } else { 129 coroutine_context_switch(&saved_sp, self.scheduler_task.saved_sp); 130 } 131 } 132 133 comptime { 134 asm ( 135 \\ .global coroutine_trampoline 136 \\ .type coroutine_trampoline, %function; 137 \\ coroutine_trampoline: 138 \\ mov r0, r3 139 \\ bl coroutine_trampoline_inner 140 \\ .global coroutine_context_switch; 141 \\ .type coroutine_context_switch, %function; 142 \\ coroutine_context_switch: 143 \\ push {r3-r11,lr} 144 \\ str sp, [r0] 145 \\ mov sp, r1 146 \\ pop {r3-r11,lr} 147 \\ bx lr 148 ); 149 } 150 151 // TODO; maybe should be naked functions? 152 extern fn coroutine_trampoline() void; 153 extern fn coroutine_context_switch(current_task_sp: **usize, next_task_sp: *usize) void;