Skip to content

Commit 7d58e0d

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Mark Vertex calls made from non-gemini models
PiperOrigin-RevId: 864253424
1 parent 9290b96 commit 7d58e0d

File tree

8 files changed

+231
-32
lines changed

8 files changed

+231
-32
lines changed

src/google/adk/models/anthropic_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pydantic import BaseModel
3737
from typing_extensions import override
3838

39+
from ..utils._google_client_headers import get_tracking_headers
3940
from .base_llm import BaseLlm
4041
from .llm_response import LlmResponse
4142

@@ -345,4 +346,5 @@ def _anthropic_client(self) -> AsyncAnthropicVertex:
345346
return AsyncAnthropicVertex(
346347
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
347348
region=os.environ["GOOGLE_CLOUD_LOCATION"],
349+
default_headers=get_tracking_headers(),
348350
)

src/google/adk/models/google_llm.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from google.genai.errors import ClientError
3131
from typing_extensions import override
3232

33-
from ..utils._client_labels_utils import get_client_labels
33+
from ..utils._google_client_headers import get_tracking_headers
34+
from ..utils._google_client_headers import merge_tracking_headers
3435
from ..utils.context_utils import Aclosing
3536
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
@@ -316,13 +317,7 @@ def _api_backend(self) -> GoogleLLMVariant:
316317
)
317318

318319
def _tracking_headers(self) -> dict[str, str]:
319-
labels = get_client_labels()
320-
header_value = ' '.join(labels)
321-
tracking_headers = {
322-
'x-goog-api-client': header_value,
323-
'user-agent': header_value,
324-
}
325-
return tracking_headers
320+
return get_tracking_headers()
326321

327322
@cached_property
328323
def _live_api_version(self) -> str:
@@ -362,8 +357,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
362357
):
363358
if not llm_request.live_connect_config.http_options.headers:
364359
llm_request.live_connect_config.http_options.headers = {}
365-
llm_request.live_connect_config.http_options.headers.update(
366-
self._tracking_headers()
360+
llm_request.live_connect_config.http_options.headers = (
361+
self._merge_tracking_headers(
362+
llm_request.live_connect_config.http_options.headers
363+
)
367364
)
368365
llm_request.live_connect_config.http_options.api_version = (
369366
self._live_api_version
@@ -456,20 +453,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
456453

457454
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
458455
"""Merge tracking headers to the given headers."""
459-
headers = headers or {}
460-
for key, tracking_header_value in self._tracking_headers().items():
461-
custom_value = headers.get(key, None)
462-
if not custom_value:
463-
headers[key] = tracking_header_value
464-
continue
465-
466-
# Merge tracking headers with existing headers and avoid duplicates.
467-
value_parts = tracking_header_value.split(' ')
468-
for custom_value_part in custom_value.split(' '):
469-
if custom_value_part not in value_parts:
470-
value_parts.append(custom_value_part)
471-
headers[key] = ' '.join(value_parts)
472-
return headers
456+
return merge_tracking_headers(headers)
473457

474458

475459
def _build_function_declaration_log(

src/google/adk/models/lite_llm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pydantic import Field
5252
from typing_extensions import override
5353

54+
from ..utils._google_client_headers import merge_tracking_headers
5455
from .base_llm import BaseLlm
5556
from .llm_request import LlmRequest
5657
from .llm_response import LlmResponse
@@ -1699,6 +1700,18 @@ def _build_request_log(req: LlmRequest) -> str:
16991700
"""
17001701

17011702

1703+
def _is_litellm_vertex_model(model_string: str) -> bool:
1704+
"""Check if the model is a Vertex AI model accessed via LiteLLM.
1705+
1706+
Args:
1707+
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
1708+
1709+
Returns:
1710+
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
1711+
"""
1712+
return model_string.startswith("vertex_ai/")
1713+
1714+
17021715
def _is_litellm_gemini_model(model_string: str) -> bool:
17031716
"""Check if the model is a Gemini model accessed via LiteLLM.
17041717
@@ -1867,6 +1880,14 @@ async def generate_content_async(
18671880
}
18681881
completion_args.update(self._additional_args)
18691882

1883+
# merge headers
1884+
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
1885+
effective_model
1886+
):
1887+
completion_args["headers"] = merge_tracking_headers(
1888+
completion_args.get("headers")
1889+
)
1890+
18701891
if generation_params:
18711892
completion_args.update(generation_params)
18721893

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from ._client_labels_utils import get_client_labels
18+
19+
20+
def get_tracking_headers() -> dict[str, str]:
21+
"""Returns a dictionary of HTTP headers for tracking API requests.
22+
23+
These headers are used to identify HTTP calls made by ADK towards
24+
Vertex AI LLM APIs.
25+
"""
26+
labels = get_client_labels()
27+
header_value = " ".join(labels)
28+
return {
29+
"x-goog-api-client": header_value,
30+
"user-agent": header_value,
31+
}
32+
33+
34+
def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]:
35+
"""Merge tracking headers to the given headers.
36+
37+
Args:
38+
headers: headers to merge tracking headers into.
39+
40+
Returns:
41+
A dictionary of HTTP headers with tracking headers merged.
42+
"""
43+
new_headers = (headers or {}).copy()
44+
for key, tracking_header_value in get_tracking_headers().items():
45+
custom_value = new_headers.get(key, None)
46+
if not custom_value:
47+
new_headers[key] = tracking_header_value
48+
continue
49+
50+
# Merge tracking headers with existing headers and avoid duplicates.
51+
value_parts = tracking_header_value.split(" ")
52+
for custom_value_part in custom_value.split(" "):
53+
if custom_value_part not in value_parts:
54+
value_parts.append(custom_value_part)
55+
new_headers[key] = " ".join(value_parts)
56+
return new_headers

tests/unittests/models/test_anthropic_llm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,31 @@ async def mock_coro():
391391
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
392392

393393

394+
def test_claude_vertex_client_uses_tracking_headers():
395+
"""Tests that Claude vertex client is called with tracking headers."""
396+
with mock.patch.object(
397+
anthropic_llm, "AsyncAnthropicVertex", autospec=True
398+
) as mock_anthropic_vertex:
399+
with mock.patch.dict(
400+
os.environ,
401+
{
402+
"GOOGLE_CLOUD_PROJECT": "test-project",
403+
"GOOGLE_CLOUD_LOCATION": "us-central1",
404+
},
405+
):
406+
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
407+
_ = instance._anthropic_client
408+
mock_anthropic_vertex.assert_called_once()
409+
_, kwargs = mock_anthropic_vertex.call_args
410+
assert "default_headers" in kwargs
411+
assert "x-goog-api-client" in kwargs["default_headers"]
412+
assert "user-agent" in kwargs["default_headers"]
413+
assert (
414+
f"google-adk/{adk_version.__version__}"
415+
in kwargs["default_headers"]["user-agent"]
416+
)
417+
418+
394419
@pytest.mark.asyncio
395420
async def test_generate_content_async_with_max_tokens(
396421
llm_request, generate_content_response, generate_llm_response

tests/unittests/models/test_google_llm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.adk.models.llm_response import LlmResponse
3232
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
3333
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
34+
from google.adk.utils._google_client_headers import get_tracking_headers
3435
from google.adk.utils.variant_utils import GoogleLLMVariant
3536
from google.genai import types
3637
from google.genai.errors import ClientError
@@ -469,7 +470,7 @@ async def test_generate_content_async_with_custom_headers(
469470
"""Test that tracking headers are updated when custom headers are provided."""
470471
# Add custom headers to the request config
471472
custom_headers = {"custom-header": "custom-value"}
472-
tracking_headers = gemini_llm._tracking_headers()
473+
tracking_headers = get_tracking_headers()
473474
for key in tracking_headers:
474475
custom_headers[key] = "custom " + tracking_headers[key]
475476
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
@@ -494,7 +495,7 @@ async def mock_coro():
494495
config_arg = call_args.kwargs["config"]
495496

496497
for key, value in config_arg.http_options.headers.items():
497-
tracking_headers = gemini_llm._tracking_headers()
498+
tracking_headers = get_tracking_headers()
498499
if key in tracking_headers:
499500
assert value == tracking_headers[key] + " custom"
500501
else:
@@ -545,7 +546,7 @@ async def mock_coro():
545546
config_arg = call_args.kwargs["config"]
546547

547548
expected_headers = custom_headers.copy()
548-
expected_headers.update(gemini_llm._tracking_headers())
549+
expected_headers.update(get_tracking_headers())
549550
assert config_arg.http_options.headers == expected_headers
550551

551552
assert len(responses) == 2
@@ -599,7 +600,7 @@ async def mock_coro():
599600
assert final_config.http_options is not None
600601
assert (
601602
final_config.http_options.headers["x-goog-api-client"]
602-
== gemini_llm._tracking_headers()["x-goog-api-client"]
603+
== get_tracking_headers()["x-goog-api-client"]
603604
)
604605

605606
assert len(responses) == 2 if stream else 1
@@ -633,7 +634,7 @@ def test_live_api_client_properties(gemini_llm):
633634
assert http_options.api_version == "v1beta1"
634635

635636
# Check that tracking headers are included
636-
tracking_headers = gemini_llm._tracking_headers()
637+
tracking_headers = get_tracking_headers()
637638
for key, value in tracking_headers.items():
638639
assert key in http_options.headers
639640
assert value in http_options.headers[key]
@@ -671,7 +672,7 @@ async def __aexit__(self, *args):
671672

672673
# Verify that tracking headers were merged with custom headers
673674
expected_headers = custom_headers.copy()
674-
expected_headers.update(gemini_llm._tracking_headers())
675+
expected_headers.update(get_tracking_headers())
675676
assert config_arg.http_options.headers == expected_headers
676677

677678
# Verify that API version was set

tests/unittests/models/test_litellm.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,11 +2849,12 @@ def test_model_response_to_chunk(
28492849
async def test_acompletion_additional_args(mock_acompletion, mock_client):
28502850
lite_llm_instance = LiteLlm(
28512851
# valid args
2852-
model="test_model",
2852+
model="vertex_ai/test_model",
28532853
llm_client=mock_client,
28542854
api_key="test_key",
28552855
api_base="some://url",
28562856
api_version="2024-09-12",
2857+
headers={"custom": "header"}, # Add custom header to test merge
28572858
# invalid args (ignored)
28582859
stream=True,
28592860
messages=[{"role": "invalid", "content": "invalid"}],
@@ -2880,13 +2881,43 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
28802881

28812882
_, kwargs = mock_acompletion.call_args
28822883

2883-
assert kwargs["model"] == "test_model"
2884+
assert kwargs["model"] == "vertex_ai/test_model"
28842885
assert kwargs["messages"][0]["role"] == "user"
28852886
assert kwargs["messages"][0]["content"] == "Test prompt"
28862887
assert kwargs["tools"][0]["function"]["name"] == "test_function"
28872888
assert "stream" not in kwargs
28882889
assert "llm_client" not in kwargs
28892890
assert kwargs["api_base"] == "some://url"
2891+
assert "headers" in kwargs
2892+
assert kwargs["headers"]["custom"] == "header"
2893+
assert "x-goog-api-client" in kwargs["headers"]
2894+
assert "user-agent" in kwargs["headers"]
2895+
2896+
2897+
@pytest.mark.asyncio
2898+
async def test_acompletion_additional_args_non_vertex(
2899+
mock_acompletion, mock_client
2900+
):
2901+
"""Test that tracking headers are not added for non-Vertex AI models."""
2902+
lite_llm_instance = LiteLlm(
2903+
model="openai/gpt-4o",
2904+
llm_client=mock_client,
2905+
api_key="test_key",
2906+
headers={"custom": "header"},
2907+
)
2908+
2909+
async for _ in lite_llm_instance.generate_content_async(
2910+
LLM_REQUEST_WITH_FUNCTION_DECLARATION
2911+
):
2912+
pass
2913+
2914+
mock_acompletion.assert_called_once()
2915+
_, kwargs = mock_acompletion.call_args
2916+
assert kwargs["model"] == "openai/gpt-4o"
2917+
assert "headers" in kwargs
2918+
assert kwargs["headers"]["custom"] == "header"
2919+
assert "x-goog-api-client" not in kwargs["headers"]
2920+
assert "user-agent" not in kwargs["headers"]
28902921

28912922

28922923
@pytest.mark.asyncio

0 commit comments

Comments
 (0)