Skip to content

Commit f638117

Browse files
authored
Merge pull request #1193 from Kiln-AI/KIL-502/proj-export
Kiln project export broken with skills
2 parents 5109268 + 0ccbc7a commit f638117

File tree

2 files changed

+329
-4
lines changed

2 files changed

+329
-4
lines changed

libs/core/kiln_ai/cli/commands/package_project.py

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
from kiln_ai.datamodel.prompt import Prompt
1717
from kiln_ai.datamodel.prompt_id import PromptGenerators
1818
from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties
19+
from kiln_ai.datamodel.skill import Skill
1920
from kiln_ai.datamodel.task import TaskRunConfig
2021
from kiln_ai.datamodel.tool_id import (
2122
KILN_TASK_TOOL_ID_PREFIX,
2223
MCP_LOCAL_TOOL_ID_PREFIX,
2324
MCP_REMOTE_TOOL_ID_PREFIX,
2425
RAG_TOOL_ID_PREFIX,
26+
SKILL_TOOL_ID_PREFIX,
2527
KilnBuiltInToolId,
2628
kiln_task_server_id_from_tool_id,
2729
mcp_server_and_tool_name_from_id,
30+
skill_id_from_tool_id,
2831
)
2932

3033
console = Console()
@@ -232,10 +235,12 @@ def collect_subtask_ids_from_tools(
232235

233236
def classify_tool_id(
234237
tool_id: str,
235-
) -> Literal["builtin", "kiln_task", "mcp_remote", "mcp_local", "rag", "unknown"]:
238+
) -> Literal[
239+
"builtin", "kiln_task", "mcp_remote", "mcp_local", "rag", "skill", "unknown"
240+
]:
236241
"""Classify a tool ID into its type category.
237242
238-
Returns one of: 'builtin', 'kiln_task', 'mcp_remote', 'mcp_local', 'rag', 'unknown'
243+
Returns one of: 'builtin', 'kiln_task', 'mcp_remote', 'mcp_local', 'rag', 'skill', 'unknown'
239244
"""
240245
if tool_id in [member.value for member in KilnBuiltInToolId]:
241246
return "builtin"
@@ -247,6 +252,8 @@ def classify_tool_id(
247252
return "mcp_local"
248253
elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
249254
return "rag"
255+
elif tool_id.startswith(SKILL_TOOL_ID_PREFIX):
256+
return "skill"
250257
else:
251258
return "unknown"
252259

@@ -293,6 +300,9 @@ def validate_tools(tasks: list[Task], run_configs: dict[str, TaskRunConfig]) ->
293300
has_mcp_local = True
294301
mcp_local_task_names.append(task.name)
295302
pass
303+
elif tool_type == "skill":
304+
# Skills are exported separately
305+
pass
296306
elif tool_type == "rag":
297307
console.print(f"[red]Error:[/red] Task '{task.name}' uses a RAG tool.")
298308
console.print(
@@ -367,6 +377,77 @@ def collect_required_tool_servers(
367377
return server_ids
368378

369379

380+
def collect_required_skills(
381+
tasks: list[Task], run_configs: dict[str, TaskRunConfig]
382+
) -> set[str]:
383+
"""Collect the IDs of skills needed by the tasks.
384+
385+
Args:
386+
tasks: List of tasks to check
387+
run_configs: Dictionary mapping task IDs to their run configs
388+
389+
Returns:
390+
Set of skill IDs that need to be exported
391+
"""
392+
skill_ids: set[str] = set()
393+
394+
for task in tasks:
395+
run_config = run_configs.get(task.id) # type: ignore
396+
if not run_config:
397+
continue
398+
399+
tools = get_tools_from_run_config(run_config)
400+
for tool_id in tools:
401+
tool_type = classify_tool_id(tool_id)
402+
if tool_type == "skill":
403+
skill_id = skill_id_from_tool_id(tool_id)
404+
skill_ids.add(skill_id)
405+
406+
return skill_ids
407+
408+
409+
def export_skills(
410+
skill_ids: set[str],
411+
project: Project,
412+
exported_project: Project,
413+
) -> None:
414+
"""Export skills needed by the tasks.
415+
416+
Copies each skill's entire directory (skill.kiln, SKILL.md, references/, assets/)
417+
to the exported project.
418+
419+
Args:
420+
skill_ids: Set of skill IDs to export
421+
project: The original project
422+
exported_project: The exported project to copy skills into
423+
"""
424+
if not skill_ids:
425+
return
426+
427+
if exported_project.path is None:
428+
raise ValueError("Exported project path is not set")
429+
430+
skills_by_id = {skill.id: skill for skill in project.skills() if skill.id}
431+
missing_skill_ids = skill_ids - set(skills_by_id)
432+
if missing_skill_ids:
433+
raise ValueError(
434+
"Skill ID(s) referenced by exported tasks were not found in the project: "
435+
+ ", ".join(sorted(missing_skill_ids))
436+
)
437+
438+
for skill_id in skill_ids:
439+
skill = skills_by_id[skill_id]
440+
if skill.path is None:
441+
raise ValueError(f"Skill '{skill.name}' path is not set")
442+
443+
folder_name = skill.path.parent.name
444+
dest_dir = exported_project.path.parent / "skills" / folder_name
445+
shutil.copytree(skill.path.parent, dest_dir, dirs_exist_ok=True)
446+
447+
exported_skill = Skill.load_from_file(dest_dir / skill.path.name)
448+
exported_skill.parent = exported_project
449+
450+
370451
def is_dynamic_prompt(prompt_id: str) -> bool:
371452
"""Check if a prompt ID refers to a dynamic prompt generator."""
372453
return prompt_id in DYNAMIC_PROMPT_GENERATORS
@@ -729,8 +810,9 @@ def package_project(
729810
validate_tools(validated_tasks, run_configs)
730811
console.print("[green]✓[/green] Validated tools")
731812

732-
# 6. Collect required tool servers
813+
# 6. Collect required tool servers and skills
733814
required_server_ids = collect_required_tool_servers(validated_tasks, run_configs)
815+
required_skill_ids = collect_required_skills(validated_tasks, run_configs)
734816

735817
# 7. Build and validate prompts
736818
task_prompts = validate_and_build_prompts(validated_tasks, run_configs)
@@ -781,7 +863,18 @@ def package_project(
781863
f"[green]✓[/green] Exported {len(required_server_ids)} tool server(s)"
782864
)
783865

784-
# 6. Create zip file
866+
# 6. Export required skills
867+
try:
868+
export_skills(required_skill_ids, project, exported_project)
869+
except ValueError as e:
870+
console.print(f"[red]Error exporting skills: {e}")
871+
raise typer.Exit(code=1)
872+
if required_skill_ids:
873+
console.print(
874+
f"[green]✓[/green] Exported {len(required_skill_ids)} skill(s)"
875+
)
876+
877+
# 7. Create zip file
785878
create_zip(temp_dir, output)
786879
console.print(f"[green]✓[/green] Created zip file: {output}")
787880

@@ -949,6 +1042,7 @@ def package_project_for_training(
9491042
run_configs[task.id] = run_config # type: ignore
9501043

9511044
required_server_ids = collect_required_tool_servers(validated_tasks, run_configs)
1045+
required_skill_ids = collect_required_skills(validated_tasks, run_configs)
9521046

9531047
task_prompts = validate_and_build_prompts_noncli(validated_tasks, run_configs)
9541048

@@ -975,6 +1069,7 @@ def package_project_for_training(
9751069
validate_exported_prompts(task_prompts, exported_tasks, exported_run_configs)
9761070

9771071
export_tool_servers(required_server_ids, project, exported_project)
1072+
export_skills(required_skill_ids, project, exported_project)
9781073

9791074
for task in validated_tasks:
9801075
exported_task = exported_tasks[task.id] # type: ignore

0 commit comments

Comments
 (0)