change folder structure

This commit is contained in:
liangyuxuan
2025-08-08 11:46:28 +08:00
parent 7238e23159
commit 2524b7db9a
2 changed files with 19 additions and 249 deletions

View File

@@ -187,6 +187,8 @@ class DetectNode(Node):
if len(valid_depths) == 0:
self.get_logger().warning(f'No valid depth in window at ({u}, {v})')
return 0.0, 0.0, 0.0
valid_depths = self.iqr(valid_depths)
# x = depth_img[v, u] / 1e3
depth = np.median(valid_depths) / 1e3
@@ -212,6 +214,23 @@ class DetectNode(Node):
int(max(0, u + w - half)):int(min(u + w + half + 1, depth_img.shape[1])):2].flatten())
return np.concatenate(patch)
def iqr(self, depths, threshold: float = 1.5):
if len(depths) < 7:
return depths
q1 = np.percentile(depths, 25)
q3 = np.percentile(depths, 75)
iqr = q3 - q1
lower_bound = q1 - iqr * threshold
upper_bound = q3 + iqr * threshold
iqr_depths = depths[(depths >= lower_bound) & (depths <= upper_bound)]
return iqr_depths
def main(args=None):

View File

@@ -1,249 +0,0 @@
import os
import random
from collections import defaultdict
import cv2
import numpy as np
from cv_bridge import CvBridge
import torch
from detect_part.ultralytics import YOLO
import rclpy
from rclpy.node import Node
from message_filters import ApproximateTimeSynchronizer, Subscriber
from sensor_msgs.msg import Image, CameraInfo
from geometry_msgs.msg import Point
from interfaces.msg import PointWithID, PointWithIDArray
import detect_part
class DetectNode(Node):
def __init__(self, name):
super().__init__(name)
self.depth_image = None
self.fx = self.fy = self.cx = self.cy = None
self.checkpoint_path = None
self._init_param()
self._init_model()
self._init_subscriber()
self.cv_bridge = CvBridge()
self.pub_pose_list = self.create_publisher(PointWithIDArray, '/pose/cv_detect_pose', 10)
if self.output_detect_image:
self.pub_detect_image = self.create_publisher(Image, '/image/detect_image', 10)
def _init_param(self):
pkg_dir = os.path.dirname(detect_part.__file__)
self.declare_parameter('checkpoint_name', 'yolo11s.pt')
checkpoint_name = self.get_parameter('checkpoint_name').value
self.checkpoint_path = os.path.join(pkg_dir, 'checkpoint', checkpoint_name)
self.declare_parameter('output_detect_image', False)
self.output_detect_image = self.get_parameter('output_detect_image').value
self.declare_parameter('set_confidence', 0.6)
self.set_confidence = self.get_parameter('set_confidence').value
def _init_model(self):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
self.model = YOLO(self.checkpoint_path).to(device)
except Exception as e:
self.get_logger().error(f'Failed to load YOLO model: {e}')
raise
self.get_logger().info(f'Loading checkpoint from: {self.checkpoint_path}')
def _init_subscriber(self):
self.sub_camera_info = self.create_subscription(
CameraInfo,
'/camera/color/camera_info',
self._camera_info_callback,
10
)
# sync get color and depth img
self.sub_color_image = Subscriber(self, Image, '/camera/color/image_raw')
self.sub_depth_image = Subscriber(self, Image, '/camera/depth/image_raw')
self.sync_subscriber = ApproximateTimeSynchronizer(
[self.sub_color_image, self.sub_depth_image],
queue_size=10,
slop=0.1
)
self.sync_subscriber.registerCallback(self._sync_callback)
def _sync_callback(self, color_img_ros, depth_img_ros):
if None in (self.fx, self.fy, self.cx, self.cy):
self.get_logger().warn('Camera intrinsics not yet received. Waiting...')
return
color_img_cv = self.cv_bridge.imgmsg_to_cv2(color_img_ros, "bgr8")
depth_img_cv = self.cv_bridge.imgmsg_to_cv2(depth_img_ros, '16UC1')
# self.get_logger().info(f'Color image: {color_img_cv.shape}')
# self.get_logger().info(f'Depth image: {depth_img_cv.shape}')
if self.output_detect_image:
detect_img, pose_dict = self._detect_image(color_img_cv, depth_img_cv)
self.pub_detect_image.publish(detect_img)
else:
pose_dict = self._detect_image(color_img_cv)
pose_list_all = PointWithIDArray()
for (class_id, class_name), points in pose_dict.items():
pose_list_all.objects.append(
PointWithID(
class_name = class_name,
class_id = class_id,
points = points
)
)
self.pub_pose_list.publish(pose_list_all)
def _camera_info_callback(self, msg: CameraInfo):
self.fx = msg.k[0]
self.fy = msg.k[4]
self.cx = msg.k[2]
self.cy = msg.k[5]
self.destroy_subscription(self.sub_camera_info)
def _detect_image(self, rgb_img: np.ndarray, depth_img: np.ndarray):
pose_dict = defaultdict(list)
img_h, img_w = rgb_img.shape[:2]
border_x = img_w * 0.12
border_y = img_h * 0.10
results = self.model(rgb_img)
result = results[0]
boxes = result.boxes.xywh.cpu().numpy()
confidences = result.boxes.conf.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy()
labels = result.names
for i, box in enumerate(boxes):
if confidences[i] < self.set_confidence:
continue
else:
x_center, y_center, width, height = box[:4]
if not (border_x <= x_center <= img_w - border_x and border_y <= y_center <= img_h - border_y):
# self.get_logger().info(f'Skipping object near edge at ({x_center},{y_center})')
continue
x, y, z = self._calculate_coordinate(depth_img, x_center, y_center, width, height)
if (x, y, z) != (0.0, 0.0, 0.0):
pose_dict[int(class_ids[i]), labels[class_ids[i]]].append(Point(x=x, y=y, z=z))
if self.output_detect_image:
self._draw_box(rgb_img, boxes, confidences, class_ids, labels)
return self.cv_bridge.cv2_to_imgmsg(rgb_img, "bgr8"), pose_dict
else:
return pose_dict
def _draw_box(self, rgb_img, boxes, confidences, class_ids, labels):
for i, box in enumerate(boxes):
if confidences[i] < self.set_confidence:
continue
else:
x_center, y_center, width, height = box[:4]
p1 = [int((x_center - width / 2)), int((y_center - height / 2))]
p2 = [int((x_center + width / 2)), int((y_center + height / 2))]
cv2.rectangle(rgb_img, p1, p2, (255, 255, 0), 2)
cv2.putText(rgb_img, f'{labels[class_ids[i]]}: {confidences[i]*100:.2f}', (p1[0], p1[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
# cv2.putText(rgb_img, f'cs:{x:.2f} {y:.2f} {z:.2f}', (p1[0], p1[1] - 35),
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
def _calculate_coordinate(self, depth_img: np.ndarray, u, v, width: int, height: int) -> tuple[float, float, float]:
u = int(round(u))
v = int(round(v))
if not (0 <= v < depth_img.shape[0] and 0 <= u < depth_img.shape[1]):
self.get_logger().warning(f'Calculate coordinate error: u={u}, v={v}')
return 0.0, 0.0, 0.0
window_size = 39
half = window_size // 2
patch = depth_img[max(0, v - half):v + half + 1, max(0, u - half):u + half + 1:2].flatten()
valid_depths = patch[patch > 0]
if len(valid_depths) == 0:
patch = self._get_nine_patch_samples(depth_img, u, v, width, height)
valid_depths = patch[patch > 0]
if len(valid_depths) == 0:
self.get_logger().warning(f'No valid depth in window at ({u}, {v})')
return 0.0, 0.0, 0.0
valid_depths = self.iqr(valid_depths)
# x = depth_img[v, u] / 1e3
depth = np.median(valid_depths) / 1e3
x = depth
y = -(u - self.cx) * x / self.fx
z = -(v - self.cy) * x / self.fy
return x, y, z
def _get_nine_patch_samples(self, depth_img, u, v, width, height):
ws = [int(round(random.uniform(-0.05*width, 0.05*width)-width/4)), 0, int(round(random.uniform(-0.05*width, 0.05*width)+width/4))]
hs = [int(round(random.uniform(-0.05*height, 0.05*height)-height/4)), 0, int(round(random.uniform(-0.05*height, 0.05*height)+height/4))]
window_size = 25
half = window_size // 2
patch = []
for w in ws:
for h in hs:
patch.append(
depth_img[int(max(0, v + h - half)):int(min(v + h + half + 1, depth_img.shape[0])):2,
int(max(0, u + w - half)):int(min(u + w + half + 1, depth_img.shape[1])):2].flatten())
return np.concatenate(patch)
def iqr(self, depths, threshold: float = 1.5):
if len(depths) < 7:
return depths
q1 = np.percentile(depths, 25)
q3 = np.percentile(depths, 75)
iqr = q3 - q1
lower_bound = q1 - iqr * threshold
upper_bound = q3 + iqr * threshold
iqr_depths = depths[(depths >= lower_bound) & (depths <= upper_bound)]
return iqr_depths
def main(args=None):
rclpy.init(args=args)
node = DetectNode('detect')
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()