|
1 | 1 | # frozen_string_literal: true |
2 | 2 |
|
3 | | -require "active_agent/providers/common/model" |
4 | | -require_relative "requests/_types" |
| 3 | +require "delegate" |
| 4 | +require "json" |
| 5 | +require_relative "transforms" |
5 | 6 |
|
6 | 7 | module ActiveAgent |
7 | 8 | module Providers |
8 | 9 | module OpenAI |
9 | 10 | module Chat |
10 | | - class Request < Common::BaseModel |
11 | | - # Messages array (required) |
12 | | - attribute :messages, Requests::Messages::MessagesType.new |
13 | | - |
14 | | - # Model ID (required) |
15 | | - attribute :model, :string |
16 | | - |
17 | | - # Audio output parameters |
18 | | - attribute :audio, Requests::AudioType.new |
19 | | - |
20 | | - # Frequency penalty |
21 | | - attribute :frequency_penalty, :float, default: 0 |
22 | | - |
23 | | - # Deprecated: function_call (use tool_choice instead) |
24 | | - attribute :function_call # String or object |
25 | | - |
26 | | - # Deprecated: functions (use tools instead) |
27 | | - attribute :functions # Array of function objects |
28 | | - |
29 | | - # Logit bias |
30 | | - attribute :logit_bias # Hash of token_id => bias_value |
31 | | - |
32 | | - # Log probabilities |
33 | | - attribute :logprobs, :boolean, default: false |
34 | | - |
35 | | - # Max completion tokens |
36 | | - attribute :max_completion_tokens, :integer |
37 | | - |
38 | | - # Deprecated: max_tokens (use max_completion_tokens) |
39 | | - attribute :max_tokens, :integer |
40 | | - |
41 | | - # Metadata |
42 | | - attribute :metadata # Hash of key-value pairs |
43 | | - |
44 | | - # Modalities |
45 | | - attribute :modalities, default: -> { [ "text" ] } # Array of strings |
46 | | - |
47 | | - # Number of completions |
48 | | - attribute :n, :integer, default: 1 |
49 | | - |
50 | | - # Parallel tool calls |
51 | | - attribute :parallel_tool_calls, :boolean, default: true |
52 | | - |
53 | | - # Prediction configuration |
54 | | - attribute :prediction, Requests::PredictionType.new |
55 | | - |
56 | | - # Presence penalty |
57 | | - attribute :presence_penalty, :float, default: 0 |
58 | | - |
59 | | - # Prompt cache key |
60 | | - attribute :prompt_cache_key, :string |
61 | | - |
62 | | - # Reasoning effort (for reasoning models) |
63 | | - attribute :reasoning_effort, :string |
64 | | - |
65 | | - # Response format |
66 | | - attribute :response_format, Requests::ResponseFormatType.new |
67 | | - |
68 | | - # Safety identifier |
69 | | - attribute :safety_identifier, :string |
70 | | - |
71 | | - # Deprecated: seed |
72 | | - attribute :seed, :integer |
73 | | - |
74 | | - # Service tier |
75 | | - attribute :service_tier, :string, default: "auto" |
76 | | - |
77 | | - # Stop sequences |
78 | | - attribute :stop # String, array, or null |
79 | | - |
80 | | - # Storage |
81 | | - attribute :store, :boolean, default: false |
82 | | - |
83 | | - # Streaming |
84 | | - attribute :stream, :boolean, default: false |
85 | | - attribute :stream_options, Requests::StreamOptionsType.new |
86 | | - |
87 | | - # Temperature sampling |
88 | | - attribute :temperature, :float, default: 1 |
89 | | - |
90 | | - # Tool choice |
91 | | - attribute :tool_choice, Requests::ToolChoiceType.new |
92 | | - |
93 | | - # Tools array |
94 | | - attribute :tools # Array of tool objects |
95 | | - |
96 | | - # Top logprobs |
97 | | - attribute :top_logprobs, :integer |
98 | | - |
99 | | - # Top P sampling |
100 | | - attribute :top_p, :float, default: 1 |
101 | | - |
102 | | - # Deprecated: user (use safety_identifier or prompt_cache_key) |
103 | | - attribute :user, :string |
104 | | - |
105 | | - # Verbosity (for reasoning models) |
106 | | - attribute :verbosity, :string |
107 | | - |
108 | | - # Web search options |
109 | | - attribute :web_search_options, Requests::WebSearchOptionsType.new |
110 | | - |
111 | | - # Validations |
112 | | - validates :model, :messages, presence: true |
113 | | - |
114 | | - validates :frequency_penalty, numericality: { greater_than_or_equal_to: -2.0, less_than_or_equal_to: 2.0 }, allow_nil: true |
115 | | - validates :presence_penalty, numericality: { greater_than_or_equal_to: -2.0, less_than_or_equal_to: 2.0 }, allow_nil: true |
116 | | - validates :temperature, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 2 }, allow_nil: true |
117 | | - validates :top_p, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 1 }, allow_nil: true |
118 | | - validates :top_logprobs, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 20 }, allow_nil: true |
119 | | - validates :n, numericality: { greater_than: 0 }, allow_nil: true |
120 | | - validates :max_completion_tokens, numericality: { greater_than: 0 }, allow_nil: true |
121 | | - validates :max_tokens, numericality: { greater_than: 0 }, allow_nil: true |
122 | | - |
123 | | - validates :service_tier, inclusion: { in: %w[auto default flex priority] }, allow_nil: true |
124 | | - validates :reasoning_effort, inclusion: { in: %w[minimal low medium high] }, allow_nil: true |
125 | | - validates :verbosity, inclusion: { in: %w[low medium high] }, allow_nil: true |
126 | | - validates :modalities, inclusion: { in: %w[text audio] }, allow_nil: true |
127 | | - |
128 | | - # Custom validations |
129 | | - validate :validate_metadata_format |
130 | | - validate :validate_logit_bias_format |
131 | | - validate :validate_stop_sequences |
132 | | - |
133 | | - def serialize |
134 | | - super.tap do |hash| |
135 | | - # Can be an empty hash, to enable the feature |
136 | | - hash[:web_search_options] ||= {} if web_search_options |
137 | | - end |
| 11 | + # Request wrapper that delegates to OpenAI gem model. |
| 12 | + # |
| 13 | + # Uses SimpleDelegator to wrap ::OpenAI::Models::Chat::CompletionCreateParams, |
| 14 | + # eliminating the need to maintain duplicate attribute definitions while |
| 15 | + # providing convenience transformations. |
| 16 | + # |
| 17 | + # All standard OpenAI Chat API fields are automatically available via delegation: |
| 18 | + # - model, messages, temperature, max_tokens, max_completion_tokens |
| 19 | + # - top_p, frequency_penalty, presence_penalty |
| 20 | + # - tools, tool_choice, response_format, stream_options |
| 21 | + # - audio, prediction, metadata, modalities |
| 22 | + # - service_tier, store, parallel_tool_calls, reasoning_effort, verbosity |
| 23 | + # - stop, seed, logit_bias, logprobs, top_logprobs |
| 24 | + # - prompt_cache_key, safety_identifier, user |
| 25 | + # - web_search_options |
| 26 | + # - function_call, functions (deprecated) |
| 27 | + # |
| 28 | + # @example Basic usage |
| 29 | + # request = Request.new( |
| 30 | + # model: "gpt-4o", |
| 31 | + # messages: [{role: "user", content: "Hello"}] |
| 32 | + # ) |
| 33 | + # request.model #=> "gpt-4o" |
| 34 | + # request.temperature #=> 1 (default) |
| 35 | + # |
| 36 | + # @example With transformations |
| 37 | + # # String messages are automatically normalized |
| 38 | + # request = Request.new( |
| 39 | + # model: "gpt-4o", |
| 40 | + # messages: "Hello" |
| 41 | + # ) |
| 42 | + # # Internally becomes: [{role: "user", content: "Hello"}] |
| 43 | + # |
| 44 | + # @example Common format compatibility |
| 45 | + # request = Request.new( |
| 46 | + # model: "gpt-4o", |
| 47 | + # messages: [{role: "user", content: "Hi"}], |
| 48 | + # instructions: ["You are helpful", "Be concise"] |
| 49 | + # ) |
| 50 | + # # instructions become developer messages |
| 51 | + class Request < SimpleDelegator |
| 52 | + # Default parameter values applied during initialization |
| 53 | + DEFAULTS = { |
| 54 | + frequency_penalty: 0, |
| 55 | + logprobs: false, |
| 56 | + modalities: [ "text" ], |
| 57 | + n: 1, |
| 58 | + parallel_tool_calls: true, |
| 59 | + presence_penalty: 0, |
| 60 | + service_tier: "auto", |
| 61 | + store: false, |
| 62 | + stream: false, |
| 63 | + temperature: 1, |
| 64 | + top_p: 1 |
| 65 | + }.freeze |
| 66 | + |
| 67 | + # @return [Boolean, nil] |
| 68 | + attr_reader :stream |
| 69 | + |
| 70 | + # Initializes request with field mapping and normalization. |
| 71 | + # |
| 72 | + # Maps common format fields (instructions) and normalizes messages. |
| 73 | + # |
| 74 | + # @param params [Hash] |
| 75 | + # @option params [String] :model required |
| 76 | + # @option params [Array, String, Hash] :messages required |
| 77 | + # @option params [Array<String>, String] :instructions system/developer prompts |
| 78 | + # @option params [Hash, String, Symbol] :response_format |
| 79 | + # @raise [ArgumentError] when gem model initialization fails |
| 80 | + def initialize(**params) |
| 81 | + # Extract stream flag |
| 82 | + @stream = params[:stream] |
| 83 | + |
| 84 | + # Apply defaults |
| 85 | + params = apply_defaults(params) |
| 86 | + |
| 87 | + # Normalize all parameters (instructions, messages, response_format) |
| 88 | + params = Chat::Transforms.normalize_params(params) |
| 89 | + |
| 90 | + # Create gem model - this validates all parameters! |
| 91 | + gem_model = ::OpenAI::Models::Chat::CompletionCreateParams.new(**params) |
| 92 | + |
| 93 | + # Delegate all method calls to gem model |
| 94 | + super(gem_model) |
| 95 | + rescue ArgumentError => e |
| 96 | + # Re-raise with more context |
| 97 | + raise ArgumentError, "Invalid OpenAI Chat request parameters: #{e.message}" |
138 | 98 | end |
139 | 99 |
|
140 | | - # Common Format Compatability |
141 | | - def instructions=(*values) |
142 | | - self.messages ||= [] |
| 100 | + # Serializes request for API call. |
| 101 | + # |
| 102 | + # Uses gem's JSON serialization, removes default values to keep request |
| 103 | + # body minimal, and simplifies messages where possible. |
| 104 | + # |
| 105 | + # @return [Hash] |
| 106 | + def serialize |
| 107 | + # Use gem's JSON serialization (handles all nested objects) |
| 108 | + hash = Chat::Transforms.gem_to_hash(__getobj__) |
143 | 109 |
|
144 | | - values.flatten.reverse.each do |value| |
145 | | - self.messages.unshift({ role: "developer", content: value }) |
146 | | - end |
| 110 | + # Cleanup and simplify for API request |
| 111 | + Chat::Transforms.cleanup_serialized_request(hash, DEFAULTS, __getobj__) |
147 | 112 | end |
148 | 113 |
|
149 | | - # Common Format Compatability |
150 | | - alias_attribute :message, :messages |
| 114 | + # Accessor for messages. |
| 115 | + # |
| 116 | + # @return [Array<Hash>, nil] |
| 117 | + def messages |
| 118 | + __getobj__.instance_variable_get(:@data)[:messages] |
| 119 | + end |
151 | 120 |
|
152 | | - # Common Format Compatability |
| 121 | + # Sets messages with normalization. |
| 122 | + # |
| 123 | + # @param value [Array, String, Hash] |
| 124 | + # @return [void] |
153 | 125 | def messages=(value) |
154 | | - case value |
155 | | - when Array |
156 | | - super((messages || []) | value) |
157 | | - else |
158 | | - super((messages || []) | [ value ]) |
159 | | - end |
| 126 | + normalized_value = Chat::Transforms.normalize_messages(value) |
| 127 | + __getobj__.instance_variable_get(:@data)[:messages] = normalized_value |
160 | 128 | end |
161 | 129 |
|
162 | | - private |
163 | | - |
164 | | - def validate_metadata_format |
165 | | - return if metadata.nil? |
166 | | - |
167 | | - unless metadata.is_a?(Hash) |
168 | | - errors.add(:metadata, "must be a hash") |
169 | | - return |
170 | | - end |
171 | | - |
172 | | - metadata.each do |key, value| |
173 | | - if key.to_s.length > 64 |
174 | | - errors.add(:metadata, "keys must be 64 characters or less") |
175 | | - end |
176 | | - if value.to_s.length > 512 |
177 | | - errors.add(:metadata, "values must be 512 characters or less") |
178 | | - end |
179 | | - end |
180 | | - |
181 | | - if metadata.size > 16 |
182 | | - errors.add(:metadata, "must have 16 key-value pairs or less") |
183 | | - end |
| 130 | + # Alias for messages (common format compatibility). |
| 131 | + # |
| 132 | + # @return [Array<Hash>, nil] |
| 133 | + def message |
| 134 | + messages |
184 | 135 | end |
185 | 136 |
|
186 | | - def validate_logit_bias_format |
187 | | - return if logit_bias.nil? |
188 | | - |
189 | | - unless logit_bias.is_a?(Hash) |
190 | | - errors.add(:logit_bias, "must be a hash") |
191 | | - return |
192 | | - end |
| 137 | + # @param value [Array, String, Hash] |
| 138 | + def message=(value) |
| 139 | + self.messages = value |
| 140 | + end |
193 | 141 |
|
194 | | - logit_bias.each do |token_id, bias| |
195 | | - unless bias.is_a?(Numeric) && bias >= -100 && bias <= 100 |
196 | | - errors.add(:logit_bias, "bias values must be between -100 and 100") |
197 | | - end |
198 | | - end |
| 142 | + # Sets instructions as developer messages (common format compatibility). |
| 143 | + # |
| 144 | + # Prepends developer messages to the messages array. |
| 145 | + # |
| 146 | + # @param values [Array<String>, String] |
| 147 | + # @return [void] |
| 148 | + def instructions=(*values) |
| 149 | + instructions_messages = Chat::Transforms.normalize_instructions(values.flatten) |
| 150 | + current_messages = messages || [] |
| 151 | + self.messages = instructions_messages + current_messages |
199 | 152 | end |
200 | 153 |
|
201 | | - def validate_stop_sequences |
202 | | - return if stop.nil? |
203 | | - return if stop.is_a?(String) |
| 154 | + private |
204 | 155 |
|
205 | | - if stop.is_a?(Array) |
206 | | - errors.add(:stop, "can have at most 4 sequences") if stop.length > 4 |
207 | | - else |
208 | | - errors.add(:stop, "must be a string, array, or null") |
| 156 | + # @api private |
| 157 | + # @param params [Hash] |
| 158 | + # @return [Hash] |
| 159 | + def apply_defaults(params) |
| 160 | + # Only apply defaults for keys that aren't present |
| 161 | + DEFAULTS.each do |key, value| |
| 162 | + params[key] = value unless params.key?(key) |
209 | 163 | end |
| 164 | + |
| 165 | + params |
210 | 166 | end |
211 | 167 | end |
212 | 168 | end |
|
0 commit comments