Skip to content

Commit 7c8a3a7

Browse files
committed
Add support for literals
1 parent c0f1f03 commit 7c8a3a7

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/llama_cpp_agent/gbnf_grammar_generator/gbnf_grammar_from_pydantic_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_args,
1818
get_origin,
1919
get_type_hints,
20+
Literal
2021
)
2122

2223
from docstring_parser import parse
@@ -65,6 +66,7 @@ class PydanticDataType(Enum):
6566
CUSTOM_CLASS = "custom-class"
6667
CUSTOM_DICT = "custom-dict"
6768
SET = "set"
69+
LITERAL = "literal"
6870

6971

7072
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
@@ -99,6 +101,10 @@ def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
99101
elif get_origin(pydantic_type) is dict:
100102
key_type, value_type = get_args(pydantic_type)
101103
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
104+
elif get_origin(pydantic_type) is Literal:
105+
literal_types = get_args(pydantic_type)
106+
literal_rules = [map_pydantic_type_to_gbnf(lt) for lt in literal_types]
107+
return f"literal-{'-or-'.join(literal_rules)}"
102108
else:
103109
return "unknown"
104110

@@ -540,6 +546,13 @@ def generate_gbnf_rule_for_type(
540546
gbnf_type, rules = generate_gbnf_integer_rules(
541547
max_digit=max_digits, min_digit=min_digits
542548
)
549+
elif gbnf_type.startswith("literal-"):
550+
literal_types = get_args(field_type)
551+
literal_types_str = [
552+
json.dumps(lt).replace('"', '\\"') for lt in literal_types
553+
]
554+
literal_types_str=[f'"{lt}"' for lt in literal_types_str]
555+
gbnf_type = "|".join(literal_types_str)
543556
else:
544557
gbnf_type, rules = gbnf_type, []
545558

0 commit comments

Comments
 (0)