commit cd7592289729e6df225b4a4cffcba3a519343f82
parent d9c1ea959de741d9412808d73e344d99afd40b8f
Author: Sylvia Ivory <git@sivory.net>
Date: Tue, 10 Feb 2026 11:53:34 -0800
TCP echo
Diffstat:
5 files changed, 487 insertions(+), 4 deletions(-)
diff --git a/programs/tcp-echo.zig b/programs/tcp-echo.zig
@@ -0,0 +1,41 @@
+const std = @import("std");
+const pi = @import("pi");
+
+const tcp = @import("shared").tcp_protocol;
+
+const uart = pi.devices.mini_uart;
+
+pub fn main() !void {
+ const w = &uart.writer;
+ const r = &uart.reader;
+
+ var buffer: [1024 * 1024]u8 = undefined;
+ var fba: std.heap.FixedBufferAllocator = .init(&buffer);
+ var allocator = fba.allocator();
+
+ try tcp.send(w, .{ .Bind = .{ .port = 8080, .ttl = 1000 } });
+
+ while (true) {
+ switch (try tcp.receive(r, allocator)) {
+ .Request => |*request| {
+ const ip: *const [4]u8 = @ptrCast(&request.ip);
+ const msg = try std.fmt.allocPrint(allocator, "got connection {d}: {d}.{d}.{d}.{d}", .{ request.id, ip[3], ip[2], ip[1], ip[0] });
+ try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } });
+ allocator.free(msg);
+ },
+ .Write => |*write| {
+ const msg = try std.fmt.allocPrint(allocator, "read {d}: {s}", .{ write.id, write.data });
+ try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } });
+ allocator.free(msg);
+
+ try tcp.send(w, .{ .Write = .{ .data = write.data, .length = write.length, .id = write.id } });
+ },
+ .Disconnect => |connection| {
+ const msg = try std.fmt.allocPrint(allocator, "lost connection {d}", .{connection});
+ try tcp.send(w, .{ .Print = .{ .data = msg, .length = @truncate(msg.len) } });
+ allocator.free(msg);
+ },
+ else => {},
+ }
+ }
+}
diff --git a/shared/root.zig b/shared/root.zig
@@ -1,4 +1,5 @@
pub const bootloader_protocol = @import("./net.zig");
+pub const tcp_protocol = @import("./tcp.zig");
pub const lists = @import("./lists.zig");
pub const pubsub = @import("./pubsub.zig");
diff --git a/shared/tcp.zig b/shared/tcp.zig
@@ -0,0 +1,134 @@
+const std = @import("std");
+
+pub const Error = error{
+ InvalidKind,
+ ExpectedBind,
+} || std.Io.Reader.Error || std.Io.Writer.Error || std.mem.Allocator.Error;
+
+pub const PacketKind = enum(u8) {
+ Bind,
+ Request,
+ Read,
+ Write,
+ Disconnect,
+ Print,
+};
+pub const Packet = union(PacketKind) {
+ pub const ConnectionId = u16;
+
+ Bind: struct { port: u16, ttl: u32 },
+ Request: struct { id: ConnectionId, ip: u32 },
+ Read: struct { id: ConnectionId, length: u16 },
+ Write: struct { id: ConnectionId, length: u16, data: []const u8 },
+ Disconnect: ConnectionId,
+ // Still print while going through UART
+ Print: struct { length: u16, data: []const u8 },
+};
+
+fn get(comptime T: type, r: *std.Io.Reader) !T {
+ var buffer: [@sizeOf(T)]u8 = undefined;
+ try r.readSliceAll(&buffer);
+ return std.mem.readInt(T, &buffer, .little);
+}
+
+fn put(w: *std.Io.Writer, v: anytype) !void {
+ const T = @TypeOf(v);
+ var buffer: [@sizeOf(T)]u8 = undefined;
+ std.mem.writeInt(T, &buffer, v, .little);
+
+ try w.writeAll(&buffer);
+ try w.flush();
+}
+
+pub fn receive(r: *std.Io.Reader, allocator: std.mem.Allocator) Error!Packet {
+ const kind_raw = try get(u8, r);
+ if (kind_raw > @intFromEnum(PacketKind.Print)) return Error.InvalidKind;
+ const kind: PacketKind = @enumFromInt(kind_raw);
+
+ switch (kind) {
+ .Bind => {
+ const port = try get(u16, r);
+ const ttl = try get(u32, r);
+
+ return Packet{ .Bind = .{
+ .port = port,
+ .ttl = ttl,
+ } };
+ },
+ .Request => {
+ const id = try get(Packet.ConnectionId, r);
+ const ip = try get(u32, r);
+
+ return Packet{ .Request = .{ .id = id, .ip = ip } };
+ },
+ .Read => {
+ const id = try get(Packet.ConnectionId, r);
+ const length = try get(u16, r);
+
+ return Packet{ .Read = .{
+ .id = id,
+ .length = length,
+ } };
+ },
+ .Write => {
+ const id = try get(Packet.ConnectionId, r);
+ const length = try get(u16, r);
+ const buffer = try allocator.alloc(u8, length);
+
+ try r.readSliceAll(buffer);
+
+ return Packet{ .Write = .{
+ .id = id,
+ .length = length,
+ .data = buffer,
+ } };
+ },
+ .Disconnect => {
+ return Packet{ .Disconnect = try get(Packet.ConnectionId, r) };
+ },
+ .Print => {
+ const length = try get(u16, r);
+ const buffer = try allocator.alloc(u8, length);
+
+ try r.readSliceAll(buffer);
+ return Packet{ .Print = .{
+ .length = length,
+ .data = buffer,
+ } };
+ },
+ }
+}
+
+pub fn send(w: *std.Io.Writer, packet: Packet) Error!void {
+ try put(w, @intFromEnum(packet));
+
+ switch (packet) {
+ .Bind => |*bind| {
+ try put(w, bind.port);
+ try put(w, bind.ttl);
+ },
+ .Request => |*request| {
+ try put(w, request.id);
+ try put(w, request.ip);
+ },
+ .Disconnect => |id| {
+ try put(w, id);
+ },
+ .Read => |*read| {
+ try put(w, read.id);
+ try put(w, read.length);
+ },
+ .Write => |*write| {
+ try put(w, write.id);
+ try put(w, write.length);
+
+ try w.writeAll(write.data);
+ try w.flush();
+ },
+ .Print => |*print| {
+ try put(w, print.length);
+ try w.writeAll(print.data);
+ try w.flush();
+ },
+ }
+}
diff --git a/tools/src/main.rs b/tools/src/main.rs
@@ -8,6 +8,7 @@ use tokio_serial::{DataBits, FlowControl, Parity, SerialPortBuilderExt, SerialSt
use std::{path::PathBuf, process::exit, time::Duration};
mod bootload;
+mod tcp;
#[derive(Parser, Debug)]
#[command(about)]
@@ -20,6 +21,8 @@ struct Args {
timeout: f64,
#[arg(short, long, default_value = "0x8000", value_parser = hex_value, value_name = "ADDRESS")]
arm_base: u32,
+ #[arg(long, default_value_t = false)]
+ tcp: bool,
#[arg(value_hint = clap::ValueHint::DirPath)]
file: PathBuf,
}
@@ -39,9 +42,6 @@ async fn default_action(port: SerialStream) -> anyhow::Result<()> {
let mut stdin = io::stdin();
let mut stdout = io::stdout();
- // let stdin_to_port = tokio::spawn(async move { io::copy(&mut stdin, &mut port_write).await });
- // let port_to_stdout = tokio::spawn(async move { io::copy(&mut port_read, &mut stdout).await });
-
tokio::select!(
_ = io::copy(&mut stdin, &mut port_write) => {
warn!("stdin disconnected, exiting");
@@ -88,7 +88,12 @@ async fn main() -> anyhow::Result<()> {
);
bootload::boot(&multi, &mut port, &args.file, args.arm_base).await?;
- default_action(port).await?;
+ if args.tcp {
+ let (port_read, port_write) = io::split(port);
+ tcp::serve(port_read, port_write).await?;
+ } else {
+ default_action(port).await?;
+ }
Ok(())
}
diff --git a/tools/src/tcp.rs b/tools/src/tcp.rs
@@ -0,0 +1,302 @@
+use std::{
+ collections::HashMap,
+ net::{IpAddr, Ipv4Addr, SocketAddrV4},
+ sync::Arc,
+};
+
+use log::{debug, error, info};
+use tokio::{
+ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest},
+ net::{
+ TcpListener,
+ tcp::{OwnedReadHalf, OwnedWriteHalf},
+ },
+ sync::RwLock,
+};
+
+pub type ConnectionId = u16;
+
+#[derive(Debug)]
+pub enum Message {
+ Bind {
+ port: u16,
+ ttl: u32,
+ },
+ Request {
+ id: ConnectionId,
+ ip: u32,
+ },
+ Read {
+ id: ConnectionId,
+ length: u16,
+ },
+ Write {
+ id: ConnectionId,
+ length: u16,
+ data: Vec<u8>,
+ },
+ Disconnect(ConnectionId),
+ Print {
+ length: u16,
+ data: Vec<u8>,
+ },
+}
+
+impl From<&Message> for u8 {
+ fn from(value: &Message) -> Self {
+ match value {
+ Message::Bind { .. } => 0,
+ Message::Request { .. } => 1,
+ Message::Read { .. } => 2,
+ Message::Write { .. } => 3,
+ Message::Disconnect(_) => 4,
+ Message::Print { .. } => 5,
+ }
+ }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+ #[error("Invalid message kind")]
+ InvalidKind,
+ #[error("io: {0}")]
+ Io(#[from] std::io::Error),
+}
+
+pub async fn parse_message<R: AsyncRead + Unpin>(r: &mut R) -> Result<Message, Error> {
+ let kind = r.read_u8().await?;
+
+ match kind {
+ // Bind: client -> server
+ 0 => {
+ let tcp_port = r.read_u16_le().await?;
+ let ttl = r.read_u32_le().await?;
+
+ Ok(Message::Bind {
+ port: tcp_port,
+ ttl,
+ })
+ }
+ // Request: server -> client
+ // 1 => {},
+ // Read: client -> server
+ 2 => Ok(Message::Read {
+ id: r.read_u16_le().await?,
+ length: r.read_u16_le().await?,
+ }),
+ // Write: client -> server, server -> client
+ 3 => {
+ let id = r.read_u16_le().await?;
+ let length = r.read_u16_le().await?;
+ let mut data = vec![0; length as usize];
+
+ r.read_exact(&mut data).await?;
+
+ Ok(Message::Write { id, length, data })
+ }
+ // Disconnect: client -> server, server -> client
+ 4 => Ok(Message::Disconnect(r.read_u16_le().await?)),
+ // Print: client -> server
+ 5 => {
+ let length = r.read_u16_le().await?;
+ let mut data = vec![0; length as usize];
+
+ r.read_exact(&mut data).await?;
+ Ok(Message::Print { length, data })
+ }
+ _ => {
+ debug!("got invalid kind: {kind}");
+ Err(Error::InvalidKind)
+ }
+ }
+}
+
+pub async fn send_message<W: AsyncWrite + Unpin>(w: &mut W, msg: &Message) -> Result<(), Error> {
+ debug!("sending message: {msg:#?}");
+
+ w.write_u8(msg.into()).await?;
+
+ match msg {
+ // Message::Bind { tcp_port, ttl } => {},
+ Message::Request { id, ip } => {
+ w.write_u16_le(*id).await?;
+ w.write_u32_le(*ip).await?;
+ }
+ Message::Write { id, length, data } => {
+ w.write_u16_le(*id).await?;
+ w.write_u16_le(*length).await?;
+ w.write_all(&data[..(*length as usize)]).await?;
+ }
+ Message::Disconnect(id) => {
+ w.write_u16_le(*id).await?;
+ }
+ _ => return Err(Error::InvalidKind),
+ }
+
+ w.flush().await?;
+
+ Ok(())
+}
+
+async fn socket_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>(
+ id: ConnectionId,
+ socket: OwnedReadHalf,
+ w: Arc<RwLock<W>>,
+) -> Result<(), Error> {
+ let ip = match socket.peer_addr()?.ip() {
+ IpAddr::V4(v4) => v4.to_bits(),
+ _ => unreachable!(),
+ };
+
+ {
+ let mut writer = w.write().await;
+ send_message(&mut *writer, &Message::Request { id, ip }).await?;
+ }
+
+ loop {
+ let ready = socket.ready(Interest::READABLE).await?;
+
+ if ready.is_readable() {
+ debug!("socket is readable");
+ let mut data = vec![0; 1024];
+
+ match socket.try_read(&mut data) {
+ Ok(0) => {
+ debug!("socket disconnected: {id}");
+ let mut writer = w.write().await;
+
+ if let Err(e) = send_message(&mut *writer, &Message::Disconnect(id)).await {
+ error!("failed to send message: {e}");
+ }
+
+ return Ok(());
+ }
+ Ok(n) => {
+ debug!("got data: n={n}: {:?}", &data[..n]);
+ let mut writer = w.write().await;
+
+ if let Err(e) = send_message(
+ &mut *writer,
+ &Message::Write {
+ id,
+ length: n as u16,
+ data,
+ },
+ )
+ .await
+ {
+ error!("failed to send message: {e}");
+ }
+ }
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+ continue;
+ }
+ Err(e) => {
+ error!("failed to read: {e}");
+ return Err(e.into());
+ }
+ }
+ }
+ }
+}
+
+async fn tcp_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>(
+ server: TcpListener,
+ w: Arc<RwLock<W>>,
+ map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>,
+) -> Result<(), Error> {
+ let mut counter: ConnectionId = 0;
+ // TODO; should automatically detect when sockets drop
+ loop {
+ let (socket, _) = server.accept().await?;
+ let (socket_read, socket_write) = socket.into_split();
+
+ map.write().await.insert(counter, socket_write);
+ tokio::spawn(socket_handler(counter, socket_read, w.clone()));
+
+ counter += 1;
+ }
+}
+
+async fn port_handler<R: AsyncRead + Unpin, W: AsyncWrite + Send + Sync + Unpin + 'static>(
+ mut r: R,
+ w: Arc<RwLock<W>>,
+ map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>,
+) -> Result<(), Error> {
+ loop {
+ let msg = parse_message(&mut r).await?;
+ debug!("got message: {msg:#?}");
+
+ match msg {
+ Message::Write {
+ id,
+ length: _,
+ data,
+ } => {
+ let mut lock = map.write().await;
+
+ if let Some(socket) = lock.get_mut(&id) {
+ if let Err(e) = socket.write_all(&data).await {
+ error!("failed to write to connection {id}: {e}");
+ }
+ } else {
+ error!("got invalid id: {id}");
+ }
+ }
+ Message::Disconnect(id) => {
+ let mut lock = map.write().await;
+
+ if let Some(socket) = lock.get_mut(&id) {
+ _ = socket.shutdown().await;
+ lock.remove(&id);
+ } else {
+ error!("got invalid id: {id}");
+ }
+ }
+ Message::Print { length: _, data } => {
+ info!("pi: {}", String::from_utf8_lossy(&data).trim())
+ }
+ _ => {}
+ }
+ }
+}
+
+async fn wait_for_bind<R: AsyncRead + Unpin>(r: &mut R) -> Result<TcpListener, Error> {
+ info!("waiting for bind");
+
+ loop {
+ if let Ok(Message::Bind { port, ttl }) = parse_message(r).await {
+ info!("got response, port={port}, ttl={ttl}");
+ let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port);
+ let listener = TcpListener::bind(addr).await?;
+ // listener.set_ttl(ttl)?;
+ return Ok(listener);
+ }
+ }
+}
+
+pub async fn serve<
+ R: AsyncRead + Send + Sync + Unpin + 'static,
+ W: AsyncWrite + Send + Sync + Unpin + 'static,
+>(
+ mut r: R,
+ w: W,
+) -> Result<(), Error> {
+ let server = wait_for_bind(&mut r).await?;
+ info!("listening at {}", server.local_addr()?);
+
+ // Rust my beloved
+ let connection_map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>> =
+ Arc::new(RwLock::new(HashMap::new()));
+ let writer: Arc<RwLock<W>> = Arc::new(RwLock::new(w));
+
+ let server_task = tokio::spawn(tcp_handler(server, writer.clone(), connection_map.clone()));
+ let port_task = tokio::spawn(port_handler(r, writer.clone(), connection_map.clone()));
+
+ tokio::select! {
+ r = server_task => r.unwrap(),
+ r = port_task => r.unwrap(),
+ }?;
+
+ Ok(())
+}