Skip to content

Commit b5e205c

Browse files
committed
fix(knowledge): wire category_ids into store + add name resolution
- Fix _op_store source INSERT to proper UPSERT on (org_id, name) - Include category_ids from inputs.categories in source UPSERT - Add _resolve_categories() helper for name -> UUID resolution - Add knowledge_entities table DDL for category entity storage - Add unique index on knowledge_sources(org_id, name) - Integrate category resolution into search, recall, forget ops - Make _build_where_clause async for category resolution - Add unit tests for category resolution and DDL changes
1 parent 92f8bfa commit b5e205c

3 files changed

Lines changed: 269 additions & 51 deletions

File tree

src/workflows_mcp/engine/executors_knowledge.py

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,51 @@ def _create_config(self, inputs: KnowledgeInput) -> Any:
350350
password=inputs.password,
351351
)
352352

353+
# ------------------------------------------------------------------
354+
# Helpers
355+
# ------------------------------------------------------------------
356+
357+
async def _resolve_categories(
358+
self,
359+
categories: list[str],
360+
org_id: str,
361+
backend: Any,
362+
) -> list[str]:
363+
"""Resolve category names or UUIDs to a list of UUID strings.
364+
365+
For each entry in categories:
366+
- If it's a valid UUID, use as-is.
367+
- Otherwise, look up by (org_id, entity_type='category', name) in
368+
knowledge_entities. Auto-create the entity if missing.
369+
"""
370+
resolved: list[str] = []
371+
for entry in categories:
372+
# Try parsing as UUID first
373+
try:
374+
uuid.UUID(entry)
375+
resolved.append(entry)
376+
continue
377+
except ValueError:
378+
pass
379+
380+
# Name-based resolution: upsert into knowledge_entities
381+
result = await backend.query(
382+
"""
383+
INSERT INTO knowledge_entities (id, org_id, entity_type, name)
384+
VALUES ($1::uuid, $2::uuid, 'category', $3)
385+
ON CONFLICT (org_id, entity_type, name) DO UPDATE
386+
SET name = EXCLUDED.name
387+
RETURNING id
388+
""",
389+
(str(uuid.uuid4()), org_id, entry),
390+
)
391+
if result.rows:
392+
resolved.append(str(result.rows[0]["id"]))
393+
else:
394+
logger.warning("Category resolution returned no rows for %r", entry)
395+
396+
return resolved
397+
353398
# ------------------------------------------------------------------
354399
# Operation Handlers
355400
# ------------------------------------------------------------------
@@ -366,6 +411,11 @@ async def _op_search(
366411
limit = int(inputs.limit) if isinstance(inputs.limit, str) else inputs.limit
367412
org_id = inputs.org_id or str(uuid.UUID(int=0))
368413

414+
# Resolve category names to UUIDs if provided
415+
resolved_categories = None
416+
if inputs.categories:
417+
resolved_categories = await self._resolve_categories(inputs.categories, org_id, backend)
418+
369419
# Compute query embedding
370420
embedding, _, _, _ = await compute_embedding(
371421
text=inputs.query,
@@ -378,7 +428,7 @@ async def _op_search(
378428
org_id=org_id,
379429
query_embedding=embedding,
380430
source=inputs.source,
381-
categories=inputs.categories,
431+
categories=resolved_categories,
382432
min_confidence=inputs.min_confidence,
383433
lifecycle_state=inputs.lifecycle_state,
384434
limit=limit,
@@ -391,7 +441,7 @@ async def _op_search(
391441
org_id=org_id,
392442
query_text=inputs.query,
393443
source=inputs.source,
394-
categories=inputs.categories,
444+
categories=resolved_categories,
395445
min_confidence=inputs.min_confidence,
396446
lifecycle_state=inputs.lifecycle_state,
397447
limit=limit,
@@ -446,6 +496,11 @@ async def _op_store(
446496
org_id = inputs.org_id or str(uuid.UUID(int=0))
447497
prop_id = str(uuid.uuid4())
448498

499+
# Resolve category names to UUIDs if provided
500+
category_ids: list[str] = []
501+
if inputs.categories:
502+
category_ids = await self._resolve_categories(inputs.categories, org_id, backend)
503+
449504
# Compute embedding for the content
450505
embedding, model_name, dimensions, _ = await compute_embedding(
451506
text=inputs.content,
@@ -456,20 +511,22 @@ async def _op_store(
456511
# Find or create source if specified
457512
item_id: str | None = None
458513
if inputs.source:
459-
# Upsert source
460-
source_id = str(uuid.uuid4())
461-
await backend.execute(
514+
# Upsert source with category_ids, using unique (org_id, name)
515+
source_result = await backend.query(
462516
"""
463-
INSERT INTO knowledge_sources (id, org_id, name, source_type)
464-
VALUES ($1::uuid, $2::uuid, $3, 'WORKFLOW')
465-
ON CONFLICT DO NOTHING
517+
INSERT INTO knowledge_sources
518+
(id, org_id, name, source_type, category_ids)
519+
VALUES ($1::uuid, $2::uuid, $3, 'WORKFLOW', $4::uuid[])
520+
ON CONFLICT (org_id, name) DO UPDATE SET
521+
category_ids = CASE
522+
WHEN EXCLUDED.category_ids != '{}'
523+
THEN EXCLUDED.category_ids
524+
ELSE knowledge_sources.category_ids
525+
END,
526+
updated_at = NOW()
527+
RETURNING id
466528
""",
467-
(source_id, org_id, inputs.source),
468-
)
469-
# Get the actual source id (may already exist)
470-
source_result = await backend.query(
471-
"SELECT id FROM knowledge_sources WHERE org_id = $1::uuid AND name = $2 LIMIT 1",
472-
(org_id, inputs.source),
529+
(str(uuid.uuid4()), org_id, inputs.source, category_ids),
473530
)
474531
if source_result.rows:
475532
actual_source_id = str(source_result.rows[0]["id"])
@@ -480,7 +537,12 @@ async def _op_store(
480537
INSERT INTO knowledge_items (id, org_id, source_id, title)
481538
VALUES ($1::uuid, $2::uuid, $3::uuid, $4)
482539
""",
483-
(item_id, org_id, actual_source_id, f"workflow-store-{prop_id[:8]}"),
540+
(
541+
item_id,
542+
org_id,
543+
actual_source_id,
544+
f"workflow-store-{prop_id[:8]}",
545+
),
484546
)
485547

486548
# Insert proposition with server-side tsvector computation
@@ -491,7 +553,8 @@ async def _op_store(
491553
authority, lifecycle_state, confidence,
492554
embedding_model, embedding_dimensions, metadata_)
493555
VALUES
494-
($1::uuid, $2::uuid, $3::uuid, $4, $5::vector, to_tsvector('english', $4),
556+
($1::uuid, $2::uuid, $3::uuid, $4, $5::vector,
557+
to_tsvector('english', $4),
495558
$6, $7, $8,
496559
$9, $10, $11::jsonb)
497560
""",
@@ -516,11 +579,12 @@ async def _op_store(
516579
stored_count=1,
517580
)
518581

519-
def _build_where_clause(
520-
self, inputs: KnowledgeInput, org_id: str
582+
async def _build_where_clause(
583+
self, inputs: KnowledgeInput, org_id: str, backend: Any
521584
) -> tuple[str, str, list[Any]]:
522585
"""Build WHERE and JOIN clauses from recall/forget filters.
523586
587+
Resolves category names to UUIDs if needed.
524588
Returns (where_clause, join_clause, params).
525589
"""
526590
params: list[Any] = []
@@ -549,9 +613,11 @@ def next_param(value: Any) -> str:
549613
where_clauses.append(f"kp.lifecycle_state = {next_param(state)}")
550614
if "category" in inputs.where:
551615
needs_join = True
552-
where_clauses.append(
553-
f"ks.category_ids && ARRAY[{next_param(inputs.where['category'])}]::uuid[]"
554-
)
616+
cat_value = inputs.where["category"]
617+
# Resolve name to UUID if needed
618+
resolved = await self._resolve_categories([cat_value], org_id, backend)
619+
cat_uuid = resolved[0] if resolved else cat_value
620+
where_clauses.append(f"ks.category_ids && ARRAY[{next_param(cat_uuid)}]::uuid[]")
555621
if "min_confidence" in inputs.where:
556622
where_clauses.append(
557623
f"kp.confidence >= {next_param(float(inputs.where['min_confidence']))}"
@@ -589,7 +655,7 @@ async def _op_recall(
589655
org_id = inputs.org_id or str(uuid.UUID(int=0))
590656
limit = int(inputs.limit) if isinstance(inputs.limit, str) else inputs.limit
591657

592-
where_sql, join_clause, params = self._build_where_clause(inputs, org_id)
658+
where_sql, join_clause, params = await self._build_where_clause(inputs, org_id, backend)
593659

594660
# Order clause
595661
order_clause = "ORDER BY kp.created_at DESC"
@@ -699,7 +765,7 @@ async def _op_forget(
699765
)
700766

701767
# Path 2: Archive by filter
702-
where_sql, join_clause, params = self._build_where_clause(inputs, org_id)
768+
where_sql, join_clause, params = await self._build_where_clause(inputs, org_id, backend)
703769

704770
# Count total matching (for skipped_count calculation)
705771
count_sql = f"""

src/workflows_mcp/engine/knowledge/schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,30 @@
5757
);
5858
"""
5959

60+
_CREATE_KNOWLEDGE_ENTITIES = """
61+
CREATE TABLE IF NOT EXISTS knowledge_entities (
62+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
63+
org_id UUID NOT NULL,
64+
entity_type VARCHAR(50) NOT NULL,
65+
name VARCHAR(500) NOT NULL,
66+
properties JSONB DEFAULT '{}'::jsonb,
67+
confidence FLOAT DEFAULT 1.0,
68+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
69+
);
70+
"""
71+
6072
_CREATE_INDEXES = """
6173
CREATE INDEX IF NOT EXISTS idx_kp_org_id ON knowledge_propositions(org_id);
6274
CREATE INDEX IF NOT EXISTS idx_kp_lifecycle ON knowledge_propositions(lifecycle_state);
6375
CREATE INDEX IF NOT EXISTS idx_kp_item_id ON knowledge_propositions(item_id);
6476
CREATE INDEX IF NOT EXISTS idx_kp_search_vector ON knowledge_propositions USING gin(search_vector);
6577
CREATE INDEX IF NOT EXISTS idx_ks_org_id ON knowledge_sources(org_id);
6678
CREATE INDEX IF NOT EXISTS idx_ks_category_ids ON knowledge_sources USING gin(category_ids);
79+
CREATE UNIQUE INDEX IF NOT EXISTS idx_ks_org_name ON knowledge_sources(org_id, name);
6780
CREATE INDEX IF NOT EXISTS idx_ki_org_id ON knowledge_items(org_id);
6881
CREATE INDEX IF NOT EXISTS idx_ki_source_id ON knowledge_items(source_id);
82+
CREATE INDEX IF NOT EXISTS idx_ke_org_id ON knowledge_entities(org_id);
83+
CREATE UNIQUE INDEX IF NOT EXISTS idx_ke_org_type_name ON knowledge_entities(org_id, entity_type, name);
6984
"""
7085

7186

@@ -81,6 +96,7 @@ def get_init_schema_sql() -> str:
8196
_CREATE_KNOWLEDGE_SOURCES,
8297
_CREATE_KNOWLEDGE_ITEMS,
8398
_CREATE_KNOWLEDGE_PROPOSITIONS,
99+
_CREATE_KNOWLEDGE_ENTITIES,
84100
_CREATE_INDEXES,
85101
]
86102
)

0 commit comments

Comments
 (0)