kanit

Toy init system
Log | Files | Refs | README | LICENSE

sort.rs (6904B)


      1 use std::cell::RefCell;
      2 use std::collections::HashMap;
      3 use std::rc::Rc;
      4 
      5 use kanit_common::error::{Result, StaticError, WithError};
      6 use kanit_unit::{Dependencies, RcUnit, UnitName};
      7 
      8 #[derive(Debug, Clone)]
      9 pub struct SortableUnit {
     10     pub name: UnitName,
     11     pub dependencies: Rc<Dependencies>,
     12 }
     13 
     14 impl From<&RcUnit> for SortableUnit {
     15     fn from(unit: &RcUnit) -> Self {
     16         SortableUnit {
     17             name: unit.borrow().name(),
     18             dependencies: Rc::new(unit.borrow().dependencies()),
     19         }
     20     }
     21 }
     22 
     23 #[derive(Clone, Debug)]
     24 struct Edge<T> {
     25     from: Rc<RefCell<Node<T>>>,
     26     deleted: bool,
     27 }
     28 
     29 #[derive(Debug)]
     30 struct Node<T> {
     31     data: T,
     32     edges: Vec<Rc<RefCell<Edge<T>>>>,
     33     idx: usize,
     34 }
     35 
     36 pub fn obtain_load_order(units: Vec<SortableUnit>) -> Result<Vec<Vec<SortableUnit>>> {
     37     let mut nodes = vec![];
     38     let mut map = HashMap::new();
     39 
     40     for (idx, unit) in units.iter().enumerate() {
     41         let node = Rc::new(RefCell::new(Node {
     42             data: unit.clone(),
     43             edges: vec![],
     44             idx,
     45         }));
     46 
     47         nodes.push(node.clone());
     48         map.insert(unit.name.clone(), node);
     49     }
     50 
     51     for node in nodes.iter() {
     52         let mut node_b = node.borrow_mut();
     53         let dependencies = node_b.data.dependencies.clone();
     54 
     55         for dep in dependencies.needs.iter() {
     56             if let Some(unit) = map.get(&dep.clone()) {
     57                 let edge = Rc::new(RefCell::new(Edge {
     58                     from: node.clone(),
     59                     deleted: false,
     60                 }));
     61 
     62                 node_b.edges.push(edge.clone());
     63                 unit.borrow_mut().edges.push(edge);
     64             } else {
     65                 let dep = dep.to_string();
     66                 Err(WithError::with(move || {
     67                     format!("failed to find needed dependency `{dep}`")
     68                 }))?;
     69             }
     70         }
     71 
     72         for dep in dependencies.wants.iter().chain(dependencies.after.iter()) {
     73             if let Some(unit) = map.get(dep) {
     74                 let edge = Rc::new(RefCell::new(Edge {
     75                     from: node.clone(),
     76                     deleted: false,
     77                 }));
     78 
     79                 node_b.edges.push(edge.clone());
     80                 unit.borrow_mut().edges.push(edge);
     81             }
     82         }
     83 
     84         for before in dependencies.before.iter() {
     85             if let Some(unit) = map.get(before) {
     86                 let mut unit_b = unit.borrow_mut();
     87 
     88                 let edge = Rc::new(RefCell::new(Edge {
     89                     from: unit.clone(),
     90                     deleted: false,
     91                 }));
     92 
     93                 node_b.edges.push(edge.clone());
     94                 unit_b.edges.push(edge);
     95             }
     96         }
     97     }
     98 
     99     let mut order = vec![];
    100 
    101     while !nodes.is_empty() {
    102         let starting_amount = nodes.len();
    103         let nodes_c = nodes.clone();
    104 
    105         let mut without_incoming: Vec<_> = nodes_c
    106             .iter()
    107             .filter(|n| {
    108                 !n.borrow().edges.iter().any(|e| {
    109                     let e_b = e.borrow();
    110                     !e_b.deleted && e_b.from.borrow().idx == n.borrow().idx
    111                 })
    112             })
    113             .collect();
    114 
    115         let mut round = vec![];
    116 
    117         while let Some(node) = without_incoming.pop() {
    118             round.push(node.clone());
    119 
    120             for edge in node.borrow().edges.iter() {
    121                 let mut edge_b = edge.borrow_mut();
    122                 edge_b.deleted = true;
    123             }
    124 
    125             if let Some(pos) = nodes
    126                 .iter()
    127                 .position(|n| n.borrow().idx == node.borrow().idx)
    128             {
    129                 nodes.remove(pos);
    130             }
    131         }
    132 
    133         order.push(round);
    134 
    135         if starting_amount == nodes.len() {
    136             Err(StaticError("cyclic dependency detected"))?;
    137         }
    138     }
    139 
    140     Ok(order
    141         .iter()
    142         .map(|r| r.iter().map(|n| n.borrow().data.clone()).collect())
    143         .collect())
    144 }
    145 
    146 // TODO; proper testing
    147 // these are more like a scratch pad
    148 #[cfg(test)]
    149 mod tests {
    150     use async_trait::async_trait;
    151 
    152     use kanit_unit::{Dependencies, RcUnit, Unit, UnitName, wrap_unit};
    153 
    154     use super::*;
    155 
    156     struct NullUnit(&'static str, Dependencies);
    157 
    158     #[async_trait]
    159     impl Unit for NullUnit {
    160         fn name(&self) -> UnitName {
    161             UnitName::from(self.0)
    162         }
    163 
    164         fn dependencies(&self) -> Dependencies {
    165             self.1.clone()
    166         }
    167 
    168         async fn start(&mut self) -> Result<()> {
    169             Ok(())
    170         }
    171     }
    172 
    173     fn print_plan(units: Vec<Vec<SortableUnit>>) {
    174         for (i, order) in units.iter().enumerate() {
    175             println!("group {i}");
    176             order.iter().for_each(|s| println!("|> {}", s.name));
    177         }
    178     }
    179 
    180     fn to_unit_info(units: Vec<RcUnit>) -> Vec<SortableUnit> {
    181         units.iter().map(SortableUnit::from).collect()
    182     }
    183 
    184     #[test]
    185     fn simple_generate_order() {
    186         let c = NullUnit("c", Dependencies::new());
    187         let d = NullUnit("d", Dependencies::new());
    188         let b = NullUnit("b", Dependencies::new().need(d.name()).clone());
    189         let a = NullUnit(
    190             "a",
    191             Dependencies::new().need(b.name()).need(c.name()).clone(),
    192         );
    193 
    194         let units = vec![wrap_unit(a), wrap_unit(b), wrap_unit(c), wrap_unit(d)];
    195 
    196         if let Ok(order) = obtain_load_order(to_unit_info(units)) {
    197             print_plan(order);
    198         }
    199     }
    200 
    201     #[test]
    202     fn complex_generate_order() {
    203         let f = NullUnit("2", Dependencies::new());
    204         let g = NullUnit("9", Dependencies::new());
    205         let h = NullUnit("10", Dependencies::new());
    206         let e = NullUnit("8", Dependencies::new().need(g.name()).clone());
    207 
    208         let c = NullUnit(
    209             "3",
    210             Dependencies::new().need(e.name()).need(h.name()).clone(),
    211         );
    212         let d = NullUnit(
    213             "11",
    214             Dependencies::new()
    215                 .need(f.name())
    216                 .need(g.name())
    217                 .need(h.name())
    218                 .clone(),
    219         );
    220         let b = NullUnit(
    221             "7",
    222             Dependencies::new().need(d.name()).need(e.name()).clone(),
    223         );
    224 
    225         let a = NullUnit("5", Dependencies::new().need(d.name()).clone());
    226 
    227         let units = vec![
    228             wrap_unit(a),
    229             wrap_unit(b),
    230             wrap_unit(c),
    231             wrap_unit(d),
    232             wrap_unit(e),
    233             wrap_unit(f),
    234             wrap_unit(g),
    235             wrap_unit(h),
    236         ];
    237 
    238         if let Ok(order) = obtain_load_order(to_unit_info(units)) {
    239             print_plan(order);
    240         }
    241     }
    242 
    243     #[test]
    244     fn cyclic_chain() {
    245         let a = NullUnit("a", Dependencies::new().need(UnitName::from("b")).clone());
    246         let b = NullUnit("b", Dependencies::new().need(UnitName::from("a")).clone());
    247 
    248         let units = vec![wrap_unit(a), wrap_unit(b)];
    249 
    250         assert!(obtain_load_order(to_unit_info(units)).is_err());
    251     }
    252 }