Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 279 additions & 0 deletions crates/csvizmo-depgraph/src/algorithm/between.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
use std::collections::HashSet;

use clap::Parser;
use petgraph::Direction;
use petgraph::graph::NodeIndex;

use super::{MatchKey, build_globset};
use crate::{DepGraph, FlatGraphView};

#[derive(Clone, Debug, Default, Parser)]
pub struct BetweenArgs {
/// Glob pattern selecting query endpoints (can be repeated, OR logic)
#[clap(short, long)]
pub pattern: Vec<String>,

/// Match patterns against 'id' or 'label'
#[clap(long, default_value_t = MatchKey::default())]
pub key: MatchKey,
}

impl BetweenArgs {
pub fn pattern(mut self, p: impl Into<String>) -> Self {
self.pattern.push(p.into());
self
}

pub fn key(mut self, k: MatchKey) -> Self {
self.key = k;
self
}
}

/// Extract the subgraph formed by all directed paths between any pair of matched query nodes.
///
/// For matched query nodes q1..qk, computes forward and backward reachability from each,
/// then for each ordered pair (qi, qj) collects nodes on directed paths from qi to qj
/// via `forward(qi) & backward(qj)`. The union of all pairwise results is the keep set.
pub fn between(graph: &DepGraph, args: &BetweenArgs) -> eyre::Result<DepGraph> {
let globset = build_globset(&args.pattern)?;
let view = FlatGraphView::new(graph);

// Match query nodes by glob pattern (OR logic).
let matched: Vec<NodeIndex> = graph
.all_nodes()
.iter()
.filter_map(|(id, info)| {
let text = match args.key {
MatchKey::Id => id.as_str(),
MatchKey::Label => info.label.as_str(),
};
if globset.is_match(text) {
view.id_to_idx.get(id.as_str()).copied()
} else {
None
}
})
.collect();

// Need at least 2 matched nodes to have a path between them.
if matched.len() < 2 {
return Ok(view.filter(&HashSet::new()));
}

// BFS forward and backward from each query node.
let forwards: Vec<HashSet<NodeIndex>> = matched
.iter()
.map(|&q| view.bfs([q], Direction::Outgoing, None))
.collect();
let backwards: Vec<HashSet<NodeIndex>> = matched
.iter()
.map(|&q| view.bfs([q], Direction::Incoming, None))
.collect();

// Pairwise intersect: for each pair (i, j) where i != j, nodes on directed paths
// from qi to qj are in forward(qi) & backward(qj).
let mut keep = HashSet::new();
for (i, fwd) in forwards.iter().enumerate() {
for (j, bwd) in backwards.iter().enumerate() {
if i == j {
continue;
}
for &node in fwd {
if bwd.contains(&node) {
keep.insert(node);
}
}
}
}

Ok(view.filter(&keep))
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{DepGraph, Edge, NodeInfo};

fn make_graph(
nodes: &[(&str, &str)],
edges: &[(&str, &str)],
subgraphs: Vec<DepGraph>,
) -> DepGraph {
DepGraph {
nodes: nodes
.iter()
.map(|(id, label)| (id.to_string(), NodeInfo::new(*label)))
.collect(),
edges: edges
.iter()
.map(|(from, to)| Edge {
from: from.to_string(),
to: to.to_string(),
..Default::default()
})
.collect(),
subgraphs,
..Default::default()
}
}

fn sorted_node_ids(graph: &DepGraph) -> Vec<&str> {
let mut ids: Vec<&str> = graph.nodes.keys().map(|s| s.as_str()).collect();
ids.sort();
ids
}

fn sorted_edge_pairs(graph: &DepGraph) -> Vec<(&str, &str)> {
let mut pairs: Vec<(&str, &str)> = graph
.edges
.iter()
.map(|e| (e.from.as_str(), e.to.as_str()))
.collect();
pairs.sort();
pairs
}

#[test]
fn direct_path() {
// a -> b: between a and b yields both
let g = make_graph(&[("a", "a"), ("b", "b")], &[("a", "b")], vec![]);
let args = BetweenArgs::default().pattern("a").pattern("b");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b"]);
assert_eq!(sorted_edge_pairs(&result), vec![("a", "b")]);
}

#[test]
fn intermediate_nodes() {
// a -> b -> c: between a and c includes intermediate b
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c")],
&[("a", "b"), ("b", "c")],
vec![],
);
let args = BetweenArgs::default().pattern("a").pattern("c");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c"]);
assert_eq!(sorted_edge_pairs(&result), vec![("a", "b"), ("b", "c")]);
}

#[test]
fn no_path_returns_empty() {
// a -> b, c -> d: no path between a and c
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c"), ("d", "d")],
&[("a", "b"), ("c", "d")],
vec![],
);
let args = BetweenArgs::default().pattern("a").pattern("c");
let result = between(&g, &args).unwrap();
assert!(result.nodes.is_empty());
assert!(result.edges.is_empty());
}

#[test]
fn diamond() {
// a -> b -> d, a -> c -> d: between a and d includes both paths
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c"), ("d", "d")],
&[("a", "b"), ("a", "c"), ("b", "d"), ("c", "d")],
vec![],
);
let args = BetweenArgs::default().pattern("a").pattern("d");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c", "d"]);
assert_eq!(
sorted_edge_pairs(&result),
vec![("a", "b"), ("a", "c"), ("b", "d"), ("c", "d")]
);
}

#[test]
fn multiple_query_nodes() {
// a -> b -> c -> d: between a, b, and d includes everything on paths a->b, a->d, b->d
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c"), ("d", "d")],
&[("a", "b"), ("b", "c"), ("c", "d")],
vec![],
);
let args = BetweenArgs::default()
.pattern("a")
.pattern("b")
.pattern("d");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c", "d"]);
}

#[test]
fn no_match_returns_empty() {
let g = make_graph(&[("a", "a"), ("b", "b")], &[("a", "b")], vec![]);
let args = BetweenArgs::default().pattern("nonexistent");
let result = between(&g, &args).unwrap();
assert!(result.nodes.is_empty());
assert!(result.edges.is_empty());
}

#[test]
fn single_match_returns_empty() {
// Only one node matches -- need at least 2 for a path
let g = make_graph(&[("a", "a"), ("b", "b")], &[("a", "b")], vec![]);
let args = BetweenArgs::default().pattern("a");
let result = between(&g, &args).unwrap();
assert!(result.nodes.is_empty());
assert!(result.edges.is_empty());
}

#[test]
fn cycle() {
// a -> b -> c -> a: between a and c includes all nodes in the cycle
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c")],
&[("a", "b"), ("b", "c"), ("c", "a")],
vec![],
);
let args = BetweenArgs::default().pattern("a").pattern("c");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c"]);
}

#[test]
fn match_by_id() {
let g = make_graph(&[("1", "libfoo"), ("2", "libbar")], &[("1", "2")], vec![]);
let args = BetweenArgs::default()
.pattern("1")
.pattern("2")
.key(MatchKey::Id);
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["1", "2"]);
assert_eq!(sorted_edge_pairs(&result), vec![("1", "2")]);
}

#[test]
fn excludes_unrelated_nodes() {
// a -> b -> c, d -> e: between a and c should not include d or e
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c"), ("d", "d"), ("e", "e")],
&[("a", "b"), ("b", "c"), ("d", "e")],
vec![],
);
let args = BetweenArgs::default().pattern("a").pattern("c");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c"]);
assert_eq!(sorted_edge_pairs(&result), vec![("a", "b"), ("b", "c")]);
}

#[test]
fn glob_matching_multiple_nodes() {
// a -> b -> c: glob "?" matches a, b, c -- all pairs have paths
let g = make_graph(
&[("a", "a"), ("b", "b"), ("c", "c")],
&[("a", "b"), ("b", "c")],
vec![],
);
let args = BetweenArgs::default().pattern("?");
let result = between(&g, &args).unwrap();
assert_eq!(sorted_node_ids(&result), vec!["a", "b", "c"]);
assert_eq!(sorted_edge_pairs(&result), vec![("a", "b"), ("b", "c")]);
}
}
Loading