修改穹彻节点的response的结构

This commit is contained in:
liangyuxuan
2025-11-19 11:40:48 +08:00
parent 20fab18c35
commit f4167e8560
2 changed files with 35 additions and 21 deletions

View File

@@ -308,7 +308,7 @@ class DetectNode(Node):
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, self.map1, self.map2)
img, pose_dict = self.function(color_img_cv, depth_img_cv)
img, pose_list = self.function(color_img_cv, depth_img_cv)
"""masks为空结束这一帧"""
if img is None:
@@ -317,14 +317,14 @@ class DetectNode(Node):
if self.output_boxes or self.output_masks:
self.pub_detect_image.publish(img)
if pose_dict is not None:
if pose_list:
pose_list_all = PoseArrayClassAndID()
for (class_id, class_name), pose_list in pose_dict.items():
for item in pose_list:
pose_list_all.objects.append(
PoseClassAndID(
class_name = class_name,
class_id = class_id,
pose_list = pose_list
class_name = item["class_name"],
class_id = item["class_id"],
pose = item["pose"]
)
)
pose_list_all.header.stamp = self.get_clock().now().to_msg()
@@ -333,7 +333,7 @@ class DetectNode(Node):
def _seg_image(self, rgb_img: np.ndarray, depth_img: np.ndarray):
"""Use segmentation model"""
pose_dict = defaultdict(list)
pose_list = []
depth_filter_mask = np.zeros_like(depth_img, dtype=np.uint8)
depth_filter_mask[(depth_img > 0) & (depth_img < 2000)] = 1
@@ -382,7 +382,13 @@ class DetectNode(Node):
pose = Pose()
pose.position = Point(x=x, y=y, z=z)
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(pose)
pose_list.append(
{
"class_id": int(class_ids[i]),
"class_name": labels[class_ids[i]],
"pose": pose
}
)
time4 = time.time()
@@ -396,16 +402,16 @@ class DetectNode(Node):
'''mask_img and box_img is or not output'''
if self.output_boxes and not self.output_masks:
draw_box(self.set_confidence, rgb_img, result)
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
elif self.output_boxes and self.output_masks:
draw_box(self.set_confidence, rgb_img, result)
draw_mask(self.set_confidence, rgb_img, result)
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
elif not self.output_boxes and self.output_masks:
draw_mask(self.set_confidence, rgb_img, result)
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
else:
return None, pose_dict
return None, pose_list
def main(args=None):

View File

@@ -260,7 +260,7 @@ class DetectNode(Node):
map1, map2, self.K = get_map(self.K, D, self.camera_size)
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, map1, map2)
img, pose_dict = self.function(color_img_cv, depth_img_cv, hand_eye_mat)
img, pose_list = self.function(color_img_cv, depth_img_cv, hand_eye_mat)
"""masks为空结束这一帧"""
if self.output_boxes or self.output_masks:
@@ -268,15 +268,16 @@ class DetectNode(Node):
img = color_img_ros
self.pub_detect_image.publish(img)
if pose_dict:
if pose_list:
response.info = "Success get pose"
response.success = True
for (class_id, class_name), pose_list in pose_dict.items():
for item in pose_list.items():
response.objects.append(
PoseClassAndID(
class_name = class_name,
class_id = class_id,
pose_list = pose_list
class_name = item["class_name"],
class_id = item["class_id"],
pose = item["pose"],
grab_width = item["grab_width"]
)
)
else:
@@ -288,7 +289,7 @@ class DetectNode(Node):
def _seg_image(self, rgb_img: np.ndarray, depth_img: np.ndarray, hand_eye_mat):
"""Use segmentation model"""
pose_dict = defaultdict(list)
pose_list = []
'''Get Predict Results'''
time1 = time.time()
@@ -371,7 +372,14 @@ class DetectNode(Node):
pose = Pose()
pose.position = Point(x=x, y=y, z=z)
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(pose)
pose_list.append(
{
"class_id": int(class_ids[i]),
"class_name": labels[class_ids[i]],
"pose": pose,
"grap_width": getattr(result, "double_value")
}
)
elif key in ["valid", "double_value", "int_value", "name"]:
for result in result_list:
@@ -386,7 +394,7 @@ class DetectNode(Node):
self.get_logger().info(f'{(time4 - time1) * 1000} ms, completing a picture entire process')
self.get_logger().info(f'end')
return None, pose_dict
return None, pose_list
def main(args=None):