Skip to content

Commit 560b74f

Browse files
committed
Prototype fix
1 parent 5711c5e commit 560b74f

3 files changed

Lines changed: 137 additions & 36 deletions

File tree

src/function/fetch.rs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,19 @@ where
8282
let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
8383
if let Some(memo) = memo_guard {
8484
let database_key_index = self.database_key_index(id);
85-
if memo.value.is_some()
86-
&& (self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
87-
|| self.validate_same_iteration(db, database_key_index, memo))
88-
&& self.shallow_verify_memo(db, zalsa, database_key_index, memo)
89-
{
90-
// SAFETY: memo is present in memo_map and we have verified that it is
91-
// still valid for the current revision.
92-
return unsafe { Some(self.extend_memo_lifetime(memo)) };
85+
if memo.value.is_some() {
86+
let shallow_verify = self.shallow_verify_memo(zalsa, database_key_index, memo);
87+
88+
if shallow_verify.as_bool()
89+
&& (self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
90+
|| self.validate_same_iteration(db, database_key_index, memo))
91+
{
92+
self.update_shallow(db, zalsa, database_key_index, memo, shallow_verify);
93+
94+
// SAFETY: memo is present in memo_map and we have verified that it is
95+
// still valid for the current revision.
96+
return unsafe { Some(self.extend_memo_lifetime(memo)) };
97+
}
9398
}
9499
}
95100
None
@@ -120,10 +125,21 @@ where
120125
if let Some(memo) = memo_guard {
121126
if memo.value.is_some()
122127
&& memo.revisions.cycle_heads.contains(&database_key_index)
123-
&& self.shallow_verify_memo(db, zalsa, database_key_index, memo)
124128
{
125-
// SAFETY: memo is present in memo_map.
126-
return unsafe { Some(self.extend_memo_lifetime(memo)) };
129+
let shallow_verify =
130+
self.shallow_verify_memo(zalsa, database_key_index, memo);
131+
132+
if shallow_verify.as_bool() {
133+
self.update_shallow(
134+
db,
135+
zalsa,
136+
database_key_index,
137+
memo,
138+
shallow_verify,
139+
);
140+
// SAFETY: memo is present in memo_map.
141+
return unsafe { Some(self.extend_memo_lifetime(memo)) };
142+
}
127143
}
128144
}
129145
// no provisional value; create/insert/return initial provisional value

src/function/maybe_changed_after.rs

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin};
1111
use crate::{AsDynDatabase as _, Id, Revision};
1212

1313
/// Result of memo validation.
14+
#[derive(Debug)]
1415
pub enum VerifyResult {
1516
/// Memo has changed and needs to be recomputed.
1617
Changed,
@@ -61,9 +62,12 @@ where
6162
// Check if we have a verified version: this is the hot path.
6263
let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
6364
if let Some(memo) = memo_guard {
64-
if self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
65-
&& self.shallow_verify_memo(db, zalsa, database_key_index, memo)
65+
let shallow_result = self.shallow_verify_memo(zalsa, database_key_index, memo);
66+
if shallow_result.as_bool()
67+
&& self.validate_provisional(db, zalsa, database_key_index, memo)
6668
{
69+
self.update_shallow(db, zalsa, database_key_index, memo, shallow_result);
70+
6771
return if memo.revisions.changed_at > revision {
6872
VerifyResult::Changed
6973
} else {
@@ -177,11 +181,10 @@ where
177181
#[inline]
178182
pub(super) fn shallow_verify_memo(
179183
&self,
180-
db: &C::DbView,
181184
zalsa: &Zalsa,
182185
database_key_index: DatabaseKeyIndex,
183186
memo: &Memo<C::Output<'_>>,
184-
) -> bool {
187+
) -> ShallowVerifyResult {
185188
tracing::debug!(
186189
"{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})",
187190
memo = memo.tracing_debug()
@@ -191,7 +194,7 @@ where
191194

192195
if verified_at == revision_now {
193196
// Already verified.
194-
return true;
197+
return ShallowVerifyResult::Verified;
195198
}
196199

197200
let last_changed = zalsa.last_changed_revision(memo.revisions.durability);
@@ -204,17 +207,32 @@ where
204207
);
205208
if last_changed <= verified_at {
206209
// No input of the suitable durability has changed since last verified.
210+
ShallowVerifyResult::HigherDurability
211+
} else {
212+
ShallowVerifyResult::MaybeChanged
213+
}
214+
}
215+
216+
pub(super) fn update_shallow(
217+
&self,
218+
db: &C::DbView,
219+
zalsa: &Zalsa,
220+
database_key_index: DatabaseKeyIndex,
221+
memo: &Memo<C::Output<'_>>,
222+
verify_result: ShallowVerifyResult,
223+
) {
224+
if verify_result == ShallowVerifyResult::HigherDurability {
225+
let revision_now = zalsa.current_revision();
226+
207227
memo.mark_as_verified(
208228
db,
209229
revision_now,
210230
database_key_index,
211231
memo.revisions.accumulated_inputs.load(),
212232
);
233+
213234
memo.mark_outputs_as_verified(zalsa, db.as_dyn_database(), database_key_index);
214-
return true;
215235
}
216-
217-
false
218236
}
219237

220238
/// Validates this memo if it is a provisional memo. Returns true for non provisional memos or
@@ -311,9 +329,12 @@ where
311329
old_memo = old_memo.tracing_debug()
312330
);
313331

314-
if self.validate_may_be_provisional(db, zalsa, database_key_index, old_memo)
315-
&& self.shallow_verify_memo(db, zalsa, database_key_index, old_memo)
332+
let shallow_result = self.shallow_verify_memo(zalsa, database_key_index, old_memo);
333+
if shallow_result.as_bool()
334+
&& self.validate_provisional(db, zalsa, database_key_index, old_memo)
316335
{
336+
self.update_shallow(db, zalsa, database_key_index, old_memo, shallow_result);
337+
317338
return VerifyResult::unchanged();
318339
}
319340

@@ -339,7 +360,9 @@ where
339360
VerifyResult::Changed
340361
}
341362
QueryOrigin::Derived(edges) => {
342-
if old_memo.may_be_provisional() {
363+
let is_provisional = old_memo.may_be_provisional();
364+
// If the value is from the same revision but is still provisional, consider it changed
365+
if shallow_result.as_bool() && is_provisional {
343366
return VerifyResult::Changed;
344367
}
345368

@@ -433,14 +456,17 @@ where
433456
);
434457

435458
if in_heads {
436-
// Iterate our dependency graph again, starting from the top. We clear the
437-
// cycle heads here because we are starting a fresh traversal. (It might be
438-
// logically clearer to create a new HashSet each time, but clearing the
439-
// existing one is more efficient.)
440-
cycle_heads.clear();
459+
if is_provisional {
460+
old_memo
461+
.revisions
462+
.verified_final
463+
.store(true, Ordering::Relaxed);
464+
}
465+
441466
continue 'cycle;
442467
}
443468
}
469+
444470
break 'cycle VerifyResult::Unchanged(
445471
InputAccumulatedValues::Empty,
446472
CycleHeads::from(cycle_heads),
@@ -450,3 +476,24 @@ where
450476
}
451477
}
452478
}
479+
480+
#[derive(Copy, Clone, Eq, PartialEq)]
481+
pub(super) enum ShallowVerifyResult {
482+
/// The memo is from this revision and has already been verified
483+
Verified,
484+
485+
/// The revision for the memo's durability hasn't changed. It can be marked as verified
486+
HigherDurability,
487+
488+
/// The memo might have changed if any of its inputs have changed
489+
MaybeChanged,
490+
}
491+
492+
impl ShallowVerifyResult {
493+
pub(super) fn as_bool(self) -> bool {
494+
match self {
495+
Self::Verified | Self::HigherDurability => true,
496+
Self::MaybeChanged => false,
497+
}
498+
}
499+
}

tests/cycle_tracked.rs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//! Tests for cycles where the cycle head is stored on a tracked struct.
1+
//! Tests for cycles where the cycle head is stored on a tracked struct
2+
//! and that tracked struct is freed in a later revision.
23
34
mod common;
45

@@ -92,9 +93,7 @@ fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize {
9293

9394
for edge in node.edges(db) {
9495
if edge.to == 0 {
95-
if min_cost.is_none_or(|min| min > edge.cost) {
96-
min_cost = Some(edge.cost);
97-
}
96+
min_cost = min_cost.min(Some(edge.cost));
9897
}
9998

10099
let edge_cost_to_start = cost_to_start(db, graph.nodes[edge.to]);
@@ -106,9 +105,7 @@ fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize {
106105
}
107106

108107
let total_cost = edge.cost + edge_cost_to_start;
109-
if min_cost.is_none_or(|min| min > total_cost) {
110-
min_cost = Some(total_cost);
111-
}
108+
min_cost = min_cost.min(Some(total_cost));
112109
}
113110

114111
// If `None`, it means that there's no path from `node` to the start.
@@ -132,7 +129,7 @@ fn cycle_recover(
132129
fn main() {
133130
let mut db = EventLoggerDatabase::default();
134131

135-
let input = GraphInput::new(&mut db, false);
132+
let input = GraphInput::new(&db, false);
136133
let graph = create_graph(&db, input);
137134
let c = graph.find_node(&db, "c").unwrap();
138135

@@ -151,5 +148,46 @@ fn main() {
151148

152149
assert_eq!(cost_to_start(&db, c), 22);
153150

154-
db.assert_logs(expect![[r#""#]]);
151+
db.assert_logs(expect![[r#"
152+
[
153+
"WillCheckCancellation",
154+
"WillExecute { database_key: create_graph(Id(0)) }",
155+
"WillCheckCancellation",
156+
"WillExecute { database_key: cost_to_start(Id(402)) }",
157+
"WillCheckCancellation",
158+
"WillCheckCancellation",
159+
"WillExecute { database_key: cost_to_start(Id(403)) }",
160+
"WillCheckCancellation",
161+
"WillCheckCancellation",
162+
"WillExecute { database_key: cost_to_start(Id(400)) }",
163+
"WillCheckCancellation",
164+
"WillCheckCancellation",
165+
"WillExecute { database_key: cost_to_start(Id(401)) }",
166+
"WillCheckCancellation",
167+
"WillCheckCancellation",
168+
"WillCheckCancellation",
169+
"WillCheckCancellation",
170+
"WillCheckCancellation",
171+
"WillExecute { database_key: cost_to_start(Id(401)) }",
172+
"WillCheckCancellation",
173+
"WillCheckCancellation",
174+
"DidSetCancellationFlag",
175+
"WillCheckCancellation",
176+
"WillExecute { database_key: create_graph(Id(0)) }",
177+
"WillDiscardStaleOutput { execute_key: create_graph(Id(0)), output_key: Node(Id(403)) }",
178+
"DidDiscard { key: Node(Id(403)) }",
179+
"DidDiscard { key: cost_to_start(Id(403)) }",
180+
"WillCheckCancellation",
181+
"WillCheckCancellation",
182+
"WillExecute { database_key: cost_to_start(Id(402)) }",
183+
"WillCheckCancellation",
184+
"WillCheckCancellation",
185+
"WillCheckCancellation",
186+
"WillExecute { database_key: cost_to_start(Id(401)) }",
187+
"WillCheckCancellation",
188+
"WillCheckCancellation",
189+
"WillCheckCancellation",
190+
"WillExecute { database_key: cost_to_start(Id(400)) }",
191+
"WillCheckCancellation",
192+
]"#]]);
155193
}

0 commit comments

Comments
 (0)