Skip to content

Commit 8639823

Browse files
No public description
PiperOrigin-RevId: 859664322
1 parent c4c20a5 commit 8639823

File tree

2 files changed

+371
-0
lines changed

2 files changed

+371
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
"""Designed to interact with Google BigQuery.
30+
31+
For the purpose of dataset and table management, as well as data ingestion
32+
from pandas DataFrames.
33+
"""
34+
35+
import logging
36+
import os
37+
import subprocess
38+
from google.cloud import bigquery
39+
from google.cloud import exceptions
40+
import pandas as pd
41+
import pandas_gbq
42+
43+
# Configure logging
44+
logging.basicConfig(
45+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
46+
)
47+
48+
# Centralized Schema Definition
49+
_BIGQUERY_SCHEMA = [
50+
bigquery.SchemaField("particle", "INTEGER", mode="REQUIRED"),
51+
bigquery.SchemaField("source_name", "STRING", mode="REQUIRED"),
52+
bigquery.SchemaField("image_name", "STRING", mode="REQUIRED"),
53+
bigquery.SchemaField("detection_scores", "FLOAT", mode="REQUIRED"),
54+
bigquery.SchemaField("creation_time", "STRING", mode="REQUIRED"),
55+
bigquery.SchemaField("bbox_0", "INTEGER", mode="REQUIRED"),
56+
bigquery.SchemaField("bbox_1", "INTEGER", mode="REQUIRED"),
57+
bigquery.SchemaField("bbox_2", "INTEGER", mode="REQUIRED"),
58+
bigquery.SchemaField("bbox_3", "INTEGER", mode="REQUIRED"),
59+
bigquery.SchemaField("detected_classes", "INTEGER", mode="REQUIRED"),
60+
bigquery.SchemaField(
61+
"detected_classes_names", "STRING", mode="REQUIRED"
62+
),
63+
bigquery.SchemaField("detected_colors", "STRING", mode="REQUIRED"),
64+
]
65+
66+
67+
class BigQueryManager:
68+
"""Manages interactions with Google BigQuery for dataset and table operations.
69+
70+
This class provides methods to create datasets and tables, ingest data from
71+
pandas DataFrames, and manage related file operations in Google Cloud Storage.
72+
"""
73+
74+
def __init__(self, project_id: str, dataset_id: str, table_id: str):
75+
"""Initializes the BigQuery client and storage coordinates."""
76+
self.client = bigquery.Client(project=project_id)
77+
self.project_id = project_id
78+
self.dataset_id = dataset_id
79+
self.table_id = table_id
80+
self.table_ref = f"{project_id}.{dataset_id}.{table_id}"
81+
82+
def _ensure_dataset(self):
83+
"""Checks if dataset exists, creates it if not."""
84+
dataset_ref = self.client.dataset(self.dataset_id)
85+
try:
86+
self.client.get_dataset(dataset_ref)
87+
except exceptions.NotFound:
88+
logging.info("Dataset %s not found. Creating...", self.dataset_id)
89+
dataset = bigquery.Dataset(dataset_ref)
90+
self.client.create_dataset(dataset, timeout=30)
91+
92+
def create_table(self, overwrite: bool = False) -> None:
93+
"""Creates the table with the defined schema."""
94+
self._ensure_dataset()
95+
96+
try:
97+
self.client.get_table(self.table_ref)
98+
if overwrite:
99+
logging.info("Overwriting table %s...", self.table_id)
100+
self.client.delete_table(self.table_ref)
101+
else:
102+
logging.info("Table %s already exists. Skipping.", self.table_id)
103+
return
104+
except exceptions.NotFound:
105+
pass
106+
107+
table = bigquery.Table(self.table_ref, schema=_BIGQUERY_SCHEMA)
108+
self.client.create_table(table)
109+
logging.info("Table %s created successfully.", self.table_id)
110+
111+
def ingest_data(self, df: pd.DataFrame) -> None:
112+
"""Ingests data from a pandas DataFrame into BigQuery using pandas_gbq."""
113+
pandas_gbq.to_gbq(
114+
df,
115+
destination_table=self.table_ref,
116+
project_id=self.project_id,
117+
if_exists="append",
118+
)
119+
logging.info("Data ingested successfully into %s", self.table_ref)
120+
121+
def upload_image_results_to_storage_bucket(
122+
self, input_directory: str, prediction_folder: str, output_directory: str
123+
) -> None:
124+
"""Moves folders to the destination bucket and cleans up local directories.
125+
126+
Args:
127+
input_directory: Path to the local input directory.
128+
prediction_folder: Path to the local folder containing results.
129+
output_directory: The GCS path (gs://...) for output.
130+
"""
131+
try:
132+
commands = [
133+
f"rm -r {os.path.basename(input_directory)}",
134+
f"gsutil -m cp -r {prediction_folder} {output_directory}",
135+
f"rm -r {prediction_folder}",
136+
]
137+
subprocess.run(" && ".join(commands), shell=True, check=True)
138+
logging.info("Successfully moved to destination bucket")
139+
except (
140+
KeyError,
141+
IndexError,
142+
TypeError,
143+
ValueError,
144+
subprocess.CalledProcessError,
145+
) as e:
146+
logging.info(
147+
"Issue in moving folders to destination bucket, due to error : %s", e
148+
)
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import subprocess
30+
import unittest
31+
from unittest import mock
32+
33+
from google.cloud import exceptions
34+
import pandas as pd
35+
36+
from official.projects.waste_identification_ml.Deploy.detr_cloud_deployment.client import big_query_ops
37+
38+
MODULE_PATH = big_query_ops.__name__
39+
40+
41+
class BigQueryManagerTest(unittest.TestCase):
42+
43+
def setUp(self):
44+
super().setUp()
45+
self.mock_bigquery_client_patch = mock.patch(
46+
f"{MODULE_PATH}.bigquery.Client"
47+
)
48+
self.mock_bigquery_client = self.mock_bigquery_client_patch.start()
49+
self.mock_pandas_gbq_patch = mock.patch(f"{MODULE_PATH}.pandas_gbq.to_gbq")
50+
self.mock_pandas_gbq = self.mock_pandas_gbq_patch.start()
51+
self.mock_subprocess_run_patch = mock.patch(f"{MODULE_PATH}.subprocess.run")
52+
self.mock_subprocess_run = self.mock_subprocess_run_patch.start()
53+
54+
self.project_id = "test-project"
55+
self.dataset_id = "test-dataset"
56+
self.table_id = "test-table"
57+
self.manager = big_query_ops.BigQueryManager(
58+
self.project_id, self.dataset_id, self.table_id
59+
)
60+
61+
def tearDown(self):
62+
super().tearDown()
63+
mock.patch.stopall()
64+
65+
def test_init_sets_attributes(self):
66+
self.assertEqual(self.manager.project_id, self.project_id)
67+
self.assertEqual(self.manager.dataset_id, self.dataset_id)
68+
self.assertEqual(self.manager.table_id, self.table_id)
69+
self.assertEqual(
70+
self.manager.table_ref,
71+
f"{self.project_id}.{self.dataset_id}.{self.table_id}",
72+
)
73+
74+
def test_init_creates_client(self):
75+
self.mock_bigquery_client.assert_called_once_with(project=self.project_id)
76+
77+
def test_ensure_dataset_exists(self):
78+
self.manager.client.get_dataset.return_value = True
79+
80+
self.manager._ensure_dataset()
81+
82+
self.manager.client.get_dataset.assert_called_once_with(
83+
self.manager.client.dataset(self.dataset_id)
84+
)
85+
self.manager.client.create_dataset.assert_not_called()
86+
87+
@mock.patch(f"{MODULE_PATH}.bigquery.Dataset")
88+
def test_ensure_dataset_not_found(self, mock_bq_dataset):
89+
self.manager.client.get_dataset.side_effect = exceptions.NotFound(
90+
"Dataset not found"
91+
)
92+
mock_dataset_ref = self.manager.client.dataset.return_value
93+
mock_bq_dataset.return_value = "dataset_obj"
94+
95+
self.manager._ensure_dataset()
96+
97+
self.manager.client.get_dataset.assert_called_once_with(mock_dataset_ref)
98+
self.manager.client.dataset.assert_called_once_with(self.dataset_id)
99+
mock_bq_dataset.assert_called_once_with(mock_dataset_ref)
100+
self.manager.client.create_dataset.assert_called_once_with(
101+
"dataset_obj", timeout=30
102+
)
103+
104+
@mock.patch(f"{MODULE_PATH}.bigquery.Table")
105+
def test_create_table_new_table_created_if_not_exists(self, mock_bq_table):
106+
self.manager.client.get_table.side_effect = exceptions.NotFound(
107+
"Table not found"
108+
)
109+
table_obj = "table_obj"
110+
mock_bq_table.return_value = table_obj
111+
with mock.patch.object(self.manager, "_ensure_dataset"):
112+
113+
self.manager.create_table()
114+
115+
mock_bq_table.assert_called_once_with(
116+
self.manager.table_ref, schema=self.manager._schema
117+
)
118+
self.manager.client.create_table.assert_called_once_with(table_obj)
119+
120+
def test_create_table_new_ensures_dataset_exists_if_table_not_exists(self):
121+
self.manager.client.get_table.side_effect = exceptions.NotFound(
122+
"Table not found"
123+
)
124+
with mock.patch.object(
125+
self.manager, "_ensure_dataset"
126+
) as mock_ensure_dataset:
127+
128+
self.manager.create_table()
129+
130+
mock_ensure_dataset.assert_called_once()
131+
132+
def test_create_table_exists_no_overwrite_does_not_recreate(self):
133+
self.manager.client.get_table.return_value = True
134+
with mock.patch.object(self.manager, "_ensure_dataset"):
135+
136+
self.manager.create_table(overwrite=False)
137+
138+
self.manager.client.delete_table.assert_not_called()
139+
self.manager.client.create_table.assert_not_called()
140+
141+
def test_create_table_exists_no_overwrite_ensures_dataset_exists(self):
142+
self.manager.client.get_table.return_value = True
143+
with mock.patch.object(
144+
self.manager, "_ensure_dataset"
145+
) as mock_ensure_dataset:
146+
147+
self.manager.create_table(overwrite=False)
148+
149+
mock_ensure_dataset.assert_called_once()
150+
151+
@mock.patch(f"{MODULE_PATH}.bigquery.Table")
152+
def test_create_table_exists_overwrite_recreates_table(self, mock_bq_table):
153+
self.manager.client.get_table.return_value = True
154+
table_obj = "table_obj"
155+
mock_bq_table.return_value = table_obj
156+
with mock.patch.object(self.manager, "_ensure_dataset"):
157+
158+
self.manager.create_table(overwrite=True)
159+
160+
self.manager.client.delete_table.assert_called_once_with(
161+
self.manager.table_ref
162+
)
163+
mock_bq_table.assert_called_once_with(
164+
self.manager.table_ref, schema=self.manager._schema
165+
)
166+
self.manager.client.create_table.assert_called_once_with(table_obj)
167+
168+
def test_create_table_exists_overwrite_ensures_dataset_exists(self):
169+
self.manager.client.get_table.return_value = True
170+
with mock.patch.object(
171+
self.manager, "_ensure_dataset"
172+
) as mock_ensure_dataset:
173+
174+
self.manager.create_table(overwrite=True)
175+
176+
mock_ensure_dataset.assert_called_once()
177+
178+
def test_ingest_data(self):
179+
df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
180+
181+
self.manager.ingest_data(df)
182+
183+
self.mock_pandas_gbq.assert_called_once_with(
184+
df,
185+
destination_table=self.manager.table_ref,
186+
project_id=self.project_id,
187+
if_exists="append",
188+
)
189+
190+
def test_upload_image_results_to_storage_bucket_success(self):
191+
input_dir = "/tmp/input"
192+
pred_dir = "/tmp/pred"
193+
output_dir = "gs://bucket/output"
194+
195+
self.manager.upload_image_results_to_storage_bucket(
196+
input_dir, pred_dir, output_dir
197+
)
198+
199+
self.mock_subprocess_run.assert_called_once()
200+
args, _ = self.mock_subprocess_run.call_args
201+
self.assertIn(f"gsutil -m cp -r {pred_dir} {output_dir}", args[0])
202+
203+
def test_upload_image_results_to_storage_bucket_failure(self):
204+
input_dir = "/tmp/input"
205+
pred_dir = "/tmp/pred"
206+
output_dir = "gs://bucket/output"
207+
self.mock_subprocess_run.side_effect = subprocess.CalledProcessError(
208+
1, "cmd"
209+
)
210+
211+
with self.assertLogs(level="INFO") as cm:
212+
self.manager.upload_image_results_to_storage_bucket(
213+
input_dir, pred_dir, output_dir
214+
)
215+
216+
self.mock_subprocess_run.assert_called_once()
217+
self.assertIn(
218+
"Issue in moving folders to destination bucket", cm.output[-1]
219+
)
220+
221+
222+
if __name__ == "__main__":
223+
unittest.main()

0 commit comments

Comments
 (0)