添加相机数据采集节点
This commit is contained in:
18
vision_detect/configs/other_onfigs/collect_data_node.json
Normal file
18
vision_detect/configs/other_onfigs/collect_data_node.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"save_path": "/collect_data",
|
||||
"left": {
|
||||
"color": "/camera/camera/color/image_raw",
|
||||
"depth": "/camera/camera/aligned_depth_to_color/image_raw",
|
||||
"info": "/camera/camera/color/camera_info"
|
||||
},
|
||||
"right": {
|
||||
"color": "/camera/color/image_raw",
|
||||
"depth": "/camera/depth/image_raw",
|
||||
"info": "/camera/color/camera_info"
|
||||
},
|
||||
"head": {
|
||||
"color": "",
|
||||
"depth": "",
|
||||
"info": ""
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ setup(
|
||||
('share/' + package_name + '/configs/flexiv_configs', glob('configs/flexiv_configs/*.json')),
|
||||
('share/' + package_name + '/configs/hand_eye_mat', glob('configs/hand_eye_mat/*.json')),
|
||||
('share/' + package_name + '/configs/launch_configs', glob('configs/launch_configs/*.json')),
|
||||
('share/' + package_name + '/configs/other_configs', glob('configs/other_configs/*.json')),
|
||||
|
||||
('share/' + package_name + '/checkpoints', glob('checkpoints/*.pt')),
|
||||
('share/' + package_name + '/checkpoints', glob('checkpoints/*.onnx')),
|
||||
|
||||
@@ -62,6 +62,9 @@ def calculate_pose_pca(
|
||||
return np.eye(4), [0.0, 0.0, 0.0]
|
||||
|
||||
if calculate_grab_width:
|
||||
if np.asarray(point_cloud.points).shape[0] < 4:
|
||||
logging.warning("点数不足,不能算 OBB")
|
||||
return np.eye(4), [0.0, 0.0, 0.0]
|
||||
obb = point_cloud.get_oriented_bounding_box()
|
||||
x, y, z = obb.center
|
||||
extent = obb.extent
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
import open3d as o3d
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"draw_box", "draw_mask", "draw_pointcloud",
|
||||
]
|
||||
|
||||
409
vision_detect/vision_detect/collect_data_node.py
Normal file
409
vision_detect/vision_detect/collect_data_node.py
Normal file
@@ -0,0 +1,409 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from cv_bridge import CvBridge
|
||||
from ament_index_python import get_package_share_directory
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from std_srvs.srv import SetBool
|
||||
|
||||
|
||||
share_dir = get_package_share_directory('vision_detect')
|
||||
|
||||
class CollectDataNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__("collect_data_node")
|
||||
self.collect_sign = False
|
||||
self.save_sign = False
|
||||
self.cv_bridge = CvBridge()
|
||||
|
||||
self.left_sign = self.right_sign = self.head_sign = False
|
||||
self.left_meta = self.right_meta = self.head_meta = False
|
||||
|
||||
self.left_depth_raw_file = None
|
||||
self.right_depth_raw_file = None
|
||||
self.head_depth_raw_file = None
|
||||
|
||||
self.left_timestamp_file = None
|
||||
self.right_timestamp_file = None
|
||||
self.head_timestamp_file = None
|
||||
|
||||
self.index = 0
|
||||
self.fps = 30.0
|
||||
|
||||
self.left_color_writer = None
|
||||
self.right_color_writer = None
|
||||
self.head_color_writer = None
|
||||
|
||||
self.left_depth_writer = None
|
||||
self.right_depth_writer = None
|
||||
self.head_depth_writer = None
|
||||
|
||||
self.right_writer_initialized = False
|
||||
self.left_writer_initialized = False
|
||||
self.head_writer_initialized = False
|
||||
|
||||
with open(os.path.join(share_dir, 'configs/other_configs/collect_data_node.json'), 'r')as f:
|
||||
configs = json.load(f)
|
||||
|
||||
home_path = os.path.expanduser("~")
|
||||
self.save_path = os.path.join(home_path, configs['save_path'])
|
||||
if self.save_path:
|
||||
os.makedirs(self.save_path, exist_ok=True)
|
||||
|
||||
# left
|
||||
if self.topic_exists(configs["left"]["color"]) and self.topic_exists(
|
||||
configs["left"]["depth"]):
|
||||
self.left_sign = True
|
||||
self.sub_left_color = Subscriber(self, Image, configs["left"]["color"])
|
||||
self.sub_left_depth = Subscriber(self, Image, configs["left"]["depth"])
|
||||
|
||||
self.left_sync_subscriber = ApproximateTimeSynchronizer(
|
||||
[self.sub_left_color, self.sub_left_depth],
|
||||
queue_size=10,
|
||||
slop=0.1
|
||||
)
|
||||
self.left_sync_subscriber.registerCallback(self.left_callback)
|
||||
|
||||
# right
|
||||
if self.topic_exists(configs["right"]["color"]) and self.topic_exists(
|
||||
configs["right"]["depth"]):
|
||||
self.right_sign = True
|
||||
self.sub_right_color = Subscriber(self, Image, configs["right"]["color"])
|
||||
self.sub_right_depth = Subscriber(self, Image, configs["right"]["depth"])
|
||||
|
||||
self.right_sync_subscriber = ApproximateTimeSynchronizer(
|
||||
[self.sub_right_color, self.sub_right_depth],
|
||||
queue_size=10,
|
||||
slop=0.1
|
||||
)
|
||||
self.right_sync_subscriber.registerCallback(self.right_callback)
|
||||
|
||||
# head
|
||||
if self.topic_exists(configs["head"]["color"]) and self.topic_exists(
|
||||
configs["head"]["depth"]):
|
||||
self.head_sign = True
|
||||
self.sub_head_color = Subscriber(self, Image, configs["head"]["color"])
|
||||
self.sub_head_depth = Subscriber(self, Image, configs["head"]["depth"])
|
||||
|
||||
self.head_sync_subscriber = ApproximateTimeSynchronizer(
|
||||
[self.sub_head_color, self.sub_head_depth],
|
||||
queue_size=10,
|
||||
slop=0.1
|
||||
)
|
||||
self.head_sync_subscriber.registerCallback(self.head_callback)
|
||||
|
||||
self.service = self.create_service(SetBool, "/collect_data", self.service_callback)
|
||||
|
||||
def service_callback(self, request, response):
|
||||
if request.data:
|
||||
self.left_meta = self.right_meta = self.head_meta = False
|
||||
while (os.path.exists(os.path.join(self.save_path,
|
||||
f"left_color_video_{self.index:04d}.mp4"))
|
||||
or os.path.exists(os.path.join(self.save_path,
|
||||
f"right_color_video_{self.index:04d}.mp4"))
|
||||
or os.path.exists(os.path.join(self.save_path,
|
||||
f"head_color_video_{self.index:04d}.mp4"))):
|
||||
self.index += 1
|
||||
|
||||
if self.left_sign:
|
||||
self.left_depth_raw_file = open(
|
||||
os.path.join(self.save_path, f"left_depth_raw_{self.index:04d}.raw"), 'ab')
|
||||
self.left_timestamp_file = open(
|
||||
os.path.join(self.save_path, f"left_timestamp_{self.index:04d}.txt"), 'a')
|
||||
if self.right_sign:
|
||||
self.right_depth_raw_file = open(
|
||||
os.path.join(self.save_path, f"right_depth_raw_{self.index:04d}.raw"), 'ab')
|
||||
self.right_timestamp_file = open(
|
||||
os.path.join(self.save_path, f"right_timestamp_{self.index:04d}.txt"), 'a')
|
||||
if self.head_sign:
|
||||
self.head_depth_raw_file = open(
|
||||
os.path.join(self.save_path, f"head_depth_raw_{self.index:04d}.raw"), 'ab')
|
||||
self.head_timestamp_file = open(
|
||||
os.path.join(self.save_path, f"head_timestamp_{self.index:04d}.txt"), 'a')
|
||||
|
||||
self.left_writer_initialized = False
|
||||
self.right_writer_initialized = False
|
||||
self.head_writer_initialized = False
|
||||
|
||||
self.collect_sign = True
|
||||
response.success = True
|
||||
response.message = "start collecting data"
|
||||
return response
|
||||
|
||||
if self.left_sign:
|
||||
if self.left_depth_raw_file is not None:
|
||||
self.left_depth_raw_file.close()
|
||||
self.left_depth_raw_file = None
|
||||
if self.left_timestamp_file is not None:
|
||||
self.left_timestamp_file.close()
|
||||
self.left_timestamp_file = None
|
||||
if self.right_sign:
|
||||
if self.right_depth_raw_file is not None:
|
||||
self.right_depth_raw_file.close()
|
||||
self.right_depth_raw_file = None
|
||||
if self.right_timestamp_file is not None:
|
||||
self.right_timestamp_file.close()
|
||||
self.right_timestamp_file = None
|
||||
if self.head_sign:
|
||||
if self.head_depth_raw_file is not None:
|
||||
self.head_depth_raw_file.close()
|
||||
self.head_depth_raw_file = None
|
||||
if self.head_timestamp_file is not None:
|
||||
self.head_timestamp_file.close()
|
||||
self.head_timestamp_file = None
|
||||
|
||||
if self.left_color_writer is not None:
|
||||
self.left_color_writer.release()
|
||||
self.left_color_writer = None
|
||||
if self.left_depth_writer is not None:
|
||||
self.left_depth_writer.release()
|
||||
self.left_depth_writer = None
|
||||
if self.right_color_writer is not None:
|
||||
self.right_color_writer.release()
|
||||
self.right_color_writer = None
|
||||
if self.right_depth_writer is not None:
|
||||
self.right_depth_writer.release()
|
||||
self.right_depth_writer = None
|
||||
if self.head_color_writer is not None:
|
||||
self.head_color_writer.release()
|
||||
self.head_color_writer = None
|
||||
if self.head_depth_writer is not None:
|
||||
self.head_depth_writer.release()
|
||||
self.head_depth_writer = None
|
||||
|
||||
self.collect_sign = False
|
||||
self.save_sign = True
|
||||
response.success = False
|
||||
response.message = "stop collecting data"
|
||||
return response
|
||||
|
||||
def left_callback(self, color_msg, depth_msg):
|
||||
if self.collect_sign:
|
||||
try:
|
||||
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
|
||||
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'cv_bridge error: {e}')
|
||||
return
|
||||
|
||||
if not self.left_writer_initialized:
|
||||
ch, cw = color_frame.shape[:2]
|
||||
dh, dw = depth_frame.shape[:2]
|
||||
|
||||
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
|
||||
self.left_color_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"left_color_video_{self.index:04d}.mp4"),
|
||||
color_fourcc,
|
||||
self.fps,
|
||||
(cw, ch)
|
||||
)
|
||||
self.left_depth_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"left_depth_video_{self.index:04d}.mp4"),
|
||||
depth_fourcc,
|
||||
self.fps,
|
||||
(dw, dh)
|
||||
)
|
||||
|
||||
if not self.left_color_writer.isOpened() or not self.left_depth_writer.isOpened():
|
||||
self.get_logger().error('Failed to open VideoWriter')
|
||||
return
|
||||
|
||||
self.left_writer_initialized = True
|
||||
self.get_logger().info('VideoWriter initialized')
|
||||
|
||||
# RAW
|
||||
if not depth_frame.flags['C_CONTIGUOUS']:
|
||||
depth_frame = depth_frame.copy()
|
||||
if self.left_depth_raw_file is not None:
|
||||
self.left_depth_raw_file.write(depth_frame.tobytes())
|
||||
if self.left_timestamp_file is not None:
|
||||
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
|
||||
self.left_timestamp_file.write(f"{ts}\n")
|
||||
|
||||
if not self.left_meta:
|
||||
meta = {
|
||||
"width": int(depth_frame.shape[1]),
|
||||
"height": int(depth_frame.shape[0]),
|
||||
"dtype": "uint16",
|
||||
"endianness": "little",
|
||||
"unit": "mm",
|
||||
"frame_order": "row-major",
|
||||
"fps": self.fps
|
||||
}
|
||||
with open(os.path.join(self.save_path, f"left_depth_meta_{self.index:04d}.json"), "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
self.left_meta = True
|
||||
|
||||
# MP4V
|
||||
self.left_color_writer.write(color_frame)
|
||||
depth_vis = cv2.convertScaleAbs(
|
||||
depth_frame,
|
||||
alpha=255.0 / 10000.0 # 10m,根据相机改
|
||||
)
|
||||
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
|
||||
self.left_depth_writer.write(depth_vis)
|
||||
|
||||
def right_callback(self, color_msg, depth_msg):
|
||||
if self.collect_sign:
|
||||
try:
|
||||
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
|
||||
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'cv_bridge error: {e}')
|
||||
return
|
||||
|
||||
if not self.right_writer_initialized:
|
||||
ch, cw = color_frame.shape[:2]
|
||||
dh, dw = depth_frame.shape[:2]
|
||||
|
||||
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
|
||||
self.right_color_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"right_color_video_{self.index:04d}.mp4"),
|
||||
color_fourcc,
|
||||
self.fps,
|
||||
(cw, ch)
|
||||
)
|
||||
self.right_depth_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"right_depth_video_{self.index:04d}.mp4"),
|
||||
depth_fourcc,
|
||||
self.fps,
|
||||
(dw, dh)
|
||||
)
|
||||
|
||||
if not self.right_color_writer.isOpened() or not self.right_depth_writer.isOpened():
|
||||
self.get_logger().error('Failed to open VideoWriter')
|
||||
return
|
||||
|
||||
self.right_writer_initialized = True
|
||||
self.get_logger().info('VideoWriter initialized')
|
||||
|
||||
# RAW
|
||||
if not depth_frame.flags['C_CONTIGUOUS']:
|
||||
depth_frame = depth_frame.copy()
|
||||
if self.right_depth_raw_file is not None:
|
||||
self.right_depth_raw_file.write(depth_frame.tobytes())
|
||||
if self.right_timestamp_file is not None:
|
||||
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
|
||||
self.right_timestamp_file.write(f"{ts}\n")
|
||||
|
||||
if not self.right_meta:
|
||||
meta = {
|
||||
"width": int(depth_frame.shape[1]),
|
||||
"height": int(depth_frame.shape[0]),
|
||||
"dtype": "uint16",
|
||||
"endianness": "little",
|
||||
"unit": "mm",
|
||||
"frame_order": "row-major",
|
||||
"fps": self.fps
|
||||
}
|
||||
with open(os.path.join(self.save_path, f"right_depth_meta_{self.index:04d}.json"), "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
self.right_meta = True
|
||||
|
||||
# MP4V
|
||||
self.right_color_writer.write(color_frame)
|
||||
depth_vis = cv2.convertScaleAbs(
|
||||
depth_frame,
|
||||
alpha=255.0 / 10000.0 # 10m,根据相机改
|
||||
)
|
||||
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
|
||||
self.right_depth_writer.write(depth_vis)
|
||||
|
||||
def head_callback(self, color_msg, depth_msg):
|
||||
if self.collect_sign:
|
||||
try:
|
||||
color_frame = self.cv_bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
|
||||
depth_frame = self.cv_bridge.imgmsg_to_cv2(depth_msg, desired_encoding='uint16')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'cv_bridge error: {e}')
|
||||
return
|
||||
|
||||
if not self.head_writer_initialized:
|
||||
ch, cw = color_frame.shape[:2]
|
||||
dh, dw = depth_frame.shape[:2]
|
||||
|
||||
color_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
depth_fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
||||
|
||||
self.head_color_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"head_color_video_{self.index:04d}.mp4"),
|
||||
color_fourcc,
|
||||
self.fps,
|
||||
(cw, ch)
|
||||
)
|
||||
self.head_depth_writer = cv2.VideoWriter(
|
||||
os.path.join(self.save_path, f"head_depth_video_{self.index:04d}.mp4"),
|
||||
depth_fourcc,
|
||||
self.fps,
|
||||
(dw, dh)
|
||||
)
|
||||
|
||||
if not self.head_color_writer.isOpened() or not self.head_depth_writer.isOpened():
|
||||
self.get_logger().error('Failed to open VideoWriter')
|
||||
return
|
||||
|
||||
self.head_writer_initialized = True
|
||||
self.get_logger().info('VideoWriter initialized')
|
||||
|
||||
# RAW
|
||||
if not depth_frame.flags['C_CONTIGUOUS']:
|
||||
depth_frame = depth_frame.copy()
|
||||
if self.head_depth_raw_file is not None:
|
||||
self.head_depth_raw_file.write(depth_frame.tobytes())
|
||||
if self.head_timestamp_file is not None:
|
||||
ts = depth_msg.header.stamp.sec * 1_000_000_000 + depth_msg.header.stamp.nanosec
|
||||
self.head_timestamp_file.write(f"{ts}\n")
|
||||
|
||||
if not self.head_meta:
|
||||
meta = {
|
||||
"width": int(depth_frame.shape[1]),
|
||||
"height": int(depth_frame.shape[0]),
|
||||
"dtype": "uint16",
|
||||
"endianness": "little",
|
||||
"unit": "mm",
|
||||
"frame_order": "row-major",
|
||||
"fps": self.fps
|
||||
}
|
||||
with open(os.path.join(self.save_path, f"head_depth_meta_{self.index:04d}.json"), "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
self.head_meta = True
|
||||
|
||||
# MP4V
|
||||
self.head_color_writer.write(color_frame)
|
||||
depth_vis = cv2.convertScaleAbs(
|
||||
depth_frame,
|
||||
alpha=255.0 / 10000.0 # 10m,根据相机改
|
||||
)
|
||||
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR)
|
||||
self.head_depth_writer.write(depth_vis)
|
||||
|
||||
def topic_exists(self, topic_name: str) -> bool:
|
||||
topics = self.get_topic_names_and_types()
|
||||
return any(name == topic_name for name, _ in topics)
|
||||
|
||||
# def destroy_node(self):
|
||||
# self.get_logger().info('VideoWriter released')
|
||||
# super().destroy_node()
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = CollectDataNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@@ -1,340 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from message_filters import ApproximateTimeSynchronizer, Subscriber
|
||||
|
||||
from sensor_msgs.msg import Image, CameraInfo
|
||||
from geometry_msgs.msg import Pose, Point, Quaternion
|
||||
from interfaces.msg import PoseClassAndID, PoseArrayClassAndID
|
||||
|
||||
sys.path.insert(0, "/home/nvidia/aidk/AIDK/release/lib_py/")
|
||||
import flexivaidk as aidk
|
||||
|
||||
share_dir = get_package_share_directory("vision_detect")
|
||||
|
||||
|
||||
def get_map(K, D, camera_size):
|
||||
h, w = camera_size[::-1]
|
||||
K = np.array(K).reshape(3, 3)
|
||||
D = np.array(D)
|
||||
new_K, _ = cv2.getOptimalNewCameraMatrix(K, D, (w, h), 1, (w, h))
|
||||
map1, map2 = cv2.initUndistortRectifyMap(K, D, None, new_K, (w, h), cv2.CV_32FC1)
|
||||
|
||||
return map1, map2, new_K.flatten()
|
||||
|
||||
|
||||
def distortion_correction(color_image, depth_image, map1, map2):
|
||||
"""畸变矫正"""
|
||||
undistorted_color = cv2.remap(color_image, map1, map2, cv2.INTER_LINEAR)
|
||||
undistorted_color = undistorted_color.astype(color_image.dtype)
|
||||
|
||||
undistorted_depth = cv2.remap(depth_image, map1, map2, cv2.INTER_NEAREST)
|
||||
undistorted_depth = undistorted_depth.astype(depth_image.dtype)
|
||||
|
||||
return undistorted_color, undistorted_depth
|
||||
|
||||
|
||||
def crop_mask_bbox(rgb_img, depth_img, mask, box):
|
||||
"""
|
||||
输入:
|
||||
depth_img: H x W
|
||||
mask: H x W (0/1 或 bool)
|
||||
输出:
|
||||
depth_crop, mask_crop
|
||||
"""
|
||||
high, width = depth_img.shape
|
||||
x_center, y_center, w, h = box[:4]
|
||||
|
||||
x_min, x_max = int(round(x_center - w/2)), int(round(x_center + w/2))
|
||||
y_min, y_max = int(round(y_center - h/2)), int(round(y_center + h/2))
|
||||
|
||||
rgb_crop = rgb_img[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
|
||||
depth_crop = depth_img[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
|
||||
mask_crop = mask[max(0, y_min):min(y_max, high) + 1, max(0, x_min):min(x_max, width) + 1]
|
||||
|
||||
return rgb_crop, depth_crop, mask_crop, (max(0, x_min), max(0, y_min))
|
||||
|
||||
class DetectNode(Node):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self.checkpoint_path = None
|
||||
self.checkpoint_name = None
|
||||
self.function = None
|
||||
self.output_boxes = None
|
||||
self.output_masks = None
|
||||
self.K = self.D = None
|
||||
self.map1 = self.map2 = None
|
||||
|
||||
self.fx = self.fy = 1.0
|
||||
self.cv_bridge = CvBridge()
|
||||
self.aidk_client = aidk.AIDKClient('127.0.0.1', 10)
|
||||
while not self.aidk_client.is_ready():
|
||||
time.sleep(0.5)
|
||||
|
||||
'''init'''
|
||||
self._init_param()
|
||||
self._init_model()
|
||||
self._init_config()
|
||||
self._init_publisher()
|
||||
self._init_subscriber()
|
||||
|
||||
self.get_logger().info("Init done")
|
||||
|
||||
def _init_param(self):
|
||||
"""init parameter"""
|
||||
self.declare_parameter('checkpoint_name', 'yolo11s-seg.pt')
|
||||
self.checkpoint_name = self.get_parameter('checkpoint_name').value
|
||||
self.checkpoint_path = os.path.join(share_dir, 'checkpoints', self.checkpoint_name)
|
||||
|
||||
self.declare_parameter('config_name', 'default_config.json')
|
||||
self.config_name = self.get_parameter('config_name').value
|
||||
self.config_dir = os.path.join(share_dir, 'configs/flexiv_configs', self.config_name)
|
||||
|
||||
self.declare_parameter('output_boxes', True)
|
||||
self.output_boxes = self.get_parameter('output_boxes').value
|
||||
|
||||
self.declare_parameter('output_masks', True)
|
||||
self.output_masks = self.get_parameter('output_masks').value
|
||||
|
||||
self.declare_parameter('set_confidence', 0.25)
|
||||
self.set_confidence = self.get_parameter('set_confidence').value
|
||||
|
||||
self.declare_parameter('color_image_topic', '/camera/camera/color/image_raw')
|
||||
self.color_image_topic = self.get_parameter('color_image_topic').value
|
||||
|
||||
self.declare_parameter('depth_image_topic', '/camera/camera/aligned_depth_to_color/image_raw')
|
||||
self.depth_image_topic = self.get_parameter('depth_image_topic').value
|
||||
|
||||
self.declare_parameter('camera_info_topic', '/camera/camera/color/camera_info')
|
||||
self.camera_info_topic = self.get_parameter('camera_info_topic').value
|
||||
|
||||
|
||||
def _init_model(self):
|
||||
"""init model"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
try:
|
||||
self.model = YOLO(self.checkpoint_path).to(device)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Failed to load YOLO model: {e}')
|
||||
raise
|
||||
self.get_logger().info(f'Loading checkpoint from: {self.checkpoint_path}')
|
||||
|
||||
if self.checkpoint_name.endswith('-seg.pt'):
|
||||
self.function = self._seg_image
|
||||
else:
|
||||
self.function = None
|
||||
self.get_logger().error(f'Unknown checkpoint: {self.checkpoint_name}')
|
||||
|
||||
def _init_config(self):
|
||||
"""init config"""
|
||||
with open(self.config_dir, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def _init_publisher(self):
|
||||
"""init_publisher"""
|
||||
self.pub_pose_list = self.create_publisher(PoseArrayClassAndID, '/pose/cv_detect_pose', 10)
|
||||
|
||||
if self.output_boxes or self.output_masks:
|
||||
self.pub_detect_image = self.create_publisher(Image, '/image/detect_image', 10)
|
||||
|
||||
def _init_subscriber(self):
|
||||
"""init_subscriber"""
|
||||
self.sub_camera_info = self.create_subscription(
|
||||
CameraInfo,
|
||||
self.camera_info_topic,
|
||||
self._camera_info_callback,
|
||||
10
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
'''sync get color and depth img'''
|
||||
self.sub_color_image = Subscriber(self, Image, self.color_image_topic)
|
||||
self.sub_depth_image = Subscriber(self, Image, self.depth_image_topic)
|
||||
|
||||
self.sync_subscriber = ApproximateTimeSynchronizer(
|
||||
[self.sub_color_image, self.sub_depth_image],
|
||||
queue_size=10,
|
||||
slop=0.1
|
||||
)
|
||||
self.sync_subscriber.registerCallback(self._sync_callback)
|
||||
|
||||
def _camera_info_callback(self, msg: CameraInfo):
|
||||
"""Get camera info"""
|
||||
self.K = msg.k
|
||||
self.D = msg.d
|
||||
|
||||
self.camera_size = [msg.width, msg.height]
|
||||
|
||||
if self.K is not None and self.D is not None:
|
||||
self.map1, self.map2, self.K = get_map(msg.k, msg.d, self.camera_size)
|
||||
|
||||
if len(self.D) != 0:
|
||||
self.destroy_subscription(self.sub_camera_info)
|
||||
else:
|
||||
self.D = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
self.destroy_subscription(self.sub_camera_info)
|
||||
else:
|
||||
raise "K and D are not defined"
|
||||
|
||||
|
||||
def _sync_callback(self, color_img_ros, depth_img_ros):
|
||||
"""同步回调函数"""
|
||||
color_img_cv = self.cv_bridge.imgmsg_to_cv2(color_img_ros, "bgr8")
|
||||
depth_img_cv = self.cv_bridge.imgmsg_to_cv2(depth_img_ros, '16UC1')
|
||||
|
||||
color_img_cv, depth_img_cv = distortion_correction(color_img_cv, depth_img_cv, self.map1, self.map2)
|
||||
|
||||
img, pose_dict = self.function(color_img_cv, depth_img_cv)
|
||||
|
||||
"""masks为空,结束这一帧"""
|
||||
if img is None:
|
||||
img = self.cv_bridge.cv2_to_imgmsg(color_img_cv, "bgr8")
|
||||
|
||||
if self.output_boxes or self.output_masks:
|
||||
self.pub_detect_image.publish(img)
|
||||
|
||||
if pose_dict is not None:
|
||||
pose_list_all = PoseArrayClassAndID()
|
||||
for (class_id, class_name), pose_list in pose_dict.items():
|
||||
pose_list_all.objects.append(
|
||||
PoseClassAndID(
|
||||
class_name = class_name,
|
||||
class_id = class_id,
|
||||
pose_list = pose_list
|
||||
)
|
||||
)
|
||||
pose_list_all.header.stamp = self.get_clock().now().to_msg()
|
||||
pose_list_all.header.frame_id = "pose_list"
|
||||
self.pub_pose_list.publish(pose_list_all)
|
||||
|
||||
def _seg_image(self, rgb_img: np.ndarray, depth_img: np.ndarray):
|
||||
"""Use segmentation model"""
|
||||
pose_dict = defaultdict(list)
|
||||
|
||||
'''Get Predict Results'''
|
||||
time1 = time.time()
|
||||
results = self.model(rgb_img, retina_masks=True, classes=[39], conf=self.set_confidence)
|
||||
time2 = time.time()
|
||||
result = results[0]
|
||||
|
||||
'''Get masks'''
|
||||
if result.masks is None or len(result.masks) == 0:
|
||||
return None, None
|
||||
masks = result.masks.data.cpu().numpy()
|
||||
orig_shape = result.masks.orig_shape
|
||||
|
||||
'''Get boxes'''
|
||||
boxes = result.boxes.xywh.cpu().numpy()
|
||||
class_ids = result.boxes.cls.cpu().numpy()
|
||||
labels = result.names
|
||||
|
||||
time3 = time.time()
|
||||
|
||||
for i, (mask, box) in enumerate(zip(masks, boxes)):
|
||||
mask = cv2.resize(mask.astype(np.uint8), orig_shape[::-1], interpolation=cv2.INTER_NEAREST)
|
||||
rgb_crop, depth_crop, mask_crop, (x_min, y_min) = crop_mask_bbox(rgb_img, depth_img, mask, box)
|
||||
|
||||
if depth_crop is None:
|
||||
continue
|
||||
|
||||
depth_img_crop_mask = np.zeros_like(depth_crop)
|
||||
depth_img_crop_mask[mask_crop > 0] = depth_crop[mask_crop > 0]
|
||||
|
||||
print(rgb_crop.shape)
|
||||
print(rgb_crop.dtype)
|
||||
|
||||
|
||||
rgb_bytes = cv2.imencode('.png', rgb_crop)[1]
|
||||
depth_bytes = cv2.imencode('.png', depth_img_crop_mask)[1]
|
||||
|
||||
intrinsics = [
|
||||
int(self.camera_size[0]),
|
||||
int(self.camera_size[1]),
|
||||
self.K[2] - x_min,
|
||||
self.K[5] - y_min,
|
||||
self.K[0],
|
||||
self.K[4]
|
||||
]
|
||||
|
||||
state = self.aidk_client.detect_with_image(
|
||||
obj_name=self.config["command"]["obj_name"],
|
||||
camera_id=self.config["command"]["camera_id"],
|
||||
coordinate_id=self.config["command"]["coordinate_id"],
|
||||
camera_pose=self.config["command"]["camera_pose"],
|
||||
camera_intrinsic=intrinsics,
|
||||
rgb_input=aidk.ImageBuffer(rgb_bytes),
|
||||
depth_input=aidk.ImageBuffer(depth_bytes),
|
||||
custom=self.config["command"]["custom"],
|
||||
)
|
||||
|
||||
self.get_logger().info(f"state: {state}")
|
||||
self.get_logger().info(
|
||||
f"current detected object names: {self.aidk_client.get_detected_obj_names()}, current detected object nums: {self.aidk_client.get_detected_obj_nums()}"
|
||||
)
|
||||
|
||||
for key in self.config["keys"]:
|
||||
parse_state, result_list = self.aidk_client.parse_result(self.config["command"]["obj_name"], key, -1)
|
||||
self.get_logger().info(
|
||||
"detected time stamp: {}".format(
|
||||
datetime.fromtimestamp(self.aidk_client.get_detected_time())
|
||||
)
|
||||
)
|
||||
if not parse_state:
|
||||
self.get_logger().error("Parse result error!!!")
|
||||
continue
|
||||
else:
|
||||
if key in ["bbox", "keypoints", "positions", "obj_pose"]:
|
||||
for result in result_list:
|
||||
for vec in result.vect:
|
||||
self.get_logger().info(f"vec: {vec}")
|
||||
x, y, z, rw, rx, ry, rz = vec
|
||||
|
||||
pose = Pose()
|
||||
pose.position = Point(x=x, y=y, z=z)
|
||||
pose.orientation = Quaternion(w=rw, x=rx, y=ry, z=rz)
|
||||
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(pose)
|
||||
|
||||
elif key in ["valid", "double_value", "int_value", "name"]:
|
||||
for result in result_list:
|
||||
self.get_logger().info(f"{key}: {getattr(result, key)}")
|
||||
|
||||
time4 = time.time()
|
||||
|
||||
self.get_logger().info(f'start')
|
||||
self.get_logger().info(f'{(time2 - time1) * 1000} ms, model predict')
|
||||
self.get_logger().info(f'{(time3 - time2) * 1000} ms, get mask and some param')
|
||||
self.get_logger().info(f'{(time4 - time3) * 1000} ms, calculate all mask PCA')
|
||||
self.get_logger().info(f'{(time4 - time1) * 1000} ms, completing a picture entire process')
|
||||
self.get_logger().info(f'end')
|
||||
|
||||
return None, pose_dict
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = DetectNode('detect')
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user