Skip to content

Commit 73e4064

Browse files
authored
feat(vertexai): Allow serializing to Developer API models (#17294)
Move parsing from top level methods, and serialization from a method on the model,into instance methods on a new `SerializationStrategy` class. Use the existing parsing code for the vertex strategy and copy over code form the developer SDK for the Google AI strategy. Add a `_ModelUri` class to allow a vertex and developer specialization for the different backends. Copy the serialization and parsing tests from the developer SDK, skip a few tests that rely on arguments or methods that are unsupported in the vertex model. Tests skipped for content embedding and extra arguments passed to countTokens.
1 parent 69cd2a6 commit 73e4064

13 files changed

+2192
-224
lines changed

packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart

+130-41
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import 'content.dart';
1616
import 'error.dart';
17+
import 'function_calling.dart' show Tool, ToolConfig;
1718
import 'schema.dart';
1819

1920
/// Response for Count Tokens
@@ -155,6 +156,23 @@ final class UsageMetadata {
155156
final List<ModalityTokenCount>? candidatesTokensDetails;
156157
}
157158

159+
/// Constructe a UsageMetadata with all it's fields.
160+
///
161+
/// Expose access to the private constructor for use within the package..
162+
UsageMetadata createUsageMetadata({
163+
required int? promptTokenCount,
164+
required int? candidatesTokenCount,
165+
required int? totalTokenCount,
166+
required List<ModalityTokenCount>? promptTokensDetails,
167+
required List<ModalityTokenCount>? candidatesTokensDetails,
168+
}) =>
169+
UsageMetadata._(
170+
promptTokenCount: promptTokenCount,
171+
candidatesTokenCount: candidatesTokenCount,
172+
totalTokenCount: totalTokenCount,
173+
promptTokensDetails: promptTokensDetails,
174+
candidatesTokensDetails: candidatesTokensDetails);
175+
158176
/// Response candidate generated from a [GenerativeModel].
159177
final class Candidate {
160178
// TODO: token count?
@@ -842,53 +860,124 @@ enum TaskType {
842860
Object toJson() => _jsonString;
843861
}
844862

845-
/// Parse the json to [GenerateContentResponse]
846-
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
847-
if (jsonObject case {'error': final Object error}) throw parseError(error);
848-
final candidates = switch (jsonObject) {
849-
{'candidates': final List<Object?> candidates} =>
850-
candidates.map(_parseCandidate).toList(),
851-
_ => <Candidate>[]
852-
};
853-
final promptFeedback = switch (jsonObject) {
854-
{'promptFeedback': final promptFeedback?} =>
855-
_parsePromptFeedback(promptFeedback),
856-
_ => null,
857-
};
858-
final usageMedata = switch (jsonObject) {
859-
{'usageMetadata': final usageMetadata?} =>
860-
_parseUsageMetadata(usageMetadata),
861-
_ => null,
862-
};
863-
return GenerateContentResponse(candidates, promptFeedback,
864-
usageMetadata: usageMedata);
863+
// ignore: public_member_api_docs
864+
abstract interface class SerializationStrategy {
865+
// ignore: public_member_api_docs
866+
GenerateContentResponse parseGenerateContentResponse(Object jsonObject);
867+
// ignore: public_member_api_docs
868+
CountTokensResponse parseCountTokensResponse(Object jsonObject);
869+
// ignore: public_member_api_docs
870+
Map<String, Object?> generateContentRequest(
871+
Iterable<Content> contents,
872+
({String prefix, String name}) model,
873+
List<SafetySetting> safetySettings,
874+
GenerationConfig? generationConfig,
875+
List<Tool>? tools,
876+
ToolConfig? toolConfig,
877+
Content? systemInstruction,
878+
);
879+
880+
// ignore: public_member_api_docs
881+
Map<String, Object?> countTokensRequest(
882+
Iterable<Content> contents,
883+
({String prefix, String name}) model,
884+
List<SafetySetting> safetySettings,
885+
GenerationConfig? generationConfig,
886+
List<Tool>? tools,
887+
ToolConfig? toolConfig,
888+
);
865889
}
866890

867-
/// Parse the json to [CountTokensResponse]
868-
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
869-
if (jsonObject case {'error': final Object error}) throw parseError(error);
891+
// ignore: public_member_api_docs
892+
final class VertexSerialization implements SerializationStrategy {
893+
/// Parse the json to [GenerateContentResponse]
894+
@override
895+
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
896+
if (jsonObject case {'error': final Object error}) throw parseError(error);
897+
final candidates = switch (jsonObject) {
898+
{'candidates': final List<Object?> candidates} =>
899+
candidates.map(_parseCandidate).toList(),
900+
_ => <Candidate>[]
901+
};
902+
final promptFeedback = switch (jsonObject) {
903+
{'promptFeedback': final promptFeedback?} =>
904+
_parsePromptFeedback(promptFeedback),
905+
_ => null,
906+
};
907+
final usageMedata = switch (jsonObject) {
908+
{'usageMetadata': final usageMetadata?} =>
909+
_parseUsageMetadata(usageMetadata),
910+
{'totalTokens': final int totalTokens} =>
911+
UsageMetadata._(totalTokenCount: totalTokens),
912+
_ => null,
913+
};
914+
return GenerateContentResponse(candidates, promptFeedback,
915+
usageMetadata: usageMedata);
916+
}
870917

871-
if (jsonObject is! Map) {
872-
throw unhandledFormat('CountTokensResponse', jsonObject);
918+
/// Parse the json to [CountTokensResponse]
919+
@override
920+
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
921+
if (jsonObject case {'error': final Object error}) throw parseError(error);
922+
923+
if (jsonObject is! Map) {
924+
throw unhandledFormat('CountTokensResponse', jsonObject);
925+
}
926+
927+
final totalTokens = jsonObject['totalTokens'] as int;
928+
final totalBillableCharacters = switch (jsonObject) {
929+
{'totalBillableCharacters': final int totalBillableCharacters} =>
930+
totalBillableCharacters,
931+
_ => null,
932+
};
933+
final promptTokensDetails = switch (jsonObject) {
934+
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
935+
promptTokensDetails.map(_parseModalityTokenCount).toList(),
936+
_ => null,
937+
};
938+
939+
return CountTokensResponse(
940+
totalTokens,
941+
totalBillableCharacters: totalBillableCharacters,
942+
promptTokensDetails: promptTokensDetails,
943+
);
873944
}
874945

875-
final totalTokens = jsonObject['totalTokens'] as int;
876-
final totalBillableCharacters = switch (jsonObject) {
877-
{'totalBillableCharacters': final int totalBillableCharacters} =>
878-
totalBillableCharacters,
879-
_ => null,
880-
};
881-
final promptTokensDetails = switch (jsonObject) {
882-
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
883-
promptTokensDetails.map(_parseModalityTokenCount).toList(),
884-
_ => null,
885-
};
946+
@override
947+
Map<String, Object?> generateContentRequest(
948+
Iterable<Content> contents,
949+
({String prefix, String name}) model,
950+
List<SafetySetting> safetySettings,
951+
GenerationConfig? generationConfig,
952+
List<Tool>? tools,
953+
ToolConfig? toolConfig,
954+
Content? systemInstruction,
955+
) {
956+
return {
957+
'model': '${model.prefix}/${model.name}',
958+
'contents': contents.map((c) => c.toJson()).toList(),
959+
if (safetySettings.isNotEmpty)
960+
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
961+
if (generationConfig != null)
962+
'generationConfig': generationConfig.toJson(),
963+
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
964+
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
965+
if (systemInstruction != null)
966+
'systemInstruction': systemInstruction.toJson(),
967+
};
968+
}
886969

887-
return CountTokensResponse(
888-
totalTokens,
889-
totalBillableCharacters: totalBillableCharacters,
890-
promptTokensDetails: promptTokensDetails,
891-
);
970+
@override
971+
Map<String, Object?> countTokensRequest(
972+
Iterable<Content> contents,
973+
({String prefix, String name}) model,
974+
List<SafetySetting> safetySettings,
975+
GenerationConfig? generationConfig,
976+
List<Tool>? tools,
977+
ToolConfig? toolConfig,
978+
) =>
979+
// Everything except contents is ignored.
980+
{'contents': contents.map((c) => c.toJson()).toList()};
892981
}
893982

894983
Candidate _parseCandidate(Object? jsonObject) {

packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart

+86-24
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import 'package:web_socket_channel/io.dart';
2626
import 'api.dart';
2727
import 'client.dart';
2828
import 'content.dart';
29+
import 'developer/api.dart';
2930
import 'function_calling.dart';
3031
import 'imagen_api.dart';
3132
import 'imagen_content.dart';
@@ -52,33 +53,28 @@ enum Task {
5253
predict,
5354
}
5455

55-
/// Base class for models.
56-
///
57-
/// Do not instantiate directly.
58-
abstract class BaseModel {
59-
// ignore: public_member_api_docs
60-
BaseModel(
56+
abstract interface class _ModelUri {
57+
String get baseAuthority;
58+
Uri taskUri(Task task);
59+
({String prefix, String name}) get model;
60+
}
61+
62+
final class _VertexUri implements _ModelUri {
63+
_VertexUri(
6164
{required String model,
6265
required String location,
6366
required FirebaseApp app})
64-
: _model = normalizeModelName(model),
67+
: model = _normalizeModelName(model),
6568
_projectUri = _vertexUri(app, location);
6669

67-
static const _baseUrl = 'firebasevertexai.googleapis.com';
70+
static const _baseAuthority = 'firebasevertexai.googleapis.com';
6871
static const _apiVersion = 'v1beta';
6972

70-
final ({String prefix, String name}) _model;
71-
72-
final Uri _projectUri;
73-
74-
/// The normalized model name.
75-
({String prefix, String name}) get model => _model;
76-
7773
/// Returns the model code for a user friendly model name.
7874
///
7975
/// If the model name is already a model code (contains a `/`), use the parts
8076
/// directly. Otherwise, return a `models/` model code.
81-
static ({String prefix, String name}) normalizeModelName(String modelName) {
77+
static ({String prefix, String name}) _normalizeModelName(String modelName) {
8278
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
8379
final parts = modelName.split('/');
8480
return (prefix: parts.first, name: parts.skip(1).join('/'));
@@ -87,11 +83,79 @@ abstract class BaseModel {
8783
static Uri _vertexUri(FirebaseApp app, String location) {
8884
var projectId = app.options.projectId;
8985
return Uri.https(
90-
_baseUrl,
86+
_baseAuthority,
9187
'/$_apiVersion/projects/$projectId/locations/$location/publishers/google',
9288
);
9389
}
9490

91+
final Uri _projectUri;
92+
@override
93+
final ({String prefix, String name}) model;
94+
95+
@override
96+
String get baseAuthority => _baseAuthority;
97+
98+
@override
99+
Uri taskUri(Task task) {
100+
return _projectUri.replace(
101+
pathSegments: _projectUri.pathSegments
102+
.followedBy([model.prefix, '${model.name}:${task.name}']));
103+
}
104+
}
105+
106+
final class _GoogleAIUri implements _ModelUri {
107+
_GoogleAIUri({
108+
required String model,
109+
required FirebaseApp app,
110+
}) : model = _normalizeModelName(model),
111+
_baseUri = _googleAIBaseUri(app: app);
112+
113+
/// Returns the model code for a user friendly model name.
114+
///
115+
/// If the model name is already a model code (contains a `/`), use the parts
116+
/// directly. Otherwise, return a `models/` model code.
117+
static ({String prefix, String name}) _normalizeModelName(String modelName) {
118+
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
119+
final parts = modelName.split('/');
120+
return (prefix: parts.first, name: parts.skip(1).join('/'));
121+
}
122+
123+
static const _apiVersion = 'v1beta';
124+
static const _baseAuthority = 'firebasevertexai.googleapis.com';
125+
static Uri _googleAIBaseUri(
126+
{String apiVersion = _apiVersion, required FirebaseApp app}) =>
127+
Uri.https(
128+
_baseAuthority, '$apiVersion/projects/${app.options.projectId}');
129+
final Uri _baseUri;
130+
131+
@override
132+
final ({String prefix, String name}) model;
133+
134+
@override
135+
String get baseAuthority => _baseAuthority;
136+
137+
@override
138+
Uri taskUri(Task task) => _baseUri.replace(
139+
pathSegments: _baseUri.pathSegments
140+
.followedBy([model.prefix, '${model.name}:${task.name}']));
141+
}
142+
143+
/// Base class for models.
144+
///
145+
/// Do not instantiate directly.
146+
abstract class BaseModel {
147+
BaseModel._(
148+
{required SerializationStrategy serializationStrategy,
149+
required _ModelUri modelUri})
150+
: _serializationStrategy = serializationStrategy,
151+
_modelUri = modelUri;
152+
153+
final SerializationStrategy _serializationStrategy;
154+
final _ModelUri _modelUri;
155+
156+
/// The normalized model name.
157+
({String prefix, String name}) get model => _modelUri.model;
158+
95159
/// Returns a function that generates Firebase auth tokens.
96160
static FutureOr<Map<String, String>> Function() firebaseTokens(
97161
FirebaseAppCheck? appCheck, FirebaseAuth? auth, FirebaseApp? app) {
@@ -120,9 +184,7 @@ abstract class BaseModel {
120184
}
121185

122186
/// Returns a URI for the given [task].
123-
Uri taskUri(Task task) => _projectUri.replace(
124-
pathSegments: _projectUri.pathSegments
125-
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
187+
Uri taskUri(Task task) => _modelUri.taskUri(task);
126188
}
127189

128190
/// An abstract base class for models that interact with an API using an [ApiClient].
@@ -136,11 +198,11 @@ abstract class BaseModel {
136198
abstract class BaseApiClientModel extends BaseModel {
137199
// ignore: public_member_api_docs
138200
BaseApiClientModel({
139-
required super.model,
140-
required super.location,
141-
required super.app,
201+
required super.serializationStrategy,
202+
required super.modelUri,
142203
required ApiClient client,
143-
}) : _client = client;
204+
}) : _client = client,
205+
super._();
144206

145207
final ApiClient _client;
146208

0 commit comments

Comments
 (0)