add rebuild service to skill bridge
This commit is contained in:
@@ -13,6 +13,7 @@ from std_msgs.msg import String
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from interfaces.action import ExecuteBtAction
|
||||
from interfaces.srv import BtRebuild
|
||||
|
||||
|
||||
class SkillBridgeNode(Node):
|
||||
@@ -20,6 +21,8 @@ class SkillBridgeNode(Node):
|
||||
super().__init__('skill_bridge_node')
|
||||
self._action_client = ActionClient(self, ExecuteBtAction, '/execute_bt_action')
|
||||
self._current_epoch = 1
|
||||
self.run_trigger_ = self.create_client(BtRebuild, '/cerebrum/rebuild_now')
|
||||
self.rebuild_requests = 0
|
||||
self._allowed_skills = self._load_allowed_skills()
|
||||
|
||||
self.skill_seq_sub = self.create_subscription(
|
||||
@@ -44,7 +47,17 @@ class SkillBridgeNode(Node):
|
||||
if not sequence:
|
||||
self.get_logger().warning(f"Invalid skill sequence: {raw}")
|
||||
return
|
||||
self._send_skill_sequence(sequence)
|
||||
# self._send_skill_sequence(sequence)
|
||||
|
||||
#判断如果sequence 中包含VisionObjectRecognition,Arm,GripperCmd0,Arm这几个actions,则调用rebuild_now
|
||||
if any(skill in sequence for skill in ["VisionObjectRecognition", "Arm", "GripperCmd0"]):
|
||||
self.get_logger().info(f"Skill sequence contains special skills, triggering rebuild: {sequence}")
|
||||
self.rebuild_now("Trigger", "bt_vision_grasp_dual_arm", "")
|
||||
else:
|
||||
#只发送逗号分隔符的第一个action
|
||||
first_skill = sequence.split(",")[0]
|
||||
self.get_logger().info(f"Sending first skill in sequence: {first_skill}")
|
||||
self.rebuild_now("Remote", first_skill, "")
|
||||
|
||||
def _load_allowed_skills(self) -> set[str]:
|
||||
try:
|
||||
@@ -121,6 +134,35 @@ class SkillBridgeNode(Node):
|
||||
msg.data = json.dumps(payload, ensure_ascii=True)
|
||||
self.result_pub.publish(msg)
|
||||
|
||||
def rebuild_now(self, type: str, config: str, param: str) -> None:
|
||||
if not self.run_trigger_.service_is_ready():
|
||||
self.get_logger().error('Rebuild service not ready')
|
||||
return
|
||||
|
||||
self.rebuild_requests += 1
|
||||
self.get_logger().info(f'Rebuild BehaviorTree now. Total requests: {self.rebuild_requests}')
|
||||
|
||||
request = BtRebuild.Request()
|
||||
request.type = type
|
||||
request.config = config
|
||||
request.param = param
|
||||
|
||||
self.get_logger().info(f'Calling rebuild service... request info: {request}')
|
||||
|
||||
future = self.run_trigger_.call_async(request)
|
||||
future.add_done_callback(self._rebuild_done_callback)
|
||||
|
||||
def _rebuild_done_callback(self, future):
|
||||
try:
|
||||
response = future.result()
|
||||
if response.success:
|
||||
self.get_logger().info('Rebuild request successful')
|
||||
else:
|
||||
self.get_logger().warning(f'Rebuild request failed: {response.message}')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Rebuild request exception: {str(e)}')
|
||||
|
||||
self.get_logger().info(f"Rebuild requested. Total rebuild requests: {str(self.rebuild_requests)}")
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
||||
@@ -141,7 +141,7 @@ class RobotSpeakerNode(Node):
|
||||
self.sv_enabled = system['sv_enabled']
|
||||
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
|
||||
self.sv_threshold = system['sv_threshold']
|
||||
self.sv_speaker_db_path = system['sv_speaker_db_path']
|
||||
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
|
||||
self.sv_buffer_size = system['sv_buffer_size']
|
||||
|
||||
camera = config['camera']
|
||||
|
||||
@@ -139,8 +139,8 @@ class DashScopeASR(ASRClient):
|
||||
# 触发 commit,等待 final 结果
|
||||
self.conversation.commit()
|
||||
|
||||
# 等待 final 回调完成(最多等待1秒)
|
||||
if self._final_result_event.wait(timeout=1.0):
|
||||
# 等待 final 回调完成(最多等待3秒)
|
||||
if self._final_result_event.wait(timeout=3.0):
|
||||
self._log("debug", "已收到 final 回调,准备关闭连接")
|
||||
else:
|
||||
self._log("warning", "等待 final 回调超时,继续执行")
|
||||
|
||||
@@ -107,11 +107,19 @@ class SpeakerVerificationClient:
|
||||
|
||||
temp_wav_path = None
|
||||
try:
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
# 限制Torch在推理时使用单线程,避免在多任务环境下(尤其是一边录音一边识别)
|
||||
# 出现的极端CPU竞争和上下文切换开销
|
||||
import torch
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
with torch.inference_mode():
|
||||
# 临时设置,虽然全局已经设置了,但在调用前再次确保
|
||||
# 注意:set_num_threads 是全局的,这里再次确认
|
||||
if torch.get_num_threads() != 1:
|
||||
torch.set_num_threads(1)
|
||||
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
|
||||
Reference in New Issue
Block a user