Skip to content

Commit f122efc

Browse files
No public description
PiperOrigin-RevId: 868718798
1 parent 995db79 commit f122efc

File tree

11 files changed

+32
-41
lines changed

11 files changed

+32
-41
lines changed

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/big_query_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/big_query_ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/inference_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.
@@ -46,9 +46,9 @@
4646
OUTPUT_DIRECTORY = flags.DEFINE_string(
4747
"output_directory", None, "The path to the directory to save the results."
4848
)
49-
MODEL = flags.DEFINE_string("model", None, "Model name")
49+
MODEL_NAME = flags.DEFINE_string("model_name", None, "Model name")
5050
PREDICTION_THRESHOLD = flags.DEFINE_float(
51-
"score", None, "Threshold to filter the prediction results"
51+
"threshold", None, "Threshold to filter the prediction results"
5252
)
5353
SEARCH_RANGE_X = flags.DEFINE_integer(
5454
"search_range_x",
@@ -102,7 +102,7 @@ def main(_) -> None:
102102
for filepath in utils.files_paths(os.path.basename(input_directory))
103103
}
104104

105-
model_manager = TritonObjectDetector(model_name=MODEL.value)
105+
model_manager = TritonObjectDetector(model_name=MODEL_NAME.value)
106106
tracking_manager = ObjectTracker(
107107
search_range=(SEARCH_RANGE_Y.value, SEARCH_RANGE_X.value),
108108
memory=MEMORY.value,

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/object_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/object_tracking_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/run_images.sh

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,20 @@ Steps Performed:
1212
inference.
1313
--output_directory : GCS directory where the model inference outputs will be
1414
saved.
15-
--height : Height to which input images are resized for the Mask
16-
R-CNN model.
17-
--width : Width to which input images are resized for the Mask
18-
R-CNN model.
19-
--model : Name of the model to download and use for inference.
20-
--score : Confidence threshold for detections during
15+
--model_name : Name of the model to download and use for inference.
16+
--threshold : Confidence threshold for detections during
2117
inference.
2218
--search_range_x : Max pixel movement allowed in the X direction for
2319
object tracking between missed frames.
2420
--search_range_y : Max pixel movement allowed in the Y direction for
2521
object tracking between missed frames.
26-
--memory : Number of frames an object can be missed and still
22+
--memory : Number of frames an object can be missed and still
2723
be tracked.
28-
--project_id : Google Cloud Project ID for BigQuery operations.
29-
--bq_dataset_id : BigQuery Dataset ID where results will be stored.
30-
--bq_table_id : BigQuery Table ID where results will be stored.
31-
--overwrite : If set to True, overwrites the pre-existing
24+
--project_id : Google Cloud Project ID for BigQuery operations.
25+
--bq_dataset_id : BigQuery Dataset ID where results will be stored.
26+
--bq_table_id : BigQuery Table ID where results will be stored.
27+
--overwrite : If set to True, overwrites the pre-existing
3228
BigQuery table.
33-
--tracking_visualization : If set to True, visualizes the tracking results
34-
from the tracking algorithm.
35-
--cropped_objects : If set to True, crops the objects per category
36-
according to the prediction and tracking results.
3729
EOF
3830
#Activate the virtual environment
3931
source myenv/bin/activate
@@ -47,12 +39,12 @@ fi
4739
python inference_pipeline.py \
4840
--input_directory=gs://recykal/TestData/SmallTestData \
4941
--output_directory=gs://recykal/TestData/SmallTestData \
50-
--model=detr_seg \
51-
--score=0.50 \
42+
--model_name=cn_segmentation_trt_model \
43+
--threshold=0.50 \
5244
--search_range_x=150 \
5345
--search_range_y=20 \
5446
--memory=3 \
5547
--project_id=waste-identification-ml-330916 \
5648
--bq_dataset_id=circularnet_dataset \
57-
--bq_table_id=vinit_test_table1 \
49+
--bq_table_id=test_table1 \
5850
--overwrite=True

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/triton_server_inference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.
@@ -249,8 +249,9 @@ def predict(
249249
raw_outputs, confidence_threshold, max_boxes
250250
)
251251

252-
# Scale to output dimensions
253-
results = self._scale_bbox_and_masks(results, output_dims)
252+
if results['labels'].any():
253+
# Scale to output dimensions
254+
results = self._scale_bbox_and_masks(results, output_dims)
254255

255256
return results
256257

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/triton_server_inference_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
1616
#
1717
# Licensed under the Apache License, Version 2.0 (the "License");
1818
# you may not use this file except in compliance with the License.
@@ -207,10 +207,10 @@ def _crop_objects_from_masks(
207207
Returns:
208208
A list of numpy arrays, where each array is a cropped object.
209209
"""
210-
cropped_objects = []
211-
for m in masks.astype(int):
212-
cropped_object = np.where(np.expand_dims(m, -1), image, 0)
213-
cropped_objects.append(cropped_object)
210+
cropped_objects = [
211+
np.where(np.expand_dims(m, -1), image, 0) for m in masks.astype(int)
212+
]
213+
214214
return cropped_objects
215215

216216

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/server/triton_inference_server.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ fi
1515

1616
# Define an associative array with model names and their URLs
1717
declare -A models=(
18-
["CircularNet_rfdetr_seg_preview_v1"]="https://storage.googleapis.com/"\
18+
["CircularNet_Segmentation_Model_v1"]="https://storage.googleapis.com/"\
1919
"tf_model_garden/vision/waste_identification_ml/"\
20-
"CircularNet_rfdetr_seg_preview_v1.zip"
20+
"CircularNet_Segmentation_Model_v1.zip"
2121
)
2222

2323
# Download, unzip, and organize models
@@ -38,6 +38,4 @@ echo "Starting Triton server in a screen session."
3838
# Start Triton server
3939
screen -dmS server bash -c '
4040
sudo docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \
41-
-v ${PWD}/model_repository:/models \
42-
nvcr.io/nvidia/tritonserver:25.03-py3 \
43-
tritonserver --model-repository=/models --backend-config=pytorch,version=2'
41+
-v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:25.05-py3 tritonserver --model-repository=/models'

0 commit comments

Comments
 (0)