1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
//! Tests for cycles where the cycle head is stored on a tracked struct
//! and that tracked struct is freed in a later revision.
mod common;
use crate::common::{EventLoggerDatabase, LogDatabase};
use expect_test::expect;
use salsa::{CycleRecoveryAction, Database, Setter};
#[derive(Clone, Debug, Eq, PartialEq, Hash, salsa::Update)]
struct Graph<'db> {
nodes: Vec<Node<'db>>,
}
impl<'db> Graph<'db> {
fn find_node(&self, db: &dyn salsa::Database, name: &str) -> Option<Node<'db>> {
self.nodes
.iter()
.find(|node| node.name(db) == name)
.copied()
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct Edge {
// Index into `graph.nodes`
to: usize,
cost: usize,
}
#[salsa::tracked(debug)]
struct Node<'db> {
#[returns(ref)]
name: String,
#[returns(deref)]
#[tracked]
edges: Vec<Edge>,
graph: GraphInput,
}
#[salsa::input(debug)]
struct GraphInput {
simple: bool,
}
#[salsa::tracked(returns(ref))]
fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> {
if input.simple(db) {
let a = Node::new(db, "a".to_string(), vec![], input);
let b = Node::new(db, "b".to_string(), vec![Edge { to: 0, cost: 20 }], input);
let c = Node::new(db, "c".to_string(), vec![Edge { to: 1, cost: 2 }], input);
Graph {
nodes: vec![a, b, c],
}
} else {
// ```
// flowchart TD
//
// A("a")
// B("b")
// C("c")
// D{"d"}
//
// B -- 20 --> D
// C -- 4 --> D
// D -- 4 --> A
// D -- 4 --> B
// ```
let a = Node::new(db, "a".to_string(), vec![], input);
let b = Node::new(db, "b".to_string(), vec![Edge { to: 3, cost: 20 }], input);
let c = Node::new(db, "c".to_string(), vec![Edge { to: 3, cost: 4 }], input);
let d = Node::new(
db,
"d".to_string(),
vec![Edge { to: 0, cost: 4 }, Edge { to: 1, cost: 4 }],
input,
);
Graph {
nodes: vec![a, b, c, d],
}
}
}
/// Computes the minimum cost from the node with offset `0` to the given node.
#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)]
fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize {
let mut min_cost = usize::MAX;
let graph = create_graph(db, node.graph(db));
for edge in node.edges(db) {
if edge.to == 0 {
min_cost = min_cost.min(edge.cost);
}
let edge_cost_to_start = cost_to_start(db, graph.nodes[edge.to]);
// We hit a cycle, never take this edge because it will always be more expensive than
// any other edge
if edge_cost_to_start == usize::MAX {
continue;
}
min_cost = min_cost.min(edge.cost + edge_cost_to_start);
}
min_cost
}
fn max_initial(_db: &dyn Database, _node: Node) -> usize {
usize::MAX
}
fn cycle_recover(
_db: &dyn Database,
_value: &usize,
_count: u32,
_inputs: Node,
) -> CycleRecoveryAction<usize> {
CycleRecoveryAction::Iterate
}
#[test]
fn main() {
let mut db = EventLoggerDatabase::default();
let input = GraphInput::new(&db, false);
let graph = create_graph(&db, input);
let c = graph.find_node(&db, "c").unwrap();
// Query the cost from `c` to `a`.
// There's a cycle between `b` and `d`, where `d` becomes the cycle head and `b` is a provisional, non finalized result.
assert_eq!(cost_to_start(&db, c), 8);
// Change the graph, this will remove `d`, leaving `b` pointing to a cycle head that's now collected.
// Querying the cost from `c` to `a` should try to verify the result of `b` and it is important
// that `b` doesn't try to dereference the cycle head (because its memo is now stored on a tracked
// struct that has been freed).
input.set_simple(&mut db).to(true);
let graph = create_graph(&db, input);
let c = graph.find_node(&db, "c").unwrap();
assert_eq!(cost_to_start(&db, c), 22);
db.assert_logs(expect![[r#"
[
"WillCheckCancellation",
"WillExecute { database_key: create_graph(Id(0)) }",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(402)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(403)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(400)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillIterateCycle { database_key: cost_to_start(Id(403)), iteration_count: IterationCount(1), fell_back: false }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillExecute { database_key: create_graph(Id(0)) }",
"WillDiscardStaleOutput { execute_key: create_graph(Id(0)), output_key: Node(Id(403)) }",
"DidDiscard { key: Node(Id(403)) }",
"DidDiscard { key: cost_to_start(Id(403)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(402)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(400)) }",
"WillCheckCancellation",
]"#]]);
}
|