sort.rs (6904B)
1 use std::cell::RefCell; 2 use std::collections::HashMap; 3 use std::rc::Rc; 4 5 use kanit_common::error::{Result, StaticError, WithError}; 6 use kanit_unit::{Dependencies, RcUnit, UnitName}; 7 8 #[derive(Debug, Clone)] 9 pub struct SortableUnit { 10 pub name: UnitName, 11 pub dependencies: Rc<Dependencies>, 12 } 13 14 impl From<&RcUnit> for SortableUnit { 15 fn from(unit: &RcUnit) -> Self { 16 SortableUnit { 17 name: unit.borrow().name(), 18 dependencies: Rc::new(unit.borrow().dependencies()), 19 } 20 } 21 } 22 23 #[derive(Clone, Debug)] 24 struct Edge<T> { 25 from: Rc<RefCell<Node<T>>>, 26 deleted: bool, 27 } 28 29 #[derive(Debug)] 30 struct Node<T> { 31 data: T, 32 edges: Vec<Rc<RefCell<Edge<T>>>>, 33 idx: usize, 34 } 35 36 pub fn obtain_load_order(units: Vec<SortableUnit>) -> Result<Vec<Vec<SortableUnit>>> { 37 let mut nodes = vec![]; 38 let mut map = HashMap::new(); 39 40 for (idx, unit) in units.iter().enumerate() { 41 let node = Rc::new(RefCell::new(Node { 42 data: unit.clone(), 43 edges: vec![], 44 idx, 45 })); 46 47 nodes.push(node.clone()); 48 map.insert(unit.name.clone(), node); 49 } 50 51 for node in nodes.iter() { 52 let mut node_b = node.borrow_mut(); 53 let dependencies = node_b.data.dependencies.clone(); 54 55 for dep in dependencies.needs.iter() { 56 if let Some(unit) = map.get(&dep.clone()) { 57 let edge = Rc::new(RefCell::new(Edge { 58 from: node.clone(), 59 deleted: false, 60 })); 61 62 node_b.edges.push(edge.clone()); 63 unit.borrow_mut().edges.push(edge); 64 } else { 65 let dep = dep.to_string(); 66 Err(WithError::with(move || { 67 format!("failed to find needed dependency `{dep}`") 68 }))?; 69 } 70 } 71 72 for dep in dependencies.wants.iter().chain(dependencies.after.iter()) { 73 if let Some(unit) = map.get(dep) { 74 let edge = Rc::new(RefCell::new(Edge { 75 from: node.clone(), 76 deleted: false, 77 })); 78 79 node_b.edges.push(edge.clone()); 80 unit.borrow_mut().edges.push(edge); 81 } 82 } 83 84 for before in dependencies.before.iter() { 85 if let Some(unit) = map.get(before) { 86 let mut unit_b = unit.borrow_mut(); 87 88 let edge = Rc::new(RefCell::new(Edge { 89 from: unit.clone(), 90 deleted: false, 91 })); 92 93 node_b.edges.push(edge.clone()); 94 unit_b.edges.push(edge); 95 } 96 } 97 } 98 99 let mut order = vec![]; 100 101 while !nodes.is_empty() { 102 let starting_amount = nodes.len(); 103 let nodes_c = nodes.clone(); 104 105 let mut without_incoming: Vec<_> = nodes_c 106 .iter() 107 .filter(|n| { 108 !n.borrow().edges.iter().any(|e| { 109 let e_b = e.borrow(); 110 !e_b.deleted && e_b.from.borrow().idx == n.borrow().idx 111 }) 112 }) 113 .collect(); 114 115 let mut round = vec![]; 116 117 while let Some(node) = without_incoming.pop() { 118 round.push(node.clone()); 119 120 for edge in node.borrow().edges.iter() { 121 let mut edge_b = edge.borrow_mut(); 122 edge_b.deleted = true; 123 } 124 125 if let Some(pos) = nodes 126 .iter() 127 .position(|n| n.borrow().idx == node.borrow().idx) 128 { 129 nodes.remove(pos); 130 } 131 } 132 133 order.push(round); 134 135 if starting_amount == nodes.len() { 136 Err(StaticError("cyclic dependency detected"))?; 137 } 138 } 139 140 Ok(order 141 .iter() 142 .map(|r| r.iter().map(|n| n.borrow().data.clone()).collect()) 143 .collect()) 144 } 145 146 // TODO; proper testing 147 // these are more like a scratch pad 148 #[cfg(test)] 149 mod tests { 150 use async_trait::async_trait; 151 152 use kanit_unit::{Dependencies, RcUnit, Unit, UnitName, wrap_unit}; 153 154 use super::*; 155 156 struct NullUnit(&'static str, Dependencies); 157 158 #[async_trait] 159 impl Unit for NullUnit { 160 fn name(&self) -> UnitName { 161 UnitName::from(self.0) 162 } 163 164 fn dependencies(&self) -> Dependencies { 165 self.1.clone() 166 } 167 168 async fn start(&mut self) -> Result<()> { 169 Ok(()) 170 } 171 } 172 173 fn print_plan(units: Vec<Vec<SortableUnit>>) { 174 for (i, order) in units.iter().enumerate() { 175 println!("group {i}"); 176 order.iter().for_each(|s| println!("|> {}", s.name)); 177 } 178 } 179 180 fn to_unit_info(units: Vec<RcUnit>) -> Vec<SortableUnit> { 181 units.iter().map(SortableUnit::from).collect() 182 } 183 184 #[test] 185 fn simple_generate_order() { 186 let c = NullUnit("c", Dependencies::new()); 187 let d = NullUnit("d", Dependencies::new()); 188 let b = NullUnit("b", Dependencies::new().need(d.name()).clone()); 189 let a = NullUnit( 190 "a", 191 Dependencies::new().need(b.name()).need(c.name()).clone(), 192 ); 193 194 let units = vec![wrap_unit(a), wrap_unit(b), wrap_unit(c), wrap_unit(d)]; 195 196 if let Ok(order) = obtain_load_order(to_unit_info(units)) { 197 print_plan(order); 198 } 199 } 200 201 #[test] 202 fn complex_generate_order() { 203 let f = NullUnit("2", Dependencies::new()); 204 let g = NullUnit("9", Dependencies::new()); 205 let h = NullUnit("10", Dependencies::new()); 206 let e = NullUnit("8", Dependencies::new().need(g.name()).clone()); 207 208 let c = NullUnit( 209 "3", 210 Dependencies::new().need(e.name()).need(h.name()).clone(), 211 ); 212 let d = NullUnit( 213 "11", 214 Dependencies::new() 215 .need(f.name()) 216 .need(g.name()) 217 .need(h.name()) 218 .clone(), 219 ); 220 let b = NullUnit( 221 "7", 222 Dependencies::new().need(d.name()).need(e.name()).clone(), 223 ); 224 225 let a = NullUnit("5", Dependencies::new().need(d.name()).clone()); 226 227 let units = vec![ 228 wrap_unit(a), 229 wrap_unit(b), 230 wrap_unit(c), 231 wrap_unit(d), 232 wrap_unit(e), 233 wrap_unit(f), 234 wrap_unit(g), 235 wrap_unit(h), 236 ]; 237 238 if let Ok(order) = obtain_load_order(to_unit_info(units)) { 239 print_plan(order); 240 } 241 } 242 243 #[test] 244 fn cyclic_chain() { 245 let a = NullUnit("a", Dependencies::new().need(UnitName::from("b")).clone()); 246 let b = NullUnit("b", Dependencies::new().need(UnitName::from("a")).clone()); 247 248 let units = vec![wrap_unit(a), wrap_unit(b)]; 249 250 assert!(obtain_load_order(to_unit_info(units)).is_err()); 251 } 252 }