unit.rs (11799B)
1 // rewrite cmd => exec, args 2 3 use std::collections::HashMap; 4 use std::fmt::Formatter; 5 use std::path::Path; 6 use std::str::FromStr; 7 use std::{error, fmt}; 8 9 use async_trait::async_trait; 10 use blocking::unblock; 11 use log::warn; 12 use nix::libc::pid_t; 13 use nix::sys::signal::{Signal, kill}; 14 use nix::unistd::Pid; 15 use nom::branch::alt; 16 use nom::bytes::complete::{is_a, take_while, take_while1}; 17 use nom::character::complete::{char, multispace0, multispace1}; 18 use nom::combinator::{all_consuming, cut, map, map_res, value, verify}; 19 use nom::error::Error; 20 use nom::multi::{fold_many0, separated_list0}; 21 use nom::sequence::{preceded, separated_pair, terminated}; 22 use nom::{Finish, IResult}; 23 24 use kanit_common::constants; 25 use kanit_common::error::{Context, Result}; 26 use kanit_supervisor::{RestartPolicy, Supervisor, SupervisorBuilder}; 27 28 use crate::formats::config_file; 29 use crate::{Dependencies, UnitName}; 30 31 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 32 pub enum UnitKind { 33 Oneshot, 34 Daemon, 35 Builtin, 36 } 37 38 #[derive(Debug, PartialEq, Eq)] 39 pub struct ParseUnitKindError; 40 41 impl fmt::Display for ParseUnitKindError { 42 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 43 write!(f, "expected `oneshot`, `daemon`, or `builtin`") 44 } 45 } 46 47 impl error::Error for ParseUnitKindError {} 48 49 impl FromStr for UnitKind { 50 type Err = ParseUnitKindError; 51 52 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { 53 match s { 54 "oneshot" => Ok(Self::Oneshot), 55 "daemon" => Ok(Self::Daemon), 56 "builtin" => Ok(Self::Builtin), 57 _ => Err(ParseUnitKindError), 58 } 59 } 60 } 61 62 fn unit_name_vec(val: Option<&&str>) -> Vec<UnitName> { 63 val.map(|s| s.split(',').map(UnitName::from).collect()) 64 .unwrap_or_default() 65 } 66 67 impl Dependencies { 68 fn from_body(body: &HashMap<&str, &str>) -> Self { 69 Self { 70 before: unit_name_vec(body.get("before")), 71 after: unit_name_vec(body.get("after")), 72 needs: unit_name_vec(body.get("needs")), 73 uses: unit_name_vec(body.get("uses")), 74 wants: unit_name_vec(body.get("wants")), 75 } 76 } 77 } 78 79 // this is stupid but it works so it isn't stupid 80 #[derive(Debug, PartialEq, Eq)] 81 pub(crate) struct Command(String, Vec<String>); 82 83 impl FromStr for Command { 84 type Err = Error<String>; 85 86 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { 87 match parse_cmd(s).finish() { 88 Ok((_remaining, cmd)) => Ok(cmd), 89 Err(e) => Err(Error { 90 input: e.input.to_string(), 91 code: e.code, 92 }), 93 } 94 } 95 } 96 97 // could this be lowered? 98 // to be fair, using supervisor with a shell allows you to get all the benefits 99 // the shell will automatically do splitting 100 fn parse_cmd(input: &str) -> IResult<&str, Command> { 101 all_consuming(map(separated_pair(command, multispace0, args), |(c, a)| { 102 Command(c, a) 103 }))(input) 104 } 105 106 fn command(input: &str) -> IResult<&str, String> { 107 map(take_while1(|c| c != ' '), String::from)(input) 108 } 109 110 fn args(input: &str) -> IResult<&str, Vec<String>> { 111 separated_list0( 112 multispace1, 113 alt((quoted_arg, map(take_while1(|c| c != ' '), String::from))), 114 )(input) 115 } 116 117 enum Fragment<'a> { 118 Char(char), 119 Str(&'a str), 120 } 121 122 fn quoted_arg(input: &str) -> IResult<&str, String> { 123 let (input, open_quote) = is_a("'\"")(input)?; 124 let quote_char = open_quote.chars().next().unwrap(); 125 126 let inner_str = fold_many0( 127 alt(( 128 map( 129 verify( 130 take_while(move |c| c != '\\' && c != quote_char), 131 |s: &str| !s.is_empty(), 132 ), 133 Fragment::Str, 134 ), 135 map( 136 preceded( 137 char('\\'), 138 alt(( 139 value('\n', char('n')), 140 value('\r', char('r')), 141 value('\t', char('t')), 142 value('\\', char('\\')), 143 value(quote_char, char(quote_char)), 144 )), 145 ), 146 Fragment::Char, 147 ), 148 )), 149 String::new, 150 |mut buff, fragment| { 151 match fragment { 152 Fragment::Char(c) => buff.push(c), 153 Fragment::Str(s) => buff.push_str(s), 154 } 155 buff 156 }, 157 ); 158 159 cut(terminated(inner_str, char(quote_char)))(input) 160 } 161 162 #[derive(Debug, Clone)] 163 pub struct Unit { 164 pub kind: UnitKind, 165 pub description: Option<Box<str>>, 166 pub dependencies: Option<Dependencies>, 167 pub supervisor: Option<Supervisor>, 168 } 169 170 fn parse_unit(input: &str) -> IResult<&str, Unit> { 171 map_res(config_file, Unit::from_config)(input) 172 } 173 174 impl FromStr for Unit { 175 type Err = Error<String>; 176 177 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { 178 match parse_unit(s).finish() { 179 Ok((_remaining, unit)) => Ok(unit), 180 Err(e) => Err(Error { 181 input: e.input.to_string(), 182 code: e.code, 183 }), 184 } 185 } 186 } 187 188 impl Unit { 189 pub fn from_config(config: HashMap<&str, HashMap<&str, &str>>) -> Result<Self> { 190 let root: HashMap<String, String> = config 191 .get(".root") 192 .map(|n| { 193 n.iter() 194 .map(|(k, v)| (k.to_string(), v.to_string())) 195 .collect() 196 }) 197 .context("expected root node")?; // should be impossible 198 let kind = root 199 .get("kind") 200 .context("expected kind")? 201 .parse::<UnitKind>() 202 .context("failed to parse unit kind")?; 203 let description = root.get("description").map(|n| Box::from(n.as_str())); 204 let dependencies = config.get("depends").map(Dependencies::from_body); 205 let supervisor = if let UnitKind::Daemon | UnitKind::Oneshot = kind { 206 let (restart_delay, restart_attempts, restart_policy) = 207 if let Some(restart) = config.get("restart") { 208 let delay = if let Some(delay) = restart.get("delay") { 209 Some(delay.parse::<u64>().context("failed to parse delay")?) 210 } else { 211 None 212 }; 213 let attempts = if let Some(attempts) = restart.get("attempts") { 214 Some( 215 attempts 216 .parse::<u64>() 217 .context("failed to parse attempts")?, 218 ) 219 } else { 220 None 221 }; 222 let policy = if let Some(policy) = restart.get("policy") { 223 Some( 224 policy 225 .parse::<RestartPolicy>() 226 .context("failed to parse policy")?, 227 ) 228 } else { 229 Some(RestartPolicy::OnFailure) 230 }; 231 (delay, attempts, policy) 232 } else { 233 (None, None, Some(RestartPolicy::OnFailure)) 234 }; 235 236 let pwd = root.get("pwd").map(|n| n.to_string()); 237 let root_dir = root.get("root").map(|n| n.to_string()); 238 let group = root.get("group").map(|n| n.to_string()); 239 let user = root.get("user").map(|n| n.to_string()); 240 let stdout = root.get("stdout").map(|n| n.to_string()); 241 let stderr = root.get("stderr").map(|n| n.to_string()); 242 243 let cmd = root 244 .get("cmd") 245 .context("expected cmd")? 246 .parse::<Command>() 247 .context("failed to parse cmd")?; 248 249 let env = config 250 .get("environment") 251 .map(|n| n.iter().map(|(k, v)| format!("{k}={v}")).collect()) 252 .unwrap_or_default(); 253 254 Some(Supervisor { 255 exec: cmd.0, 256 args: cmd.1, 257 cgroup: None, 258 restart_delay, 259 restart_attempts, 260 restart_policy, 261 pwd, 262 root: root_dir, 263 env, 264 group, 265 user, 266 stdout, 267 stderr, 268 }) 269 } else { 270 None 271 }; 272 273 Ok(Unit { 274 kind, 275 description, 276 dependencies, 277 supervisor, 278 }) 279 } 280 } 281 282 #[async_trait] 283 impl crate::Unit for (UnitName, Unit) { 284 fn name(&self) -> UnitName { 285 self.0.clone() 286 } 287 288 fn description(&self) -> Option<&str> { 289 self.1.description.as_deref() 290 } 291 292 fn dependencies(&self) -> Dependencies { 293 self.1.dependencies.clone().unwrap_or_default() 294 } 295 296 async fn start(&mut self) -> Result<()> { 297 if let Some(mut supervisor) = self.1.supervisor.clone() { 298 supervisor.cgroup = Some(self.0.to_string()); 299 300 let child = SupervisorBuilder::from_supervisor(supervisor).spawn()?; 301 302 async_fs::write( 303 Path::new(constants::KAN_PIDS).join(self.0.to_string()), 304 child.id().to_string(), 305 ) 306 .await 307 .context("failed to write pid")?; 308 } 309 310 Ok(()) 311 } 312 313 async fn stop(&mut self) -> Result<()> { 314 if self.1.supervisor.is_none() { 315 return Ok(()); 316 } 317 318 if let Ok(pid) = 319 async_fs::read_to_string(Path::new(constants::KAN_PIDS).join(self.0.to_string())).await 320 { 321 unblock(move || { 322 kill( 323 Pid::from_raw(pid.parse::<u32>().context("failed to parse pid")? as pid_t), 324 Signal::SIGTERM, 325 ) 326 .context("failed to kill service") 327 }) 328 .await?; 329 } else { 330 warn!("failed to find pid file"); 331 } 332 333 Ok(()) 334 } 335 } 336 337 #[cfg(test)] 338 mod test { 339 use crate::parser_test; 340 341 use super::*; 342 343 #[test] 344 fn parse_unit_kind() { 345 assert_eq!("oneshot".parse::<UnitKind>().unwrap(), UnitKind::Oneshot); 346 assert_eq!("daemon".parse::<UnitKind>().unwrap(), UnitKind::Daemon); 347 assert_eq!("builtin".parse::<UnitKind>().unwrap(), UnitKind::Builtin); 348 assert_eq!("foo".parse::<UnitKind>(), Err(ParseUnitKindError)); 349 } 350 351 #[test] 352 fn parse_command() -> std::result::Result<(), Error<&'static str>> { 353 parser_test!( 354 command, 355 [ 356 "foo" => "foo", 357 "foo2" => "foo2" 358 ] 359 ); 360 361 Ok(()) 362 } 363 364 #[test] 365 fn parse_args() -> std::result::Result<(), Error<&'static str>> { 366 parser_test!( 367 args, 368 [ 369 "foo" => vec!["foo"], 370 "foo2 \"bar\"" => vec!["foo2", "bar"], 371 "\"\\\"baz\\\"\"" => vec!["\"baz\""], 372 "'\\n'" => vec!["\n"], 373 "'hello\\nworld'" => vec!["hello\nworld"] 374 ] 375 ); 376 377 Ok(()) 378 } 379 380 #[test] 381 fn parse_parse_cmd() -> std::result::Result<(), Error<&'static str>> { 382 parser_test!( 383 parse_cmd, 384 [ 385 "foo" => Command(String::from("foo"), vec![]), 386 "foo bar" => Command(String::from("foo"), vec![String::from("bar")]), 387 "foo \"bar baz\"" => Command(String::from("foo"), vec![String::from("bar baz")]), 388 "foo \"bar baz\" \"\\n\"" => Command( 389 String::from("foo"), 390 vec![String::from("bar baz"), String::from("\n")] 391 ) 392 ] 393 ); 394 395 Ok(()) 396 } 397 }