Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion java-vertexai/google-cloud-vertexai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,18 @@
<artifactId>truth</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</dependency>
<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>grpc-google-cloud-vertexai-v1</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.google.cloud.vertexai.generativeai;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.google.cloud.vertexai.api.Candidate;
import com.google.cloud.vertexai.api.Candidate.FinishReason;
import com.google.cloud.vertexai.api.Citation;
Expand All @@ -25,10 +28,12 @@
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.util.JsonFormat;

import java.io.IOException;
import java.util.*;

/**
* Helper class to post-process GenerateContentResponse.
Expand All @@ -39,6 +44,13 @@
*/
@Deprecated
public class ResponseHandler {
private static final ObjectMapper DEFAULT_MAPPER = JsonMapper.builder()
.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS)
.enable(DeserializationFeature.USE_BIG_INTEGER_FOR_INTS)
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
.enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS)
.build();

/**
* Gets the text message in a GenerateContentResponse.
Expand All @@ -60,6 +72,46 @@ public static String getText(GenerateContentResponse response) {
return text;
}

/**
* Deserialises the text of a {@link GenerateContentResponse} into a strongly-typed value.
* Behaviour:
* - If {@code clazz} extends {@code com.google.protobuf.Message}, the JSON is merged into a
* new protobuf builder via {@link JsonFormat}.</li>
* - Otherwise the JSON is read with Jackson. If {@code customMapper} is present its
* configuration is used; when empty the library’s default {@code MAPPER} is applied.</li>
* @param response the Vertex AI response whose first candidate contains JSON
* @param clazz target class (protobuf message or POJO)
* @param customMapper optional Jackson {@link JsonMapper} to override the default settings
* @param <T> concrete return type
* @return an instance of {@code T} populated from the model output
* @throws IllegalArgumentException if reflection fails, the JSON is invalid, or the payload
* cannot be merged into the protobuf builder
*/
@SuppressWarnings("unchecked")
public static <T> T getStructuredResponse(
GenerateContentResponse response,
Class<T> clazz,
Optional<JsonMapper> customMapper) {

String text = getText(response);
if (com.google.protobuf.Message.class.isAssignableFrom(clazz)) {
try {
Message.Builder builder =
(Message.Builder) clazz.getMethod("newBuilder").invoke(null);
JsonFormat.parser().ignoringUnknownFields().merge(text, builder);
return (T) builder.build();
} catch (ReflectiveOperationException | InvalidProtocolBufferException e) {
throw new IllegalArgumentException("Parsing as protobuf failed", e);
}
}
ObjectMapper mapper = customMapper.orElse((JsonMapper) DEFAULT_MAPPER);
try {
return mapper.readValue(text, clazz);
} catch (IOException e) {
throw new IllegalArgumentException("JSON parse failed", e);
}
}

/**
* Gets the list of function calls in a GenerateContentResponse.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.google.cloud.vertexai.api.Candidate;
import com.google.cloud.vertexai.api.Candidate.FinishReason;
import com.google.cloud.vertexai.api.Citation;
Expand All @@ -31,6 +33,8 @@
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Optional;

import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -83,6 +87,29 @@ public final class ResponseHandlerTest {
.addCandidates(CANDIDATE_1)
.addCandidates(CANDIDATE_2)
.build();
private static final class ExampleDto {
public String name;
public int count;

// Default constructor for Jackson
public ExampleDto() {
}

ExampleDto(String name, int count) {
this.name = name;
this.count = count;
}
}

private static final String DTO_JSON = "{\"name\":\"vertex\",\"count\":42}";
private static final GenerateContentResponse DTO_RESPONSE =
GenerateContentResponse.newBuilder()
.addCandidates(
Candidate.newBuilder()
.setContent(
Content.newBuilder()
.addParts(Part.newBuilder().setText(DTO_JSON))))
.build();

@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();

Expand Down Expand Up @@ -166,4 +193,29 @@ public void testAggregateStreamIntoResponse() {
assertThat(response.getCandidates(0).getCitationMetadata().getCitationsList())
.isEqualTo(Arrays.asList(CITATION_1, CITATION_2));
}

@Test
public void testGetStructuredResponseWithCustomMapper() {
JsonMapper strictMapper =
JsonMapper.builder()
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.build();

ExampleDto dto =
ResponseHandler.getStructuredResponse(
DTO_RESPONSE, ExampleDto.class, Optional.of(strictMapper));

assertThat(dto.name).isEqualTo("vertex");
assertThat(dto.count).isEqualTo(42);
}

@Test
public void testGetStructuredResponseWithDefaultMapper() {
ExampleDto dto =
ResponseHandler.getStructuredResponse(
DTO_RESPONSE, ExampleDto.class, Optional.empty());

assertThat(dto.name).isEqualTo("vertex");
assertThat(dto.count).isEqualTo(42);
}
}
Loading