Skip to content

Commit 4671886

Browse files
committed
Refactor OpenAI's Chat Provider to use Native Gem Types
1 parent 0b1798d commit 4671886

31 files changed

Lines changed: 506 additions & 1536 deletions

lib/active_agent/providers/open_ai/chat/request.rb

Lines changed: 142 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -1,212 +1,168 @@
11
# frozen_string_literal: true
22

3-
require "active_agent/providers/common/model"
4-
require_relative "requests/_types"
3+
require "delegate"
4+
require "json"
5+
require_relative "transforms"
56

67
module ActiveAgent
78
module Providers
89
module OpenAI
910
module Chat
10-
class Request < Common::BaseModel
11-
# Messages array (required)
12-
attribute :messages, Requests::Messages::MessagesType.new
13-
14-
# Model ID (required)
15-
attribute :model, :string
16-
17-
# Audio output parameters
18-
attribute :audio, Requests::AudioType.new
19-
20-
# Frequency penalty
21-
attribute :frequency_penalty, :float, default: 0
22-
23-
# Deprecated: function_call (use tool_choice instead)
24-
attribute :function_call # String or object
25-
26-
# Deprecated: functions (use tools instead)
27-
attribute :functions # Array of function objects
28-
29-
# Logit bias
30-
attribute :logit_bias # Hash of token_id => bias_value
31-
32-
# Log probabilities
33-
attribute :logprobs, :boolean, default: false
34-
35-
# Max completion tokens
36-
attribute :max_completion_tokens, :integer
37-
38-
# Deprecated: max_tokens (use max_completion_tokens)
39-
attribute :max_tokens, :integer
40-
41-
# Metadata
42-
attribute :metadata # Hash of key-value pairs
43-
44-
# Modalities
45-
attribute :modalities, default: -> { [ "text" ] } # Array of strings
46-
47-
# Number of completions
48-
attribute :n, :integer, default: 1
49-
50-
# Parallel tool calls
51-
attribute :parallel_tool_calls, :boolean, default: true
52-
53-
# Prediction configuration
54-
attribute :prediction, Requests::PredictionType.new
55-
56-
# Presence penalty
57-
attribute :presence_penalty, :float, default: 0
58-
59-
# Prompt cache key
60-
attribute :prompt_cache_key, :string
61-
62-
# Reasoning effort (for reasoning models)
63-
attribute :reasoning_effort, :string
64-
65-
# Response format
66-
attribute :response_format, Requests::ResponseFormatType.new
67-
68-
# Safety identifier
69-
attribute :safety_identifier, :string
70-
71-
# Deprecated: seed
72-
attribute :seed, :integer
73-
74-
# Service tier
75-
attribute :service_tier, :string, default: "auto"
76-
77-
# Stop sequences
78-
attribute :stop # String, array, or null
79-
80-
# Storage
81-
attribute :store, :boolean, default: false
82-
83-
# Streaming
84-
attribute :stream, :boolean, default: false
85-
attribute :stream_options, Requests::StreamOptionsType.new
86-
87-
# Temperature sampling
88-
attribute :temperature, :float, default: 1
89-
90-
# Tool choice
91-
attribute :tool_choice, Requests::ToolChoiceType.new
92-
93-
# Tools array
94-
attribute :tools # Array of tool objects
95-
96-
# Top logprobs
97-
attribute :top_logprobs, :integer
98-
99-
# Top P sampling
100-
attribute :top_p, :float, default: 1
101-
102-
# Deprecated: user (use safety_identifier or prompt_cache_key)
103-
attribute :user, :string
104-
105-
# Verbosity (for reasoning models)
106-
attribute :verbosity, :string
107-
108-
# Web search options
109-
attribute :web_search_options, Requests::WebSearchOptionsType.new
110-
111-
# Validations
112-
validates :model, :messages, presence: true
113-
114-
validates :frequency_penalty, numericality: { greater_than_or_equal_to: -2.0, less_than_or_equal_to: 2.0 }, allow_nil: true
115-
validates :presence_penalty, numericality: { greater_than_or_equal_to: -2.0, less_than_or_equal_to: 2.0 }, allow_nil: true
116-
validates :temperature, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 2 }, allow_nil: true
117-
validates :top_p, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 1 }, allow_nil: true
118-
validates :top_logprobs, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 20 }, allow_nil: true
119-
validates :n, numericality: { greater_than: 0 }, allow_nil: true
120-
validates :max_completion_tokens, numericality: { greater_than: 0 }, allow_nil: true
121-
validates :max_tokens, numericality: { greater_than: 0 }, allow_nil: true
122-
123-
validates :service_tier, inclusion: { in: %w[auto default flex priority] }, allow_nil: true
124-
validates :reasoning_effort, inclusion: { in: %w[minimal low medium high] }, allow_nil: true
125-
validates :verbosity, inclusion: { in: %w[low medium high] }, allow_nil: true
126-
validates :modalities, inclusion: { in: %w[text audio] }, allow_nil: true
127-
128-
# Custom validations
129-
validate :validate_metadata_format
130-
validate :validate_logit_bias_format
131-
validate :validate_stop_sequences
132-
133-
def serialize
134-
super.tap do |hash|
135-
# Can be an empty hash, to enable the feature
136-
hash[:web_search_options] ||= {} if web_search_options
137-
end
11+
# Request wrapper that delegates to OpenAI gem model.
12+
#
13+
# Uses SimpleDelegator to wrap ::OpenAI::Models::Chat::CompletionCreateParams,
14+
# eliminating the need to maintain duplicate attribute definitions while
15+
# providing convenience transformations.
16+
#
17+
# All standard OpenAI Chat API fields are automatically available via delegation:
18+
# - model, messages, temperature, max_tokens, max_completion_tokens
19+
# - top_p, frequency_penalty, presence_penalty
20+
# - tools, tool_choice, response_format, stream_options
21+
# - audio, prediction, metadata, modalities
22+
# - service_tier, store, parallel_tool_calls, reasoning_effort, verbosity
23+
# - stop, seed, logit_bias, logprobs, top_logprobs
24+
# - prompt_cache_key, safety_identifier, user
25+
# - web_search_options
26+
# - function_call, functions (deprecated)
27+
#
28+
# @example Basic usage
29+
# request = Request.new(
30+
# model: "gpt-4o",
31+
# messages: [{role: "user", content: "Hello"}]
32+
# )
33+
# request.model #=> "gpt-4o"
34+
# request.temperature #=> 1 (default)
35+
#
36+
# @example With transformations
37+
# # String messages are automatically normalized
38+
# request = Request.new(
39+
# model: "gpt-4o",
40+
# messages: "Hello"
41+
# )
42+
# # Internally becomes: [{role: "user", content: "Hello"}]
43+
#
44+
# @example Common format compatibility
45+
# request = Request.new(
46+
# model: "gpt-4o",
47+
# messages: [{role: "user", content: "Hi"}],
48+
# instructions: ["You are helpful", "Be concise"]
49+
# )
50+
# # instructions become developer messages
51+
class Request < SimpleDelegator
52+
# Default parameter values applied during initialization
53+
DEFAULTS = {
54+
frequency_penalty: 0,
55+
logprobs: false,
56+
modalities: [ "text" ],
57+
n: 1,
58+
parallel_tool_calls: true,
59+
presence_penalty: 0,
60+
service_tier: "auto",
61+
store: false,
62+
stream: false,
63+
temperature: 1,
64+
top_p: 1
65+
}.freeze
66+
67+
# @return [Boolean, nil]
68+
attr_reader :stream
69+
70+
# Initializes request with field mapping and normalization.
71+
#
72+
# Maps common format fields (instructions) and normalizes messages.
73+
#
74+
# @param params [Hash]
75+
# @option params [String] :model required
76+
# @option params [Array, String, Hash] :messages required
77+
# @option params [Array<String>, String] :instructions system/developer prompts
78+
# @option params [Hash, String, Symbol] :response_format
79+
# @raise [ArgumentError] when gem model initialization fails
80+
def initialize(**params)
81+
# Extract stream flag
82+
@stream = params[:stream]
83+
84+
# Apply defaults
85+
params = apply_defaults(params)
86+
87+
# Normalize all parameters (instructions, messages, response_format)
88+
params = Chat::Transforms.normalize_params(params)
89+
90+
# Create gem model - this validates all parameters!
91+
gem_model = ::OpenAI::Models::Chat::CompletionCreateParams.new(**params)
92+
93+
# Delegate all method calls to gem model
94+
super(gem_model)
95+
rescue ArgumentError => e
96+
# Re-raise with more context
97+
raise ArgumentError, "Invalid OpenAI Chat request parameters: #{e.message}"
13898
end
13999

140-
# Common Format Compatability
141-
def instructions=(*values)
142-
self.messages ||= []
100+
# Serializes request for API call.
101+
#
102+
# Uses gem's JSON serialization, removes default values to keep request
103+
# body minimal, and simplifies messages where possible.
104+
#
105+
# @return [Hash]
106+
def serialize
107+
# Use gem's JSON serialization (handles all nested objects)
108+
hash = Chat::Transforms.gem_to_hash(__getobj__)
143109

144-
values.flatten.reverse.each do |value|
145-
self.messages.unshift({ role: "developer", content: value })
146-
end
110+
# Cleanup and simplify for API request
111+
Chat::Transforms.cleanup_serialized_request(hash, DEFAULTS, __getobj__)
147112
end
148113

149-
# Common Format Compatability
150-
alias_attribute :message, :messages
114+
# Accessor for messages.
115+
#
116+
# @return [Array<Hash>, nil]
117+
def messages
118+
__getobj__.instance_variable_get(:@data)[:messages]
119+
end
151120

152-
# Common Format Compatability
121+
# Sets messages with normalization.
122+
#
123+
# @param value [Array, String, Hash]
124+
# @return [void]
153125
def messages=(value)
154-
case value
155-
when Array
156-
super((messages || []) | value)
157-
else
158-
super((messages || []) | [ value ])
159-
end
126+
normalized_value = Chat::Transforms.normalize_messages(value)
127+
__getobj__.instance_variable_get(:@data)[:messages] = normalized_value
160128
end
161129

162-
private
163-
164-
def validate_metadata_format
165-
return if metadata.nil?
166-
167-
unless metadata.is_a?(Hash)
168-
errors.add(:metadata, "must be a hash")
169-
return
170-
end
171-
172-
metadata.each do |key, value|
173-
if key.to_s.length > 64
174-
errors.add(:metadata, "keys must be 64 characters or less")
175-
end
176-
if value.to_s.length > 512
177-
errors.add(:metadata, "values must be 512 characters or less")
178-
end
179-
end
180-
181-
if metadata.size > 16
182-
errors.add(:metadata, "must have 16 key-value pairs or less")
183-
end
130+
# Alias for messages (common format compatibility).
131+
#
132+
# @return [Array<Hash>, nil]
133+
def message
134+
messages
184135
end
185136

186-
def validate_logit_bias_format
187-
return if logit_bias.nil?
188-
189-
unless logit_bias.is_a?(Hash)
190-
errors.add(:logit_bias, "must be a hash")
191-
return
192-
end
137+
# @param value [Array, String, Hash]
138+
def message=(value)
139+
self.messages = value
140+
end
193141

194-
logit_bias.each do |token_id, bias|
195-
unless bias.is_a?(Numeric) && bias >= -100 && bias <= 100
196-
errors.add(:logit_bias, "bias values must be between -100 and 100")
197-
end
198-
end
142+
# Sets instructions as developer messages (common format compatibility).
143+
#
144+
# Prepends developer messages to the messages array.
145+
#
146+
# @param values [Array<String>, String]
147+
# @return [void]
148+
def instructions=(*values)
149+
instructions_messages = Chat::Transforms.normalize_instructions(values.flatten)
150+
current_messages = messages || []
151+
self.messages = instructions_messages + current_messages
199152
end
200153

201-
def validate_stop_sequences
202-
return if stop.nil?
203-
return if stop.is_a?(String)
154+
private
204155

205-
if stop.is_a?(Array)
206-
errors.add(:stop, "can have at most 4 sequences") if stop.length > 4
207-
else
208-
errors.add(:stop, "must be a string, array, or null")
156+
# @api private
157+
# @param params [Hash]
158+
# @return [Hash]
159+
def apply_defaults(params)
160+
# Only apply defaults for keys that aren't present
161+
DEFAULTS.each do |key, value|
162+
params[key] = value unless params.key?(key)
209163
end
164+
165+
params
210166
end
211167
end
212168
end

0 commit comments

Comments
 (0)