event.rs (4164B)
1 use std::collections::HashSet; 2 3 use kanit_common::error::{Context, Result, StaticError}; 4 use kanit_unit::UnitName; 5 use log::warn; 6 7 use crate::loader::Loader; 8 9 async fn modify_service(start: bool, level: usize, name: &[u8]) -> Result<()> { 10 // this is horrible but it makes the compiler happy 11 let (diff, groups) = { 12 let loader = Loader::obtain()?.borrow(); 13 14 let unit_name = UnitName::from(String::from_utf8_lossy(name)); 15 16 if (start && loader.is_started(level, &unit_name)) 17 || (!start && !loader.is_started(level, &unit_name)) 18 { 19 Err(StaticError("unit already started/stopped"))?; 20 } 21 22 // rebuild database and diff to find out what needs to start 23 let mut db = loader.database().clone(); 24 25 { 26 let enabled = db.enabled.get_mut(level).context("failed to get level")?; 27 28 if start { 29 enabled.insert(unit_name.clone()); 30 } else { 31 enabled.remove(&unit_name); 32 } 33 } 34 35 db.rebuild_levels()?; 36 37 if !db.unit_infos.contains_key(&unit_name) { 38 Err(StaticError("failed to find unit in database"))?; 39 } 40 41 let started = loader.started.get(level).context("failed to get level")?; 42 43 // unwrap: we `get_mut` earlier 44 let enabled = db.enabled.get(level).unwrap(); 45 46 let (diff, levels) = if start { 47 ( 48 enabled.difference(started).cloned().collect::<HashSet<_>>(), 49 db.levels, 50 ) 51 } else { 52 ( 53 started.difference(enabled).cloned().collect::<HashSet<_>>(), 54 loader.database().clone().levels, 55 ) 56 }; 57 58 let groups = levels 59 .get(level) 60 .context("failed to get level")? 61 .get_order(); 62 63 (diff, groups.clone()) 64 }; 65 66 let mut loader = Loader::obtain()?.borrow_mut(); 67 68 if start { 69 for group in groups { 70 for unit_n in group.iter().filter(|u| diff.contains(*u)) { 71 let unit = loader.get_unit(unit_n).context("failed to get unit")?; 72 73 let mut unit_b = unit.borrow_mut(); 74 75 if !unit_b.prepare().await? { 76 warn!("failed preparations for {}", unit_b.name()); 77 78 continue; 79 } 80 81 if let Err(e) = unit_b.start().await { 82 warn!("{e}"); 83 return Err(e); 84 } else { 85 loader.mark_started(level, unit_b.name()); 86 } 87 } 88 } 89 } else { 90 for group in groups.iter().rev() { 91 for unit_n in group.iter().filter(|u| diff.contains(*u)) { 92 let unit = loader.get_unit(unit_n).context("failed to get unit")?; 93 94 let mut unit_b = unit.borrow_mut(); 95 96 if let Err(e) = unit_b.stop().await { 97 warn!("{e}"); 98 return Err(e); 99 } else { 100 loader.mark_stopped(level, &unit_b.name()); 101 } 102 } 103 } 104 } 105 106 Ok(()) 107 } 108 109 pub async fn event(data: Vec<u8>) -> Result<()> { 110 if data.starts_with(b"db-reload") { 111 let mut loader = Loader::obtain()?.borrow_mut(); 112 113 let ev_lock = loader.ev_lock.clone(); 114 115 let lock = ev_lock.lock().await; 116 117 loader.reload()?; 118 119 drop(lock); 120 } else if data.starts_with(b"start") || data.starts_with(b"stop") { 121 // start:tty:1 122 let mut parts = data.split(|b| *b == b':'); 123 124 parts.next(); // forward start/stop 125 126 let name = parts.next().context("failed to get name")?; 127 128 let level = String::from_utf8_lossy(parts.next().context("failed to get level")?) 129 .trim() 130 .parse::<usize>() 131 .context("failed to parse level")?; 132 133 let ev_lock = Loader::obtain()?.borrow().ev_lock.clone(); 134 135 let lock = ev_lock.lock().await; // get lock to ensure no one else is using the loader 136 137 modify_service(data.starts_with(b"start"), level, name).await?; 138 139 drop(lock); 140 } 141 142 Ok(()) 143 }