Skip to content

Commit 9a9908a

Browse files
sararobcopybara-github
authored andcommitted
feat: Add UnifiedMetric support to Vertex Tuning evaluation config
PiperOrigin-RevId: 869192469
1 parent e15ad64 commit 9a9908a

File tree

3 files changed

+532
-132
lines changed

3 files changed

+532
-132
lines changed

google/genai/_transformers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,22 @@ def t_metrics(
12951295
metrics_payload = []
12961296

12971297
for metric in metrics:
1298+
1299+
if isinstance(metric, dict):
1300+
try:
1301+
metric = types.UnifiedMetric.model_validate(metric)
1302+
except pydantic.ValidationError:
1303+
pass
1304+
1305+
if isinstance(metric, types.UnifiedMetric):
1306+
unified_metric_payload: dict[str, Any] = metric.model_dump()
1307+
unified_metric_payload['aggregation_metrics'] = [
1308+
'AVERAGE',
1309+
'STANDARD_DEVIATION',
1310+
]
1311+
metrics_payload.append(unified_metric_payload)
1312+
continue
1313+
12981314
metric_payload_item: dict[str, Any] = {}
12991315
metric_payload_item['aggregation_metrics'] = [
13001316
'AVERAGE',

google/genai/tests/tunings/test_tune.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,114 @@ def test_eval_config_with_metrics(client):
344344
assert tuning_job.state == genai_types.JobState.JOB_STATE_PENDING
345345

346346

347+
@pytest.mark.skipif(
348+
"config.getoption('--private')",
349+
reason="Skipping in pre-public tests"
350+
)
351+
def test_eval_config_with_unified_metrics(client):
352+
"""Tests tuning with eval config metrics."""
353+
if client._api_client.vertexai:
354+
evaluation_config=genai_types.EvaluationConfig(
355+
metrics=[
356+
genai_types.UnifiedMetric(
357+
pointwise_metric_spec=genai_types.PointwiseMetricSpec(
358+
metric_prompt_template=(
359+
"How well does the response address the prompt?: "
360+
"PROMPT: {request}\n RESPONSE: {response}\n"
361+
),
362+
system_instruction=(
363+
"You are a cat. Make all evaluations from this perspective."
364+
),
365+
custom_output_format_config=genai_types.CustomOutputFormatConfig(
366+
return_raw_output=True
367+
),
368+
)
369+
),
370+
genai_types.UnifiedMetric(
371+
bleu_spec=genai_types.BleuSpec(use_effective_order=True)
372+
),
373+
genai_types.UnifiedMetric(
374+
rouge_spec=genai_types.RougeSpec(rouge_type="rouge1")
375+
),
376+
],
377+
output_config=genai_types.OutputConfig(
378+
gcs_destination=genai_types.GcsDestination(
379+
output_uri_prefix="gs://sararob_test/"
380+
)
381+
),
382+
autorater_config=genai_types.AutoraterConfig(
383+
sampling_count=1,
384+
autorater_model="test-model",
385+
),
386+
)
387+
tuning_job = client.tunings.tune(
388+
base_model="gemini-2.5-flash",
389+
training_dataset=genai_types.TuningDataset(gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl"),
390+
config=genai_types.CreateTuningJobConfig(
391+
tuned_model_display_name="tuning job with eval config",
392+
epoch_count=1,
393+
learning_rate_multiplier=1.0,
394+
adapter_size="ADAPTER_SIZE_ONE",
395+
validation_dataset=genai_types.TuningValidationDataset(
396+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
397+
),
398+
evaluation_config=evaluation_config,
399+
),
400+
)
401+
assert tuning_job.state == genai_types.JobState.JOB_STATE_PENDING
402+
403+
404+
@pytest.mark.skipif(
405+
"config.getoption('--private')",
406+
reason="Skipping in pre-public tests"
407+
)
408+
def test_eval_config_with_metrics_dict(client):
409+
"""Tests tuning with eval config metrics."""
410+
if client._api_client.vertexai:
411+
evaluation_config=genai_types.EvaluationConfig(
412+
metrics = [
413+
{
414+
"name": "prompt-relevance",
415+
"prompt_template": "How well does the response address the prompt?: PROMPT: {request}\n RESPONSE: {response}\n",
416+
"return_raw_output": True,
417+
"judge_model_system_instruction": "You are a cat. Make all evaluations from this perspective.",
418+
},
419+
{"name": "bleu"},
420+
{"name": "rouge_1"},
421+
{
422+
"bleu_spec": {
423+
"use_effective_order": True
424+
}
425+
},
426+
],
427+
output_config=genai_types.OutputConfig(
428+
gcs_destination=genai_types.GcsDestination(
429+
output_uri_prefix="gs://sararob_test/"
430+
)
431+
),
432+
autorater_config=genai_types.AutoraterConfig(
433+
sampling_count=1,
434+
autorater_model="test-model",
435+
),
436+
)
437+
tuning_job = client.tunings.tune(
438+
base_model="gemini-2.5-flash",
439+
training_dataset=genai_types.TuningDataset(gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl"),
440+
config=genai_types.CreateTuningJobConfig(
441+
tuned_model_display_name="tuning job with eval config",
442+
epoch_count=1,
443+
learning_rate_multiplier=1.0,
444+
adapter_size="ADAPTER_SIZE_ONE",
445+
validation_dataset=genai_types.TuningValidationDataset(
446+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
447+
),
448+
evaluation_config=evaluation_config,
449+
),
450+
)
451+
assert tuning_job.state == genai_types.JobState.JOB_STATE_PENDING
452+
453+
454+
347455
@pytest.mark.skipif(
348456
"config.getoption('--private')",
349457
reason="Skipping in pre-public tests"

0 commit comments

Comments
 (0)