sylveos

Toy Operating System
Log | Files | Refs

tcp.rs (9446B)


      1 use std::{
      2     collections::HashMap,
      3     net::{IpAddr, Ipv4Addr, SocketAddrV4},
      4     sync::Arc,
      5 };
      6 
      7 use tokio::{
      8     io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest},
      9     net::{
     10         TcpListener,
     11         tcp::{OwnedReadHalf, OwnedWriteHalf},
     12     },
     13     sync::RwLock,
     14 };
     15 use tracing::{Instrument, Level, debug, error, field, info, instrument, span, warn};
     16 
     17 pub type ConnectionId = u16;
     18 
     19 #[derive(Debug)]
     20 pub enum Message {
     21     Bind {
     22         port: u16,
     23         ttl: u32,
     24     },
     25     Request {
     26         id: ConnectionId,
     27         ip: u32,
     28     },
     29     Read {
     30         id: ConnectionId,
     31         length: u16,
     32     },
     33     Write {
     34         id: ConnectionId,
     35         length: u16,
     36         data: Vec<u8>,
     37     },
     38     Disconnect(ConnectionId),
     39     Print {
     40         length: u16,
     41         data: Vec<u8>,
     42     },
     43 }
     44 
     45 impl From<&Message> for u8 {
     46     fn from(value: &Message) -> Self {
     47         match value {
     48             Message::Bind { .. } => 0,
     49             Message::Request { .. } => 1,
     50             Message::Read { .. } => 2,
     51             Message::Write { .. } => 3,
     52             Message::Disconnect(_) => 4,
     53             Message::Print { .. } => 5,
     54         }
     55     }
     56 }
     57 
     58 #[derive(thiserror::Error, Debug)]
     59 pub enum Error {
     60     #[error("invalid message kind: {0}")]
     61     InvalidKind(u8),
     62     #[error("io: {0}")]
     63     Io(#[from] std::io::Error),
     64 }
     65 
     66 #[instrument(skip_all)]
     67 pub async fn parse_message<R: AsyncRead + Unpin>(
     68     r: &mut R,
     69     log_err: bool,
     70 ) -> Result<Message, Error> {
     71     let kind = r.read_u8().await?;
     72 
     73     match kind {
     74         // Bind: client -> server
     75         0 => {
     76             let tcp_port = r.read_u16_le().await?;
     77             let ttl = r.read_u32_le().await?;
     78 
     79             Ok(Message::Bind {
     80                 port: tcp_port,
     81                 ttl,
     82             })
     83         }
     84         // Request: server -> client
     85         // 1 => {},
     86         // Read: client -> server
     87         2 => Ok(Message::Read {
     88             id: r.read_u16_le().await?,
     89             length: r.read_u16_le().await?,
     90         }),
     91         // Write: client -> server, server -> client
     92         3 => {
     93             let id = r.read_u16_le().await?;
     94             let length = r.read_u16_le().await?;
     95             let mut data = vec![0; length as usize];
     96 
     97             r.read_exact(&mut data).await?;
     98 
     99             Ok(Message::Write { id, length, data })
    100         }
    101         // Disconnect: client -> server, server -> client
    102         4 => Ok(Message::Disconnect(r.read_u16_le().await?)),
    103         // Print: client -> server
    104         5 => {
    105             let length = r.read_u16_le().await?;
    106             let mut data = vec![0; length as usize];
    107 
    108             r.read_exact(&mut data).await?;
    109             Ok(Message::Print { length, data })
    110         }
    111         k => {
    112             if log_err {
    113                 error!(kind = kind, "invalid message kind");
    114             }
    115 
    116             Err(Error::InvalidKind(k))
    117         }
    118     }
    119 }
    120 
    121 #[instrument(err, skip_all)]
    122 pub async fn send_message<W: AsyncWrite + Unpin>(w: &mut W, msg: &Message) -> Result<(), Error> {
    123     w.write_u8(msg.into()).await?;
    124 
    125     match msg {
    126         // Message::Bind { tcp_port, ttl } => {},
    127         Message::Request { id, ip } => {
    128             w.write_u16_le(*id).await?;
    129             w.write_u32_le(*ip).await?;
    130         }
    131         Message::Write { id, length, data } => {
    132             w.write_u16_le(*id).await?;
    133             w.write_u16_le(*length).await?;
    134             w.write_all(&data[..(*length as usize)]).await?;
    135         }
    136         Message::Disconnect(id) => {
    137             w.write_u16_le(*id).await?;
    138         }
    139         m => return Err(Error::InvalidKind(m.into())),
    140     }
    141 
    142     w.flush().await?;
    143 
    144     Ok(())
    145 }
    146 
    147 async fn socket_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>(
    148     id: ConnectionId,
    149     ip: IpAddr,
    150     socket: OwnedReadHalf,
    151     w: Arc<RwLock<W>>,
    152 ) -> Result<(), Error> {
    153     let ip = match ip {
    154         IpAddr::V4(v4) => v4.to_bits(),
    155         _ => unreachable!(),
    156     };
    157 
    158     {
    159         let mut writer = w.write().await;
    160         send_message(&mut *writer, &Message::Request { id, ip }).await?;
    161     }
    162 
    163     debug!("ready");
    164 
    165     loop {
    166         let ready = socket.ready(Interest::READABLE).await?;
    167 
    168         if !ready.is_readable() {
    169             continue;
    170         };
    171 
    172         let mut data = vec![0; 1024];
    173 
    174         match socket.try_read(&mut data) {
    175             Ok(0) => {
    176                 debug!("disconnected");
    177                 let mut writer = w.write().await;
    178 
    179                 if let Err(e) = send_message(&mut *writer, &Message::Disconnect(id)).await {
    180                     error!("failed to send message: {e}");
    181                 }
    182 
    183                 return Ok(());
    184             }
    185             Ok(n) => {
    186                 debug!("read {n} bytes");
    187                 let mut writer = w.write().await;
    188 
    189                 if let Err(e) = send_message(
    190                     &mut *writer,
    191                     &Message::Write {
    192                         id,
    193                         length: n as u16,
    194                         data,
    195                     },
    196                 )
    197                 .await
    198                 {
    199                     error!("failed to send message: {e}");
    200                 }
    201             }
    202             Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
    203                 continue;
    204             }
    205             Err(e) => {
    206                 error!("failed to read: {e}");
    207                 return Err(e.into());
    208             }
    209         }
    210     }
    211 }
    212 
    213 #[instrument(name = "server", skip_all, fields(indicatif.pb_show))]
    214 async fn tcp_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>(
    215     server: TcpListener,
    216     w: Arc<RwLock<W>>,
    217     map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>,
    218 ) -> Result<(), Error> {
    219     let mut counter: ConnectionId = 0;
    220     // TODO; should automatically detect when sockets drop
    221     loop {
    222         let (socket, peer) = server.accept().await?;
    223         let (socket_read, socket_write) = socket.into_split();
    224 
    225         map.write().await.insert(counter, socket_write);
    226 
    227         let socket_span = tracing::info_span!(
    228             "socket",
    229             id = counter,
    230             ip = peer.ip().to_string(),
    231             indicatif.pb_show = field::Empty,
    232         );
    233 
    234         tokio::spawn(
    235             socket_handler(counter, peer.ip(), socket_read, w.clone())
    236                 .instrument(socket_span.or_current()),
    237         );
    238 
    239         counter += 1;
    240     }
    241 }
    242 
    243 #[instrument(name = "serial", skip_all, fields(ip, indicatif.pb_show))]
    244 async fn port_handler<R: AsyncRead + Unpin>(
    245     mut r: R,
    246     map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>,
    247 ) -> Result<(), Error> {
    248     loop {
    249         let msg = match parse_message(&mut r, true).await {
    250             Err(Error::InvalidKind(k)) => {
    251                 warn!(kind = k, "invalid message kind");
    252                 continue;
    253             }
    254             Err(Error::Io(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
    255                 warn!("serial port disconnected, exiting");
    256                 return Ok(());
    257             }
    258             Err(e) => return Err(e),
    259             Ok(m) => m,
    260         };
    261 
    262         match msg {
    263             Message::Write {
    264                 id,
    265                 length: _,
    266                 data,
    267             } => {
    268                 let mut lock = map.write().await;
    269 
    270                 if let Some(socket) = lock.get_mut(&id) {
    271                     if let Err(e) = socket.write_all(&data).await {
    272                         error!("failed to write to connection {id}: {e}");
    273                     }
    274                 } else {
    275                     error!("got invalid id: {id}");
    276                 }
    277             }
    278             Message::Disconnect(id) => {
    279                 let mut lock = map.write().await;
    280 
    281                 if let Some(socket) = lock.get_mut(&id) {
    282                     _ = socket.shutdown().await;
    283                     lock.remove(&id);
    284                 } else {
    285                     error!("got invalid id: {id}");
    286                 }
    287             }
    288             Message::Print { length: _, data } => {
    289                 let _enter = span!(Level::INFO, "pi").entered();
    290                 info!("{}", String::from_utf8_lossy(&data).trim())
    291             }
    292             _ => {}
    293         }
    294     }
    295 }
    296 
    297 #[instrument(skip_all)]
    298 async fn wait_for_bind<R: AsyncRead + Unpin>(r: &mut R) -> Result<TcpListener, Error> {
    299     loop {
    300         if let Ok(Message::Bind { port, ttl }) = parse_message(r, false).await {
    301             debug!(port = port, ttl = ttl, "bind request");
    302             let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port);
    303             let listener = TcpListener::bind(addr).await?;
    304             if ttl != 0 {
    305                 listener.set_ttl(ttl)?;
    306             }
    307             return Ok(listener);
    308         }
    309     }
    310 }
    311 
    312 #[instrument(name = "tcp", skip_all)]
    313 pub async fn serve<
    314     R: AsyncRead + Send + Sync + Unpin + 'static,
    315     W: AsyncWrite + Send + Sync + Unpin + 'static,
    316 >(
    317     mut r: R,
    318     w: W,
    319 ) -> Result<(), Error> {
    320     let server = wait_for_bind(&mut r).await?;
    321     info!("listening at {}", server.local_addr()?);
    322 
    323     // Rust my beloved
    324     let connection_map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>> =
    325         Arc::new(RwLock::new(HashMap::new()));
    326     let writer: Arc<RwLock<W>> = Arc::new(RwLock::new(w));
    327 
    328     let server_task = tokio::spawn(tcp_handler(server, writer.clone(), connection_map.clone()));
    329     let port_task = tokio::spawn(port_handler(r, connection_map.clone()));
    330 
    331     tokio::select! {
    332         r = server_task => r.unwrap(),
    333         r = port_task => r.unwrap(),
    334     }?;
    335 
    336     Ok(())
    337 }