sylveos

Toy Operating System
Log | Files | Refs

pubsub.zig (4128B)


      1 const std = @import("std");
      2 
      3 const ListenerFn = *const fn (*const anyopaque) void;
      4 
      5 fn validate_listener_fn(comptime ptr: anytype) type {
      6     const P = @TypeOf(ptr);
      7     const ptr_info = @typeInfo(P);
      8     if (ptr_info != .pointer) {
      9         @compileError("expected *fn, found " ++ @typeName(ptr_info));
     10     }
     11     const L = ptr_info.pointer.child;
     12 
     13     const listener_info = @typeInfo(L);
     14     if (listener_info != .@"fn") {
     15         @compileError("expected *fn, found " ++ @typeName(P));
     16     }
     17 
     18     const listener_fn = listener_info.@"fn";
     19     if (listener_fn.params.len != 1) {
     20         @compileError("expected 1 parameter, found " ++ listener_fn.params.len);
     21     }
     22     const msg_type = listener_fn.params[0].type orelse unreachable;
     23     const msg_type_info = @typeInfo(msg_type);
     24     if (msg_type_info != .pointer) {
     25         @compileError("expected *const, found " ++ @typeName(msg_type));
     26     }
     27     if (!msg_type_info.pointer.is_const) {
     28         @compileError("expected *const, found " ++ @typeName(msg_type));
     29     }
     30 
     31     // Questionable legality
     32     // return @ptrCast(ptr);
     33     return msg_type;
     34 }
     35 
     36 // pub fn create_listener_static(comptime ptr: anytype) Listener {
     37 //     return .{
     38 //         .node = undefined,
     39 //         .callback = @ptrCast(ptr),
     40 //         .callback_fn_param = validate_listener_fn(ptr),
     41 //     };
     42 // }
     43 
     44 pub fn Publisher(comptime M: type) type {
     45     const msg_info = @typeInfo(M);
     46     if (msg_info != .@"union") {
     47         @compileError("expected tagged union, found " ++ @typeName(M));
     48     }
     49 
     50     const msg_union = msg_info.@"union";
     51     const msg_tag_type = msg_union.tag_type orelse @compileError("expected tagged union, found untagged union");
     52 
     53     const Slot = struct { []const u8, usize };
     54 
     55     var slots: [msg_union.fields.len]Slot = undefined;
     56 
     57     comptime for (msg_union.fields, 0..) |field, index| {
     58         slots[index] = .{ field.name, index };
     59     };
     60 
     61     const event_map: std.StaticStringMap(usize) = .initComptime(slots);
     62 
     63     return struct {
     64         const Self = @This();
     65 
     66         listeners: [msg_union.fields.len]std.ArrayList(ListenerFn),
     67         gpa: std.mem.Allocator,
     68 
     69         pub fn init(gpa: std.mem.Allocator) !Self {
     70             var self: Self = undefined;
     71 
     72             for (0..self.listeners.len) |index| {
     73                 self.listeners[index] = try .initCapacity(gpa, 1);
     74             }
     75             self.gpa = gpa;
     76 
     77             return self;
     78         }
     79 
     80         pub fn listen(self: *Self, comptime event: msg_tag_type, comptime listener: anytype) !void {
     81             comptime {
     82                 const param_type = validate_listener_fn(listener);
     83                 for (msg_union.fields) |f| {
     84                     if (std.mem.eql(u8, f.name, @tagName(event))) {
     85                         // We know that param_type is a pointer so its safe to index into
     86                         if (f.type != @typeInfo(param_type).pointer.child) {
     87                             @compileError("expected *const " ++ @typeName(f.type) ++ ", found " ++ @typeName(param_type));
     88                         }
     89 
     90                         break;
     91                     }
     92                 }
     93             }
     94 
     95             const index = event_map.get(@tagName(event)) orelse unreachable;
     96             try self.listeners[index].append(self.gpa, @ptrCast(listener));
     97         }
     98 
     99         pub fn publish(self: *Self, comptime event: msg_tag_type, data: anytype) void {
    100             comptime {
    101                 for (msg_union.fields) |f| {
    102                     if (std.mem.eql(u8, f.name, @tagName(event))) {
    103                         // We know that param_type is a pointer so its safe to index into
    104                         if (f.type != @TypeOf(data)) {
    105                             @compileError("expected " ++ @typeName(f.type) ++ ", found " ++ @typeName(@TypeOf(data)));
    106                         }
    107 
    108                         break;
    109                     }
    110                 }
    111             }
    112 
    113             const index = event_map.get(@tagName(event)) orelse unreachable;
    114             const listeners = self.listeners[index];
    115 
    116             for (listeners.items) |listener| {
    117                 listener(@ptrCast(&data));
    118             }
    119         }
    120     };
    121 }