Skip to content

Commit 37ee186

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Enhance BigQuery Plugin Robustness and Schema Accuracy
This update improves the `BigQueryAgentAnalyticsPlugin` in several ways: * Corrects the PyArrow schema generation to accurately reflect BigQuery field nullability based on the `mode` attribute. * Introduces a configurable `shutdown_timeout` in `BigQueryLoggerConfig` to manage how long the plugin waits for pending logs to flush during shutdown. * Adds more robust error handling within the `shutdown` method and background write tasks, particularly for event loop closure issues. * Improves internal logging to provide better diagnostics. * Ensures consistent use of safe content formatting. PiperOrigin-RevId: 831225837
1 parent a501c59 commit 37ee186

File tree

2 files changed

+202
-62
lines changed

2 files changed

+202
-62
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from typing import Callable
2424
from typing import List
2525
from typing import Optional
26-
from typing import Set
2726
from typing import TYPE_CHECKING
2827

2928
from google.api_core.gapic_v1 import client_info as gapic_client_info
@@ -173,10 +172,11 @@ def _bq_to_arrow_field(bq_field):
173172
metadata = _BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(
174173
bq_field.field_type.upper() if bq_field.field_type else ""
175174
)
175+
nullable = bq_field.mode.upper() != "REQUIRED"
176176
return pa.field(
177177
bq_field.name,
178178
arrow_type,
179-
nullable=(bq_field.mode != "REPEATED"),
179+
nullable=nullable,
180180
metadata=metadata,
181181
)
182182
logging.warning(
@@ -213,12 +213,18 @@ class BigQueryLoggerConfig:
213213
event_denylist: A list of event types to skip logging.
214214
content_formatter: An optional function to format event content before
215215
logging.
216+
shutdown_timeout: Seconds to wait for logs to flush during shutdown.
217+
client_close_timeout: Seconds to wait for BQ client to close.
218+
max_content_length: The maximum length of content parts before truncation.
216219
"""
217220

218221
enabled: bool = True
219222
event_allowlist: Optional[List[str]] = None
220223
event_denylist: Optional[List[str]] = None
221224
content_formatter: Optional[Callable[[Any], str]] = None
225+
shutdown_timeout: float = 5.0
226+
client_close_timeout: float = 2.0
227+
max_content_length: int = 500
222228

223229

224230
# --- Helper Formatters ---
@@ -313,16 +319,17 @@ def __init__(
313319
self._write_client: BigQueryWriteAsyncClient | None = None
314320
self._init_lock: asyncio.Lock | None = None
315321
self._arrow_schema: pa.Schema | None = None
316-
self._background_tasks: Set[asyncio.Task] = set() # Track pending logs
322+
self._background_tasks: set[asyncio.Task] = set()
323+
self._is_shutting_down = False
317324
self._schema = [
318-
bigquery.SchemaField("timestamp", "TIMESTAMP"),
319-
bigquery.SchemaField("event_type", "STRING"),
320-
bigquery.SchemaField("agent", "STRING"),
321-
bigquery.SchemaField("session_id", "STRING"),
322-
bigquery.SchemaField("invocation_id", "STRING"),
323-
bigquery.SchemaField("user_id", "STRING"),
324-
bigquery.SchemaField("content", "STRING"),
325-
bigquery.SchemaField("error_message", "STRING"),
325+
bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
326+
bigquery.SchemaField("event_type", "STRING", mode="NULLABLE"),
327+
bigquery.SchemaField("agent", "STRING", mode="NULLABLE"),
328+
bigquery.SchemaField("session_id", "STRING", mode="NULLABLE"),
329+
bigquery.SchemaField("invocation_id", "STRING", mode="NULLABLE"),
330+
bigquery.SchemaField("user_id", "STRING", mode="NULLABLE"),
331+
bigquery.SchemaField("content", "STRING", mode="NULLABLE"),
332+
bigquery.SchemaField("error_message", "STRING", mode="NULLABLE"),
326333
]
327334

328335
def _format_content_safely(
@@ -334,7 +341,7 @@ def _format_content_safely(
334341
try:
335342
if self._config.content_formatter:
336343
return self._config.content_formatter(content)
337-
return _format_content(content)
344+
return _format_content(content, max_len=self._config.max_content_length)
338345
except Exception as e:
339346
logging.warning(f"Content formatter failed: {e}")
340347
return "[FORMATTING FAILED]"
@@ -363,14 +370,17 @@ async def _ensure_init(self):
363370
# Ensure table exists (sync call in thread)
364371
def create_resources():
365372
if self._bq_client:
366-
dataset = self._bq_client.create_dataset(
367-
self._dataset_id, exists_ok=True
368-
)
373+
self._bq_client.create_dataset(self._dataset_id, exists_ok=True)
369374
table = bigquery.Table(
370375
f"{self._project_id}.{self._dataset_id}.{self._table_id}",
371376
schema=self._schema,
372377
)
373378
self._bq_client.create_table(table, exists_ok=True)
379+
logging.info(
380+
"BQ Plugin: Dataset %s and Table %s ensured to exist.",
381+
self._dataset_id,
382+
self._table_id,
383+
)
374384

375385
await asyncio.to_thread(create_resources)
376386

@@ -379,9 +389,12 @@ def create_resources():
379389
client_info=client_info,
380390
)
381391
self._arrow_schema = to_arrow_schema(self._schema)
392+
if not self._arrow_schema:
393+
raise RuntimeError("Failed to convert BigQuery schema to Arrow.")
394+
logging.info("BQ Plugin: Initialized successfully.")
382395
return True
383396
except Exception as e:
384-
logging.error(f"BQ Init Failed: {e}")
397+
logging.error("BQ Plugin: Init Failed:", exc_info=True)
385398
return False
386399

387400
async def _perform_write(self, row: dict):
@@ -412,14 +425,16 @@ async def _perform_write(self, row: dict):
412425
self._write_client.append_rows(iter([req]))
413426
):
414427
if resp.error.code != 0:
415-
logging.error(f"BQ Write Error: {resp.error.message}")
428+
logging.error(f"BQ Plugin: Write Error: {resp.error.message}")
416429

417430
except RuntimeError as e:
418-
# Silently ignore event loop closed errors during background writes
419-
if "Event loop is closed" not in str(e):
420-
logging.exception(f"BQ Runtime Error: {e}")
431+
if "Event loop is closed" not in str(e) and not self._is_shutting_down:
432+
logging.error("BQ Plugin: Runtime Error during write:", exc_info=True)
433+
except asyncio.CancelledError:
434+
if not self._is_shutting_down:
435+
logging.warning("BQ Plugin: Write task cancelled unexpectedly.")
421436
except Exception as e:
422-
logging.error(f"BQ Write Failed: {e}")
437+
logging.error("BQ Plugin: Write Failed:", exc_info=True)
423438

424439
async def _log(self, data: dict):
425440
"""Schedules a log entry to be written in the background."""
@@ -457,32 +472,44 @@ async def _log(self, data: dict):
457472

458473
async def close(self):
459474
"""Flushes pending logs and closes client."""
460-
# 1. Wait for pending background logs (best effort, 2s timeout)
475+
if self._is_shutting_down:
476+
return
477+
self._is_shutting_down = True
478+
logging.info("BQ Plugin: Shutdown started.")
479+
461480
if self._background_tasks:
462-
logging.info(f"Flushing {len(self._background_tasks)} pending BQ logs...")
463-
done, pending = await asyncio.wait(self._background_tasks, timeout=2.0)
464-
if pending:
465-
logging.warning(
466-
f"{len(pending)} BQ logs could not be flushed before shutdown."
481+
logging.info(
482+
f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..."
483+
)
484+
try:
485+
await asyncio.wait(
486+
self._background_tasks, timeout=self._config.shutdown_timeout
467487
)
488+
except asyncio.TimeoutError:
489+
logging.warning("BQ Plugin: Timeout waiting for logs to flush.")
490+
except Exception as e:
491+
logging.warning("BQ Plugin: Error flushing logs:", exc_info=True)
468492

469-
# 2. Close client
470-
if self._write_client and self._write_client.transport:
493+
# Use getattr for safe access in case transport is not present.
494+
if self._write_client and getattr(self._write_client, "transport", None):
471495
try:
472-
logging.info("Closing BQ Write client transport...")
496+
logging.info("BQ Plugin: Closing write client.")
473497
await asyncio.wait_for(
474-
self._write_client.transport.close(), timeout=1.0
498+
self._write_client.transport.close(),
499+
timeout=self._config.client_close_timeout,
475500
)
476501
except Exception as e:
477-
logging.warning(f"Error during BQ Write client transport close: {e}")
478-
self._write_client = None
502+
logging.warning(f"BQ Plugin: Error closing write client: {e}")
479503
if self._bq_client:
480504
try:
481-
logging.info("Closing BQ client...")
482505
self._bq_client.close()
483506
except Exception as e:
484-
logging.warning(f"Error during BQ client close: {e}")
485-
self._bq_client = None
507+
logging.warning(f"BQ Plugin: Error closing BQ client: {e}")
508+
509+
self._write_client = None
510+
self._bq_client = None
511+
self._is_shutting_down = False
512+
logging.info("BQ Plugin: Shutdown complete.")
486513

487514
# --- Streamlined Callbacks ---
488515
async def on_user_message_callback(
@@ -523,13 +550,7 @@ async def on_event_callback(
523550
"session_id": invocation_context.session.id,
524551
"invocation_id": invocation_context.invocation_id,
525552
"user_id": invocation_context.session.user_id,
526-
"content": (
527-
json.dumps(
528-
[part.model_dump(mode="json") for part in event.content.parts]
529-
)
530-
if event.content and event.content.parts
531-
else None
532-
),
553+
"content": self._format_content_safely(event.content),
533554
"error_message": event.error_message,
534555
"timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc),
535556
})
@@ -579,6 +600,11 @@ async def before_model_callback(
579600
content_parts = [
580601
f"Model: {llm_request.model or 'default'}",
581602
]
603+
if contents := getattr(llm_request, "contents", None):
604+
prompt_str = " | ".join(
605+
[f"{c.role}: {self._format_content_safely(c)}" for c in contents]
606+
)
607+
content_parts.append(f"Prompt: {prompt_str}")
582608
system_instruction_text = "None"
583609
if llm_request.config and llm_request.config.system_instruction:
584610
si = llm_request.config.system_instruction
@@ -627,6 +653,9 @@ async def before_model_callback(
627653
)
628654

629655
final_content = " | ".join(content_parts)
656+
max_len = self._config.max_content_length
657+
if len(final_content) > max_len:
658+
final_content = final_content[:max_len] + "..."
630659
await self._log({
631660
"event_type": "LLM_REQUEST",
632661
"agent": callback_context.agent_name,
@@ -702,7 +731,8 @@ async def before_tool_callback(
702731
"user_id": tool_context.session.user_id,
703732
"content": (
704733
f"Tool Name: {tool.name}, Description: {tool.description},"
705-
f" Arguments: {_format_args(tool_args)}"
734+
" Arguments:"
735+
f" {_format_args(tool_args, max_len=self._config.max_content_length)}"
706736
),
707737
})
708738

@@ -721,7 +751,10 @@ async def after_tool_callback(
721751
"session_id": tool_context.session.id,
722752
"invocation_id": tool_context.invocation_id,
723753
"user_id": tool_context.session.user_id,
724-
"content": f"Tool Name: {tool.name}, Result: {_format_args(result)}",
754+
"content": (
755+
f"Tool Name: {tool.name}, Result:"
756+
f" {_format_args(result, max_len=self._config.max_content_length)}"
757+
),
725758
})
726759

727760
async def on_model_error_callback(
@@ -757,7 +790,8 @@ async def on_tool_error_callback(
757790
"invocation_id": tool_context.invocation_id,
758791
"user_id": tool_context.session.user_id,
759792
"content": (
760-
f"Tool Name: {tool.name}, Arguments: {_format_args(tool_args)}"
793+
f"Tool Name: {tool.name}, Arguments:"
794+
f" {_format_args(tool_args, max_len=self._config.max_content_length)}"
761795
),
762796
"error_message": str(error),
763797
})

0 commit comments

Comments
 (0)