Skip to content

Commit f9e9297

Browse files
authored
Merge branch 'main' into feature/ollama-llm
2 parents c0fc0bc + 7d58e0d commit f9e9297

File tree

13 files changed

+287
-72
lines changed

13 files changed

+287
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
4646
"google-genai>=1.56.0, <2.0.0", # Google GenAI SDK
4747
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
48+
"httpx>=0.27.0, <1.0.0", # HTTP client library
4849
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
4950
"mcp>=1.23.0, <2.0.0", # For MCP Toolset
5051
"opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package.

src/google/adk/models/anthropic_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pydantic import BaseModel
3737
from typing_extensions import override
3838

39+
from ..utils._google_client_headers import get_tracking_headers
3940
from .base_llm import BaseLlm
4041
from .llm_response import LlmResponse
4142

@@ -345,4 +346,5 @@ def _anthropic_client(self) -> AsyncAnthropicVertex:
345346
return AsyncAnthropicVertex(
346347
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
347348
region=os.environ["GOOGLE_CLOUD_LOCATION"],
349+
default_headers=get_tracking_headers(),
348350
)

src/google/adk/models/google_llm.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from google.genai.errors import ClientError
3131
from typing_extensions import override
3232

33-
from ..utils._client_labels_utils import get_client_labels
33+
from ..utils._google_client_headers import get_tracking_headers
34+
from ..utils._google_client_headers import merge_tracking_headers
3435
from ..utils.context_utils import Aclosing
3536
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
@@ -316,13 +317,7 @@ def _api_backend(self) -> GoogleLLMVariant:
316317
)
317318

318319
def _tracking_headers(self) -> dict[str, str]:
319-
labels = get_client_labels()
320-
header_value = ' '.join(labels)
321-
tracking_headers = {
322-
'x-goog-api-client': header_value,
323-
'user-agent': header_value,
324-
}
325-
return tracking_headers
320+
return get_tracking_headers()
326321

327322
@cached_property
328323
def _live_api_version(self) -> str:
@@ -362,8 +357,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
362357
):
363358
if not llm_request.live_connect_config.http_options.headers:
364359
llm_request.live_connect_config.http_options.headers = {}
365-
llm_request.live_connect_config.http_options.headers.update(
366-
self._tracking_headers()
360+
llm_request.live_connect_config.http_options.headers = (
361+
self._merge_tracking_headers(
362+
llm_request.live_connect_config.http_options.headers
363+
)
367364
)
368365
llm_request.live_connect_config.http_options.api_version = (
369366
self._live_api_version
@@ -456,20 +453,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
456453

457454
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
458455
"""Merge tracking headers to the given headers."""
459-
headers = headers or {}
460-
for key, tracking_header_value in self._tracking_headers().items():
461-
custom_value = headers.get(key, None)
462-
if not custom_value:
463-
headers[key] = tracking_header_value
464-
continue
465-
466-
# Merge tracking headers with existing headers and avoid duplicates.
467-
value_parts = tracking_header_value.split(' ')
468-
for custom_value_part in custom_value.split(' '):
469-
if custom_value_part not in value_parts:
470-
value_parts.append(custom_value_part)
471-
headers[key] = ' '.join(value_parts)
472-
return headers
456+
return merge_tracking_headers(headers)
473457

474458

475459
def _build_function_declaration_log(

src/google/adk/models/lite_llm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pydantic import Field
5252
from typing_extensions import override
5353

54+
from ..utils._google_client_headers import merge_tracking_headers
5455
from .base_llm import BaseLlm
5556
from .llm_request import LlmRequest
5657
from .llm_response import LlmResponse
@@ -1699,6 +1700,18 @@ def _build_request_log(req: LlmRequest) -> str:
16991700
"""
17001701

17011702

1703+
def _is_litellm_vertex_model(model_string: str) -> bool:
1704+
"""Check if the model is a Vertex AI model accessed via LiteLLM.
1705+
1706+
Args:
1707+
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
1708+
1709+
Returns:
1710+
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
1711+
"""
1712+
return model_string.startswith("vertex_ai/")
1713+
1714+
17021715
def _is_litellm_gemini_model(model_string: str) -> bool:
17031716
"""Check if the model is a Gemini model accessed via LiteLLM.
17041717
@@ -1867,6 +1880,14 @@ async def generate_content_async(
18671880
}
18681881
completion_args.update(self._additional_args)
18691882

1883+
# merge headers
1884+
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
1885+
effective_model
1886+
):
1887+
completion_args["headers"] = merge_tracking_headers(
1888+
completion_args.get("headers")
1889+
)
1890+
18701891
if generation_params:
18711892
completion_args.update(generation_params)
18721893

src/google/adk/tools/openapi_tool/auth/auth_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from fastapi.openapi.models import OAuth2
2929
from fastapi.openapi.models import OpenIdConnect
3030
from fastapi.openapi.models import Schema
31+
import httpx
3132
from pydantic import BaseModel
3233
from pydantic import ValidationError
33-
import requests
3434

3535
from ....auth.auth_credential import AuthCredential
3636
from ....auth.auth_credential import AuthCredentialTypes
@@ -289,14 +289,14 @@ def openid_url_to_scheme_credential(
289289
Raises:
290290
ValueError: If the OpenID URL is invalid, fetching fails, or required
291291
fields are missing.
292-
requests.exceptions.RequestException: If there's an error during the
292+
httpx.HTTPStatusError or httpx.RequestError: If there's an error during the
293293
HTTP request.
294294
"""
295295
try:
296-
response = requests.get(openid_url, timeout=10)
296+
response = httpx.get(openid_url, timeout=10)
297297
response.raise_for_status()
298298
config_dict = response.json()
299-
except requests.exceptions.RequestException as e:
299+
except httpx.RequestError as e:
300300
raise ValueError(
301301
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
302302
) from e

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from fastapi.openapi.models import Operation
2929
from fastapi.openapi.models import Schema
3030
from google.genai.types import FunctionDeclaration
31-
import requests
31+
import httpx
3232
from typing_extensions import override
3333

3434
from ....agents.readonly_context import ReadonlyContext
@@ -312,7 +312,7 @@ def _prepare_request_params(
312312
313313
Returns:
314314
A dictionary containing the request parameters for the API call. This
315-
initializes a requests.request() call.
315+
initializes an httpx.AsyncClient.request() call.
316316
317317
Example:
318318
self._prepare_request_params({"input_id": "test-id"})
@@ -497,17 +497,7 @@ async def call(
497497
if provider_headers:
498498
request_params.setdefault("headers", {}).update(provider_headers)
499499

500-
# Log the API request
501-
self._logger.debug(
502-
"API Request: %s %s",
503-
request_params.get("method", "").upper(),
504-
request_params.get("url", ""),
505-
)
506-
self._logger.debug("API Request params: %s", request_params.get("params"))
507-
if "json" in request_params:
508-
self._logger.debug("API Request body: %s", request_params.get("json"))
509-
510-
response = requests.request(**request_params)
500+
response = await _request(**request_params)
511501

512502
# Log the API response
513503
self._logger.debug(
@@ -519,11 +509,9 @@ async def call(
519509

520510
# Parse API response
521511
try:
522-
response.raise_for_status() # Raise HTTPError for bad responses
523-
result = response.json() # Try to decode JSON
524-
self._logger.debug("API Response body: %s", result)
525-
return result
526-
except requests.exceptions.HTTPError:
512+
response.raise_for_status() # Raise HTTPStatusError for bad responses
513+
return response.json() # Try to decode JSON
514+
except httpx.HTTPStatusError:
527515
error_details = response.content.decode("utf-8")
528516
self._logger.warning(
529517
"API call failed for tool %s: Status %d - %s",
@@ -556,3 +544,10 @@ def __repr__(self):
556544
f' auth_scheme="{self.auth_scheme}",'
557545
f' auth_credential="{self.auth_credential}")'
558546
)
547+
548+
549+
async def _request(**request_params) -> httpx.Response:
550+
async with httpx.AsyncClient(
551+
verify=request_params.pop("verify", True)
552+
) as client:
553+
return await client.request(**request_params)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from ._client_labels_utils import get_client_labels
18+
19+
20+
def get_tracking_headers() -> dict[str, str]:
21+
"""Returns a dictionary of HTTP headers for tracking API requests.
22+
23+
These headers are used to identify HTTP calls made by ADK towards
24+
Vertex AI LLM APIs.
25+
"""
26+
labels = get_client_labels()
27+
header_value = " ".join(labels)
28+
return {
29+
"x-goog-api-client": header_value,
30+
"user-agent": header_value,
31+
}
32+
33+
34+
def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]:
35+
"""Merge tracking headers to the given headers.
36+
37+
Args:
38+
headers: headers to merge tracking headers into.
39+
40+
Returns:
41+
A dictionary of HTTP headers with tracking headers merged.
42+
"""
43+
new_headers = (headers or {}).copy()
44+
for key, tracking_header_value in get_tracking_headers().items():
45+
custom_value = new_headers.get(key, None)
46+
if not custom_value:
47+
new_headers[key] = tracking_header_value
48+
continue
49+
50+
# Merge tracking headers with existing headers and avoid duplicates.
51+
value_parts = tracking_header_value.split(" ")
52+
for custom_value_part in custom_value.split(" "):
53+
if custom_value_part not in value_parts:
54+
value_parts.append(custom_value_part)
55+
new_headers[key] = " ".join(value_parts)
56+
return new_headers

tests/unittests/models/test_anthropic_llm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,31 @@ async def mock_coro():
391391
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
392392

393393

394+
def test_claude_vertex_client_uses_tracking_headers():
395+
"""Tests that Claude vertex client is called with tracking headers."""
396+
with mock.patch.object(
397+
anthropic_llm, "AsyncAnthropicVertex", autospec=True
398+
) as mock_anthropic_vertex:
399+
with mock.patch.dict(
400+
os.environ,
401+
{
402+
"GOOGLE_CLOUD_PROJECT": "test-project",
403+
"GOOGLE_CLOUD_LOCATION": "us-central1",
404+
},
405+
):
406+
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
407+
_ = instance._anthropic_client
408+
mock_anthropic_vertex.assert_called_once()
409+
_, kwargs = mock_anthropic_vertex.call_args
410+
assert "default_headers" in kwargs
411+
assert "x-goog-api-client" in kwargs["default_headers"]
412+
assert "user-agent" in kwargs["default_headers"]
413+
assert (
414+
f"google-adk/{adk_version.__version__}"
415+
in kwargs["default_headers"]["user-agent"]
416+
)
417+
418+
394419
@pytest.mark.asyncio
395420
async def test_generate_content_async_with_max_tokens(
396421
llm_request, generate_content_response, generate_llm_response

tests/unittests/models/test_google_llm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.adk.models.llm_response import LlmResponse
3232
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
3333
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
34+
from google.adk.utils._google_client_headers import get_tracking_headers
3435
from google.adk.utils.variant_utils import GoogleLLMVariant
3536
from google.genai import types
3637
from google.genai.errors import ClientError
@@ -469,7 +470,7 @@ async def test_generate_content_async_with_custom_headers(
469470
"""Test that tracking headers are updated when custom headers are provided."""
470471
# Add custom headers to the request config
471472
custom_headers = {"custom-header": "custom-value"}
472-
tracking_headers = gemini_llm._tracking_headers()
473+
tracking_headers = get_tracking_headers()
473474
for key in tracking_headers:
474475
custom_headers[key] = "custom " + tracking_headers[key]
475476
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
@@ -494,7 +495,7 @@ async def mock_coro():
494495
config_arg = call_args.kwargs["config"]
495496

496497
for key, value in config_arg.http_options.headers.items():
497-
tracking_headers = gemini_llm._tracking_headers()
498+
tracking_headers = get_tracking_headers()
498499
if key in tracking_headers:
499500
assert value == tracking_headers[key] + " custom"
500501
else:
@@ -545,7 +546,7 @@ async def mock_coro():
545546
config_arg = call_args.kwargs["config"]
546547

547548
expected_headers = custom_headers.copy()
548-
expected_headers.update(gemini_llm._tracking_headers())
549+
expected_headers.update(get_tracking_headers())
549550
assert config_arg.http_options.headers == expected_headers
550551

551552
assert len(responses) == 2
@@ -599,7 +600,7 @@ async def mock_coro():
599600
assert final_config.http_options is not None
600601
assert (
601602
final_config.http_options.headers["x-goog-api-client"]
602-
== gemini_llm._tracking_headers()["x-goog-api-client"]
603+
== get_tracking_headers()["x-goog-api-client"]
603604
)
604605

605606
assert len(responses) == 2 if stream else 1
@@ -633,7 +634,7 @@ def test_live_api_client_properties(gemini_llm):
633634
assert http_options.api_version == "v1beta1"
634635

635636
# Check that tracking headers are included
636-
tracking_headers = gemini_llm._tracking_headers()
637+
tracking_headers = get_tracking_headers()
637638
for key, value in tracking_headers.items():
638639
assert key in http_options.headers
639640
assert value in http_options.headers[key]
@@ -671,7 +672,7 @@ async def __aexit__(self, *args):
671672

672673
# Verify that tracking headers were merged with custom headers
673674
expected_headers = custom_headers.copy()
674-
expected_headers.update(gemini_llm._tracking_headers())
675+
expected_headers.update(get_tracking_headers())
675676
assert config_arg.http_options.headers == expected_headers
676677

677678
# Verify that API version was set

0 commit comments

Comments
 (0)