diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart index 902fa9812bec..08984b41f0fb 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/function_calling_page.dart @@ -15,6 +15,7 @@ import 'package:flutter/material.dart'; import 'package:firebase_ai/firebase_ai.dart'; import 'package:firebase_auth/firebase_auth.dart'; +import '../utils/function_call_utils.dart'; import '../widgets/message_widget.dart'; class FunctionCallingPage extends StatefulWidget { @@ -31,13 +32,6 @@ class FunctionCallingPage extends StatefulWidget { State createState() => _FunctionCallingPageState(); } -class Location { - final String city; - final String state; - - Location(this.city, this.state); -} - class _FunctionCallingPageState extends State { late GenerativeModel _functionCallModel; late GenerativeModel _autoFunctionCallModel; @@ -76,7 +70,7 @@ class _FunctionCallingPageState extends State { 'The date for which to get the weather. Date must be in the format: YYYY-MM-DD.', ), }, - callable: _fetchWeatherCallable, + callable: fetchWeatherCallable, ); _autoFindRestaurantsTool = AutoFunctionDeclaration( name: 'findRestaurants', @@ -148,16 +142,6 @@ class _FunctionCallingPageState extends State { }; } - Future> _fetchWeatherCallable( - Map args, - ) async { - final locationData = args['location']! as Map; - final city = locationData['city']! as String; - final state = locationData['state']! as String; - final date = args['date']! as String; - return fetchWeather(Location(city, state), date); - } - void _initializeModel() { final generationConfig = GenerationConfig( thinkingConfig: _enableThinking @@ -204,23 +188,6 @@ class _FunctionCallingPageState extends State { ); } - // This is a hypothetical API to return a fake weather data collection for - // certain location - Future> fetchWeather( - Location location, - String date, - ) async { - // TODO(developer): Call a real weather API. - // Mock response from the API. In developer live code this would call the - // external API and return what that API returns. - final apiResponse = { - 'temperature': 38, - 'chancePrecipitation': '56%', - 'cloudConditions': 'partly-cloudy', - }; - return apiResponse; - } - /// Actual function to demonstrate the function calling feature. final fetchWeatherTool = FunctionDeclaration( 'fetchWeather', diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart index 86e5d8222e5e..155c401d81a3 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart @@ -14,6 +14,7 @@ import 'dart:convert'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; +import '../utils/function_call_utils.dart'; import '../widgets/message_widget.dart'; import 'package:firebase_ai/firebase_ai.dart'; @@ -41,6 +42,10 @@ class _ServerTemplatePageState extends State { TemplateGenerativeModel? _templateGenerativeModel; TemplateImagenModel? _templateImagenModel; + TemplateChatSession? _chatSession; + TemplateChatSession? _chatFunctionSession; + TemplateChatSession? _chatAutoFunctionSession; + @override void initState() { super.initState(); @@ -58,6 +63,26 @@ class _ServerTemplatePageState extends State { FirebaseAI.googleAI().templateGenerativeModel(); _templateImagenModel = FirebaseAI.googleAI().templateImagenModel(); } + + // Inputs are now provided ONCE here when creating the session + _chatSession = _templateGenerativeModel?.startChat( + 'chat_history.prompt', + inputs: {}, + ); + _chatFunctionSession = _templateGenerativeModel?.startChat( + 'cj-function-calling-weather', + inputs: {}, + ); + _chatAutoFunctionSession = _templateGenerativeModel?.startChat( + 'cj-function-calling-weather', + inputs: {}, + autoFunctions: [ + TemplateAutoFunction( + name: 'fetchWeather', + callable: fetchWeatherCallable, + ), + ], + ); } void _scrollDown() { @@ -122,6 +147,41 @@ class _ServerTemplatePageState extends State { const SizedBox.square( dimension: 15, ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateAutoFunctionCall( + _textController.text, + ); + }, + icon: Icon( + Icons.auto_mode, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Auto Function Calling', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateFunctionCall(_textController.text); + }, + icon: Icon( + Icons.functions, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Function Calling', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateChat(_textController.text); + }, + icon: Icon( + Icons.chat, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Chat', + ), if (!_loading) IconButton( onPressed: () async { @@ -177,77 +237,134 @@ class _ServerTemplatePageState extends State { ); } - Future _serverTemplateUrlContext(String message) async { + Future _handleServerTemplateMessage( + String message, + Future Function(String) generateContent, + ) async { setState(() { _loading = true; }); try { _messages.add(MessageData(text: message, fromUser: true)); - var response = await _templateGenerativeModel - ?.generateContent('cj-urlcontext', inputs: {'url': message}); - - final candidate = response?.candidates.first; - if (candidate == null) { - _messages.add(MessageData(text: 'No response', fromUser: false)); - } else { - final responseText = candidate.text ?? ''; - final groundingMetadata = candidate.groundingMetadata; - final urlContextMetadata = candidate.urlContextMetadata; - - final buffer = StringBuffer(responseText); - if (groundingMetadata != null) { - buffer.writeln('\n\n--- Grounding Metadata ---'); - buffer.writeln('Web Search Queries:'); - for (final query in groundingMetadata.webSearchQueries) { - buffer.writeln(' - $query'); - } - buffer.writeln('\nGrounding Chunks:'); - for (final chunk in groundingMetadata.groundingChunks) { - if (chunk.web != null) { - buffer.writeln(' - Web Chunk:'); - buffer.writeln(' - Title: ${chunk.web!.title}'); - buffer.writeln(' - URI: ${chunk.web!.uri}'); - buffer.writeln(' - Domain: ${chunk.web!.domain}'); - } - } - } - - if (urlContextMetadata != null) { - buffer.writeln('\n\n--- URL Context Metadata ---'); - for (final data in urlContextMetadata.urlMetadata) { - buffer.writeln(' - URL: ${data.retrievedUrl}'); - buffer.writeln(' Status: ${data.urlRetrievalStatus}'); - } - } - _messages.add(MessageData(text: buffer.toString(), fromUser: false)); - } - - setState(() { - _loading = false; - _scrollDown(); - }); + await generateContent(message); } catch (e) { _showError(e.toString()); - setState(() { - _loading = false; - }); } finally { _textController.clear(); setState(() { _loading = false; }); _textFieldFocus.requestFocus(); + _scrollDown(); } } - Future _serverTemplateImagen(String message) async { - setState(() { - _loading = true; + Future _serverTemplateUrlContext(String message) async { + await _handleServerTemplateMessage( + message, + (message) async { + _messages.add(MessageData(text: message, fromUser: true)); + var response = await _templateGenerativeModel + ?.generateContent('cj-urlcontext', inputs: {'url': message}); + + final candidate = response?.candidates.first; + if (candidate == null) { + _messages.add(MessageData(text: 'No response', fromUser: false)); + } else { + final responseText = candidate.text ?? ''; + final groundingMetadata = candidate.groundingMetadata; + final urlContextMetadata = candidate.urlContextMetadata; + + final buffer = StringBuffer(responseText); + if (groundingMetadata != null) { + buffer.writeln('\n\n--- Grounding Metadata ---'); + buffer.writeln('Web Search Queries:'); + for (final query in groundingMetadata.webSearchQueries) { + buffer.writeln(' - $query'); + } + buffer.writeln('\nGrounding Chunks:'); + for (final chunk in groundingMetadata.groundingChunks) { + if (chunk.web != null) { + buffer.writeln(' - Web Chunk:'); + buffer.writeln(' - Title: ${chunk.web!.title}'); + buffer.writeln(' - URI: ${chunk.web!.uri}'); + buffer.writeln(' - Domain: ${chunk.web!.domain}'); + } + } + } + + if (urlContextMetadata != null) { + buffer.writeln('\n\n--- URL Context Metadata ---'); + for (final data in urlContextMetadata.urlMetadata) { + buffer.writeln(' - URL: ${data.retrievedUrl}'); + buffer.writeln(' Status: ${data.urlRetrievalStatus}'); + } + } + _messages.add(MessageData(text: buffer.toString(), fromUser: false)); + } + }, + ); + } + + Future _serverTemplateAutoFunctionCall(String message) async { + await _handleServerTemplateMessage(message, (message) async { + // Inputs are no longer passed during sendMessage + var response = await _chatAutoFunctionSession?.sendMessage( + Content.text(message), + ); + + _messages.add(MessageData(text: response?.text, fromUser: false)); }); - MessageData? resultMessage; - try { - _messages.add(MessageData(text: message, fromUser: true)); + } + + Future _serverTemplateFunctionCall(String message) async { + await _handleServerTemplateMessage(message, (message) async { + // Inputs are no longer passed during sendMessage + var response = await _chatFunctionSession?.sendMessage( + Content.text(message), + ); + + _messages.add(MessageData(text: response?.text, fromUser: false)); + final functionCalls = response?.functionCalls.toList(); + if (functionCalls!.isNotEmpty) { + final functionCall = functionCalls.first; + if (functionCall.name == 'fetchWeather') { + final location = + functionCall.args['location']! as Map; + final date = functionCall.args['date']! as String; + final city = location['city'] as String; + final state = location['state'] as String; + final functionResult = + await fetchWeather(Location(city, state), date); + + // Respond to the function call + var functionResponse = await _chatFunctionSession?.sendMessage( + Content.functionResponse(functionCall.name, functionResult), + ); + _messages + .add(MessageData(text: functionResponse?.text, fromUser: false)); + } + } + }); + } + + Future _serverTemplateChat(String message) async { + await _handleServerTemplateMessage(message, (message) async { + // Inputs are no longer passed during sendMessage + var response = await _chatSession?.sendMessage( + Content.text(message), + ); + + var text = response?.text; + + _messages.add(MessageData(text: text, fromUser: false)); + }); + } + + Future _serverTemplateImagen(String message) async { + await _handleServerTemplateMessage(message, (message) async { + MessageData? resultMessage; var response = await _templateImagenModel?.generateImages( 'portrait-googleai', inputs: { @@ -267,34 +384,14 @@ class _ServerTemplatePageState extends State { // Handle the case where no images were generated _showError('Error: No images were generated.'); } - - setState(() { - if (resultMessage != null) { - _messages.add(resultMessage); - } - _loading = false; - _scrollDown(); - }); - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } + if (resultMessage != null) { + _messages.add(resultMessage); + } + }); } Future _serverTemplateImageInput(String message) async { - setState(() { - _loading = true; - }); - - try { + await _handleServerTemplateMessage(message, (message) async { ByteData catBytes = await rootBundle.load('assets/images/cat.jpg'); var imageBytes = catBytes.buffer.asUint8List(); _messages.add( @@ -306,7 +403,7 @@ class _ServerTemplatePageState extends State { ); var response = await _templateGenerativeModel?.generateContent( - 'media.prompt', + 'media', inputs: { 'imageData': { 'isInline': true, @@ -316,53 +413,16 @@ class _ServerTemplatePageState extends State { }, ); _messages.add(MessageData(text: response?.text, fromUser: false)); - - setState(() { - _loading = false; - _scrollDown(); - }); - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } + }); } Future _sendServerTemplateMessage(String message) async { - setState(() { - _loading = true; - }); - - try { - _messages.add(MessageData(text: message, fromUser: true)); + await _handleServerTemplateMessage(message, (message) async { var response = await _templateGenerativeModel ?.generateContent('new-greeting', inputs: {}); _messages.add(MessageData(text: response?.text, fromUser: false)); - - setState(() { - _loading = false; - _scrollDown(); - }); - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } + }); } void _showError(String message) { diff --git a/packages/firebase_ai/firebase_ai/example/lib/utils/function_call_utils.dart b/packages/firebase_ai/firebase_ai/example/lib/utils/function_call_utils.dart new file mode 100644 index 000000000000..ff4a5d6991e0 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/lib/utils/function_call_utils.dart @@ -0,0 +1,47 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +class Location { + final String city; + final String state; + + Location(this.city, this.state); +} + +// This is a hypothetical API to return a fake weather data collection for +// certain location +Future> fetchWeather( + Location location, + String date, +) async { + // TODO(developer): Call a real weather API. + // Mock response from the API. In developer live code this would call the + // external API and return what that API returns. + final apiResponse = { + 'temperature': 38, + 'chancePrecipitation': '56%', + 'cloudConditions': 'partly-cloudy', + }; + return apiResponse; +} + +Future> fetchWeatherCallable( + Map args, +) async { + final locationData = args['location']! as Map; + final city = locationData['city']! as String; + final state = locationData['state']! as String; + final date = args['date']! as String; + return fetchWeather(Location(city, state), date); +} diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 681989831e53..ad35755ff06e 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -111,7 +111,8 @@ export 'src/live_api.dart' Transcription; export 'src/live_session.dart' show LiveSession; export 'src/schema.dart' show Schema, SchemaType; - +export 'src/server_template/template_chat.dart' + show TemplateChatSession, TemplateAutoFunction, StartTemplateChatExtension; export 'src/tool.dart' show AutoFunctionDeclaration, diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart new file mode 100644 index 000000000000..b945d3b4ea6e --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart @@ -0,0 +1,205 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:async'; + +import '../api.dart'; +import '../base_model.dart'; +import '../content.dart'; +import '../utils/chat_utils.dart'; +import '../utils/mutex.dart'; + +final class TemplateAutoFunction { + TemplateAutoFunction({ + required this.name, + required this.callable, + }); + + /// The name of the function. + /// + /// Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum + /// length of 63. + final String name; + + /// The callable function that this declaration represents. + final FutureOr> Function(Map args) + callable; +} + +/// A back-and-forth chat with a server template. +/// +/// Records messages sent and received in [history]. The history will always +/// record the content from the first candidate in the +/// [GenerateContentResponse], other candidates may be available on the returned +/// response. The history reflects the most current state of the chat session. +final class TemplateChatSession { + TemplateChatSession._( + this._templateHistoryGenerateContent, + this._templateHistoryGenerateContentStream, + this._templateId, + this._inputs, + this._history, + List autoFunctionLists, + this._maxTurns, + ) : _autoFunctions = {for (var item in autoFunctionLists) item.name: item}; + + final Future Function( + Iterable content, String templateId, + {required Map inputs}) _templateHistoryGenerateContent; + + final Stream Function( + Iterable content, String templateId, + {required Map inputs}) + _templateHistoryGenerateContentStream; + + final String _templateId; + final Map _inputs; + final List _history; + final Map _autoFunctions; + final int _maxTurns; + + final _mutex = Mutex(); + + /// The content that has been successfully sent to, or received from, the + /// generative model. + /// + /// If there are outstanding requests from calls to [sendMessage], + /// these will not be reflected in the history. + /// Messages without a candidate in the response are not recorded in history, + /// including the message sent to the model. + Iterable get history => _history.skip(0); + + /// Sends [message] to the server template as a continuation of the chat [history]. + /// + /// Prepends the history to the request and uses the provided model to + /// generate new content, providing the session's initialized inputs. + /// + /// When there are no candidates in the response, the [message] and response + /// are ignored and will not be recorded in the [history]. + Future sendMessage(Content message) async { + final lock = await _mutex.acquire(); + try { + final requestHistory = [message]; + var turn = 0; + while (turn < _maxTurns) { + final response = await _templateHistoryGenerateContent( + _history.followedBy(requestHistory), + _templateId, + inputs: _inputs, + ); + + final functionCalls = response.functionCalls; + final shouldAutoExecute = _autoFunctions.isNotEmpty && + functionCalls.isNotEmpty && + functionCalls.every((c) => _autoFunctions.containsKey(c.name)); + + if (!shouldAutoExecute) { + // Standard handling: Update history and return the response to the user. + if (response.candidates case [final candidate, ...]) { + _history.add(message); + final normalizedContent = candidate.content.role == null + ? Content('model', candidate.content.parts) + : candidate.content; + _history.add(normalizedContent); + } + return response; + } + + // Auto function execution + requestHistory.add(response.candidates.first.content); + final functionResponses = []; + for (final functionCall in functionCalls) { + final function = _autoFunctions[functionCall.name]; + + Object? result; + try { + result = await function!.callable(functionCall.args); + } catch (e) { + result = e.toString(); + } + functionResponses + .add(FunctionResponse(functionCall.name, {'result': result})); + } + requestHistory.add(Content('function', functionResponses)); + turn++; + } + throw Exception('Max turns of $_maxTurns reached.'); + } finally { + lock.release(); + } + } + + /// Sends [message] to the server template as a continuation of the chat + /// [history]. + /// + /// Returns a stream of responses, which may be chunks of a single aggregate + /// response. + /// + /// Prepends the history to the request and uses the provided model to + /// generate new content, providing the session's initialized inputs. + /// + /// When there are no candidates in the response, the [message] and response + /// are ignored and will not be recorded in the [history]. + Stream sendMessageStream(Content message) { + final controller = StreamController(sync: true); + _mutex.acquire().then((lock) async { + try { + final responses = _templateHistoryGenerateContentStream( + _history.followedBy([message]), + _templateId, + inputs: _inputs, + ); + final content = []; + await for (final response in responses) { + if (response.candidates case [final candidate, ...]) { + content.add(candidate.content); + } + controller.add(response); + } + if (content.isNotEmpty) { + _history.add(message); + _history.add(historyAggregate(content)); + } + } catch (e, s) { + controller.addError(e, s); + } + lock.release(); + unawaited(controller.close()); + }); + return controller.stream; + } +} + +/// An extension on [TemplateGenerativeModel] that provides a `startChat` method. +extension StartTemplateChatExtension on TemplateGenerativeModel { + /// Starts a [TemplateChatSession] that will use this model to respond to messages. + /// + /// ```dart + /// final chat = model.startChat('my_template', inputs: {'language': 'en'}); + /// final response = await chat.sendMessage(Content.text('Hello there.')); + /// print(response.text); + /// ``` + TemplateChatSession startChat(String templateId, + {required Map inputs, + List? history, + List? autoFunctions, + int? maxTurns}) => + TemplateChatSession._( + templateGenerateContentWithHistory, + templateGenerateContentWithHistoryStream, + templateId, + inputs, + history ?? [], + autoFunctions ?? [], + maxTurns ?? 5); +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart index 75e9029f44b4..78e302b5df47 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart @@ -87,6 +87,29 @@ final class TemplateGenerativeModel extends BaseTemplateApiClientModel { null, _serializationStrategy.parseGenerateContentResponse); } + + /// Generates content from a template with the given [templateId], [inputs] and + /// [history]. + @experimental + Future templateGenerateContentWithHistory( + Iterable history, String templateId, + {required Map inputs}) => + makeTemplateRequest(TemplateTask.templateGenerateContent, templateId, + inputs, history, _serializationStrategy.parseGenerateContentResponse); + + /// Generates a stream of content from a template with the given [templateId], + /// [inputs] and [history]. + @experimental + Stream templateGenerateContentWithHistoryStream( + Iterable history, String templateId, + {required Map inputs}) { + return streamTemplateRequest( + TemplateTask.templateStreamGenerateContent, + templateId, + inputs, + history, + _serializationStrategy.parseGenerateContentResponse); + } } /// Returns a [TemplateGenerativeModel] using its private constructor.