|
17 | 17 | get_args, |
18 | 18 | get_origin, |
19 | 19 | get_type_hints, |
| 20 | + Literal |
20 | 21 | ) |
21 | 22 |
|
22 | 23 | from docstring_parser import parse |
@@ -65,6 +66,7 @@ class PydanticDataType(Enum): |
65 | 66 | CUSTOM_CLASS = "custom-class" |
66 | 67 | CUSTOM_DICT = "custom-dict" |
67 | 68 | SET = "set" |
| 69 | + LITERAL = "literal" |
68 | 70 |
|
69 | 71 |
|
70 | 72 | def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: |
@@ -99,6 +101,10 @@ def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: |
99 | 101 | elif get_origin(pydantic_type) is dict: |
100 | 102 | key_type, value_type = get_args(pydantic_type) |
101 | 103 | return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" |
| 104 | + elif get_origin(pydantic_type) is Literal: |
| 105 | + literal_types = get_args(pydantic_type) |
| 106 | + literal_rules = [map_pydantic_type_to_gbnf(lt) for lt in literal_types] |
| 107 | + return f"literal-{'-or-'.join(literal_rules)}" |
102 | 108 | else: |
103 | 109 | return "unknown" |
104 | 110 |
|
@@ -540,6 +546,13 @@ def generate_gbnf_rule_for_type( |
540 | 546 | gbnf_type, rules = generate_gbnf_integer_rules( |
541 | 547 | max_digit=max_digits, min_digit=min_digits |
542 | 548 | ) |
| 549 | + elif gbnf_type.startswith("literal-"): |
| 550 | + literal_types = get_args(field_type) |
| 551 | + literal_types_str = [ |
| 552 | + json.dumps(lt).replace('"', '\\"') for lt in literal_types |
| 553 | + ] |
| 554 | + literal_types_str=[f'"{lt}"' for lt in literal_types_str] |
| 555 | + gbnf_type = "|".join(literal_types_str) |
543 | 556 | else: |
544 | 557 | gbnf_type, rules = gbnf_type, [] |
545 | 558 |
|
|
0 commit comments