sylveos

Toy Operating System
Log | Files | Refs

coroutine.zig (4097B)


      1 const std = @import("std");
      2 
      3 const Self = @This();
      4 
      5 tid_counter: usize,
      6 tasks: std.DoublyLinkedList,
      7 scheduler_task: *Task,
      8 allocator: std.mem.Allocator,
      9 
     10 pub const Task = struct {
     11     tid: usize,
     12     node: std.DoublyLinkedList.Node,
     13 
     14     func: *const fn (*Self, ?*anyopaque) callconv(.c) void,
     15     arg: ?*anyopaque,
     16 
     17     saved_sp: *usize,
     18     // r0-r3 are caller-saved and so we don't need to store them
     19     // r4-r11 are callee-saved
     20     // r12 is scratch
     21     // r13 is sp
     22     // r14 is lr
     23     // r15 is pc
     24     stack: [1024]usize align(8),
     25 };
     26 
     27 pub fn new(gpa: std.mem.Allocator) !*Self {
     28     const scheduler = try gpa.create(Self);
     29     scheduler.* = .{ .tid_counter = 1, .tasks = .{}, .scheduler_task = undefined, .allocator = gpa };
     30 
     31     return scheduler;
     32 }
     33 
     34 pub fn destroy(self: *Self) void {
     35     var task_node = self.tasks.first;
     36 
     37     while (task_node) |n| {
     38         task_node = n.next;
     39         self.allocator.destroy(self.get_current());
     40     }
     41 
     42     self.allocator.destroy(self);
     43 }
     44 
     45 fn push_stack(sp: *usize, value: usize) *usize {
     46     const sp_new: *usize = @ptrFromInt(@intFromPtr(sp) - @sizeOf(usize));
     47     sp_new.* = value;
     48     return sp_new;
     49 }
     50 
     51 pub fn fork(self: *Self, f: *const fn (*Self, ?*anyopaque) callconv(.c) void, arg: ?*anyopaque) !void {
     52     var task = try self.allocator.create(Task);
     53 
     54     task.tid = self.tid_counter;
     55     self.tid_counter += 1;
     56 
     57     task.func = f;
     58     task.arg = arg;
     59 
     60     const stack_top = @intFromPtr(&task.stack) + task.stack.len;
     61     var sp: *usize = @ptrFromInt(stack_top);
     62     std.debug.assertAligned(sp, .@"8");
     63 
     64     // lr
     65     sp = push_stack(sp, @intFromPtr(&coroutine_trampoline));
     66 
     67     // r4-r11
     68     for (4..12) |_| {
     69         sp = push_stack(sp, 0);
     70     }
     71     // r3 (for trampoline)
     72     sp = push_stack(sp, @intFromPtr(self));
     73 
     74     task.saved_sp = sp;
     75 
     76     self.tasks.append(&task.node);
     77 }
     78 
     79 fn get_current(self: *Self) *Task {
     80     return @alignCast(@fieldParentPtr("node", self.tasks.first.?));
     81 }
     82 
     83 pub fn join(self: *Self) !void {
     84     if (self.tasks.first == null) {
     85         // Nothing to do
     86         return;
     87     }
     88 
     89     self.scheduler_task = try self.allocator.create(Task);
     90     self.scheduler_task.tid = 0;
     91     self.scheduler_task.saved_sp = @ptrFromInt(@frameAddress());
     92 
     93     coroutine_context_switch(&self.scheduler_task.saved_sp, self.get_current().saved_sp);
     94 }
     95 
     96 pub fn yield(self: *Self) void {
     97     // Single task
     98     if (self.tasks.first == self.tasks.last) {
     99         return;
    100     }
    101 
    102     const node_task = self.tasks.popFirst().?;
    103     self.tasks.append(node_task);
    104 
    105     const current: *Task = @alignCast(@fieldParentPtr("node", node_task));
    106     coroutine_context_switch(&current.saved_sp, self.get_current().saved_sp);
    107 }
    108 
    109 pub fn get_tid(self: *Self) usize {
    110     return self.get_current().tid;
    111 }
    112 
    113 export fn coroutine_trampoline_inner(self: *Self) callconv(.c) void {
    114     {
    115         const current: *Task = self.get_current();
    116         current.func(self, current.arg);
    117     }
    118 
    119     // We aren't preserving sp as the function ended
    120     const current: *Task = @alignCast(@fieldParentPtr("node", self.tasks.popFirst().?));
    121     var saved_sp = current.saved_sp;
    122     // Free the task
    123     self.allocator.destroy(current);
    124 
    125     if (self.tasks.first) |node_task| {
    126         const next: *Task = @alignCast(@fieldParentPtr("node", node_task));
    127         coroutine_context_switch(&saved_sp, next.saved_sp);
    128     } else {
    129         coroutine_context_switch(&saved_sp, self.scheduler_task.saved_sp);
    130     }
    131 }
    132 
    133 comptime {
    134     asm (
    135         \\ .global coroutine_trampoline
    136         \\ .type coroutine_trampoline, %function;
    137         \\ coroutine_trampoline:
    138         \\   mov r0, r3
    139         \\   bl  coroutine_trampoline_inner
    140         \\ .global coroutine_context_switch;
    141         \\ .type coroutine_context_switch, %function;
    142         \\ coroutine_context_switch:
    143         \\   push {r3-r11,lr}
    144         \\   str  sp, [r0]
    145         \\   mov  sp, r1
    146         \\   pop  {r3-r11,lr}
    147         \\   bx   lr
    148     );
    149 }
    150 
    151 // TODO; maybe should be naked functions?
    152 extern fn coroutine_trampoline() void;
    153 extern fn coroutine_context_switch(current_task_sp: **usize, next_task_sp: *usize) void;