Skip to content

feat(vertexai): Allow serializing to Developer API models #17294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
171 changes: 130 additions & 41 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import 'content.dart';
import 'error.dart';
import 'function_calling.dart' show Tool, ToolConfig;
import 'schema.dart';

/// Response for Count Tokens
Expand Down Expand Up @@ -155,6 +156,23 @@ final class UsageMetadata {
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Constructe a UsageMetadata with all it's fields.
///
/// Expose access to the private constructor for use within the package..
UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
// TODO: token count?
Expand Down Expand Up @@ -842,53 +860,124 @@ enum TaskType {
Object toJson() => _jsonString;
}

/// Parse the json to [GenerateContentResponse]
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
final candidates = switch (jsonObject) {
{'candidates': final List<Object?> candidates} =>
candidates.map(_parseCandidate).toList(),
_ => <Candidate>[]
};
final promptFeedback = switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
// ignore: public_member_api_docs
abstract interface class SerializationStrategy {
// ignore: public_member_api_docs
GenerateContentResponse parseGenerateContentResponse(Object jsonObject);
// ignore: public_member_api_docs
CountTokensResponse parseCountTokensResponse(Object jsonObject);
// ignore: public_member_api_docs
Map<String, Object?> generateContentRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
);

// ignore: public_member_api_docs
Map<String, Object?> countTokensRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
);
}

/// Parse the json to [CountTokensResponse]
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
// ignore: public_member_api_docs
final class VertexSerialization implements SerializationStrategy {
/// Parse the json to [GenerateContentResponse]
@override
GenerateContentResponse parseGenerateContentResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);
final candidates = switch (jsonObject) {
{'candidates': final List<Object?> candidates} =>
candidates.map(_parseCandidate).toList(),
_ => <Candidate>[]
};
final promptFeedback = switch (jsonObject) {
{'promptFeedback': final promptFeedback?} =>
_parsePromptFeedback(promptFeedback),
_ => null,
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
{'totalTokens': final int totalTokens} =>
UsageMetadata._(totalTokenCount: totalTokens),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
usageMetadata: usageMedata);
}

if (jsonObject is! Map) {
throw unhandledFormat('CountTokensResponse', jsonObject);
/// Parse the json to [CountTokensResponse]
@override
CountTokensResponse parseCountTokensResponse(Object jsonObject) {
if (jsonObject case {'error': final Object error}) throw parseError(error);

if (jsonObject is! Map) {
throw unhandledFormat('CountTokensResponse', jsonObject);
}

final totalTokens = jsonObject['totalTokens'] as int;
final totalBillableCharacters = switch (jsonObject) {
{'totalBillableCharacters': final int totalBillableCharacters} =>
totalBillableCharacters,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};

return CountTokensResponse(
totalTokens,
totalBillableCharacters: totalBillableCharacters,
promptTokensDetails: promptTokensDetails,
);
}

final totalTokens = jsonObject['totalTokens'] as int;
final totalBillableCharacters = switch (jsonObject) {
{'totalBillableCharacters': final int totalBillableCharacters} =>
totalBillableCharacters,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
_ => null,
};
@override
Map<String, Object?> generateContentRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
) {
return {
'model': '${model.prefix}/${model.name}',
'contents': contents.map((c) => c.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (systemInstruction != null)
'systemInstruction': systemInstruction.toJson(),
};
}

return CountTokensResponse(
totalTokens,
totalBillableCharacters: totalBillableCharacters,
promptTokensDetails: promptTokensDetails,
);
@override
Map<String, Object?> countTokensRequest(
Iterable<Content> contents,
({String prefix, String name}) model,
List<SafetySetting> safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
) =>
// Everything except contents is ignored.
{'contents': contents.map((c) => c.toJson()).toList()};
}

Candidate _parseCandidate(Object? jsonObject) {
Expand Down
110 changes: 86 additions & 24 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import 'package:web_socket_channel/io.dart';
import 'api.dart';
import 'client.dart';
import 'content.dart';
import 'developer/api.dart';
import 'function_calling.dart';
import 'imagen_api.dart';
import 'imagen_content.dart';
Expand All @@ -52,33 +53,28 @@ enum Task {
predict,
}

/// Base class for models.
///
/// Do not instantiate directly.
abstract class BaseModel {
// ignore: public_member_api_docs
BaseModel(
abstract interface class _ModelUri {
String get baseAuthority;
Uri taskUri(Task task);
({String prefix, String name}) get model;
}

final class _VertexUri implements _ModelUri {
_VertexUri(
{required String model,
required String location,
required FirebaseApp app})
: _model = normalizeModelName(model),
: model = _normalizeModelName(model),
_projectUri = _vertexUri(app, location);

static const _baseUrl = 'firebasevertexai.googleapis.com';
static const _baseAuthority = 'firebasevertexai.googleapis.com';
static const _apiVersion = 'v1beta';

final ({String prefix, String name}) _model;

final Uri _projectUri;

/// The normalized model name.
({String prefix, String name}) get model => _model;

/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
/// directly. Otherwise, return a `models/` model code.
static ({String prefix, String name}) normalizeModelName(String modelName) {
static ({String prefix, String name}) _normalizeModelName(String modelName) {
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
final parts = modelName.split('/');
return (prefix: parts.first, name: parts.skip(1).join('/'));
Expand All @@ -87,11 +83,79 @@ abstract class BaseModel {
static Uri _vertexUri(FirebaseApp app, String location) {
var projectId = app.options.projectId;
return Uri.https(
_baseUrl,
_baseAuthority,
'/$_apiVersion/projects/$projectId/locations/$location/publishers/google',
);
}

final Uri _projectUri;
@override
final ({String prefix, String name}) model;

@override
String get baseAuthority => _baseAuthority;

@override
Uri taskUri(Task task) {
return _projectUri.replace(
pathSegments: _projectUri.pathSegments
.followedBy([model.prefix, '${model.name}:${task.name}']));
}
}

final class _GoogleAIUri implements _ModelUri {
_GoogleAIUri({
required String model,
required FirebaseApp app,
}) : model = _normalizeModelName(model),
_baseUri = _googleAIBaseUri(app: app);

/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
/// directly. Otherwise, return a `models/` model code.
static ({String prefix, String name}) _normalizeModelName(String modelName) {
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
final parts = modelName.split('/');
return (prefix: parts.first, name: parts.skip(1).join('/'));
}

static const _apiVersion = 'v1beta';
static const _baseAuthority = 'firebasevertexai.googleapis.com';
static Uri _googleAIBaseUri(
{String apiVersion = _apiVersion, required FirebaseApp app}) =>
Uri.https(
_baseAuthority, '$apiVersion/projects/${app.options.projectId}');
final Uri _baseUri;

@override
final ({String prefix, String name}) model;

@override
String get baseAuthority => _baseAuthority;

@override
Uri taskUri(Task task) => _baseUri.replace(
pathSegments: _baseUri.pathSegments
.followedBy([model.prefix, '${model.name}:${task.name}']));
}

/// Base class for models.
///
/// Do not instantiate directly.
abstract class BaseModel {
BaseModel._(
{required SerializationStrategy serializationStrategy,
required _ModelUri modelUri})
: _serializationStrategy = serializationStrategy,
_modelUri = modelUri;

final SerializationStrategy _serializationStrategy;
final _ModelUri _modelUri;

/// The normalized model name.
({String prefix, String name}) get model => _modelUri.model;

/// Returns a function that generates Firebase auth tokens.
static FutureOr<Map<String, String>> Function() firebaseTokens(
FirebaseAppCheck? appCheck, FirebaseAuth? auth, FirebaseApp? app) {
Expand Down Expand Up @@ -120,9 +184,7 @@ abstract class BaseModel {
}

/// Returns a URI for the given [task].
Uri taskUri(Task task) => _projectUri.replace(
pathSegments: _projectUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
Uri taskUri(Task task) => _modelUri.taskUri(task);
}

/// An abstract base class for models that interact with an API using an [ApiClient].
Expand All @@ -136,11 +198,11 @@ abstract class BaseModel {
abstract class BaseApiClientModel extends BaseModel {
// ignore: public_member_api_docs
BaseApiClientModel({
required super.model,
required super.location,
required super.app,
required super.serializationStrategy,
required super.modelUri,
required ApiClient client,
}) : _client = client;
}) : _client = client,
super._();

final ApiClient _client;

Expand Down
Loading
Loading