1
This commit is contained in:
@@ -297,7 +297,7 @@ class DetectNode(Node):
|
||||
|
||||
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, self.map1, self.map2)
|
||||
|
||||
img, pose_dict = self.function(color_img_cv, depth_img_cv)
|
||||
img, pose_list = self.function(color_img_cv, depth_img_cv)
|
||||
|
||||
"""masks为空,结束这一帧"""
|
||||
if img is None:
|
||||
@@ -306,14 +306,15 @@ class DetectNode(Node):
|
||||
if self.output_boxes or self.output_masks:
|
||||
self.pub_detect_image.publish(img)
|
||||
|
||||
if pose_dict:
|
||||
if pose_list:
|
||||
pose_list_all = PoseArrayClassAndID()
|
||||
for (class_id, class_name), pose_list in pose_dict.items():
|
||||
for item in pose_list:
|
||||
pose_list_all.objects.append(
|
||||
PoseClassAndID(
|
||||
class_name = class_name,
|
||||
class_id = class_id,
|
||||
pose_list = pose_list
|
||||
class_name = item["class_name"],
|
||||
class_id = item["class_id"],
|
||||
pose = item["pose"],
|
||||
grab_width = item["grab_width"]
|
||||
)
|
||||
)
|
||||
pose_list_all.header.stamp = self.get_clock().now().to_msg()
|
||||
@@ -349,7 +350,7 @@ class DetectNode(Node):
|
||||
map1, map2, self.K = get_map(self.K, D, self.camera_size)
|
||||
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, map1, map2)
|
||||
|
||||
img, pose_dict = self.function(color_img_cv, depth_img_cv)
|
||||
img, pose_list = self.function(color_img_cv, depth_img_cv)
|
||||
|
||||
"""masks为空,结束这一帧"""
|
||||
if self.output_boxes or self.output_masks:
|
||||
@@ -357,15 +358,16 @@ class DetectNode(Node):
|
||||
img = color_img_ros
|
||||
self.pub_detect_image.publish(img)
|
||||
|
||||
if pose_dict:
|
||||
if pose_list:
|
||||
response.info = "Success get pose"
|
||||
response.success = True
|
||||
for (class_id, class_name), pose_list in pose_dict.items():
|
||||
for item in pose_list:
|
||||
response.objects.append(
|
||||
PoseClassAndID(
|
||||
class_name=class_name,
|
||||
class_id=class_id,
|
||||
pose_list=pose_list
|
||||
class_name = item["class_name"],
|
||||
class_id = item["class_id"],
|
||||
pose = item["pose"],
|
||||
grab_width = item["grab_width"]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -380,7 +382,7 @@ class DetectNode(Node):
|
||||
|
||||
def _seg_image(self, rgb_img: np.ndarray, depth_img: np.ndarray):
|
||||
"""Use segmentation model"""
|
||||
pose_dict = defaultdict(list)
|
||||
pose_list = []
|
||||
|
||||
'''Get Predict Results'''
|
||||
time1 = time.time()
|
||||
@@ -429,7 +431,14 @@ class DetectNode(Node):
|
||||
pose = Pose()
|
||||
pose.position = Point(x=x, y=y, z=z)
|
||||
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
|
||||
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(pose)
|
||||
pose_list.append(
|
||||
{
|
||||
"class_id": int(class_ids[i]),
|
||||
"class_name": labels[class_ids[i]],
|
||||
"pose": pose,
|
||||
"grap_width": getattr(result, "double_value")
|
||||
}
|
||||
)
|
||||
|
||||
time4 = time.time()
|
||||
|
||||
@@ -443,20 +452,20 @@ class DetectNode(Node):
|
||||
'''mask_img and box_img is or not output'''
|
||||
if self.output_boxes and not self.output_masks:
|
||||
draw_box(rgb_img, result)
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
|
||||
elif self.output_boxes and self.output_masks:
|
||||
draw_box(rgb_img, result)
|
||||
draw_mask(rgb_img, result)
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
|
||||
elif not self.output_boxes and self.output_masks:
|
||||
draw_mask(rgb_img, result)
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
|
||||
else:
|
||||
return None, pose_dict
|
||||
return None, pose_list
|
||||
|
||||
def _seg_color(self, rgb_img: np.ndarray, depth_img: np.ndarray):
|
||||
"""Use segmentation model"""
|
||||
pose_dict = defaultdict(list)
|
||||
pose_list = []
|
||||
|
||||
hsv_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2HSV)
|
||||
|
||||
@@ -500,12 +509,18 @@ class DetectNode(Node):
|
||||
pose = Pose()
|
||||
pose.position = Point(x=x, y=y, z=z)
|
||||
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
|
||||
pose_dict[int(98), "red"].append(pose)
|
||||
pose_list.append(
|
||||
{
|
||||
"class_id": int(98),
|
||||
"class_name": "red",
|
||||
"pose": pose,
|
||||
}
|
||||
)
|
||||
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
|
||||
|
||||
def _seg_crossboard(self, rgb_img, depth_img):
|
||||
pose_dict = defaultdict(list)
|
||||
pose_list = []
|
||||
rgb_img_gray = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY)
|
||||
ret, corners = cv2.findChessboardCorners(rgb_img_gray, self.pattern_size, cv2.CALIB_CB_FAST_CHECK)
|
||||
if ret:
|
||||
@@ -558,7 +573,13 @@ class DetectNode(Node):
|
||||
pose = Pose()
|
||||
pose.position = Point(x=x, y=y, z=z)
|
||||
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
|
||||
pose_dict[int(99), 'crossboard'].append(pose)
|
||||
pose_list.append(
|
||||
{
|
||||
"class_id": int(99),
|
||||
"class_name": 'crossboard',
|
||||
"pose": pose,
|
||||
}
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
rgb_img,
|
||||
@@ -576,7 +597,7 @@ class DetectNode(Node):
|
||||
(255, 255, 0),
|
||||
2
|
||||
)
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_list
|
||||
else:
|
||||
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), None
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
import cv2
|
||||
|
||||
@@ -5,11 +5,9 @@ import json
|
||||
import sys
|
||||
import ast
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
import cv2
|
||||
import open3d as o3d
|
||||
import numpy as np
|
||||
import transforms3d as tfs
|
||||
from cv_bridge import CvBridge
|
||||
@@ -271,7 +269,7 @@ class DetectNode(Node):
|
||||
if pose_list:
|
||||
response.info = "Success get pose"
|
||||
response.success = True
|
||||
for item in pose_list.items():
|
||||
for item in pose_list:
|
||||
response.objects.append(
|
||||
PoseClassAndID(
|
||||
class_name = item["class_name"],
|
||||
|
||||
Reference in New Issue
Block a user