Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ruby-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.2.2
18 changes: 2 additions & 16 deletions lib/ruby_llm/active_record/acts_as_legacy.rb
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,12 @@ def with_schema(...)
end

def on_new_message(&block)
to_llm

existing_callback = @chat.instance_variable_get(:@on)[:new_message]

@chat.on_new_message do
existing_callback&.call
block&.call
end
to_llm.on_new_message(&block)
self
end

def on_end_message(&block)
to_llm

existing_callback = @chat.instance_variable_get(:@on)[:end_message]

@chat.on_end_message do |msg|
existing_callback&.call(msg)
block&.call(msg)
end
to_llm.on_end_message(&block)
self
end

Expand Down
18 changes: 2 additions & 16 deletions lib/ruby_llm/active_record/chat_methods.rb
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,12 @@ def with_schema(...)
end

def on_new_message(&block)
to_llm

existing_callback = @chat.instance_variable_get(:@on)[:new_message]

@chat.on_new_message do
existing_callback&.call
block&.call
end
to_llm.on_new_message(&block)
self
end

def on_end_message(&block)
to_llm

existing_callback = @chat.instance_variable_get(:@on)[:end_message]

@chat.on_end_message do |msg|
existing_callback&.call(msg)
block&.call(msg)
end
to_llm.on_end_message(&block)
self
end

Expand Down
93 changes: 52 additions & 41 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ class Chat

attr_reader :model, :messages, :tools, :params, :headers, :schema

# Stores multiple callbacks per key and invokes all of them.
#
# Internally we keep a callable per event (via `callback_for`) so higher
# level code can safely chain callbacks without overwriting persistence.
class CallbackFanout
def initialize
@callbacks = Hash.new { |h, k| h[k] = [] }
end

def add(key, callable)
return unless callable

@callbacks[key] << callable
end

def callback_for(key)
callbacks = @callbacks[key]
return if callbacks.empty?

->(*args) { callbacks.each { |cb| cb.call(*args) } }
end
end

def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
raise ArgumentError, 'Provider must be specified if assume_model_exists is true'
Expand All @@ -22,17 +45,12 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
@params = {}
@headers = {}
@schema = nil
@on = {
new_message: nil,
end_message: nil,
tool_call: nil,
tool_result: nil
}
@on = CallbackFanout.new
end

def ask(message = nil, with: nil, &)
add_message role: :user, content: build_content(message, with)
complete(&)
def ask(message = nil, with: nil, &block)
add_message role: :user, content: Content.new(message, with)
complete(&block)
end

alias say ask
Expand All @@ -57,7 +75,7 @@ def with_tools(*tools, replace: false)
end

def with_model(model_id, provider: nil, assume_exists: false)
@model, @provider = Models.resolve(model_id, provider:, assume_exists:, config: @config)
@model, @provider = Models.resolve(model_id, provider: provider, assume_exists: assume_exists, config: @config)
@connection = @provider.connection
self
end
Expand Down Expand Up @@ -98,30 +116,30 @@ def with_schema(schema)
end

def on_new_message(&block)
@on[:new_message] = block
@on.add(:new_message, block)
self
end

def on_end_message(&block)
@on[:end_message] = block
@on.add(:end_message, block)
self
end

def on_tool_call(&block)
@on[:tool_call] = block
@on.add(:tool_call, block)
self
end

def on_tool_result(&block)
@on[:tool_result] = block
@on.add(:tool_result, block)
self
end

def each(&)
messages.each(&)
def each(&block)
messages.each(&block)
end

def complete(&) # rubocop:disable Metrics/PerceivedComplexity
def complete(&block) # rubocop:disable Metrics/PerceivedComplexity
response = @provider.complete(
messages,
tools: @tools,
Expand All @@ -130,10 +148,10 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
params: @params,
headers: @headers,
schema: @schema,
&wrap_streaming_block(&)
&wrap_streaming_block(&block)
)

@on[:new_message]&.call unless block_given?
callback_for(:new_message)&.call unless block

if @schema && response.content.is_a?(String)
begin
Expand All @@ -144,10 +162,10 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
end

add_message response
@on[:end_message]&.call(response)
callback_for(:end_message)&.call(response)

if response.tool_call?
handle_tool_calls(response, &)
handle_tool_calls(response, &block)
else
response
end
Expand All @@ -169,6 +187,10 @@ def instance_variables

private

def callback_for(key)
@on.callback_for(key)
end

def wrap_streaming_block(&block)
return nil unless block_given?

Expand All @@ -178,46 +200,35 @@ def wrap_streaming_block(&block)
# Create message on first content chunk
unless first_chunk_received
first_chunk_received = true
@on[:new_message]&.call
callback_for(:new_message)&.call
end

block.call chunk
end
end

def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
def handle_tool_calls(response, &block) # rubocop:disable Metrics/PerceivedComplexity
halt_result = nil

response.tool_calls.each_value do |tool_call|
@on[:new_message]&.call
@on[:tool_call]&.call(tool_call)
callback_for(:new_message)&.call
callback_for(:tool_call)&.call(tool_call)
result = execute_tool tool_call
@on[:tool_result]&.call(result)
tool_payload = result.is_a?(Tool::Halt) ? result.content : result
content = content_like?(tool_payload) ? tool_payload : tool_payload.to_s
message = add_message role: :tool, content:, tool_call_id: tool_call.id
@on[:end_message]&.call(message)
callback_for(:tool_result)&.call(result)
content = result.is_a?(Content) ? result : result.to_s
message = add_message role: :tool, content: content, tool_call_id: tool_call.id
callback_for(:end_message)&.call(message)

halt_result = result if result.is_a?(Tool::Halt)
end

halt_result || complete(&)
halt_result || complete(&block)
end

def execute_tool(tool_call)
tool = tools[tool_call.name.to_sym]
args = tool_call.arguments
tool.call(args)
end

def build_content(message, attachments)
return message if content_like?(message)

Content.new(message, attachments)
end

def content_like?(object)
object.is_a?(Content) || object.is_a?(Content::Raw)
end
end
end

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions spec/ruby_llm/active_record/acts_as_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,20 @@ def uploaded_file(path, type)
expect(chat.messages.count).to eq(2) # Persistence still works
end

it 'allows chaining callbacks on to_llm without losing persistence' do
chat = Chat.create!(model: model)
llm_chat = chat.to_llm

user_callback_called = false
# Directly attach callback to the underlying Chat object
llm_chat.on_new_message { user_callback_called = true }

chat.ask('Hello')

expect(user_callback_called).to be true
expect(chat.messages.count).to eq(2) # Persistence still works
end

it 'calls on_tool_call and on_tool_result callbacks' do
tool_call_received = nil
tool_result_received = nil
Expand Down