Merge branch 'feature' into feature_cpp_test

This commit is contained in:
liangyuxuan
2025-12-24 17:52:42 +08:00
8 changed files with 486 additions and 342 deletions

View File

@@ -0,0 +1,18 @@
{
"save_path": "/collect_data",
"left": {
"color": "/camera/camera/color/image_raw",
"depth": "/camera/camera/aligned_depth_to_color/image_raw",
"info": "/camera/camera/color/camera_info"
},
"right": {
"color": "/camera/color/image_raw",
"depth": "/camera/depth/image_raw",
"info": "/camera/color/camera_info"
},
"head": {
"color": "",
"depth": "",
"info": ""
}
}

View File

@@ -15,6 +15,7 @@ setup(
('share/' + package_name + '/configs/flexiv_configs', glob('configs/flexiv_configs/*.json')),
('share/' + package_name + '/configs/hand_eye_mat', glob('configs/hand_eye_mat/*.json')),
('share/' + package_name + '/configs/launch_configs', glob('configs/launch_configs/*.json')),
('share/' + package_name + '/configs/other_configs', glob('configs/other_configs/*.json')),
('share/' + package_name + '/checkpoints', glob('checkpoints/*.pt')),
('share/' + package_name + '/checkpoints', glob('checkpoints/*.onnx')),
@@ -42,6 +43,9 @@ setup(
'get_camera_pose_node = vision_detect.get_camera_pose:main',
'detect_node = vision_detect.detect_node:main',
'collect_data_node = vision_detect.collect_data_node:main',
'collect_data_client = vision_detect.collect_data_client:main',
],
},
)

View File

@@ -171,6 +171,11 @@ class DetectNode(InitBase):
class_ids = result.boxes.cls.cpu().numpy()
labels = result.names
x_centers, y_centers = boxes[:, 0], boxes[:, 1]
sorted_index = np.lexsort((-y_centers, x_centers))
masks = masks[sorted_index]
boxes = boxes[sorted_index]
time3 = time.time()
self.get_logger().info(f"Detect object num: {len(masks)}")
@@ -202,6 +207,9 @@ class DetectNode(InitBase):
self.get_logger().warning("Object point cloud have too many noise")
continue
if np.allclose(rmat, np.eye(4)):
continue
self.get_logger().info(f"grab_width: {grab_width}")
rmat = self.hand_eye_mat @ rmat

View File

@@ -58,10 +58,13 @@ def calculate_pose_pca(
return None, [0.0, 0.0, 0.0]
if len(point_cloud.points) == 0:
logging.warning("clean_pcd is empty")
logging.warning("point_cloud is empty")
return np.eye(4), [0.0, 0.0, 0.0]
if calculate_grab_width:
if np.asarray(point_cloud.points).shape[0] < 4:
logging.warning("点数不足,不能算 OBB")
return np.eye(4), [0.0, 0.0, 0.0]
obb = point_cloud.get_oriented_bounding_box()
x, y, z = obb.center
extent = obb.extent

View File

@@ -7,7 +7,6 @@ import numpy as np
import open3d as o3d
__all__ = [
"draw_box", "draw_mask", "draw_pointcloud",
]

View File

@@ -0,0 +1,44 @@
import rclpy
from rclpy.node import Node
from rclpy.parameter import Parameter
from std_srvs.srv import SetBool
def main(args=None):
def timer_callback():
start_sign = node.get_parameter("collect_sign").value
end_sign = node.get_parameter("end_sign").value
if start_sign:
request = SetBool.Request()
request.data = True
client.call_async(request).add_done_callback(result_callback)
node.set_parameters([Parameter('collect_sign', Parameter.Type.BOOL, False)])
if end_sign:
request = SetBool.Request()
request.data = False
client.call_async(request).add_done_callback(result_callback)
node.set_parameters([Parameter('end_sign', Parameter.Type.BOOL, False)])
def result_callback(response):
if response.success:
node.get_logger().info(response.message)
else:
node.get_logger().error(response.message)
rclpy.init(args=args)
node = Node("collect_data_client")
client = node.create_client(SetBool, "/collect_data")
while not client.wait_for_service(timeout_sec=1.0):
node.get_logger().info('Service not available, waiting again...')
node.declare_parameter("collect_sign", False)
node.declare_parameter("end_sign", False)
node.create_timer(0.5, timer_callback)
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()

View File

@@ -0,0 +1,408 @@
import os
import json
import cv2
import rclpy
from rclpy.node import Node
from cv_bridge import CvBridge
from ament_index_python import get_package_share_directory
from message_filters import Subscriber, ApproximateTimeSynchronizer
from sensor_msgs.msg import Image
from std_srvs.srv import SetBool
share_dir = get_package_share_directory('vision_detect')
class CollectDataNode(Node):
def __init__(self):
super().__init__("collect_data_node")
self.collect_sign = False
self.save_sign = False
self.cv_bridge = CvBridge()
self.left_sign = self.right_sign = self.head_sign = False
self.left_meta = self.right_meta = self.head_meta = False
self.left_depth_raw_file = None
self.right_depth_raw_file = None
self.head_depth_raw_file = None
self.left_timestamp_file = None
self.right_timestamp_file = None
self.head_timestamp_file = None
self.index = 0
self.fps = 30.0
self.left_color_writer = None
self.right_color_writer = None
self.head_color_writer = None
self.left_depth_writer = None
self.right_depth_writer = None
self.head_depth_writer = None
self.right_writer_initialized = False
self.left_writer_initialized = False
self.head_writer_initialized = False
with open(os.path.join(share_dir, 'configs/other_configs/collect_data_node.json'), 'r')as f:
configs = json.load(f)
home_path = os.path.expanduser("~")
self.save_path = os.path.join(home_path, configs['save_path'])
if self.save_path:
os.makedirs(self.save_path, exist_ok=True)
# left
if self.topic_exists(configs["left"]["color"]) and self.topic_exists(
configs["left"]["depth"]):
self.left_sign = True
self.sub_left_color = Subscriber(self, Image, configs["left"]["color"])
self.sub_left_depth = Subscriber(self, Image, configs["left"]["depth"])
self.left_sync_subscriber = ApproximateTimeSynchronizer(
[self.sub_left_color, self.sub_left_depth],
queue_size=10,
slop=0.1
)
self.left_sync_subscriber.registerCallback(self.left_callback)
# right
if self.topic_exists(configs["right"]["color"]) and self.topic_exists(
configs["right"]["depth"]):
self.right_sign = True
self.sub_right_color = Subscriber(self, Image, configs["right"]["color"])
self.sub_right_depth = Subscriber(self, Image, configs["right"]["depth"])
self.right_sync_subscriber = ApproximateTimeSynchronizer(
[self.sub_right_color, self.sub_right_depth],
queue_size=10,
slop=0.1
)
self.right_sync_subscriber.registerCallback(self.right_callback)
# head
if self.topic_exists(configs["head"]["color"]) and self.topic_exists(
configs["head"]["depth"]):
self.head_sign = True
self.sub_head_color = Subscriber(self, Image, configs["head"]["color"])
self.sub_head_depth = Subscriber(self, Image, configs["head"]["depth"])
self.head_sync_subscriber = ApproximateTimeSynchronizer(
[self.sub_head_color, self.sub_head_depth],
queue_size=10,
slop=0.1
)
self.head_sync_subscriber.registerCallback(self.head_callback)
self.service = self.create_service(SetBool, "/collect_data", self.service_callback)
def service_callback(self, request, response):
if request.data:
self.left_meta = self.right_meta = self.head_meta = False
while (os.path.exists(os.path.join(self.save_path,
f"left_color_video_{self.index:04d}.mp4"))
or os.path.exists(os.path.join(self.save_path,
f"right_color_video_{self.index:04d}.mp4"))
or os.path.exists(os.path.join(self.save_path,
f"head_color_video_{self.index:04d}.mp4"))):
self.index += 1
if self.left_sign:
self.left_depth_raw_file = open(
os.path.join(self.save_path, f"left_depth_raw_{self.index:04d}.raw"), 'ab')
self.left_timestamp_file = open(
os.path.join(self.save_path, f"left_timestamp_{self.index:04d}.txt"), 'a')
if self.right_sign:
self.right_depth_raw_file = open(
os.path.join(self.save_path, f"right_depth_raw_{self.index:04d}.raw"), 'ab')
self.right_timestamp_file = open(
os.path.join(self.save_path, f"right_timestamp_{self.index:04d}.txt"), 'a')
if self.head_sign:
self.head_depth_raw_file = open(
os.path.join(self.save_path, f"head_depth_raw_{self.index:04d}.raw"), 'ab')
self.head_timestamp_file = open(
os.path.join(self.save_path, f"head_timestamp_{self.index:04d}.txt"), 'a')
self.left_writer_initialized = False
self.right_writer_initialized = False
self.head_writer_initialized = False
self.collect_sign = True
response.success = True
response.message = "start collecting data"
return response
if self.left_sign:
if self.left_depth_raw_file is not None:
self.left_depth_raw_file.close()
self.left_depth_raw_file = None
if self.left_timestamp_file is not None:
self.left_timestamp_file.close()
self.left_timestamp_file = None
if self.right_sign:
if self.right_depth_raw_file is not None:
self.right_depth_raw_file.close()
self.right_depth_raw_file = None
if self.right_timestamp_file is not None:
self.right_timestamp_file.close()
self.right_timestamp_file = None
if self.head_sign:
if self.head_depth_raw_file is not None:
self.head_depth_raw_file.close()
self.head_depth_raw_file = None
if self.head_timestamp_file is not None:
self.head_timestamp_file.close()
self.head_timestamp_file = None
if self.left_color_writer is not None:
self.left_color_writer.release()
self.left_color_writer = None
if self.left_depth_writer is not None:
self.left_depth_writer.release()
self.left_depth_writer = None
if self.right_color_writer is not None:
self.right_color_writer.release()
self.right_color_writer = None
if self.right_depth_writer is not None:
self.right_depth_writer.release()
self.right_depth_writer = None
if self.head_color_writer is not None:
self.head_color_writer.release()
self.head_color_writer = None
if self.head_depth_writer is not None:
self.head_depth_writer.release()
self.head_depth_writer = None
self.collect_sign = False
self.save_sign = True
response.success = False
response.message = "stop collecting data"
return response
def left_callback(self, color_msg, depth_msg):
if self.collect_sign:
try:
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
except Exception as e:
self.get_logger().error(f'cv_bridge error: {e}')
return
if not self.left_writer_initialized:
ch, cw = color_frame.shape[:2]
dh, dw = depth_frame.shape[:2]
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
self.left_color_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"left_color_video_{self.index:04d}.mp4"),
color_fourcc,
self.fps,
(cw, ch)
)
self.left_depth_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"left_depth_video_{self.index:04d}.mp4"),
depth_fourcc,
self.fps,
(dw, dh)
)
if not self.left_color_writer.isOpened() or not self.left_depth_writer.isOpened():
self.get_logger().error('Failed to open VideoWriter')
return
self.left_writer_initialized = True
self.get_logger().info('VideoWriter initialized')
# RAW
if not depth_frame.flags['C_CONTIGUOUS']:
depth_frame = depth_frame.copy()
if self.left_depth_raw_file is not None:
self.left_depth_raw_file.write(depth_frame.tobytes())
if self.left_timestamp_file is not None:
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
self.left_timestamp_file.write(f"{ts}\n")
if not self.left_meta:
meta = {
"width": int(depth_frame.shape[1]),
"height": int(depth_frame.shape[0]),
"dtype": "uint16",
"endianness": "little",
"unit": "mm",
"frame_order": "row-major",
"fps": self.fps
}
with open(os.path.join(self.save_path, f"left_depth_meta_{self.index:04d}.json"), "w") as f:
json.dump(meta, f, indent=2)
self.left_meta = True
# MP4V
self.left_color_writer.write(color_frame)
depth_vis = cv2.convertScaleAbs(
depth_frame,
alpha=255.0 / 10000.0 # 10m根据相机改
)
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
self.left_depth_writer.write(depth_vis)
def right_callback(self, color_msg, depth_msg):
if self.collect_sign:
try:
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
except Exception as e:
self.get_logger().error(f'cv_bridge error: {e}')
return
if not self.right_writer_initialized:
ch, cw = color_frame.shape[:2]
dh, dw = depth_frame.shape[:2]
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
self.right_color_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"right_color_video_{self.index:04d}.mp4"),
color_fourcc,
self.fps,
(cw, ch)
)
self.right_depth_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"right_depth_video_{self.index:04d}.mp4"),
depth_fourcc,
self.fps,
(dw, dh)
)
if not self.right_color_writer.isOpened() or not self.right_depth_writer.isOpened():
self.get_logger().error('Failed to open VideoWriter')
return
self.right_writer_initialized = True
self.get_logger().info('VideoWriter initialized')
# RAW
if not depth_frame.flags['C_CONTIGUOUS']:
depth_frame = depth_frame.copy()
if self.right_depth_raw_file is not None:
self.right_depth_raw_file.write(depth_frame.tobytes())
if self.right_timestamp_file is not None:
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
self.right_timestamp_file.write(f"{ts}\n")
if not self.right_meta:
meta = {
"width": int(depth_frame.shape[1]),
"height": int(depth_frame.shape[0]),
"dtype": "uint16",
"endianness": "little",
"unit": "mm",
"frame_order": "row-major",
"fps": self.fps
}
with open(os.path.join(self.save_path, f"right_depth_meta_{self.index:04d}.json"), "w") as f:
json.dump(meta, f, indent=2)
self.right_meta = True
# MP4V
self.right_color_writer.write(color_frame)
depth_vis = cv2.convertScaleAbs(
depth_frame,
alpha=255.0 / 10000.0 # 10m根据相机改
)
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
self.right_depth_writer.write(depth_vis)
def head_callback(self, color_msg, depth_msg):
if self.collect_sign:
try:
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
except Exception as e:
self.get_logger().error(f'cv_bridge error: {e}')
return
if not self.head_writer_initialized:
ch, cw = color_frame.shape[:2]
dh, dw = depth_frame.shape[:2]
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
self.head_color_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"head_color_video_{self.index:04d}.mp4"),
color_fourcc,
self.fps,
(cw, ch)
)
self.head_depth_writer = cv2.VideoWriter(
os.path.join(self.save_path, f"head_depth_video_{self.index:04d}.mp4"),
depth_fourcc,
self.fps,
(dw, dh)
)
if not self.head_color_writer.isOpened() or not self.head_depth_writer.isOpened():
self.get_logger().error('Failed to open VideoWriter')
return
self.head_writer_initialized = True
self.get_logger().info('VideoWriter initialized')
# RAW
if not depth_frame.flags['C_CONTIGUOUS']:
depth_frame = depth_frame.copy()
if self.head_depth_raw_file is not None:
self.head_depth_raw_file.write(depth_frame.tobytes())
if self.head_timestamp_file is not None:
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
self.head_timestamp_file.write(f"{ts}\n")
if not self.head_meta:
meta = {
"width": int(depth_frame.shape[1]),
"height": int(depth_frame.shape[0]),
"dtype": "uint16",
"endianness": "little",
"unit": "mm",
"frame_order": "row-major",
"fps": self.fps
}
with open(os.path.join(self.save_path, f"head_depth_meta_{self.index:04d}.json"), "w") as f:
json.dump(meta, f, indent=2)
self.head_meta = True
# MP4V
self.head_color_writer.write(color_frame)
depth_vis = cv2.convertScaleAbs(
depth_frame,
alpha=255.0 / 10000.0 # 10m根据相机改
)
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
self.head_depth_writer.write(depth_vis)
def topic_exists(self, topic_name: str) -> bool:
topics = self.get_topic_names_and_types()
return any(name == topic_name for name, _ in topics)
# def destroy_node(self):
# self.get_logger().info('VideoWriter released')
# super().destroy_node()
def main(args=None):
rclpy.init(args=args)
node = CollectDataNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@@ -1,340 +0,0 @@
import os
import time
import json
import sys
from datetime import datetime
from collections import defaultdict
from ament_index_python.packages import get_package_share_directory
import cv2
import numpy as np
from cv_bridge import CvBridge
import torch
from ultralytics import YOLO
import rclpy
from rclpy.node import Node
from message_filters import ApproximateTimeSynchronizer, Subscriber
from sensor_msgs.msg import Image, CameraInfo
from geometry_msgs.msg import Pose, Point, Quaternion
from interfaces.msg import PoseClassAndID, PoseArrayClassAndID
sys.path.insert(0, "/home/nvidia/aidk/AIDK/release/lib_py/")
import flexivaidk as aidk
share_dir = get_package_share_directory("vision_detect")
def get_map(K, D, camera_size):
h, w = camera_size[::-1]
K = np.array(K).reshape(3, 3)
D = np.array(D)
new_K, _ = cv2.getOptimalNewCameraMatrix(K, D, (w, h), 1, (w, h))
map1, map2 = cv2.initUndistortRectifyMap(K, D, None, new_K, (w, h), cv2.CV_32FC1)
return map1, map2, new_K.flatten()
def distortion_correction(color_image, depth_image, map1, map2):
"""畸变矫正"""
undistorted_color = cv2.remap(color_image, map1, map2, cv2.INTER_LINEAR)
undistorted_color = undistorted_color.astype(color_image.dtype)
undistorted_depth = cv2.remap(depth_image, map1, map2, cv2.INTER_NEAREST)
undistorted_depth = undistorted_depth.astype(depth_image.dtype)
return undistorted_color, undistorted_depth
def crop_mask_bbox(rgb_img, depth_img, mask, box):
"""
输入:
depth_img: H x W
mask: H x W (0/1 或 bool)
输出:
depth_crop, mask_crop
"""
high, width = depth_img.shape
x_center, y_center, w, h = box[:4]
x_min, x_max = int(round(x_center - w/2)), int(round(x_center + w/2))
y_min, y_max = int(round(y_center - h/2)), int(round(y_center + h/2))
rgb_crop = rgb_img[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
depth_crop = depth_img[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
mask_crop = mask[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
return rgb_crop, depth_crop, mask_crop, (max(0, x_min), max(0, y_min))
class DetectNode(Node):
def __init__(self, name):
super().__init__(name)
self.checkpoint_path = None
self.checkpoint_name = None
self.function = None
self.output_boxes = None
self.output_masks = None
self.K = self.D = None
self.map1 = self.map2 = None
self.fx = self.fy = 1.0
self.cv_bridge = CvBridge()
self.aidk_client = aidk.AIDKClient('127.0.0.1', 10)
while not self.aidk_client.is_ready():
time.sleep(0.5)
'''init'''
self._init_param()
self._init_model()
self._init_config()
self._init_publisher()
self._init_subscriber()
self.get_logger().info("Init done")
def _init_param(self):
"""init parameter"""
self.declare_parameter('checkpoint_name', 'yolo11s-seg.pt')
self.checkpoint_name = self.get_parameter('checkpoint_name').value
self.checkpoint_path = os.path.join(share_dir, 'checkpoints', self.checkpoint_name)
self.declare_parameter('config_name', 'default_config.json')
self.config_name = self.get_parameter('config_name').value
self.config_dir = os.path.join(share_dir, 'configs/flexiv_configs', self.config_name)
self.declare_parameter('output_boxes', True)
self.output_boxes = self.get_parameter('output_boxes').value
self.declare_parameter('output_masks', True)
self.output_masks = self.get_parameter('output_masks').value
self.declare_parameter('set_confidence', 0.25)
self.set_confidence = self.get_parameter('set_confidence').value
self.declare_parameter('color_image_topic', '/camera/camera/color/image_raw')
self.color_image_topic = self.get_parameter('color_image_topic').value
self.declare_parameter('depth_image_topic', '/camera/camera/aligned_depth_to_color/image_raw')
self.depth_image_topic = self.get_parameter('depth_image_topic').value
self.declare_parameter('camera_info_topic', '/camera/camera/color/camera_info')
self.camera_info_topic = self.get_parameter('camera_info_topic').value
def _init_model(self):
"""init model"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
self.model = YOLO(self.checkpoint_path).to(device)
except Exception as e:
self.get_logger().error(f'Failed to load YOLO model: {e}')
raise
self.get_logger().info(f'Loading checkpoint from: {self.checkpoint_path}')
if self.checkpoint_name.endswith('-seg.pt'):
self.function = self._seg_image
else:
self.function = None
self.get_logger().error(f'Unknown checkpoint: {self.checkpoint_name}')
def _init_config(self):
"""init config"""
with open(self.config_dir, "r") as f:
self.config = json.load(f)
def _init_publisher(self):
"""init_publisher"""
self.pub_pose_list = self.create_publisher(PoseArrayClassAndID, '/pose/cv_detect_pose', 10)
if self.output_boxes or self.output_masks:
self.pub_detect_image = self.create_publisher(Image, '/image/detect_image', 10)
def _init_subscriber(self):
"""init_subscriber"""
self.sub_camera_info = self.create_subscription(
CameraInfo,
self.camera_info_topic,
self._camera_info_callback,
10
)
time.sleep(1)
'''sync get color and depth img'''
self.sub_color_image = Subscriber(self, Image, self.color_image_topic)
self.sub_depth_image = Subscriber(self, Image, self.depth_image_topic)
self.sync_subscriber = ApproximateTimeSynchronizer(
[self.sub_color_image, self.sub_depth_image],
queue_size=10,
slop=0.1
)
self.sync_subscriber.registerCallback(self._sync_callback)
def _camera_info_callback(self, msg: CameraInfo):
"""Get camera info"""
self.K = msg.k
self.D = msg.d
self.camera_size = [msg.width, msg.height]
if self.K is not None and self.D is not None:
self.map1, self.map2, self.K = get_map(msg.k, msg.d, self.camera_size)
if len(self.D) != 0:
self.destroy_subscription(self.sub_camera_info)
else:
self.D = [0, 0, 0, 0, 0, 0, 0, 0]
self.destroy_subscription(self.sub_camera_info)
else:
raise "K and D are not defined"
def _sync_callback(self, color_img_ros, depth_img_ros):
"""同步回调函数"""
color_img_cv = self.cv_bridge.imgmsg_to_cv2(color_img_ros, "bgr8")
depth_img_cv = self.cv_bridge.imgmsg_to_cv2(depth_img_ros, '16UC1')
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, self.map1, self.map2)
img, pose_dict = self.function(color_img_cv, depth_img_cv)
"""masks为空结束这一帧"""
if img is None:
img = self.cv_bridge.cv2_to_imgmsg(color_img_cv, "bgr8")
if self.output_boxes or self.output_masks:
self.pub_detect_image.publish(img)
if pose_dict is not None:
pose_list_all = PoseArrayClassAndID()
for (class_id, class_name), pose_list in pose_dict.items():
pose_list_all.objects.append(
PoseClassAndID(
class_name = class_name,
class_id = class_id,
pose_list = pose_list
)
)
pose_list_all.header.stamp = self.get_clock().now().to_msg()
pose_list_all.header.frame_id = "pose_list"
self.pub_pose_list.publish(pose_list_all)
def _seg_image(self, rgb_img: np.ndarray, depth_img: np.ndarray):
"""Use segmentation model"""
pose_dict = defaultdict(list)
'''Get Predict Results'''
time1 = time.time()
results = self.model(rgb_img, retina_masks=True, classes=[39], conf=self.set_confidence)
time2 = time.time()
result = results[0]
'''Get masks'''
if result.masks is None or len(result.masks) == 0:
return None, None
masks = result.masks.data.cpu().numpy()
orig_shape = result.masks.orig_shape
'''Get boxes'''
boxes = result.boxes.xywh.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy()
labels = result.names
time3 = time.time()
for i, (mask, box) in enumerate(zip(masks, boxes)):
mask = cv2.resize(mask.astype(np.uint8), orig_shape[::-1], interpolation=cv2.INTER_NEAREST)
rgb_crop, depth_crop, mask_crop, (x_min, y_min) = crop_mask_bbox(rgb_img, depth_img, mask, box)
if depth_crop is None:
continue
depth_img_crop_mask = np.zeros_like(depth_crop)
depth_img_crop_mask[mask_crop > 0] = depth_crop[mask_crop > 0]
print(rgb_crop.shape)
print(rgb_crop.dtype)
rgb_bytes = cv2.imencode('.png', rgb_crop)[1]
depth_bytes = cv2.imencode('.png', depth_img_crop_mask)[1]
intrinsics = [
int(self.camera_size[0]),
int(self.camera_size[1]),
self.K[2] - x_min,
self.K[5] - y_min,
self.K[0],
self.K[4]
]
state = self.aidk_client.detect_with_image(
obj_name=self.config["command"]["obj_name"],
camera_id=self.config["command"]["camera_id"],
coordinate_id=self.config["command"]["coordinate_id"],
camera_pose=self.config["command"]["camera_pose"],
camera_intrinsic=intrinsics,
rgb_input=aidk.ImageBuffer(rgb_bytes),
depth_input=aidk.ImageBuffer(depth_bytes),
custom=self.config["command"]["custom"],
)
self.get_logger().info(f"state: {state}")
self.get_logger().info(
f"current detected object names: {self.aidk_client.get_detected_obj_names()}, current detected object nums: {self.aidk_client.get_detected_obj_nums()}"
)
for key in self.config["keys"]:
parse_state, result_list = self.aidk_client.parse_result(self.config["command"]["obj_name"], key, -1)
self.get_logger().info(
"detected time stamp: {}".format(
datetime.fromtimestamp(self.aidk_client.get_detected_time())
)
)
if not parse_state:
self.get_logger().error("Parse result error!!!")
continue
else:
if key in ["bbox", "keypoints", "positions", "obj_pose"]:
for result in result_list:
for vec in result.vect:
self.get_logger().info(f"vec: {vec}")
x, y, z, rw, rx, ry, rz = vec
pose = Pose()
pose.position = Point(x=x, y=y, z=z)
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(pose)
elif key in ["valid", "double_value", "int_value", "name"]:
for result in result_list:
self.get_logger().info(f"{key}: {getattr(result, key)}")
time4 = time.time()
self.get_logger().info(f'start')
self.get_logger().info(f'{(time2 - time1) * 1000} ms, model predict')
self.get_logger().info(f'{(time3 - time2) * 1000} ms, get mask and some param')
self.get_logger().info(f'{(time4 - time3) * 1000} ms, calculate all mask PCA')
self.get_logger().info(f'{(time4 - time1) * 1000} ms, completing a picture entire process')
self.get_logger().info(f'end')
return None, pose_dict
def main(args=None):
rclpy.init(args=args)
node = DetectNode('detect')
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()