Skip to content

Commit ce120c7

Browse files
committed
feat(knowledge): introduce junction table for per-proposition category classification
Replaces the JOIN-based category filter (through knowledge_sources) with a proper many-to-many junction table (knowledge_proposition_categories), fixing the silent exclusion of agent observations (item_id IS NULL) from category- filtered queries. Key changes: - schema.py: migration v8 creates knowledge_proposition_categories with proposition_id, category_id, assigned_by, assigned_at; backfills existing document-derived propositions from source.category_ids (INHERITED) - search.py: category filter switched from JOIN to EXISTS subquery in both build_vector_search_query and build_fts_search_query; eliminates duplicate rows when a proposition matches multiple categories - executors_knowledge.py: _op_store writes EXPLICIT junction rows after INSERT; _build_where_clause returns (where, params) — no longer returns join_clause; source re-categorisation propagates INHERITED rows only - tests: updated for new two-value _build_where_clause return, EXISTS subquery assertions, explicit junction row tests, and mock fixes
1 parent 503b25b commit ce120c7

4 files changed

Lines changed: 339 additions & 156 deletions

File tree

src/workflows_mcp/engine/executors_knowledge.py

Lines changed: 86 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,9 @@ async def _op_store(
601601
source_type: str = inputs.source_type
602602

603603
if inputs.source and inputs.path:
604-
# Full provenance: upsert source + item, link proposition
604+
# Full provenance: upsert source + item, link proposition.
605+
# RETURNING also surfaces the pre-update category_ids so we can propagate
606+
# any re-categorization to existing INHERITED junction rows.
605607
source_name = inputs.source
606608
source_result = await backend.query(
607609
"""
@@ -615,7 +617,9 @@ async def _op_store(
615617
ELSE knowledge_sources.category_ids
616618
END,
617619
updated_at = NOW()
618-
RETURNING id
620+
RETURNING id,
621+
(SELECT category_ids FROM knowledge_sources
622+
WHERE name = $2) AS old_category_ids
619623
""",
620624
(str(uuid.uuid4()), inputs.source, category_ids),
621625
)
@@ -688,6 +692,55 @@ async def _op_store(
688692
),
689693
)
690694

695+
# Write category junction rows (EXPLICIT — caller specified these categories)
696+
if category_ids:
697+
for cat_id in category_ids:
698+
await backend.execute(
699+
"""
700+
INSERT INTO knowledge_proposition_categories
701+
(proposition_id, category_id, assigned_by)
702+
VALUES ($1::uuid, $2::uuid, 'EXPLICIT')
703+
ON CONFLICT (proposition_id, category_id) DO NOTHING
704+
""",
705+
(prop_id, cat_id),
706+
)
707+
708+
# Propagate source re-categorization to existing INHERITED junction rows.
709+
# Only runs when the source UPSERT actually changed category_ids.
710+
# EXPLICIT and INFERRED rows on the same propositions are never touched.
711+
if inputs.source and inputs.path and source_result.rows and category_ids:
712+
old_cat_ids = source_result.rows[0].get("old_category_ids") or []
713+
old_cats = {str(c) for c in old_cat_ids}
714+
new_cats = set(category_ids)
715+
if old_cats != new_cats:
716+
removed = old_cats - new_cats
717+
if removed:
718+
await backend.execute(
719+
"""
720+
DELETE FROM knowledge_proposition_categories
721+
WHERE category_id = ANY($1::uuid[])
722+
AND assigned_by = 'INHERITED'
723+
AND proposition_id IN (
724+
SELECT id FROM knowledge_propositions
725+
WHERE source_name = $2
726+
)
727+
""",
728+
(list(removed), inputs.source),
729+
)
730+
added = new_cats - old_cats
731+
for cat_id in added:
732+
await backend.execute(
733+
"""
734+
INSERT INTO knowledge_proposition_categories
735+
(proposition_id, category_id, assigned_by)
736+
SELECT id, $1::uuid, 'INHERITED'
737+
FROM knowledge_propositions
738+
WHERE source_name = $2
739+
ON CONFLICT (proposition_id, category_id) DO NOTHING
740+
""",
741+
(cat_id, inputs.source),
742+
)
743+
691744
# SECURITY: Log to audit table
692745
await self._log_audit_entry(
693746
backend=backend,
@@ -757,15 +810,19 @@ async def _log_audit_entry(
757810

758811
async def _build_where_clause(
759812
self, inputs: KnowledgeInput, backend: Any
760-
) -> tuple[str, str, list[Any]]:
761-
"""Build WHERE and JOIN clauses from recall/forget filters.
813+
) -> tuple[str, list[Any]]:
814+
"""Build WHERE clause from recall/forget filters.
762815
763816
Resolves category names to UUIDs if needed.
764-
Returns (where_clause, join_clause, params).
817+
Returns (where_clause, params).
765818
766819
Source filtering uses the denormalized source_name column for consistent,
767820
performant queries across both document-derived and agent observation
768821
propositions.
822+
823+
Category filtering uses an EXISTS subquery on knowledge_proposition_categories,
824+
which works for all proposition types (including agent observations with
825+
item_id IS NULL) and never produces duplicate rows.
769826
"""
770827
params: list[Any] = []
771828
param_idx = 0
@@ -777,7 +834,6 @@ def next_param(value: Any) -> str:
777834
return f"${param_idx}"
778835

779836
where_clauses: list[str] = []
780-
needs_join = False
781837

782838
# Handle source filter (top-level inputs.source or where.source_name)
783839
source_filter = None
@@ -789,21 +845,25 @@ def next_param(value: Any) -> str:
789845
if source_filter:
790846
if isinstance(source_filter, str) and source_filter.endswith("*"):
791847
prefix_param = next_param(source_filter[:-1] + "%")
792-
# Use source_name column directly (denormalized for performance)
793848
where_clauses.append(f"kp.source_name LIKE {prefix_param}")
794849
else:
795850
source_param = next_param(source_filter)
796-
# Use source_name column directly (denormalized for performance)
797851
where_clauses.append(f"kp.source_name = {source_param}")
798852

799-
# Handle category filter (requires JOIN to knowledge_sources)
853+
# Category filter — EXISTS subquery on junction table.
854+
# Works for agent observations (item_id IS NULL) and document propositions alike.
800855
if inputs.where and "category" in inputs.where:
801-
needs_join = True
802856
cat_value = inputs.where["category"]
803-
# Resolve name to UUID if needed
804857
resolved = await self._resolve_categories([cat_value], backend)
805858
cat_uuid = resolved[0] if resolved else cat_value
806-
where_clauses.append(f"ks.category_ids && ARRAY[{next_param(cat_uuid)}]::uuid[]")
859+
cat_param = next_param(cat_uuid)
860+
where_clauses.append(
861+
f"EXISTS ("
862+
f" SELECT 1 FROM knowledge_proposition_categories kpc"
863+
f" WHERE kpc.proposition_id = kp.id"
864+
f" AND kpc.category_id = {cat_param}::uuid"
865+
f")"
866+
)
807867

808868
# Handle lifecycle_state filter
809869
if inputs.where and "lifecycle_state" in inputs.where:
@@ -837,15 +897,8 @@ def next_param(value: Any) -> str:
837897
f"kp.created_at <= {next_param(inputs.created_before)}::timestamptz"
838898
)
839899

840-
join_clause = (
841-
"JOIN knowledge_items ki ON kp.item_id = ki.id "
842-
"JOIN knowledge_sources ks ON ki.source_id = ks.id"
843-
if needs_join
844-
else ""
845-
)
846-
847900
where_sql = " AND ".join(where_clauses) if where_clauses else "TRUE"
848-
return where_sql, join_clause, params
901+
return where_sql, params
849902

850903
async def _op_recall(
851904
self,
@@ -859,7 +912,7 @@ async def _op_recall(
859912
"""
860913
limit = int(inputs.limit) if isinstance(inputs.limit, str) else inputs.limit
861914

862-
where_sql, join_clause, params = await self._build_where_clause(inputs, backend)
915+
where_sql, params = await self._build_where_clause(inputs, backend)
863916

864917
# Order clause
865918
order_clause = "ORDER BY kp.created_at DESC"
@@ -897,7 +950,6 @@ async def _op_recall(
897950
kp.created_by, kp.auth_method
898951
FROM knowledge_propositions kp
899952
LEFT JOIN knowledge_items ki_ip ON kp.item_id = ki_ip.id
900-
{join_clause}
901953
WHERE {where_sql}
902954
{order_clause}
903955
LIMIT {limit_param}
@@ -1008,47 +1060,28 @@ async def _op_forget(
10081060
)
10091061

10101062
# Path 2: Archive by filter
1011-
where_sql, join_clause, params = await self._build_where_clause(inputs, backend)
1063+
where_sql, params = await self._build_where_clause(inputs, backend)
10121064

10131065
# Count total matching (for skipped_count calculation)
10141066
count_sql = f"""
10151067
SELECT COUNT(*) AS total FROM knowledge_propositions kp
1016-
{join_clause}
10171068
WHERE {where_sql}
10181069
"""
10191070
count_result = await backend.query(count_sql, tuple(params))
10201071
total = count_result.rows[0]["total"] if count_result.rows else 0
10211072

10221073
# Archive with USER_VALIDATED immunity
1023-
# Use subquery to handle JOIN-based filters in UPDATE
1024-
if join_clause:
1025-
param_offset = len(params)
1026-
update_sql = f"""
1027-
UPDATE knowledge_propositions
1028-
SET lifecycle_state = '{LifecycleState.ARCHIVED.value}',
1029-
archived_by = ${param_offset + 1}::uuid,
1030-
archive_reason = ${param_offset + 2}
1031-
WHERE id IN (
1032-
SELECT kp.id FROM knowledge_propositions kp
1033-
{join_clause}
1034-
WHERE {where_sql}
1035-
)
1036-
AND authority != '{Authority.USER_VALIDATED}'
1037-
RETURNING id
1038-
"""
1039-
params.extend([str(archived_by) if archived_by else None, archive_reason])
1040-
else:
1041-
param_offset = len(params)
1042-
update_sql = f"""
1043-
UPDATE knowledge_propositions kp
1044-
SET lifecycle_state = '{LifecycleState.ARCHIVED.value}',
1045-
archived_by = ${param_offset + 1}::uuid,
1046-
archive_reason = ${param_offset + 2}
1047-
WHERE {where_sql}
1048-
AND authority != '{Authority.USER_VALIDATED}'
1049-
RETURNING id
1050-
"""
1051-
params.extend([str(archived_by) if archived_by else None, archive_reason])
1074+
param_offset = len(params)
1075+
update_sql = f"""
1076+
UPDATE knowledge_propositions kp
1077+
SET lifecycle_state = '{LifecycleState.ARCHIVED.value}',
1078+
archived_by = ${param_offset + 1}::uuid,
1079+
archive_reason = ${param_offset + 2}
1080+
WHERE {where_sql}
1081+
AND authority != '{Authority.USER_VALIDATED}'
1082+
RETURNING id
1083+
"""
1084+
params.extend([str(archived_by) if archived_by else None, archive_reason])
10521085

10531086
result = await backend.query(update_sql, tuple(params))
10541087
archived = len(result.rows) if result and result.rows else 0

src/workflows_mcp/engine/knowledge/schema.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,38 @@
336336
END $$;
337337
""",
338338
),
339+
(
340+
8,
341+
"Add knowledge_proposition_categories junction table for per-proposition categories",
342+
"""
343+
CREATE TABLE IF NOT EXISTS knowledge_proposition_categories (
344+
proposition_id UUID NOT NULL
345+
REFERENCES knowledge_propositions(id) ON DELETE CASCADE,
346+
category_id UUID NOT NULL
347+
REFERENCES knowledge_entities(id) ON DELETE CASCADE,
348+
assigned_by VARCHAR(20) NOT NULL DEFAULT 'EXPLICIT',
349+
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
350+
PRIMARY KEY (proposition_id, category_id)
351+
);
352+
353+
CREATE INDEX IF NOT EXISTS idx_kpc_category_id
354+
ON knowledge_proposition_categories(category_id);
355+
CREATE INDEX IF NOT EXISTS idx_kpc_proposition_id
356+
ON knowledge_proposition_categories(proposition_id);
357+
358+
-- Backfill: inherit source-level categories onto existing document-derived propositions.
359+
-- Agent observations (item_id IS NULL) are untouched — they had no categories before
360+
-- and correctly get none unless explicitly set at store time.
361+
INSERT INTO knowledge_proposition_categories (proposition_id, category_id, assigned_by)
362+
SELECT kp.id, unnest(ks.category_ids), 'INHERITED'
363+
FROM knowledge_propositions kp
364+
JOIN knowledge_items ki ON kp.item_id = ki.id
365+
JOIN knowledge_sources ks ON ki.source_id = ks.id
366+
WHERE ks.category_ids IS NOT NULL
367+
AND ks.category_ids != '{}'
368+
ON CONFLICT (proposition_id, category_id) DO NOTHING;
369+
""",
370+
),
339371
]
340372

341373
SCHEMA_VERSION = MIGRATIONS[-1][0] if MIGRATIONS else 0

src/workflows_mcp/engine/knowledge/search.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def next_param(value: Any) -> str:
6666
f"kp.confidence >= {confidence_param}",
6767
]
6868

69-
joins = []
70-
7169
# Source filter (exact match or prefix with *)
7270
# Uses the denormalized source_name column for consistent, performant queries
7371
if source:
@@ -78,16 +76,19 @@ def next_param(value: Any) -> str:
7876
source_param = next_param(source)
7977
where_clauses.append(f"kp.source_name = {source_param}")
8078

81-
# Category filter - requires JOIN to knowledge_sources
79+
# Category filter — EXISTS subquery on junction table.
80+
# Works for all proposition types including agent observations (item_id IS NULL).
81+
# No JOIN needed; no duplicate rows when a proposition matches multiple categories.
8282
if categories:
83-
joins.append(
84-
"JOIN knowledge_items ki ON kp.item_id = ki.id "
85-
"JOIN knowledge_sources ks ON ki.source_id = ks.id"
86-
)
8783
cat_param = next_param(categories)
88-
where_clauses.append(f"ks.category_ids && {cat_param}::uuid[]")
84+
where_clauses.append(
85+
f"EXISTS ("
86+
f" SELECT 1 FROM knowledge_proposition_categories kpc"
87+
f" WHERE kpc.proposition_id = kp.id"
88+
f" AND kpc.category_id = ANY({cat_param}::uuid[])"
89+
f")"
90+
)
8991

90-
join_clause = "\n".join(joins)
9192
where_clause = " AND ".join(where_clauses)
9293

9394
embedding_col = ", kp.embedding" if include_embeddings else ""
@@ -99,7 +100,6 @@ def next_param(value: Any) -> str:
99100
ki_path.path AS item_path{embedding_col}
100101
FROM knowledge_propositions kp
101102
LEFT JOIN knowledge_items ki_path ON kp.item_id = ki_path.id
102-
{join_clause}
103103
WHERE {where_clause}
104104
AND kp.embedding IS NOT NULL
105105
ORDER BY kp.embedding <=> {embedding_param}::vector
@@ -144,8 +144,6 @@ def next_param(value: Any) -> str:
144144
f"kp.confidence >= {confidence_param}",
145145
]
146146

147-
joins = []
148-
149147
# Source filter (exact match or prefix with *)
150148
# Uses the denormalized source_name column for consistent, performant queries
151149
if source:
@@ -156,16 +154,19 @@ def next_param(value: Any) -> str:
156154
source_param = next_param(source)
157155
where_clauses.append(f"kp.source_name = {source_param}")
158156

159-
# Category filter - requires JOIN to knowledge_sources
157+
# Category filter — EXISTS subquery on junction table.
158+
# Works for all proposition types including agent observations (item_id IS NULL).
159+
# No JOIN needed; no duplicate rows when a proposition matches multiple categories.
160160
if categories:
161-
joins.append(
162-
"JOIN knowledge_items ki ON kp.item_id = ki.id "
163-
"JOIN knowledge_sources ks ON ki.source_id = ks.id"
164-
)
165161
cat_param = next_param(categories)
166-
where_clauses.append(f"ks.category_ids && {cat_param}::uuid[]")
162+
where_clauses.append(
163+
f"EXISTS ("
164+
f" SELECT 1 FROM knowledge_proposition_categories kpc"
165+
f" WHERE kpc.proposition_id = kp.id"
166+
f" AND kpc.category_id = ANY({cat_param}::uuid[])"
167+
f")"
168+
)
167169

168-
join_clause = "\n".join(joins)
169170
where_clause = " AND ".join(where_clauses)
170171

171172
sql = f"""
@@ -175,7 +176,6 @@ def next_param(value: Any) -> str:
175176
ki_path.path AS item_path
176177
FROM knowledge_propositions kp
177178
LEFT JOIN knowledge_items ki_path ON kp.item_id = ki_path.id
178-
{join_clause}
179179
WHERE {where_clause}
180180
ORDER BY fts_rank DESC
181181
LIMIT {candidate_param}

0 commit comments

Comments
 (0)