-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMain_perception.py
More file actions
80 lines (67 loc) · 2.55 KB
/
Main_perception.py
File metadata and controls
80 lines (67 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
DFOD perception with primary detection, auxiliary objectness scoring,
crop-quality estimation, and adapter-aware embedding extraction.
"""
from __future__ import annotations
from typing import Dict, List, Tuple
import numpy as np
from dfod_config import normalize_config
from Baseline_Perception.cropper import crop_detections
from Baseline_Perception.detector import (
attach_objectness_scores,
postprocess_detections,
run_detector,
run_objectness_detector,
)
from Baseline_Perception.embedder import extract_embeddings
from Baseline_Perception.io import load_image
from Baseline_Perception.quality import score_and_rank_detections
from Baseline_Perception.validation import validate_outputs
Detection = Dict[str, object]
def run_perception(
image_path: str,
device: str = "cuda",
conf_thresh: float = 0.25,
max_detections: int = 100,
config: Dict[str, object] | None = None,
) -> Tuple[List[Detection], np.ndarray]:
cfg = normalize_config(config)
image = load_image(image_path)
raw_detections = run_detector(
image=image,
device=device,
model_name=str(cfg["perception"].get("detector_model", "yolov8n")),
)
detections = postprocess_detections(
detections=raw_detections,
image_shape=image.shape,
conf_thresh=conf_thresh,
max_detections=max_detections,
)
if cfg["perception"].get("objectness_enabled", True):
try:
objectness_detections = run_objectness_detector(
image=image,
device=device,
max_detections=max_detections,
)
detections = attach_objectness_scores(
detections=detections,
objectness_detections=objectness_detections,
iou_threshold=float(cfg["perception"].get("objectness_iou_threshold", 0.3)),
)
except Exception:
detections = [dict(det, objectness=float(det.get("conf", 0.0))) for det in detections]
else:
detections = [dict(det, objectness=float(det.get("conf", 0.0))) for det in detections]
detections = score_and_rank_detections(
detections=detections,
image_shape=image.shape,
config=cfg,
)[:max_detections]
for detection in detections:
detection["source_image_path"] = image_path
crops = crop_detections(image=image, detections=detections)
embeddings = extract_embeddings(crops=crops, device=device, config=cfg)
validate_outputs(detections=detections, embeddings=embeddings)
return detections, embeddings