Skip to content

Commit 1a6452b

Browse files
cleop-googlecopybara-github
authored andcommitted
fix: GenAI SDK client(multimodal) - Fix Pydantic validation errors when using create_* in some cases
PiperOrigin-RevId: 901268856
1 parent 6332d33 commit 1a6452b

2 files changed

Lines changed: 33 additions & 16 deletions

File tree

vertexai/_genai/_datasets_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import google.auth.credentials
2323
from vertexai._genai.types import common
24-
from pydantic import BaseModel
24+
from google.genai import _common
2525

2626

2727
METADATA_SCHEMA_URI = (
@@ -31,18 +31,27 @@
3131
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
3232
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"
3333

34-
T = TypeVar("T", bound=BaseModel)
34+
T = TypeVar("T", bound=_common.BaseModel)
3535

3636

37-
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
37+
def create_from_response(
38+
model_type: Type[T],
39+
response: dict[str, Any],
40+
config: Any | None = None,
41+
) -> T:
3842
"""Creates a model from a response."""
39-
model_field_names = model_type.model_fields.keys()
40-
filtered_response = {}
41-
for key, value in response.items():
42-
snake_key = common.camel_to_snake(key)
43-
if snake_key in model_field_names:
44-
filtered_response[snake_key] = value
45-
return model_type(**filtered_response)
43+
kwargs = (
44+
{
45+
"config": {
46+
"response_schema": getattr(config, "response_schema", None),
47+
"response_json_schema": getattr(config, "response_json_schema", None),
48+
"include_all_fields": getattr(config, "include_all_fields", None),
49+
}
50+
}
51+
if config
52+
else {}
53+
)
54+
return model_type._from_response(response=response, kwargs=kwargs)
4655

4756

4857
def validate_multimodal_dataset_bigquery_uri(

vertexai/_genai/datasets.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,9 @@ def create_from_bigquery(
963963
operation=multimodal_dataset_operation,
964964
timeout_seconds=config.timeout,
965965
)
966-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
966+
return _datasets_utils.create_from_response(
967+
types.MultimodalDataset, response, config
968+
)
967969

968970
def create_from_pandas(
969971
self,
@@ -1286,6 +1288,7 @@ def assess_tuning_resources(
12861288
return _datasets_utils.create_from_response(
12871289
types.TuningResourceUsageAssessmentResult,
12881290
response["tuningResourceUsageAssessmentResult"],
1291+
config,
12891292
)
12901293

12911294
def assess_tuning_validity(
@@ -1348,6 +1351,7 @@ def assess_tuning_validity(
13481351
return _datasets_utils.create_from_response(
13491352
types.TuningValidationAssessmentResult,
13501353
response["tuningValidationAssessmentResult"],
1354+
config,
13511355
)
13521356

13531357
def assess_batch_prediction_resources(
@@ -1406,7 +1410,7 @@ def assess_batch_prediction_resources(
14061410
)
14071411
result = response["batchPredictionResourceUsageAssessmentResult"]
14081412
return _datasets_utils.create_from_response(
1409-
types.BatchPredictionResourceUsageAssessmentResult, result
1413+
types.BatchPredictionResourceUsageAssessmentResult, result, config
14101414
)
14111415

14121416
def assess_batch_prediction_validity(
@@ -1465,7 +1469,7 @@ def assess_batch_prediction_validity(
14651469
)
14661470
result = response["batchPredictionValidationAssessmentResult"]
14671471
return _datasets_utils.create_from_response(
1468-
types.BatchPredictionValidationAssessmentResult, result
1472+
types.BatchPredictionValidationAssessmentResult, result, config
14691473
)
14701474

14711475

@@ -2203,7 +2207,9 @@ async def create_from_bigquery(
22032207
operation=multimodal_dataset_operation,
22042208
timeout_seconds=config.timeout,
22052209
)
2206-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
2210+
return _datasets_utils.create_from_response(
2211+
types.MultimodalDataset, response, config
2212+
)
22072213

22082214
async def create_from_pandas(
22092215
self,
@@ -2524,6 +2530,7 @@ async def assess_tuning_resources(
25242530
return _datasets_utils.create_from_response(
25252531
types.TuningResourceUsageAssessmentResult,
25262532
response["tuningResourceUsageAssessmentResult"],
2533+
config,
25272534
)
25282535

25292536
async def assess_tuning_validity(
@@ -2586,6 +2593,7 @@ async def assess_tuning_validity(
25862593
return _datasets_utils.create_from_response(
25872594
types.TuningValidationAssessmentResult,
25882595
response["tuningValidationAssessmentResult"],
2596+
config,
25892597
)
25902598

25912599
async def assess_batch_prediction_resources(
@@ -2644,7 +2652,7 @@ async def assess_batch_prediction_resources(
26442652
)
26452653
result = response["batchPredictionResourceUsageAssessmentResult"]
26462654
return _datasets_utils.create_from_response(
2647-
types.BatchPredictionResourceUsageAssessmentResult, result
2655+
types.BatchPredictionResourceUsageAssessmentResult, result, config
26482656
)
26492657

26502658
async def assess_batch_prediction_validity(
@@ -2703,5 +2711,5 @@ async def assess_batch_prediction_validity(
27032711
)
27042712
result = response["batchPredictionValidationAssessmentResult"]
27052713
return _datasets_utils.create_from_response(
2706-
types.BatchPredictionValidationAssessmentResult, result
2714+
types.BatchPredictionValidationAssessmentResult, result, config
27072715
)

0 commit comments

Comments
 (0)