如何更改deepsort中的update函数,使outpus增加目标的类别?
引用chatgpt部分指引作答:
要在DeepSORT的update函数中添加目标的类别信息,您需要进行以下更改:
1、在目标的跟踪状态中添加一个新的变量来保存目标的类别。您可以在TrackState类中添加一个新的属性来存储目标的类别,例如class_label。
class TrackState:
def __init__(self, track_id):
self.track_id = track_id
self.class_label = None # 新增的类别属性
# ...
2、在DeepSORT类的update函数中,当更新跟踪目标时,将目标的类别信息传递给相应的TrackState对象。
class DeepSORT:
def __init__(self):
# 初始化...
def update(self, detections):
# ...
# 更新跟踪目标
for track in self.tracks:
if not track.is_confirmed() or track.is_deleted():
continue
track.predict(self.kf)
# 查找最匹配的检测结果
detection_idx = self._match_detection(track, detections)
if detection_idx is not None:
# 更新跟踪状态
track.update(self.kf, detections[detection_idx])
track.state.class_label = detections[detection_idx].class_label # 将类别信息赋值给目标的class_label属性
# ...
在上述代码中,假设您的检测结果对象具有class_label属性来存储目标的类别信息。请根据您的具体数据结构进行相应的更改。
通过以上修改,DeepSORT的update函数将在更新跟踪目标时将目标的类别信息添加到outputs中。
你是不是没有看官方文档嘞!!
看下这个 demo 有木有用
def create_detection(bbox, confidence, category):
"""
创建一个包含检测结果的字典。
参数:
bbox (list): 目标的边界框坐标,形如[x_min, y_min, x_max, y_max]。
confidence (float): 检测结果的置信度。
category (int): 目标的类别信息。
返回:
detection (dict): 包含目标检测结果的字典。
"""
return {'bbox': bbox, 'confidence': confidence, 'category': category}
def update(self, detections):
"""
对给定的检测结果进行单次更新。
参数:
detections (list): 当前帧中的目标检测结果列表。
返回:
updated_tracks (list): 更新后的跟踪目标状态列表。
"""
# ... 省略部分代码
# 将检测结果与跟踪目标进行匹配
if self.tracks:
track_indices = list(range(len(self.tracks)))
detection_indices = nn_matching.matching_cascade(
self.similarity_metric, self.neighborhood_size,
self.detection_threshold, self.track_threshold, self.fallback_iou_threshold,
self.tracks, detections, track_indices, detection_indices)
# 更新匹配成功的目标状态
for track_idx, detection_idx in zip(track_indices, detection_indices):
self.tracks[track_idx].update(detections[detection_idx], self.frame_num)
# 将目标的类别信息添加到检测结果中
detections[detection_idx]['category'] = self.tracks[track_idx].category
# ... 省略部分代码
# 返回更新后的跟踪目标状态列表
return [track.to_dict() for track in self.tracks] + detections
摘取自 https://github.com/nwojke/deep_sort
def update(self, bbox_xywh, confidences, ori_img):
self.height, self.width = ori_img.shape[:2]
# generate detections
features = self._get_features(bbox_xywh, ori_img)
bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
detections = [Detection(bbox_tlwh[i], conf, features[i]) for i, conf in enumerate(
confidences) if conf > self.min_confidence]
# run on non-maximum supression
boxes = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# update tracker
self.tracker.predict()
self.tracker.update(detections)
# output bbox identities
outputs = []
for track in self.tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
box = track.to_tlwh()
x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
track_id = track.track_id
outputs.append(np.array([x1, y1, x2, y2,track_id], dtype=np.int))
if len(outputs) > 0:
outputs = np.stack(outputs, axis=0)
return np.array(outputs)
怎么在这个update函数中再加入类别矩阵,跟踪后把结果加入到outputs中?
要修改DeepSORT中的update
函数以增加目标的类别,您需要进行以下步骤:
deep_sort/tracker.py
文件中找到update
函数。class_label
。track.class_label = class_label
将目标的类别信息赋值给跟踪器。update
函数。请注意,以上是一般的指导步骤。具体修改update
函数的细节可能取决于DeepSORT的代码结构和实现方式,您可能需要根据代码的具体情况进行适当的修改和调整。