Skip to content

Commit 8dbc6dc

Browse files
No public description
PiperOrigin-RevId: 859191499
1 parent 9bc1efa commit 8dbc6dc

File tree

2 files changed

+548
-0
lines changed

2 files changed

+548
-0
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
"""Object tracking using trackpy."""
29+
30+
import os
31+
from typing import Any, Dict, List
32+
import cv2
33+
import numpy as np
34+
import pandas as pd
35+
import skimage.measure
36+
import trackpy as tp
37+
38+
39+
class ObjectTracker:
40+
"""Tracks objects across multiple frames using trackpy.
41+
42+
This class collects object detections from multiple frames, extracts features,
43+
links them using trackpy, and aggregates the tracking results.
44+
"""
45+
46+
def __init__(self, search_range: tuple[int, int] = (20, 20), memory: int = 3):
47+
"""Initializes the tracker.
48+
49+
Args:
50+
search_range: (y_range, x_range) pixels for tracking.
51+
memory: Number of frames an object can vanish and still be linked.
52+
"""
53+
self.search_range = search_range
54+
self.memory = memory
55+
self.all_detections: List[pd.DataFrame] = []
56+
57+
# Region properties to extract
58+
self._properties = (
59+
'area',
60+
'bbox',
61+
'convex_area',
62+
'bbox_area',
63+
'major_axis_length',
64+
'minor_axis_length',
65+
'eccentricity',
66+
'centroid',
67+
'label',
68+
'mean_intensity',
69+
'max_intensity',
70+
'min_intensity',
71+
'perimeter',
72+
)
73+
74+
def extract_features_for_tracking(
75+
self,
76+
image: np.ndarray,
77+
results: Dict[str, Any],
78+
tracking_image_size: tuple[int, int],
79+
image_path: str,
80+
creation_time: Any,
81+
frame_idx: int,
82+
colors: List[str],
83+
):
84+
"""Extracts features from detection results for tracking.
85+
86+
This method resizes masks, extracts region properties using skimage,
87+
and compiles a DataFrame of features for each frame, which is then
88+
stored internally for later use by the tracking algorithm.
89+
90+
Args:
91+
image: The original image as a numpy array.
92+
results: A dictionary containing detection results, including 'masks',
93+
'confidence', 'labels', and 'class_names'.
94+
tracking_image_size: The target size (width, height) for resizing masks
95+
before feature extraction.
96+
image_path: The file path of the image.
97+
creation_time: The timestamp of when the image was created.
98+
frame_idx: The index of the current frame.
99+
colors: A list of color strings corresponding to each detection.
100+
"""
101+
results['resized_masks_for_tracking'] = np.array([
102+
cv2.resize(
103+
m,
104+
tracking_image_size,
105+
interpolation=cv2.INTER_NEAREST,
106+
)
107+
for m in results['masks'].astype('int')
108+
])
109+
110+
frame_features_list = []
111+
for mask in results['resized_masks_for_tracking']:
112+
mask = np.where(mask, 1, 0)
113+
props = skimage.measure.regionprops_table(
114+
mask.astype(np.uint8),
115+
intensity_image=image,
116+
properties=self._properties,
117+
)
118+
df = pd.DataFrame(props)
119+
frame_features_list.append(df)
120+
121+
if frame_features_list:
122+
frame_df = pd.concat(frame_features_list, ignore_index=True)
123+
frame_df.rename(
124+
columns={
125+
'centroid-0': 'y',
126+
'centroid-1': 'x',
127+
'bbox-0': 'bbox_0',
128+
'bbox-1': 'bbox_1',
129+
'bbox-2': 'bbox_2',
130+
'bbox-3': 'bbox_3',
131+
},
132+
inplace=True,
133+
)
134+
135+
frame_df['source_name'] = os.path.basename(os.path.dirname(image_path))
136+
frame_df['image_name'] = os.path.basename(image_path)
137+
frame_df['creation_time'] = creation_time
138+
frame_df['frame'] = frame_idx
139+
frame_df['detection_scores'] = results['confidence']
140+
frame_df['detection_classes'] = results['labels']
141+
frame_df['detection_classes_names'] = results['class_names']
142+
frame_df['color'] = colors
143+
self.all_detections.append(frame_df)
144+
else:
145+
self.all_detections.append(pd.DataFrame(columns=self._properties))
146+
147+
def _select_class_with_model_scores(self, group: pd.DataFrame) -> pd.Series:
148+
"""Selects the most representative class for a tracked particle.
149+
150+
This method is used within a groupby operation on 'particle'. It determines
151+
the best class for a given particle by first finding the class(es) with the
152+
highest frequency. If there's a tie in frequency, it breaks the tie by
153+
selecting the class with the highest maximum detection score among the tied
154+
classes.
155+
156+
Args:
157+
group: A pandas DataFrame containing all detections associated with a
158+
single tracked particle.
159+
160+
Returns:
161+
A pandas Series containing the 'class_id', 'class_name', and
162+
'color_name' of the selected class.
163+
"""
164+
class_counts = group['detection_classes'].value_counts()
165+
tied_classes = class_counts[class_counts == class_counts.iloc[0]].index
166+
167+
max_scores = {
168+
cls: group[group['detection_classes'] == cls]['detection_scores'].max()
169+
for cls in tied_classes
170+
}
171+
best_class = max(max_scores.items(), key=lambda x: x[1])[0]
172+
173+
class_name = group[group['detection_classes'] == best_class][
174+
'detection_classes_names'
175+
].iloc[0]
176+
color_name = group[group['detection_classes'] == best_class]['color'].iloc[
177+
0
178+
]
179+
return pd.Series({
180+
'class_id': best_class,
181+
'class_name': class_name,
182+
'color_name': color_name,
183+
})
184+
185+
def run_tracking(self) -> pd.DataFrame:
186+
"""Runs the trackpy linking algorithm on all collected detections.
187+
188+
This method concatenates all extracted features from multiple frames,
189+
applies trackpy's linking to connect detections across frames into tracks
190+
(particles), and preserves additional metadata.
191+
192+
Returns:
193+
A pandas DataFrame containing the linked particles, with each row
194+
representing a detection instance and including a 'particle' ID.
195+
Returns an empty DataFrame if no detections have been collected.
196+
"""
197+
if not self.all_detections:
198+
return pd.DataFrame()
199+
200+
full_df = pd.concat(self.all_detections, ignore_index=True)
201+
202+
tracking_cols = [
203+
'x',
204+
'y',
205+
'frame',
206+
'bbox_0',
207+
'bbox_1',
208+
'bbox_2',
209+
'bbox_3',
210+
'major_axis_length',
211+
'minor_axis_length',
212+
'perimeter',
213+
]
214+
215+
track_df = tp.link_df(
216+
full_df[tracking_cols],
217+
search_range=self.search_range,
218+
memory=self.memory,
219+
)
220+
221+
additional_columns = [
222+
'source_name',
223+
'image_name',
224+
'detection_scores',
225+
'detection_classes_names',
226+
'detection_classes',
227+
'color',
228+
'creation_time',
229+
]
230+
track_df[additional_columns] = full_df[additional_columns]
231+
232+
track_df.drop(columns=['frame'], inplace=True)
233+
return track_df
234+
235+
def process_tracking_results(self, track_df):
236+
"""Aggregates tracking results by particle.
237+
238+
This method takes the DataFrame with linked particles and aggregates
239+
information such as the best class, detection scores, and initial bounding
240+
box for each unique particle.
241+
242+
Args:
243+
track_df: A pandas DataFrame containing tracking results, including a
244+
'particle' column generated by trackpy.
245+
246+
Returns:
247+
A pandas DataFrame where each row represents a unique tracked object
248+
('particle'), containing aggregated information.
249+
"""
250+
# Select best class per particle
251+
class_info = (
252+
track_df.groupby('particle')
253+
.apply(self._select_class_with_model_scores, include_groups=False)
254+
.reset_index()
255+
)
256+
257+
final_particles = (
258+
track_df.groupby('particle')
259+
.agg({
260+
'source_name': 'first',
261+
'image_name': 'first',
262+
'detection_scores': 'max',
263+
'creation_time': 'first',
264+
'bbox_0': 'first',
265+
'bbox_1': 'first',
266+
'bbox_2': 'first',
267+
'bbox_3': 'first',
268+
})
269+
.reset_index()
270+
)
271+
272+
final_particles['detected_classes'] = class_info['class_id']
273+
final_particles['detected_classes_names'] = class_info['class_name']
274+
final_particles['detected_colors'] = class_info['color_name']
275+
276+
return final_particles

0 commit comments

Comments
 (0)