tcp.rs (9446B)
1 use std::{ 2 collections::HashMap, 3 net::{IpAddr, Ipv4Addr, SocketAddrV4}, 4 sync::Arc, 5 }; 6 7 use tokio::{ 8 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Interest}, 9 net::{ 10 TcpListener, 11 tcp::{OwnedReadHalf, OwnedWriteHalf}, 12 }, 13 sync::RwLock, 14 }; 15 use tracing::{Instrument, Level, debug, error, field, info, instrument, span, warn}; 16 17 pub type ConnectionId = u16; 18 19 #[derive(Debug)] 20 pub enum Message { 21 Bind { 22 port: u16, 23 ttl: u32, 24 }, 25 Request { 26 id: ConnectionId, 27 ip: u32, 28 }, 29 Read { 30 id: ConnectionId, 31 length: u16, 32 }, 33 Write { 34 id: ConnectionId, 35 length: u16, 36 data: Vec<u8>, 37 }, 38 Disconnect(ConnectionId), 39 Print { 40 length: u16, 41 data: Vec<u8>, 42 }, 43 } 44 45 impl From<&Message> for u8 { 46 fn from(value: &Message) -> Self { 47 match value { 48 Message::Bind { .. } => 0, 49 Message::Request { .. } => 1, 50 Message::Read { .. } => 2, 51 Message::Write { .. } => 3, 52 Message::Disconnect(_) => 4, 53 Message::Print { .. } => 5, 54 } 55 } 56 } 57 58 #[derive(thiserror::Error, Debug)] 59 pub enum Error { 60 #[error("invalid message kind: {0}")] 61 InvalidKind(u8), 62 #[error("io: {0}")] 63 Io(#[from] std::io::Error), 64 } 65 66 #[instrument(skip_all)] 67 pub async fn parse_message<R: AsyncRead + Unpin>( 68 r: &mut R, 69 log_err: bool, 70 ) -> Result<Message, Error> { 71 let kind = r.read_u8().await?; 72 73 match kind { 74 // Bind: client -> server 75 0 => { 76 let tcp_port = r.read_u16_le().await?; 77 let ttl = r.read_u32_le().await?; 78 79 Ok(Message::Bind { 80 port: tcp_port, 81 ttl, 82 }) 83 } 84 // Request: server -> client 85 // 1 => {}, 86 // Read: client -> server 87 2 => Ok(Message::Read { 88 id: r.read_u16_le().await?, 89 length: r.read_u16_le().await?, 90 }), 91 // Write: client -> server, server -> client 92 3 => { 93 let id = r.read_u16_le().await?; 94 let length = r.read_u16_le().await?; 95 let mut data = vec![0; length as usize]; 96 97 r.read_exact(&mut data).await?; 98 99 Ok(Message::Write { id, length, data }) 100 } 101 // Disconnect: client -> server, server -> client 102 4 => Ok(Message::Disconnect(r.read_u16_le().await?)), 103 // Print: client -> server 104 5 => { 105 let length = r.read_u16_le().await?; 106 let mut data = vec![0; length as usize]; 107 108 r.read_exact(&mut data).await?; 109 Ok(Message::Print { length, data }) 110 } 111 k => { 112 if log_err { 113 error!(kind = kind, "invalid message kind"); 114 } 115 116 Err(Error::InvalidKind(k)) 117 } 118 } 119 } 120 121 #[instrument(err, skip_all)] 122 pub async fn send_message<W: AsyncWrite + Unpin>(w: &mut W, msg: &Message) -> Result<(), Error> { 123 w.write_u8(msg.into()).await?; 124 125 match msg { 126 // Message::Bind { tcp_port, ttl } => {}, 127 Message::Request { id, ip } => { 128 w.write_u16_le(*id).await?; 129 w.write_u32_le(*ip).await?; 130 } 131 Message::Write { id, length, data } => { 132 w.write_u16_le(*id).await?; 133 w.write_u16_le(*length).await?; 134 w.write_all(&data[..(*length as usize)]).await?; 135 } 136 Message::Disconnect(id) => { 137 w.write_u16_le(*id).await?; 138 } 139 m => return Err(Error::InvalidKind(m.into())), 140 } 141 142 w.flush().await?; 143 144 Ok(()) 145 } 146 147 async fn socket_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>( 148 id: ConnectionId, 149 ip: IpAddr, 150 socket: OwnedReadHalf, 151 w: Arc<RwLock<W>>, 152 ) -> Result<(), Error> { 153 let ip = match ip { 154 IpAddr::V4(v4) => v4.to_bits(), 155 _ => unreachable!(), 156 }; 157 158 { 159 let mut writer = w.write().await; 160 send_message(&mut *writer, &Message::Request { id, ip }).await?; 161 } 162 163 debug!("ready"); 164 165 loop { 166 let ready = socket.ready(Interest::READABLE).await?; 167 168 if !ready.is_readable() { 169 continue; 170 }; 171 172 let mut data = vec![0; 1024]; 173 174 match socket.try_read(&mut data) { 175 Ok(0) => { 176 debug!("disconnected"); 177 let mut writer = w.write().await; 178 179 if let Err(e) = send_message(&mut *writer, &Message::Disconnect(id)).await { 180 error!("failed to send message: {e}"); 181 } 182 183 return Ok(()); 184 } 185 Ok(n) => { 186 debug!("read {n} bytes"); 187 let mut writer = w.write().await; 188 189 if let Err(e) = send_message( 190 &mut *writer, 191 &Message::Write { 192 id, 193 length: n as u16, 194 data, 195 }, 196 ) 197 .await 198 { 199 error!("failed to send message: {e}"); 200 } 201 } 202 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { 203 continue; 204 } 205 Err(e) => { 206 error!("failed to read: {e}"); 207 return Err(e.into()); 208 } 209 } 210 } 211 } 212 213 #[instrument(name = "server", skip_all, fields(indicatif.pb_show))] 214 async fn tcp_handler<W: AsyncWrite + Send + Sync + Unpin + 'static>( 215 server: TcpListener, 216 w: Arc<RwLock<W>>, 217 map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>, 218 ) -> Result<(), Error> { 219 let mut counter: ConnectionId = 0; 220 // TODO; should automatically detect when sockets drop 221 loop { 222 let (socket, peer) = server.accept().await?; 223 let (socket_read, socket_write) = socket.into_split(); 224 225 map.write().await.insert(counter, socket_write); 226 227 let socket_span = tracing::info_span!( 228 "socket", 229 id = counter, 230 ip = peer.ip().to_string(), 231 indicatif.pb_show = field::Empty, 232 ); 233 234 tokio::spawn( 235 socket_handler(counter, peer.ip(), socket_read, w.clone()) 236 .instrument(socket_span.or_current()), 237 ); 238 239 counter += 1; 240 } 241 } 242 243 #[instrument(name = "serial", skip_all, fields(ip, indicatif.pb_show))] 244 async fn port_handler<R: AsyncRead + Unpin>( 245 mut r: R, 246 map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>>, 247 ) -> Result<(), Error> { 248 loop { 249 let msg = match parse_message(&mut r, true).await { 250 Err(Error::InvalidKind(k)) => { 251 warn!(kind = k, "invalid message kind"); 252 continue; 253 } 254 Err(Error::Io(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { 255 warn!("serial port disconnected, exiting"); 256 return Ok(()); 257 } 258 Err(e) => return Err(e), 259 Ok(m) => m, 260 }; 261 262 match msg { 263 Message::Write { 264 id, 265 length: _, 266 data, 267 } => { 268 let mut lock = map.write().await; 269 270 if let Some(socket) = lock.get_mut(&id) { 271 if let Err(e) = socket.write_all(&data).await { 272 error!("failed to write to connection {id}: {e}"); 273 } 274 } else { 275 error!("got invalid id: {id}"); 276 } 277 } 278 Message::Disconnect(id) => { 279 let mut lock = map.write().await; 280 281 if let Some(socket) = lock.get_mut(&id) { 282 _ = socket.shutdown().await; 283 lock.remove(&id); 284 } else { 285 error!("got invalid id: {id}"); 286 } 287 } 288 Message::Print { length: _, data } => { 289 let _enter = span!(Level::INFO, "pi").entered(); 290 info!("{}", String::from_utf8_lossy(&data).trim()) 291 } 292 _ => {} 293 } 294 } 295 } 296 297 #[instrument(skip_all)] 298 async fn wait_for_bind<R: AsyncRead + Unpin>(r: &mut R) -> Result<TcpListener, Error> { 299 loop { 300 if let Ok(Message::Bind { port, ttl }) = parse_message(r, false).await { 301 debug!(port = port, ttl = ttl, "bind request"); 302 let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port); 303 let listener = TcpListener::bind(addr).await?; 304 if ttl != 0 { 305 listener.set_ttl(ttl)?; 306 } 307 return Ok(listener); 308 } 309 } 310 } 311 312 #[instrument(name = "tcp", skip_all)] 313 pub async fn serve< 314 R: AsyncRead + Send + Sync + Unpin + 'static, 315 W: AsyncWrite + Send + Sync + Unpin + 'static, 316 >( 317 mut r: R, 318 w: W, 319 ) -> Result<(), Error> { 320 let server = wait_for_bind(&mut r).await?; 321 info!("listening at {}", server.local_addr()?); 322 323 // Rust my beloved 324 let connection_map: Arc<RwLock<HashMap<ConnectionId, OwnedWriteHalf>>> = 325 Arc::new(RwLock::new(HashMap::new())); 326 let writer: Arc<RwLock<W>> = Arc::new(RwLock::new(w)); 327 328 let server_task = tokio::spawn(tcp_handler(server, writer.clone(), connection_map.clone())); 329 let port_task = tokio::spawn(port_handler(r, connection_map.clone())); 330 331 tokio::select! { 332 r = server_task => r.unwrap(), 333 r = port_task => r.unwrap(), 334 }?; 335 336 Ok(()) 337 }