add rebuild service to skill bridge

This commit is contained in:
NuoDaJia02
2026-01-20 15:20:48 +08:00
parent 98c0eb5ca5
commit 04ca80c3f9
4 changed files with 58 additions and 8 deletions

View File

@@ -13,6 +13,7 @@ from std_msgs.msg import String
from ament_index_python.packages import get_package_share_directory
from interfaces.action import ExecuteBtAction
from interfaces.srv import BtRebuild
class SkillBridgeNode(Node):
@@ -20,6 +21,8 @@ class SkillBridgeNode(Node):
super().__init__('skill_bridge_node')
self._action_client = ActionClient(self, ExecuteBtAction, '/execute_bt_action')
self._current_epoch = 1
self.run_trigger_ = self.create_client(BtRebuild, '/cerebrum/rebuild_now')
self.rebuild_requests = 0
self._allowed_skills = self._load_allowed_skills()
self.skill_seq_sub = self.create_subscription(
@@ -44,7 +47,17 @@ class SkillBridgeNode(Node):
if not sequence:
self.get_logger().warning(f"Invalid skill sequence: {raw}")
return
self._send_skill_sequence(sequence)
# self._send_skill_sequence(sequence)
#判断如果sequence 中包含VisionObjectRecognition,Arm,GripperCmd0,Arm这几个actions则调用rebuild_now
if any(skill in sequence for skill in ["VisionObjectRecognition", "Arm", "GripperCmd0"]):
self.get_logger().info(f"Skill sequence contains special skills, triggering rebuild: {sequence}")
self.rebuild_now("Trigger", "bt_vision_grasp_dual_arm", "")
else:
#只发送逗号分隔符的第一个action
first_skill = sequence.split(",")[0]
self.get_logger().info(f"Sending first skill in sequence: {first_skill}")
self.rebuild_now("Remote", first_skill, "")
def _load_allowed_skills(self) -> set[str]:
try:
@@ -121,6 +134,35 @@ class SkillBridgeNode(Node):
msg.data = json.dumps(payload, ensure_ascii=True)
self.result_pub.publish(msg)
def rebuild_now(self, type: str, config: str, param: str) -> None:
if not self.run_trigger_.service_is_ready():
self.get_logger().error('Rebuild service not ready')
return
self.rebuild_requests += 1
self.get_logger().info(f'Rebuild BehaviorTree now. Total requests: {self.rebuild_requests}')
request = BtRebuild.Request()
request.type = type
request.config = config
request.param = param
self.get_logger().info(f'Calling rebuild service... request info: {request}')
future = self.run_trigger_.call_async(request)
future.add_done_callback(self._rebuild_done_callback)
def _rebuild_done_callback(self, future):
try:
response = future.result()
if response.success:
self.get_logger().info('Rebuild request successful')
else:
self.get_logger().warning(f'Rebuild request failed: {response.message}')
except Exception as e:
self.get_logger().error(f'Rebuild request exception: {str(e)}')
self.get_logger().info(f"Rebuild requested. Total rebuild requests: {str(self.rebuild_requests)}")
def main(args=None):

View File

@@ -141,7 +141,7 @@ class RobotSpeakerNode(Node):
self.sv_enabled = system['sv_enabled']
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
self.sv_threshold = system['sv_threshold']
self.sv_speaker_db_path = system['sv_speaker_db_path']
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
self.sv_buffer_size = system['sv_buffer_size']
camera = config['camera']

View File

@@ -139,8 +139,8 @@ class DashScopeASR(ASRClient):
# 触发 commit等待 final 结果
self.conversation.commit()
# 等待 final 回调完成(最多等待1秒)
if self._final_result_event.wait(timeout=1.0):
# 等待 final 回调完成(最多等待3秒)
if self._final_result_event.wait(timeout=3.0):
self._log("debug", "已收到 final 回调,准备关闭连接")
else:
self._log("warning", "等待 final 回调超时,继续执行")

View File

@@ -107,11 +107,19 @@ class SpeakerVerificationClient:
temp_wav_path = None
try:
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
# 限制Torch在推理时使用单线程避免在多任务环境下尤其是一边录音一边识别
# 出现的极端CPU竞争和上下文切换开销
import torch
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
with torch.inference_mode():
# 临时设置,虽然全局已经设置了,但在调用前再次确保
# 注意set_num_threads 是全局的,这里再次确认
if torch.get_num_threads() != 1:
torch.set_num_threads(1)
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
embedding_dim = len(embedding)
if embedding_dim == 0: