kanit

Toy init system
Log | Files | Refs | README | LICENSE

event.rs (4164B)


      1 use std::collections::HashSet;
      2 
      3 use kanit_common::error::{Context, Result, StaticError};
      4 use kanit_unit::UnitName;
      5 use log::warn;
      6 
      7 use crate::loader::Loader;
      8 
      9 async fn modify_service(start: bool, level: usize, name: &[u8]) -> Result<()> {
     10     // this is horrible but it makes the compiler happy
     11     let (diff, groups) = {
     12         let loader = Loader::obtain()?.borrow();
     13 
     14         let unit_name = UnitName::from(String::from_utf8_lossy(name));
     15 
     16         if (start && loader.is_started(level, &unit_name))
     17             || (!start && !loader.is_started(level, &unit_name))
     18         {
     19             Err(StaticError("unit already started/stopped"))?;
     20         }
     21 
     22         // rebuild database and diff to find out what needs to start
     23         let mut db = loader.database().clone();
     24 
     25         {
     26             let enabled = db.enabled.get_mut(level).context("failed to get level")?;
     27 
     28             if start {
     29                 enabled.insert(unit_name.clone());
     30             } else {
     31                 enabled.remove(&unit_name);
     32             }
     33         }
     34 
     35         db.rebuild_levels()?;
     36 
     37         if !db.unit_infos.contains_key(&unit_name) {
     38             Err(StaticError("failed to find unit in database"))?;
     39         }
     40 
     41         let started = loader.started.get(level).context("failed to get level")?;
     42 
     43         // unwrap: we `get_mut` earlier
     44         let enabled = db.enabled.get(level).unwrap();
     45 
     46         let (diff, levels) = if start {
     47             (
     48                 enabled.difference(started).cloned().collect::<HashSet<_>>(),
     49                 db.levels,
     50             )
     51         } else {
     52             (
     53                 started.difference(enabled).cloned().collect::<HashSet<_>>(),
     54                 loader.database().clone().levels,
     55             )
     56         };
     57 
     58         let groups = levels
     59             .get(level)
     60             .context("failed to get level")?
     61             .get_order();
     62 
     63         (diff, groups.clone())
     64     };
     65 
     66     let mut loader = Loader::obtain()?.borrow_mut();
     67 
     68     if start {
     69         for group in groups {
     70             for unit_n in group.iter().filter(|u| diff.contains(*u)) {
     71                 let unit = loader.get_unit(unit_n).context("failed to get unit")?;
     72 
     73                 let mut unit_b = unit.borrow_mut();
     74 
     75                 if !unit_b.prepare().await? {
     76                     warn!("failed preparations for {}", unit_b.name());
     77 
     78                     continue;
     79                 }
     80 
     81                 if let Err(e) = unit_b.start().await {
     82                     warn!("{e}");
     83                     return Err(e);
     84                 } else {
     85                     loader.mark_started(level, unit_b.name());
     86                 }
     87             }
     88         }
     89     } else {
     90         for group in groups.iter().rev() {
     91             for unit_n in group.iter().filter(|u| diff.contains(*u)) {
     92                 let unit = loader.get_unit(unit_n).context("failed to get unit")?;
     93 
     94                 let mut unit_b = unit.borrow_mut();
     95 
     96                 if let Err(e) = unit_b.stop().await {
     97                     warn!("{e}");
     98                     return Err(e);
     99                 } else {
    100                     loader.mark_stopped(level, &unit_b.name());
    101                 }
    102             }
    103         }
    104     }
    105 
    106     Ok(())
    107 }
    108 
    109 pub async fn event(data: Vec<u8>) -> Result<()> {
    110     if data.starts_with(b"db-reload") {
    111         let mut loader = Loader::obtain()?.borrow_mut();
    112 
    113         let ev_lock = loader.ev_lock.clone();
    114 
    115         let lock = ev_lock.lock().await;
    116 
    117         loader.reload()?;
    118 
    119         drop(lock);
    120     } else if data.starts_with(b"start") || data.starts_with(b"stop") {
    121         // start:tty:1
    122         let mut parts = data.split(|b| *b == b':');
    123 
    124         parts.next(); // forward start/stop
    125 
    126         let name = parts.next().context("failed to get name")?;
    127 
    128         let level = String::from_utf8_lossy(parts.next().context("failed to get level")?)
    129             .trim()
    130             .parse::<usize>()
    131             .context("failed to parse level")?;
    132 
    133         let ev_lock = Loader::obtain()?.borrow().ev_lock.clone();
    134 
    135         let lock = ev_lock.lock().await; // get lock to ensure no one else is using the loader
    136 
    137         modify_service(data.starts_with(b"start"), level, name).await?;
    138 
    139         drop(lock);
    140     }
    141 
    142     Ok(())
    143 }