diff --git a/src/expr/linkplan.rs b/src/expr/linkplan.rs index 3cadab0..a98432c 100644 --- a/src/expr/linkplan.rs +++ b/src/expr/linkplan.rs @@ -1,6 +1,6 @@ //! query planning for link derivative computation. -//use super::Link; +use super::Link; use std::collections::BTreeMap; /* @@ -34,35 +34,86 @@ fn parents(drvs: &[u8]) -> impl Iterator> + '_ { }) } +#[derive(Clone, Debug, Default)] +struct VoteEntry { + votes: u16, + best_parent: u16, + is_used: bool, +} + /// calculate votes for each possible parent -fn calculate_votes(mut que: Vec>) -> Vec, u16>> { +fn calculate_votes(mut que: Vec>) -> Vec, VoteEntry>> { assert!(que.len() <= usize::from(u16::MAX)); // split que into levels let mut levels = core::iter::repeat_with(BTreeMap::new) .take(que.iter().map(|i| sum(&i[..])).max().map(|i| i.checked_add(1).unwrap()).unwrap_or(0)) - .collect::, u16>>>(); + .collect::, VoteEntry>>>(); for i in core::mem::replace(&mut que, Vec::new()) { let mut l = &mut levels[sum(&i[..])]; - l.insert(i, 1); + l.insert(i, VoteEntry { + votes: 1, + best_parent: 0, + is_used: true, + }); } // process levels (start from largest) for (lid, i) in levels.iter_mut().enumerate().rev() { // insert queued items into this level for j in core::mem::replace(&mut que, Vec::new()) { - assert_eq!(lid, sum(&j[..])); let mut v = i.entry(j).or_default(); - *v = v.checked_add(1).unwrap(); + v.votes = v.votes.checked_add(1).unwrap(); } // handle all items in this level - for (j, v) in i.iter_mut() { + for j in i.keys() { + assert_eq!(lid, sum(&j[..])); que.extend(parents(&j[..])); } } + // calculate best parent for each entry + for lid in (0..levels.len() - 1).rev() { + let mut taken = core::mem::take(&mut levels[lid + 1]); + let plvl = &mut levels[lid]; + taken.retain(|_, j| j.is_used); + + for (j, v) in taken.iter_mut() { + let mut pbs = None; + for k in parents(j) { + pbs = Some(match pbs { + None => k, + Some(x) => { + if plvl[&k].votes > plvl[&x].votes { + k + } else { + x + } + }, + }); + } + let pbs = pbs.unwrap(); + let pbs_index = { + let mut pbsi = j.iter().zip(pbs.iter()).map(|(i, p)| i - p).enumerate().filter(|(_, i)| *i > 0); + let pbs_index = pbsi.next().unwrap(); + assert_eq!(pbsi.next(), None); + pbs_index + }; + assert_eq!(pbs_index.1, 1); + + v.best_parent = pbs_index.0.try_into().unwrap(); + + // propagate dependency + if v.is_used { + plvl.get_mut(&pbs).unwrap().is_used = true; + } + } + + levels[lid + 1] = taken; + } + levels } @@ -70,27 +121,36 @@ fn calculate_votes(mut que: Vec>) -> Vec, u16>> { /// and compute the parent /// /// maximize reuse of the same derivatives -/* -pub fn fullfill_links(que: Vec) -> HashMap, u16)>> { - let mut g = HashMap::>::new(); +/// +/// return value is structures as: `[root variable] [level] [.] -> (derivative, best_parent)` +pub fn fullfill_links(que: Vec) -> Vec, u16)>>> { + let mut g = BTreeMap::>::new(); // sort the links by root variable - for (vid, drv) in que { + let mut max_vid = 1; + for Link(vid, drv) in que { g.entry(vid).or_default().push(drv); + if usize::from(vid) >= max_vid { + max_vid = usize::from(vid) + 1; + } } - for (vid, que) in que { - let votes = calculate_votes(que.clone()); + let mut ret = vec![Vec::new(); max_vid]; - for i in que { - - for j in parents(i) { - + for (vid, que) in g { + let mut votes = calculate_votes(que.clone()); + let mut ret = &mut ret[usize::from(vid)]; + ret.resize_with(votes.len(), || Vec::new()); + + for (i, j) in ret.iter_mut().zip(votes.into_iter()) { + for (k, l) in j.into_iter() { + i.push((k, l.best_parent)); } } } + + ret } -*/ #[cfg(test)] mod tests { @@ -103,4 +163,12 @@ mod tests { vec![0, 1, 1, 1].into_boxed_slice(), ])); } + + #[test] + fn test_fullfill_links() { + insta::assert_debug_snapshot!(fullfill_links(vec![ + Link(0, vec![0, 1, 2, 0].into_boxed_slice()), + Link(0, vec![0, 1, 1, 1].into_boxed_slice()), + ])); + } } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 95b0c7a..7a2ce6d 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -14,6 +14,8 @@ use yz_string_utils::StrLexerBase; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Link(pub u16, pub Box<[u8]>); +pub use linkplan::fullfill_links; + #[derive(Clone, Debug, PartialEq)] pub enum Expr { Const(f64), diff --git a/src/expr/snapshots/pdxdrive__expr__linkplan__tests__calculate_votes.snap b/src/expr/snapshots/pdxdrive__expr__linkplan__tests__calculate_votes.snap index fbcb734..94cea76 100644 --- a/src/expr/snapshots/pdxdrive__expr__linkplan__tests__calculate_votes.snap +++ b/src/expr/snapshots/pdxdrive__expr__linkplan__tests__calculate_votes.snap @@ -9,53 +9,35 @@ expression: "calculate_votes(vec![vec![0, 1, 2, 0].into_boxed_slice(),\n 0, 0, 0, - ]: 3, - }, - { - [ - 0, - 0, - 0, - 1, - ]: 2, - [ - 0, - 0, - 1, - 0, - ]: 3, - [ - 0, - 1, - 0, - 0, - ]: 2, + ]: VoteEntry { + votes: 3, + best_parent: 0, + is_used: true, + }, }, { [ 0, 0, 1, - 1, - ]: 1, - [ 0, - 0, - 2, - 0, - ]: 1, - [ - 0, - 1, - 0, - 1, - ]: 1, + ]: VoteEntry { + votes: 3, + best_parent: 2, + is_used: true, + }, + }, + { [ 0, 1, 1, 0, - ]: 2, + ]: VoteEntry { + votes: 2, + best_parent: 1, + is_used: true, + }, }, { [ @@ -63,12 +45,20 @@ expression: "calculate_votes(vec![vec![0, 1, 2, 0].into_boxed_slice(),\n 1, 1, 1, - ]: 1, + ]: VoteEntry { + votes: 1, + best_parent: 3, + is_used: true, + }, [ 0, 1, 2, 0, - ]: 1, + ]: VoteEntry { + votes: 1, + best_parent: 2, + is_used: true, + }, }, ] diff --git a/src/expr/snapshots/pdxdrive__expr__linkplan__tests__fullfill_links.snap b/src/expr/snapshots/pdxdrive__expr__linkplan__tests__fullfill_links.snap new file mode 100644 index 0000000..82957bb --- /dev/null +++ b/src/expr/snapshots/pdxdrive__expr__linkplan__tests__fullfill_links.snap @@ -0,0 +1,61 @@ +--- +source: src/expr/linkplan.rs +expression: "fullfill_links(vec![Link(0, vec![0, 1, 2, 0].into_boxed_slice()),\n Link(0, vec![0, 1, 1, 1].into_boxed_slice()),])" +--- +[ + [ + [ + ( + [ + 0, + 0, + 0, + 0, + ], + 0, + ), + ], + [ + ( + [ + 0, + 0, + 1, + 0, + ], + 2, + ), + ], + [ + ( + [ + 0, + 1, + 1, + 0, + ], + 1, + ), + ], + [ + ( + [ + 0, + 1, + 1, + 1, + ], + 3, + ), + ( + [ + 0, + 1, + 2, + 0, + ], + 2, + ), + ], + ], +]