diff --git a/dotCMS/pom.xml b/dotCMS/pom.xml index 7b59167a09b..d54c16ab3c4 100644 --- a/dotCMS/pom.xml +++ b/dotCMS/pom.xml @@ -514,6 +514,21 @@ dev.langchain4j langchain4j-bedrock + + + dev.langchain4j + langchain4j-vertex-ai-gemini + + + org.checkerframework + checker-qual + + + com.google.android + annotations + + + jakarta.inject jakarta.inject-api diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java index 351c015344c..c7af07b3e8a 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java @@ -9,6 +9,8 @@ import dev.langchain4j.model.bedrock.BedrockCohereEmbeddingModel; import dev.langchain4j.model.bedrock.BedrockStreamingChatModel; import dev.langchain4j.model.bedrock.BedrockTitanEmbeddingModel; +import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel; +import dev.langchain4j.model.vertexai.VertexAiGeminiStreamingChatModel; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -36,8 +38,8 @@ * To add support for a new provider, add a case to each switch block below. * No other class needs to change. * - *

Supported providers: {@code openai}, {@code azure_openai}, {@code bedrock} - *

Planned: {@code vertex_ai} + *

Supported providers: {@code openai}, {@code azure_openai}, {@code bedrock}, {@code vertex_ai} + *

Note: {@code vertex_ai} supports chat only; embeddings and image are not available via LangChain4J. */ public class LangChain4jModelFactory { @@ -54,7 +56,8 @@ public static ChatModel buildChatModel(final ProviderConfig config) { return build(config, "chat", LangChain4jModelFactory::buildOpenAiChatModel, LangChain4jModelFactory::buildAzureOpenAiChatModel, - LangChain4jModelFactory::buildBedrockChatModel); + LangChain4jModelFactory::buildBedrockChatModel, + LangChain4jModelFactory::buildVertexAiChatModel); } /** @@ -68,7 +71,8 @@ public static StreamingChatModel buildStreamingChatModel(final ProviderConfig co return build(config, "chat", LangChain4jModelFactory::buildOpenAiStreamingChatModel, LangChain4jModelFactory::buildAzureOpenAiStreamingChatModel, - LangChain4jModelFactory::buildBedrockStreamingChatModel); + LangChain4jModelFactory::buildBedrockStreamingChatModel, + LangChain4jModelFactory::buildVertexAiStreamingChatModel); } /** @@ -82,7 +86,8 @@ public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) { return build(config, "embeddings", LangChain4jModelFactory::buildOpenAiEmbeddingModel, LangChain4jModelFactory::buildAzureOpenAiEmbeddingModel, - LangChain4jModelFactory::buildBedrockEmbeddingModel); + LangChain4jModelFactory::buildBedrockEmbeddingModel, + LangChain4jModelFactory::buildVertexAiEmbeddingModel); } /** @@ -96,14 +101,16 @@ public static ImageModel buildImageModel(final ProviderConfig config) { return build(config, "image", LangChain4jModelFactory::buildOpenAiImageModel, LangChain4jModelFactory::buildAzureOpenAiImageModel, - LangChain4jModelFactory::buildBedrockImageModel); + LangChain4jModelFactory::buildBedrockImageModel, + LangChain4jModelFactory::buildVertexAiImageModel); } private static T build(final ProviderConfig config, final String modelType, final Function openAiFn, final Function azureOpenAiFn, - final Function bedrockFn) { + final Function bedrockFn, + final Function vertexAiFn) { if (config == null || config.provider() == null) { throw new IllegalArgumentException("ProviderConfig or provider name is null for model type: " + modelType); } @@ -118,9 +125,17 @@ private static T build(final ProviderConfig config, case "bedrock": validateBedrock(config, modelType); return bedrockFn.apply(config); + case "vertex_ai": + // Throw UnsupportedOperationException before validating config fields — + // the operation itself is unsupported regardless of how the config is set. + if (!"chat".equals(modelType)) { + return vertexAiFn.apply(config); + } + validateVertexAi(config, modelType); + return vertexAiFn.apply(config); default: throw new IllegalArgumentException("Unsupported " + modelType + " provider: " - + config.provider() + ". Supported: openai, azure_openai, bedrock"); + + config.provider() + ". Supported: openai, azure_openai, bedrock, vertex_ai"); } } @@ -142,6 +157,11 @@ private static void validateBedrock(final ProviderConfig config, final String mo requireNonBlank(config.region(), "region", modelType); } + private static void validateVertexAi(final ProviderConfig config, final String modelType) { + requireNonBlank(config.projectId(), "projectId", modelType); + requireNonBlank(config.location(), "location", modelType); + } + private static void requireNonBlank(final String value, final String field, final String modelType) { if (value == null || value.isBlank()) { throw new IllegalArgumentException( @@ -343,4 +363,41 @@ private static ImageModel buildBedrockImageModel(final ProviderConfig config) { "Image generation is not supported for Bedrock provider via LangChain4J"); } + // ── Vertex AI builders ──────────────────────────────────────────────────── + + private static ChatModel buildVertexAiChatModel(final ProviderConfig config) { + final VertexAiGeminiChatModel.VertexAiGeminiChatModelBuilder builder = + VertexAiGeminiChatModel.builder() + .project(config.projectId()) + .location(config.location()) + .modelName(config.model()); + if (config.maxRetries() != null) builder.maxRetries(config.maxRetries()); + if (config.temperature() != null) builder.temperature(config.temperature().floatValue()); + if (config.maxTokens() != null) builder.maxOutputTokens(config.maxTokens()); + // timeout and streaming maxRetries not exposed by VertexAiGemini builders in LangChain4J + return builder.build(); + } + + private static StreamingChatModel buildVertexAiStreamingChatModel(final ProviderConfig config) { + final VertexAiGeminiStreamingChatModel.VertexAiGeminiStreamingChatModelBuilder builder = + VertexAiGeminiStreamingChatModel.builder() + .project(config.projectId()) + .location(config.location()) + .modelName(config.model()); + if (config.temperature() != null) builder.temperature(config.temperature().floatValue()); + if (config.maxTokens() != null) builder.maxOutputTokens(config.maxTokens()); + // maxRetries and timeout not exposed by VertexAiGeminiStreamingChatModel builder in LangChain4J + return builder.build(); + } + + private static EmbeddingModel buildVertexAiEmbeddingModel(final ProviderConfig config) { + throw new UnsupportedOperationException( + "Embeddings are not supported for Vertex AI provider via LangChain4J"); + } + + private static ImageModel buildVertexAiImageModel(final ProviderConfig config) { + throw new UnsupportedOperationException( + "Image generation is not supported for Vertex AI provider via LangChain4J"); + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/ProviderConfig.java b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/ProviderConfig.java index 5b21d240436..c56a576833f 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/ProviderConfig.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/ProviderConfig.java @@ -41,11 +41,13 @@ *

  • {@code embeddingInputType} – Cohere only: {@code search_document} (default) or {@code search_query}
  • * * - *

    Google Vertex AI: + *

    Google Vertex AI (chat only — embeddings and image not supported via LangChain4J): *

    + *

    Auth is handled automatically via Application Default Credentials (ADC). + * No API key is required. */ @Value.Immutable @JsonSerialize(as = ImmutableProviderConfig.class) diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java index c78d5dbff55..64b25b37698 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java @@ -82,6 +82,26 @@ public void test_buildChatModel_bedrock_missingRegion_throws() { assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); } + @Test + public void test_buildChatModel_vertexAi_missingProjectId_throws() { + final ProviderConfig config = ImmutableProviderConfig.builder() + .provider("vertex_ai") + .model("gemini-1.5-pro") + .location("us-central1") + .build(); + assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); + } + + @Test + public void test_buildChatModel_vertexAi_missingLocation_throws() { + final ProviderConfig config = ImmutableProviderConfig.builder() + .provider("vertex_ai") + .model("gemini-1.5-pro") + .projectId("my-gcp-project") + .build(); + assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); + } + @Test public void test_buildChatModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -121,6 +141,12 @@ public void test_buildEmbeddingModel_bedrock_cohere_returnsModel() { assertNotNull(model); } + @Test + public void test_buildEmbeddingModel_vertexAi_throws() { + assertThrows(UnsupportedOperationException.class, + () -> LangChain4jModelFactory.buildEmbeddingModel(vertexAiConfig("text-embedding-004"))); + } + @Test public void test_buildEmbeddingModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -154,6 +180,12 @@ public void test_buildImageModel_bedrock_throws() { () -> LangChain4jModelFactory.buildImageModel(bedrockConfig("stability.stable-diffusion-xl-v1"))); } + @Test + public void test_buildImageModel_vertexAi_throws() { + assertThrows(UnsupportedOperationException.class, + () -> LangChain4jModelFactory.buildImageModel(vertexAiConfig("imagen-3.0"))); + } + @Test public void test_buildImageModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -172,6 +204,15 @@ private static ProviderConfig openAiConfig(final String model) { .build(); } + private static ProviderConfig vertexAiConfig(final String model) { + return ImmutableProviderConfig.builder() + .provider("vertex_ai") + .model(model) + .projectId("my-gcp-project") + .location("us-central1") + .build(); + } + private static ProviderConfig bedrockConfig(final String model) { return ImmutableProviderConfig.builder() .provider("bedrock")