kanit

Toy init system
Log | Files | Refs | README | LICENSE

unit.rs (11799B)


      1 // rewrite cmd => exec, args
      2 
      3 use std::collections::HashMap;
      4 use std::fmt::Formatter;
      5 use std::path::Path;
      6 use std::str::FromStr;
      7 use std::{error, fmt};
      8 
      9 use async_trait::async_trait;
     10 use blocking::unblock;
     11 use log::warn;
     12 use nix::libc::pid_t;
     13 use nix::sys::signal::{Signal, kill};
     14 use nix::unistd::Pid;
     15 use nom::branch::alt;
     16 use nom::bytes::complete::{is_a, take_while, take_while1};
     17 use nom::character::complete::{char, multispace0, multispace1};
     18 use nom::combinator::{all_consuming, cut, map, map_res, value, verify};
     19 use nom::error::Error;
     20 use nom::multi::{fold_many0, separated_list0};
     21 use nom::sequence::{preceded, separated_pair, terminated};
     22 use nom::{Finish, IResult};
     23 
     24 use kanit_common::constants;
     25 use kanit_common::error::{Context, Result};
     26 use kanit_supervisor::{RestartPolicy, Supervisor, SupervisorBuilder};
     27 
     28 use crate::formats::config_file;
     29 use crate::{Dependencies, UnitName};
     30 
     31 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
     32 pub enum UnitKind {
     33     Oneshot,
     34     Daemon,
     35     Builtin,
     36 }
     37 
     38 #[derive(Debug, PartialEq, Eq)]
     39 pub struct ParseUnitKindError;
     40 
     41 impl fmt::Display for ParseUnitKindError {
     42     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
     43         write!(f, "expected `oneshot`, `daemon`, or `builtin`")
     44     }
     45 }
     46 
     47 impl error::Error for ParseUnitKindError {}
     48 
     49 impl FromStr for UnitKind {
     50     type Err = ParseUnitKindError;
     51 
     52     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
     53         match s {
     54             "oneshot" => Ok(Self::Oneshot),
     55             "daemon" => Ok(Self::Daemon),
     56             "builtin" => Ok(Self::Builtin),
     57             _ => Err(ParseUnitKindError),
     58         }
     59     }
     60 }
     61 
     62 fn unit_name_vec(val: Option<&&str>) -> Vec<UnitName> {
     63     val.map(|s| s.split(',').map(UnitName::from).collect())
     64         .unwrap_or_default()
     65 }
     66 
     67 impl Dependencies {
     68     fn from_body(body: &HashMap<&str, &str>) -> Self {
     69         Self {
     70             before: unit_name_vec(body.get("before")),
     71             after: unit_name_vec(body.get("after")),
     72             needs: unit_name_vec(body.get("needs")),
     73             uses: unit_name_vec(body.get("uses")),
     74             wants: unit_name_vec(body.get("wants")),
     75         }
     76     }
     77 }
     78 
     79 // this is stupid but it works so it isn't stupid
     80 #[derive(Debug, PartialEq, Eq)]
     81 pub(crate) struct Command(String, Vec<String>);
     82 
     83 impl FromStr for Command {
     84     type Err = Error<String>;
     85 
     86     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
     87         match parse_cmd(s).finish() {
     88             Ok((_remaining, cmd)) => Ok(cmd),
     89             Err(e) => Err(Error {
     90                 input: e.input.to_string(),
     91                 code: e.code,
     92             }),
     93         }
     94     }
     95 }
     96 
     97 // could this be lowered?
     98 // to be fair, using supervisor with a shell allows you to get all the benefits
     99 // the shell will automatically do splitting
    100 fn parse_cmd(input: &str) -> IResult<&str, Command> {
    101     all_consuming(map(separated_pair(command, multispace0, args), |(c, a)| {
    102         Command(c, a)
    103     }))(input)
    104 }
    105 
    106 fn command(input: &str) -> IResult<&str, String> {
    107     map(take_while1(|c| c != ' '), String::from)(input)
    108 }
    109 
    110 fn args(input: &str) -> IResult<&str, Vec<String>> {
    111     separated_list0(
    112         multispace1,
    113         alt((quoted_arg, map(take_while1(|c| c != ' '), String::from))),
    114     )(input)
    115 }
    116 
    117 enum Fragment<'a> {
    118     Char(char),
    119     Str(&'a str),
    120 }
    121 
    122 fn quoted_arg(input: &str) -> IResult<&str, String> {
    123     let (input, open_quote) = is_a("'\"")(input)?;
    124     let quote_char = open_quote.chars().next().unwrap();
    125 
    126     let inner_str = fold_many0(
    127         alt((
    128             map(
    129                 verify(
    130                     take_while(move |c| c != '\\' && c != quote_char),
    131                     |s: &str| !s.is_empty(),
    132                 ),
    133                 Fragment::Str,
    134             ),
    135             map(
    136                 preceded(
    137                     char('\\'),
    138                     alt((
    139                         value('\n', char('n')),
    140                         value('\r', char('r')),
    141                         value('\t', char('t')),
    142                         value('\\', char('\\')),
    143                         value(quote_char, char(quote_char)),
    144                     )),
    145                 ),
    146                 Fragment::Char,
    147             ),
    148         )),
    149         String::new,
    150         |mut buff, fragment| {
    151             match fragment {
    152                 Fragment::Char(c) => buff.push(c),
    153                 Fragment::Str(s) => buff.push_str(s),
    154             }
    155             buff
    156         },
    157     );
    158 
    159     cut(terminated(inner_str, char(quote_char)))(input)
    160 }
    161 
    162 #[derive(Debug, Clone)]
    163 pub struct Unit {
    164     pub kind: UnitKind,
    165     pub description: Option<Box<str>>,
    166     pub dependencies: Option<Dependencies>,
    167     pub supervisor: Option<Supervisor>,
    168 }
    169 
    170 fn parse_unit(input: &str) -> IResult<&str, Unit> {
    171     map_res(config_file, Unit::from_config)(input)
    172 }
    173 
    174 impl FromStr for Unit {
    175     type Err = Error<String>;
    176 
    177     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
    178         match parse_unit(s).finish() {
    179             Ok((_remaining, unit)) => Ok(unit),
    180             Err(e) => Err(Error {
    181                 input: e.input.to_string(),
    182                 code: e.code,
    183             }),
    184         }
    185     }
    186 }
    187 
    188 impl Unit {
    189     pub fn from_config(config: HashMap<&str, HashMap<&str, &str>>) -> Result<Self> {
    190         let root: HashMap<String, String> = config
    191             .get(".root")
    192             .map(|n| {
    193                 n.iter()
    194                     .map(|(k, v)| (k.to_string(), v.to_string()))
    195                     .collect()
    196             })
    197             .context("expected root node")?; // should be impossible
    198         let kind = root
    199             .get("kind")
    200             .context("expected kind")?
    201             .parse::<UnitKind>()
    202             .context("failed to parse unit kind")?;
    203         let description = root.get("description").map(|n| Box::from(n.as_str()));
    204         let dependencies = config.get("depends").map(Dependencies::from_body);
    205         let supervisor = if let UnitKind::Daemon | UnitKind::Oneshot = kind {
    206             let (restart_delay, restart_attempts, restart_policy) =
    207                 if let Some(restart) = config.get("restart") {
    208                     let delay = if let Some(delay) = restart.get("delay") {
    209                         Some(delay.parse::<u64>().context("failed to parse delay")?)
    210                     } else {
    211                         None
    212                     };
    213                     let attempts = if let Some(attempts) = restart.get("attempts") {
    214                         Some(
    215                             attempts
    216                                 .parse::<u64>()
    217                                 .context("failed to parse attempts")?,
    218                         )
    219                     } else {
    220                         None
    221                     };
    222                     let policy = if let Some(policy) = restart.get("policy") {
    223                         Some(
    224                             policy
    225                                 .parse::<RestartPolicy>()
    226                                 .context("failed to parse policy")?,
    227                         )
    228                     } else {
    229                         Some(RestartPolicy::OnFailure)
    230                     };
    231                     (delay, attempts, policy)
    232                 } else {
    233                     (None, None, Some(RestartPolicy::OnFailure))
    234                 };
    235 
    236             let pwd = root.get("pwd").map(|n| n.to_string());
    237             let root_dir = root.get("root").map(|n| n.to_string());
    238             let group = root.get("group").map(|n| n.to_string());
    239             let user = root.get("user").map(|n| n.to_string());
    240             let stdout = root.get("stdout").map(|n| n.to_string());
    241             let stderr = root.get("stderr").map(|n| n.to_string());
    242 
    243             let cmd = root
    244                 .get("cmd")
    245                 .context("expected cmd")?
    246                 .parse::<Command>()
    247                 .context("failed to parse cmd")?;
    248 
    249             let env = config
    250                 .get("environment")
    251                 .map(|n| n.iter().map(|(k, v)| format!("{k}={v}")).collect())
    252                 .unwrap_or_default();
    253 
    254             Some(Supervisor {
    255                 exec: cmd.0,
    256                 args: cmd.1,
    257                 cgroup: None,
    258                 restart_delay,
    259                 restart_attempts,
    260                 restart_policy,
    261                 pwd,
    262                 root: root_dir,
    263                 env,
    264                 group,
    265                 user,
    266                 stdout,
    267                 stderr,
    268             })
    269         } else {
    270             None
    271         };
    272 
    273         Ok(Unit {
    274             kind,
    275             description,
    276             dependencies,
    277             supervisor,
    278         })
    279     }
    280 }
    281 
    282 #[async_trait]
    283 impl crate::Unit for (UnitName, Unit) {
    284     fn name(&self) -> UnitName {
    285         self.0.clone()
    286     }
    287 
    288     fn description(&self) -> Option<&str> {
    289         self.1.description.as_deref()
    290     }
    291 
    292     fn dependencies(&self) -> Dependencies {
    293         self.1.dependencies.clone().unwrap_or_default()
    294     }
    295 
    296     async fn start(&mut self) -> Result<()> {
    297         if let Some(mut supervisor) = self.1.supervisor.clone() {
    298             supervisor.cgroup = Some(self.0.to_string());
    299 
    300             let child = SupervisorBuilder::from_supervisor(supervisor).spawn()?;
    301 
    302             async_fs::write(
    303                 Path::new(constants::KAN_PIDS).join(self.0.to_string()),
    304                 child.id().to_string(),
    305             )
    306             .await
    307             .context("failed to write pid")?;
    308         }
    309 
    310         Ok(())
    311     }
    312 
    313     async fn stop(&mut self) -> Result<()> {
    314         if self.1.supervisor.is_none() {
    315             return Ok(());
    316         }
    317 
    318         if let Ok(pid) =
    319             async_fs::read_to_string(Path::new(constants::KAN_PIDS).join(self.0.to_string())).await
    320         {
    321             unblock(move || {
    322                 kill(
    323                     Pid::from_raw(pid.parse::<u32>().context("failed to parse pid")? as pid_t),
    324                     Signal::SIGTERM,
    325                 )
    326                 .context("failed to kill service")
    327             })
    328             .await?;
    329         } else {
    330             warn!("failed to find pid file");
    331         }
    332 
    333         Ok(())
    334     }
    335 }
    336 
    337 #[cfg(test)]
    338 mod test {
    339     use crate::parser_test;
    340 
    341     use super::*;
    342 
    343     #[test]
    344     fn parse_unit_kind() {
    345         assert_eq!("oneshot".parse::<UnitKind>().unwrap(), UnitKind::Oneshot);
    346         assert_eq!("daemon".parse::<UnitKind>().unwrap(), UnitKind::Daemon);
    347         assert_eq!("builtin".parse::<UnitKind>().unwrap(), UnitKind::Builtin);
    348         assert_eq!("foo".parse::<UnitKind>(), Err(ParseUnitKindError));
    349     }
    350 
    351     #[test]
    352     fn parse_command() -> std::result::Result<(), Error<&'static str>> {
    353         parser_test!(
    354             command,
    355             [
    356                 "foo" => "foo",
    357                 "foo2" => "foo2"
    358             ]
    359         );
    360 
    361         Ok(())
    362     }
    363 
    364     #[test]
    365     fn parse_args() -> std::result::Result<(), Error<&'static str>> {
    366         parser_test!(
    367             args,
    368             [
    369                 "foo" => vec!["foo"],
    370                 "foo2 \"bar\"" => vec!["foo2", "bar"],
    371                 "\"\\\"baz\\\"\"" => vec!["\"baz\""],
    372                 "'\\n'" => vec!["\n"],
    373                 "'hello\\nworld'" => vec!["hello\nworld"]
    374             ]
    375         );
    376 
    377         Ok(())
    378     }
    379 
    380     #[test]
    381     fn parse_parse_cmd() -> std::result::Result<(), Error<&'static str>> {
    382         parser_test!(
    383             parse_cmd,
    384             [
    385                 "foo" => Command(String::from("foo"), vec![]),
    386                 "foo bar" => Command(String::from("foo"), vec![String::from("bar")]),
    387                 "foo \"bar baz\"" => Command(String::from("foo"), vec![String::from("bar baz")]),
    388                 "foo \"bar baz\" \"\\n\"" => Command(
    389                     String::from("foo"),
    390                     vec![String::from("bar baz"), String::from("\n")]
    391                 )
    392             ]
    393         );
    394 
    395         Ok(())
    396     }
    397 }