sylveos

Toy Operating System
Log | Files | Refs

commit cd7592289729e6df225b4a4cffcba3a519343f82
parent d9c1ea959de741d9412808d73e344d99afd40b8f
Author: Sylvia Ivory <git@sivory.net>
Date:   Tue, 10 Feb 2026 11:53:34 -0800

TCP echo

Diffstat:
Aprograms/tcp-echo.zig | 41+++++++++++++++++++++++++++++++++++++++++
Mshared/root.zig | 1+
Ashared/tcp.zig | 134+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mtools/src/main.rs | 13+++++++++----
Atools/src/tcp.rs | 302++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5 files changed, 487 insertions(+), 4 deletions(-)

diff --git a/programs/tcp-echo.zig b/programs/tcp-echo.zig @@ -0,0 +1,41 @@ +const std = @import("std"); +const pi = @import("pi"); + +const tcp = @import("shared").tcp_protocol; + +const uart = pi.devices.mini_uart; + +pub fn main() !void { + const w = &uart.writer; + const r = &uart.reader; + + var buffer: [1024 * 1024]u8 = undefined; + var fba: std.heap.FixedBufferAllocator = .init(&buffer); + var allocator = fba.allocator(); + + try tcp.send(w, .{ .Bind = .{ .port = 8080, .ttl = 1000 } }); + + while (true) { + switch (try tcp.receive(r, allocator)) { + .Request => |*request| { + const ip: *const [4]u8 = @ptrCast(&request.ip); + const msg = try std.fmt.allocPrint(allocator, "got connection {d}: {d}.{d}.{d}.{d}", .{ request.id, ip[3], ip[2], ip[1], ip[0] }); + try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } }); + allocator.free(msg); + }, + .Write => |*write| { + const msg = try std.fmt.allocPrint(allocator, "read {d}: {s}", .{ write.id, write.data }); + try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } }); + allocator.free(msg); + + try tcp.send(w, .{ .Write = .{ .data = write.data, .length = write.length, .id = write.id } }); + }, + .Disconnect => |connection| { + const msg = try std.fmt.allocPrint(allocator, "lost connection {d}", .{connection}); + try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } }); + allocator.free(msg); + }, + else => {}, + } + } +} diff --git a/shared/root.zig b/shared/root.zig @@ -1,4 +1,5 @@ pub const bootloader_protocol = @import("./net.zig"); +pub const tcp_protocol = @import("./tcp.zig"); pub const lists = @import("./lists.zig"); pub const pubsub = @import("./pubsub.zig"); diff --git a/shared/tcp.zig b/shared/tcp.zig @@ -0,0 +1,134 @@ +const std = @import("std"); + +pub const Error = error{ + InvalidKind, + ExpectedBind, +} || std.Io.Reader.Error || std.Io.Writer.Error || std.mem.Allocator.Error; + +pub const PacketKind = enum(u8) { + Bind, + Request, + Read, + Write, + Disconnect, + Print, +}; +pub const Packet = union(PacketKind) { + pub const ConnectionId = u16; + + Bind: struct { port: u16, ttl: u32 }, + Request: struct { id: ConnectionId, ip: u32 }, + Read: struct { id: ConnectionId, length: u16 }, + Write: struct { id: ConnectionId, length: u16, data: []const u8 }, + Disconnect: ConnectionId, + // Still print while going through UART + Print: struct { length: u16, data: []const u8 }, +}; + +fn get(comptime T: type, r: *std.Io.Reader) !T { + var buffer: [@sizeOf(T)]u8 = undefined; + try r.readSliceAll(&buffer); + return std.mem.readInt(T, &buffer, .little); +} + +fn put(w: *std.Io.Writer, v: anytype) !void { + const T = @TypeOf(v); + var buffer: [@sizeOf(T)]u8 = undefined; + std.mem.writeInt(T, &buffer, v, .little); + + try w.writeAll(&buffer); + try w.flush(); +} + +pub fn receive(r: *std.Io.Reader, allocator: std.mem.Allocator) Error!Packet { + const kind_raw = try get(u8, r); + if (kind_raw > @intFromEnum(PacketKind.Print)) return Error.InvalidKind; + const kind: PacketKind = @enumFromInt(kind_raw); + + switch (kind) { + .Bind => { + const port = try get(u16, r); + const ttl = try get(u32, r); + + return Packet{ .Bind = .{ + .port = port, + .ttl = ttl, + } }; + }, + .Request => { + const id = try get(Packet.ConnectionId, r); + const ip = try get(u32, r); + + return Packet{ .Request = .{ .id = id, .ip = ip } }; + }, + .Read => { + const id = try get(Packet.ConnectionId, r); + const length = try get(u16, r); + + return Packet{ .Read = .{ + .id = id, + .length = length, + } }; + }, + .Write => { + const id = try get(Packet.ConnectionId, r); + const length = try get(u16, r); + const buffer = try allocator.alloc(u8, length); + + try r.readSliceAll(buffer); + + return Packet{ .Write = .{ + .id = id, + .length = length, + .data = buffer, + } }; + }, + .Disconnect => { + return Packet{ .Disconnect = try get(Packet.ConnectionId, r) }; + }, + .Print => { + const length = try get(u16, r); + const buffer = try allocator.alloc(u8, length); + + try r.readSliceAll(buffer); + return Packet{ .Print = .{ + .length = length, + .data = buffer, + } }; + }, + } +} + +pub fn send(w: *std.Io.Writer, packet: Packet) Error!void { + try put(w, @intFromEnum(packet)); + + switch (packet) { + .Bind => |*bind| { + try put(w, bind.port); + try put(w, bind.ttl); + }, + .Request => |*request| { + try put(w, request.id); + try put(w, request.ip); + }, + .Disconnect => |id| { + try put(w, id); + }, + .Read => |*read| { + try put(w, read.id); + try put(w, read.length); + }, + .Write => |*write| { + try put(w, write.id); + try put(w, write.length); + + try w.writeAll(write.data); + try w.flush(); + }, + .Print => |*print| { + try put(w, print.length); + try w.writeAll(print.data); + try w.flush(); + }, + } +} diff --git a/tools/src/main.rs b/tools/src/main.rs @@ -8,6 +8,7 @@ use tokio_serial::{DataBits, FlowControl, Parity, SerialPortBuilderExt, SerialSt use std::{path::PathBuf, process::exit, time::Duration}; mod bootload; +mod tcp; #[derive(Parser, Debug)] #[command(about)] @@ -20,6 +21,8 @@ struct Args { timeout: f64, #[arg(short, long, default_value = "0x8000", value_parser = hex_value, value_name = "ADDRESS")] arm_base: u32, + #[arg(long, default_value_t = false)] + tcp: bool, #[arg(value_hint = clap::ValueHint::DirPath)] file: PathBuf, } @@ -39,9 +42,6 @@ async fn default_action(port: SerialStream) -> anyhow::Result<()> { let mut stdin = io::stdin(); let mut stdout = io::stdout(); - // let stdin_to_port = tokio::spawn(async move { io::copy(&mut stdin, &mut port_write).await }); - // let port_to_stdout = tokio::spawn(async move { io::copy(&mut port_read, &mut stdout).await }); - tokio::select!( _ = io::copy(&mut stdin, &mut port_write) => { warn!("stdin disconnected, exiting"); @@ -88,7 +88,12 @@ async fn main() -> anyhow::Result<()> { ); bootload::boot(&multi, &mut port, &args.file, args.arm_base).await?; - default_action(port).await?; + if args.tcp { + let (port_read, port_write) = io::split(port); + tcp::serve(port_read, port_write).await?; + } else { + default_action(port).await?; + } Ok(()) } diff --git a/tools/src/tcp.rs b/tools/src/tcp.rs @@ -0,0 +1,302 @@ +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddrV4}, + sync::Arc, +}; + +use log::{debug, error, info}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest}, + net::{ + TcpListener, + tcp::{OwnedReadHalf, OwnedWriteHalf}, + }, + sync::RwLock, +}; + +pub type ConnectionId = u16; + +#[derive(Debug)] +pub enum Message { + Bind { + port: u16, + ttl: u32, + }, + Request { + id: ConnectionId, + ip: u32, + }, + Read { + id: ConnectionId, + length: u16, + }, + Write { + id: ConnectionId, + length: u16, + data: Vec<u8>, + }, + Disconnect(ConnectionId), + Print { + length: u16, + data: Vec<u8>, + }, +} + +impl From<&Message> for u8 { + fn from(value: &Message) -> Self { + match value { + Message::Bind { .. } => 0, + Message::Request { .. } => 1, + Message::Read { .. } => 2, + Message::Write { .. } => 3, + Message::Disconnect(_) => 4, + Message::Print { .. } => 5, + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid message kind")] + InvalidKind, + #[error("io: {0}")] + Io(#[from] std::io::Error), +} + +pub async fn parse_message<R: AsyncRead + Unpin>(r: &mut R) -> Result<Message, Error> { + let kind = r.read_u8().await?; + + match kind { + // Bind: client -> server + 0 => { + let tcp_port = r.read_u16_le().await?; + let ttl = r.read_u32_le().await?; + + Ok(Message::Bind { + port: tcp_port, + ttl, + }) + } + // Request: server -> client + // 1 => {}, + // Read: client -> server + 2 => Ok(Message::Read { + id: r.read_u16_le().await?, + length: r.read_u16_le().await?, + }), + // Write: client -> server, server -> client + 3 => { + let id = r.read_u16_le().await?; + let length = r.read_u16_le().await?; + let mut data = vec![0; length as usize]; + + r.read_exact(&mut data).await?; + + Ok(Message::Write { id, length, data }) + } + // Disconnect: client -> server, server -> client + 4 => Ok(Message::Disconnect(r.read_u16_le().await?)), + // Print: client -> server + 5 => { + let length = r.read_u16_le().await?; + let mut data = vec![0; length as usize]; + + r.read_exact(&mut data).await?; + Ok(Message::Print { length, data }) + } + _ => { + debug!("got invalid kind: {kind}"); + Err(Error::InvalidKind) + } + } +} + +pub async fn send_message<W: AsyncWrite + Unpin>(w: &mut W, msg: &Message) -> Result<(), Error> { + debug!("sending message: {msg:#?}"); + + w.write_u8(msg.into()).await?; + + match msg { + // Message::Bind { tcp_port, ttl } => {}, + Message::Request { id, ip } => { + w.write_u16_le(*id).await?; + w.write_u32_le(*ip).await?; + } + Message::Write { id, length, data } => { + w.write_u16_le(*id).await?; + w.write_u16_le(*length).await?; + w.write_all(&data[..(*length as usize)]).await?; + } + Message::Disconnect(id) => { + w.write_u16_le(*id).await?; + } + _ => return Err(Error::InvalidKind), + } + + w.flush().await?; + + Ok(()) +} + +async fn socket_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>( + id: ConnectionId, + socket: OwnedReadHalf, + w: Arc<RwLock<W>>, +) -> Result<(), Error> { + let ip = match socket.peer_addr()?.ip() { + IpAddr::V4(v4) => v4.to_bits(), + _ => unreachable!(), + }; + + { + let mut writer = w.write().await; + send_message(&mut *writer, &Message::Request { id, ip }).await?; + } + + loop { + let ready = socket.ready(Interest::READABLE).await?; + + if ready.is_readable() { + debug!("socket is readable"); + let mut data = vec![0; 1024]; + + match socket.try_read(&mut data) { + Ok(0) => { + debug!("socket disconnected: {id}"); + let mut writer = w.write().await; + + if let Err(e) = send_message(&mut *writer, &Message::Disconnect(id)).await { + error!("failed to send message: {e}"); + } + + return Ok(()); + } + Ok(n) => { + debug!("got data: n={n}: {:?}", &data[..n]); + let mut writer = w.write().await; + + if let Err(e) = send_message( + &mut *writer, + &Message::Write { + id, + length: n as u16, + data, + }, + ) + .await + { + error!("failed to send message: {e}"); + } + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + error!("failed to read: {e}"); + return Err(e.into()); + } + } + } + } +} + +async fn tcp_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>( + server: TcpListener, + w: Arc<RwLock<W>>, + map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>, +) -> Result<(), Error> { + let mut counter: ConnectionId = 0; + // TODO; should automatically detect when sockets drop + loop { + let (socket, _) = server.accept().await?; + let (socket_read, socket_write) = socket.into_split(); + + map.write().await.insert(counter, socket_write); + tokio::spawn(socket_handler(counter, socket_read, w.clone())); + + counter += 1; + } +} + +async fn port_handler<R: AsyncRead + Unpin, W: AsyncWrite + Send + Sync + Unpin + 'static>( + mut r: R, + w: Arc<RwLock<W>>, + map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>, +) -> Result<(), Error> { + loop { + let msg = parse_message(&mut r).await?; + debug!("got message: {msg:#?}"); + + match msg { + Message::Write { + id, + length: _, + data, + } => { + let mut lock = map.write().await; + + if let Some(socket) = lock.get_mut(&id) { + if let Err(e) = socket.write_all(&data).await { + error!("failed to write to connection {id}: {e}"); + } + } else { + error!("got invalid id: {id}"); + } + } + Message::Disconnect(id) => { + let mut lock = map.write().await; + + if let Some(socket) = lock.get_mut(&id) { + _ = socket.shutdown().await; + lock.remove(&id); + } else { + error!("got invalid id: {id}"); + } + } + Message::Print { length: _, data } => { + info!("pi: {}", String::from_utf8_lossy(&data).trim()) + } + _ => {} + } + } +} + +async fn wait_for_bind<R: AsyncRead + Unpin>(r: &mut R) -> Result<TcpListener, Error> { + info!("waiting for bind"); + + loop { + if let Ok(Message::Bind { port, ttl }) = parse_message(r).await { + info!("got response, port={port}, ttl={ttl}"); + let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port); + let listener = TcpListener::bind(addr).await?; + // listener.set_ttl(ttl)?; + return Ok(listener); + } + } +} + +pub async fn serve< + R: AsyncRead + Send + Sync + Unpin + 'static, + W: AsyncWrite + Send + Sync + Unpin + 'static, +>( + mut r: R, + w: W, +) -> Result<(), Error> { + let server = wait_for_bind(&mut r).await?; + info!("listening at {}", server.local_addr()?); + + // Rust my beloved + let connection_map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>> = + Arc::new(RwLock::new(HashMap::new())); + let writer: Arc<RwLock<W>> = Arc::new(RwLock::new(w)); + + let server_task = tokio::spawn(tcp_handler(server, writer.clone(), connection_map.clone())); + let port_task = tokio::spawn(port_handler(r, writer.clone(), connection_map.clone())); + + tokio::select! { + r = server_task => r.unwrap(), + r = port_task => r.unwrap(), + }?; + + Ok(()) +}