Skip to content

Commit 66c49b2

Browse files
examples/traffic_analysis: improve CLI argument handling (#2059)
* refactor(traffic_analysis): improve CLI argument handling * fix(pre_commit): 🎨 auto format pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0efb55e commit 66c49b2

File tree

3 files changed

+64
-90
lines changed

3 files changed

+64
-90
lines changed

examples/traffic_analysis/inference_example.py

Lines changed: 34 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import argparse
43
import os
54
from collections.abc import Iterable
65

@@ -180,62 +179,47 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray:
180179
return self.annotate_frame(frame, detections)
181180

182181

183-
if __name__ == "__main__":
184-
parser = argparse.ArgumentParser(
185-
description="Traffic Flow Analysis with Inference and ByteTrack"
186-
)
187-
188-
parser.add_argument(
189-
"--model_id",
190-
default="vehicle-count-in-drone-video/6",
191-
help="Roboflow model ID",
192-
type=str,
193-
)
194-
parser.add_argument(
195-
"--roboflow_api_key",
196-
default=None,
197-
help="Roboflow API KEY",
198-
type=str,
199-
)
200-
parser.add_argument(
201-
"--source_video_path",
202-
required=True,
203-
help="Path to the source video file",
204-
type=str,
205-
)
206-
parser.add_argument(
207-
"--target_video_path",
208-
default=None,
209-
help="Path to the target video file (output)",
210-
type=str,
211-
)
212-
parser.add_argument(
213-
"--confidence_threshold",
214-
default=0.3,
215-
help="Confidence threshold for the model",
216-
type=float,
217-
)
218-
parser.add_argument(
219-
"--iou_threshold", default=0.7, help="IOU threshold for the model", type=float
220-
)
221-
222-
args = parser.parse_args()
223-
224-
api_key = args.roboflow_api_key
182+
def main(
183+
source_video_path: str,
184+
target_video_path: str,
185+
roboflow_api_key: str,
186+
model_id: str = "vehicle-count-in-drone-video/6",
187+
confidence_threshold: float = 0.3,
188+
iou_threshold: float = 0.7,
189+
) -> None:
190+
"""
191+
Traffic Flow Analysis with Inference and ByteTrack.
192+
193+
Args:
194+
source_video_path: Path to the source video file
195+
target_video_path: Path to the target video file (output)
196+
roboflow_api_key: Roboflow API key
197+
model_id: Roboflow model ID
198+
confidence_threshold: Confidence threshold for the model
199+
iou_threshold: IOU threshold for the model
200+
"""
201+
api_key = roboflow_api_key
225202
api_key = os.environ.get("ROBOFLOW_API_KEY", api_key)
226203
if api_key is None:
227204
raise ValueError(
228205
"Roboflow API KEY is missing. Please provide it as an argument or set the "
229206
"ROBOFLOW_API_KEY environment variable."
230207
)
231-
args.roboflow_api_key = api_key
208+
roboflow_api_key = api_key
232209

233210
processor = VideoProcessor(
234-
roboflow_api_key=args.roboflow_api_key,
235-
model_id=args.model_id,
236-
source_video_path=args.source_video_path,
237-
target_video_path=args.target_video_path,
238-
confidence_threshold=args.confidence_threshold,
239-
iou_threshold=args.iou_threshold,
211+
roboflow_api_key=roboflow_api_key,
212+
model_id=model_id,
213+
source_video_path=source_video_path,
214+
target_video_path=target_video_path,
215+
confidence_threshold=confidence_threshold,
216+
iou_threshold=iou_threshold,
240217
)
241218
processor.process_video()
219+
220+
221+
if __name__ == "__main__":
222+
from jsonargparse import auto_cli, set_parsing_settings
223+
224+
set_parsing_settings(parse_optionals_as_positionals=True)
225+
auto_cli(main, as_positional=False)

examples/traffic_analysis/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ inference
33
supervision
44
tqdm
55
ultralytics
6+
jsonargparse[signatures]

examples/traffic_analysis/ultralytics_example.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import argparse
43
from collections.abc import Iterable
54

65
import cv2
@@ -177,45 +176,35 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray:
177176
return self.annotate_frame(frame, detections)
178177

179178

180-
if __name__ == "__main__":
181-
parser = argparse.ArgumentParser(
182-
description="Traffic Flow Analysis with YOLO and ByteTrack"
183-
)
184-
185-
parser.add_argument(
186-
"--source_weights_path",
187-
required=True,
188-
help="Path to the source weights file",
189-
type=str,
190-
)
191-
parser.add_argument(
192-
"--source_video_path",
193-
required=True,
194-
help="Path to the source video file",
195-
type=str,
196-
)
197-
parser.add_argument(
198-
"--target_video_path",
199-
default=None,
200-
help="Path to the target video file (output)",
201-
type=str,
202-
)
203-
parser.add_argument(
204-
"--confidence_threshold",
205-
default=0.3,
206-
help="Confidence threshold for the model",
207-
type=float,
208-
)
209-
parser.add_argument(
210-
"--iou_threshold", default=0.7, help="IOU threshold for the model", type=float
211-
)
212-
213-
args = parser.parse_args()
179+
def main(
180+
source_weights_path: str,
181+
source_video_path: str,
182+
target_video_path: str,
183+
confidence_threshold: float = 0.3,
184+
iou_threshold: float = 0.7,
185+
) -> None:
186+
"""
187+
Traffic Flow Analysis with YOLO and ByteTrack.
188+
189+
Args:
190+
source_weights_path: Path to the source weights file
191+
source_video_path: Path to the source video file
192+
target_video_path: Path to the target video file (output)
193+
confidence_threshold: Confidence threshold for the model
194+
iou_threshold: IOU threshold for the model
195+
"""
214196
processor = VideoProcessor(
215-
source_weights_path=args.source_weights_path,
216-
source_video_path=args.source_video_path,
217-
target_video_path=args.target_video_path,
218-
confidence_threshold=args.confidence_threshold,
219-
iou_threshold=args.iou_threshold,
197+
source_weights_path=source_weights_path,
198+
source_video_path=source_video_path,
199+
target_video_path=target_video_path,
200+
confidence_threshold=confidence_threshold,
201+
iou_threshold=iou_threshold,
220202
)
221203
processor.process_video()
204+
205+
206+
if __name__ == "__main__":
207+
from jsonargparse import auto_cli, set_parsing_settings
208+
209+
set_parsing_settings(parse_optionals_as_positionals=True)
210+
auto_cli(main, as_positional=False)

0 commit comments

Comments
 (0)