14 Commits

Author SHA1 Message Date
lxy
a0ceb934ce 修复1在WebSocket回调线程内执行stop/start竞争条件,'socket already closed'循环出现,2陈旧结果5秒复用窗口旧识别结果污染新请求,意图混乱 2026-03-06 17:29:55 +08:00
NuoDaJia02
ed861a9fb1 fix run issues 2026-01-30 10:53:07 +08:00
lxy
aaa17c10f2 修复rclpy.spin() 单线程执行器导致异步回调死锁,增加ASR WebSocket 自动重连机制 2026-01-29 17:24:49 +08:00
NuoDaJia02
c65395c50f merge remote 2026-01-28 14:45:42 +08:00
lxy
9c8bd017e1 分出asr和tts节点 2026-01-27 20:53:43 +08:00
NuoDaJia02
856c07715c Update voice configuration and skill bridge logic
- Update voice.yaml to use default audio devices and 48kHz sample rate.
- Update voice.yaml paths for voice model and interfaces.
- Improve skill_bridge_node.py JSON parsing and skill parameter handling.
- Update audio_pipeline.py warning message for device detection.
2026-01-22 17:28:28 +08:00
lxy
e8a9821ce4 配置文件增加没有图像skill_sequence/chat_camera是否推理的button,扩充kb_qa的回复,减少闲聊模式的回复长度 2026-01-21 18:04:26 +08:00
lxy
ab1fb4f3f8 修改声纹验证失败仍然执行,增加接口解析提示词 2026-01-21 15:13:31 +08:00
lxy
dd6ccf77bb 修改技能序列历史管理--不接入历史上下文 2026-01-21 11:22:25 +08:00
lxy
7324630458 修改声纹注册选择第一句话完整片段。去掉注册时多余的阈值信息,修改llm技能序列输出格式 2026-01-20 21:39:15 +08:00
NuoDaJia02
04ca80c3f9 add rebuild service to skill bridge 2026-01-20 15:20:48 +08:00
lxy
98c0eb5ca5 refactor: 删除回声消除相关代码,支持从hivecore_robot_drivers/img_dev获取图片 2026-01-20 09:28:57 +08:00
lxy
71062701e1 Merge branch 'feature-deploy' into develop
# Conflicts:
#	config/voice.yaml
#	robot_speaker/core/robot_speaker_node.py
2026-01-19 15:16:11 +08:00
lxy
6d101b9d9e 添加与行为树的桥接节点 2026-01-19 09:58:40 +08:00
48 changed files with 3652 additions and 3501 deletions

3
.gitignore vendored
View File

@@ -4,3 +4,6 @@ log/
__pycache__/
*.pyc
*.egg-info/
dist/
lib/
installed_files.txt

116
CMakeLists.txt Normal file
View File

@@ -0,0 +1,116 @@
cmake_minimum_required(VERSION 3.8)
project(robot_speaker)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()
find_package(ament_cmake REQUIRED)
find_package(ament_cmake_python REQUIRED)
find_package(interfaces REQUIRED)
# 确保使用系统 Python而不是 conda/miniconda 的 Python
find_program(PYTHON3_CMD python3 PATHS /usr/bin /usr/local/bin NO_DEFAULT_PATH)
if(NOT PYTHON3_CMD)
find_program(PYTHON3_CMD python3)
endif()
if(PYTHON3_CMD)
set(Python3_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python 3 executable" FORCE)
set(PYTHON_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python executable" FORCE)
endif()
install(CODE "
execute_process(
COMMAND ${PYTHON3_CMD} -m pip install --prefix=${CMAKE_INSTALL_PREFIX} --no-deps ${CMAKE_CURRENT_SOURCE_DIR}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE install_result
OUTPUT_VARIABLE install_output
ERROR_VARIABLE install_error
)
if(NOT install_result EQUAL 0)
message(FATAL_ERROR \"Failed to install Python package. Output: ${install_output} Error: ${install_error}\")
endif()
execute_process(
COMMAND ${PYTHON3_CMD} -c \"
import os
import shutil
import glob
import sysconfig
install_prefix = '${CMAKE_INSTALL_PREFIX}'
build_dir = '${CMAKE_CURRENT_BINARY_DIR}'
python_version = f'{sysconfig.get_python_version()}'
# ROS2 期望的 Python 包位置
ros2_site_packages = os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages')
os.makedirs(ros2_site_packages, exist_ok=True)
# pip install --prefix 可能将包安装到不同位置(系统环境通常是 local/lib/pythonX/dist-packages
pip_locations = [
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'dist-packages'),
os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages'),
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'site-packages'),
]
# 查找并复制 robot_speaker 包到 ROS2 期望的位置
robot_speaker_src = None
for location in pip_locations:
candidate = os.path.join(location, 'robot_speaker')
if os.path.exists(candidate) and os.path.isdir(candidate):
robot_speaker_src = candidate
break
if robot_speaker_src:
robot_speaker_dest = os.path.join(ros2_site_packages, 'robot_speaker')
if os.path.exists(robot_speaker_dest):
shutil.rmtree(robot_speaker_dest)
if robot_speaker_src != robot_speaker_dest:
shutil.copytree(robot_speaker_src, robot_speaker_dest)
print(f'Copied robot_speaker from {robot_speaker_src} to {ros2_site_packages}')
else:
print(f'robot_speaker already in correct location')
# 处理 entry_points 脚本
lib_dir = os.path.join(install_prefix, 'lib', 'robot_speaker')
os.makedirs(lib_dir, exist_ok=True)
# 脚本可能在 local/bin 或 bin 中
for bin_dir in [os.path.join(install_prefix, 'local', 'bin'), os.path.join(install_prefix, 'bin')]:
if os.path.exists(bin_dir):
scripts = glob.glob(os.path.join(bin_dir, '*_node'))
for script in scripts:
script_name = os.path.basename(script)
dest = os.path.join(lib_dir, script_name)
if script != dest:
shutil.copy2(script, dest)
os.chmod(dest, 0o755)
print(f'Copied {script_name} to {lib_dir}')
\"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE python_result
OUTPUT_VARIABLE python_output
)
if(python_result EQUAL 0)
message(STATUS \"${python_output}\")
else()
message(WARNING \"Failed to setup Python package: ${python_output}\")
endif()
")
install(DIRECTORY launch/
DESTINATION share/${PROJECT_NAME}/launch
FILES_MATCHING PATTERN "*.launch.py"
)
install(DIRECTORY config/
DESTINATION share/${PROJECT_NAME}/config
FILES_MATCHING PATTERN "*.yaml" PATTERN "*.json"
)
if(BUILD_TESTING)
find_package(ament_lint_auto REQUIRED)
ament_lint_auto_find_test_dependencies()
endif()
ament_package()

View File

@@ -22,19 +22,21 @@ pip3 install -r requirements.txt --break-system-packages
## 编译启动
1. 注册声纹
- 启动节点后可以说:二狗今天天气真好开始注册声纹
- 正确的注册姿势:
方法A推荐唤醒后停顿一下然后说一段长句子。
用户:"二狗"
机器:(日志提示等待声纹语音)
用户:"我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去。"(持续说 3-5 秒)
方法B连贯说一口气说很长的一句话。
用户:"二狗你好,我是你的主人,请记住我的声音,这是一段用来注册的长语音。"
- 注意要包含唤醒词语句不要停顿尽量大于1.5秒
- 启动节点后可以说:er gou我现在正在注册声纹这是一段很长的测试语音请把我的声音录进去。
- 正确的注册姿势:包含唤醒词二狗不要停顿的尽量说完3秒
- 现在的逻辑只要识别到二狗就注册,然后退出节点,识别不到二狗继续等待
- 多注册几段,换方向距离注册,可以提高识别相似度,注册方向对声纹相似性影响很大
```bash
cd ~/ros_learn/hivecore_robot_voice
colcon build
source install/setup.bash
```
```bash
# 终端1: 启动ASR节点
ros2 run robot_speaker asr_audio_node
# 终端2: 注册声纹
ros2 run robot_speaker register_speaker_node
```
@@ -49,38 +51,34 @@ source install/setup.bash
ros2 launch robot_speaker voice.launch.py
```
## 架构说明
[录音线程] - 唯一实时线程
├─ 麦克风采集 PCM
├─ VAD + 能量检测
├─ 检测到人声 → 立即中断TTS
├─ 语音 PCM → ASR 音频队列
└─ 语音 PCM → 声纹音频队列(旁路,不阻塞)
3. ASR节点
```bash
ros2 run robot_speaker asr_audio_node
```
[ASR推理线程] - 只做 audio → text
└─ 从 ASR 音频队列取音频→ 实时 / 流式 ASR → text → 文本队列
4. TTS节点
```bash
# 终端1: 启动TTS节点
ros2 run robot_speaker tts_audio_node
[声纹识别线程] - 非实时、低频CAM++
├─ 通过回调函数接收音频chunk写入缓冲区等待 speech_end 事件触发处理
├─ 累积 1~2 秒有效人声VAD 后)
├─ CAM++ 提取 speaker embedding
├─ 声纹匹配 / 注册
└─ 更新 current_speaker_id共享状态只写不控
声纹线程要求不影响录音不影响ASR不控制TTS只更新当前说话人是谁
# 终端2: 启动播放
source install/setup.bash
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
"{command: 'synthesize', text: '这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。让我继续说下去这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。让我继续说下去这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。', voice: ''}"
[主线程/处理线程] - 处理业务逻辑
├─ 从 文本队列 取 ASR 文本
├─ 读取 current_speaker_id只读
├─ 唤醒词处理(结合 speaker_id
├─ 权限 / 身份判断(是否允许继续)
├─ VLM处理文本 / 多模态)
└─ TTS播放启动TTS线程不等待
[TTS播放线程] - 只播放(可被中断)
├─ 接收 TTS 音频流
├─ 播放到输出设备
└─ 响应中断标志(由录音线程触发)
# 终端3: 立即执行中断
source install/setup.bash
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
"{command: 'interrupt', text: '', voice: ''}"
```
5. 完整运行
```bash
# 终端1启动 brain 节点
# 终端2启动 voice 节点
# 终端3启动 bridge 节点
# 终端4订阅相机
```
## 用到的命令
1. 音频设备

View File

@@ -1,11 +1,18 @@
{
"entries": [
{
"id": "robot_identity",
"id": "robot_identity_1",
"patterns": [
"ni shi shei"
"ni shi shui"
],
"answer": "我叫二狗,是蜂核科技的机器人,很高兴为你服务"
},
{
"id": "robot_identity_2",
"patterns": [
"ni jiao sha"
],
"answer": "我叫二狗呀,我是你的好帮手"
},
{
"id": "wake_word",
@@ -13,6 +20,27 @@
"ni de ming zi"
],
"answer": "我的名字是二狗"
},
{
"id": "skill_1",
"patterns": [
"tiao ge wu"
],
"answer": "这个我真不会,我怕跳起来吓到你"
},
{
"id": "skill_2",
"patterns": [
"ni neng gan"
],
"answer": "我可以陪你聊天,也能帮你干活"
},
{
"id": "skill_3",
"patterns": [
"ni hui gan"
],
"answer": "我可以陪你聊天,你也可以发布具体的指令让我干活"
}
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -18,26 +18,25 @@ dashscope:
audio:
microphone:
device_index: 3 # 指向 iFLYTEK-M2 (hw:1,0)
sample_rate: 48000 # 尝试使用硬件原生采样率 48kHz避免重采样可能导致的问题
device_index: -1 # 使用系统默认输入设备
sample_rate: 48000 # 尝试使用硬件原生采样率 48kHz避免重采样可能导致的问题
channels: 1 # 输入声道数单声道MONO适合语音采集
chunk: 1024
heartbeat_interval: 2.0 # 心跳间隔(秒),用于定期输出录音状态
soundcard:
card_index: 1 # USB Audio Device (card 1)
device_index: 0 # USB Audio [USB Audio] (device 0)
# card_index: -1 # 使用默认声卡
# device_index: -1 # 使用默认输出设备
sample_rate: 48000 # 输出采样率48kHziFLYTEK 支持 48000
card_index: -1 # 使用默认声卡
device_index: -1 # 使用默认输出设备
sample_rate: 48000 # 输出采样率:默认 44100
channels: 2 # 输出声道数立体声2声道FL+FR
volume: 1.0 # 音量比例0.0-1.00.2表示20%音量)
echo_cancellation:
enable: false # 是否启用回声消除
max_duration_ms: 500 # 参考信号缓冲区最大时长(毫秒)
tts:
source_sample_rate: 22050 # TTS服务固定输出采样率DashScope服务固定值不可修改
source_channels: 1 # TTS服务固定输出声道数DashScope服务固定值不可修改
ffmpeg_thread_queue_size: 4096 # ffmpeg输入线程队列大小增大以减少卡顿
force_stop_delay: 0.1 # 强制停止时的延迟(秒)
cleanup_timeout: 30.0 # 清理超时(秒)
terminate_timeout: 1.0 # 终止超时(秒)
interrupt_wait: 0.1 # 中断等待时间(秒)
vad:
vad_mode: 3 # VAD模式0-33最严格
@@ -45,26 +44,24 @@ vad:
min_energy_threshold: 300 # 最小能量阈值
system:
use_llm: true # 是否使用LLM
use_wake_word: true # 是否启用唤醒词检测
wake_word: "er gou" # 唤醒词(拼音)
session_timeout: 3.0 # 会话超时时间(秒)
shutup_keywords: "bi zui" # 闭嘴指令关键词(拼音,逗号分隔)
interrupt_command_queue_depth: 10 # 中断命令订阅的队列深度QoS
sv_enabled: true # 是否启用声纹识别
sv_model_path: "~/hivecore_robot_os1/voice_model" # 声纹模型路径
sv_threshold: 0.55 # 声纹识别阈值0.0-1.0,值越小越宽松,值越大越严格)
sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录
sv_buffer_size: 240000 # 声纹验证录音缓冲区大小样本数48kHz下5秒=240000
sv_registration_silence_threshold_ms: 500 # 声纹注册状态下的静音阈值(毫秒
sv_enabled: false # 是否启用声纹识别
# sv_model_path: "~/hivecore_robot_os1/voice_model" # 声纹模型路径
sv_model_path: "~/ros_learn/speech_campplus_sv_zh-cn_16k-common" # 声纹模型路径
sv_threshold: 0.65 # 声纹识别阈值0.0-1.0,值越小越宽松,值越大越严格
# sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录
sv_speaker_db_path: "~/ros_learn/hivecore_robot_voice/config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录
sv_buffer_size: 96000 # 声纹验证录音缓冲区大小样本数48kHz下2秒=96000
continue_without_image: true # 多模态意图skill_sequence/chat_camera未获取到图片时是否继续推理
camera:
serial_number: "405622075404" # 相机序列号Intel RealSense D435
rgb:
width: 640 # 图像宽度
height: 480 # 图像高度
fps: 30 # 帧率支持6, 10, 15, 30, 60
format: "RGB8" # 图像格式RGB8, BGR8
image:
jpeg_quality: 85 # JPEG压缩质量0-10085是质量和大小平衡点
max_size: "1280x720" # 最大尺寸
interfaces:
# root_path: "~/hivecore_robot_os1/hivecore_robot_interfaces/src" # 接口文件根目录,支持 ~ 展开和相对路径
root_path: "~/ros_learn/hivecore_robot_interfaces/src" # 接口文件根目录,支持 ~ 展开和相对路径

View File

@@ -0,0 +1,54 @@
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.actions import SetEnvironmentVariable, RegisterEventHandler
from launch.event_handlers import OnProcessExit
from launch.actions import EmitEvent
from launch.events import Shutdown
import os
def generate_launch_description():
"""启动声纹注册节点需要ASR服务"""
# 获取interfaces包的install路径
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
# 设置AMENT_PREFIX_PATH确保能找到interfaces包的消息类型
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
if interfaces_install_path not in ament_prefix_path:
if ament_prefix_path:
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
else:
ament_prefix_path = interfaces_install_path
# ASR + 音频输入设备节点提供ASR和AudioData服务
asr_audio_node = Node(
package='robot_speaker',
executable='asr_audio_node',
name='asr_audio_node',
output='screen'
)
# 声纹注册节点
register_speaker_node = Node(
package='robot_speaker',
executable='register_speaker_node',
name='register_speaker_node',
output='screen'
)
# 当注册节点退出时,关闭整个 launch
register_exit_handler = RegisterEventHandler(
OnProcessExit(
target_action=register_speaker_node,
on_exit=[
EmitEvent(event=Shutdown(reason='注册完成,关闭所有节点'))
]
)
)
return LaunchDescription([
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
asr_audio_node,
register_speaker_node,
register_exit_handler,
])

View File

@@ -1,10 +1,39 @@
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.actions import SetEnvironmentVariable
import os
def generate_launch_description():
"""启动语音交互节点,所有参数从 voice.yaml 读取"""
# 获取interfaces包的install路径
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
# 设置AMENT_PREFIX_PATH确保能找到interfaces包的消息类型
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
if interfaces_install_path not in ament_prefix_path:
if ament_prefix_path:
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
else:
ament_prefix_path = interfaces_install_path
return LaunchDescription([
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
# ASR + 音频输入设备节点同时提供VAD事件Service利用云端ASR的VAD
Node(
package='robot_speaker',
executable='asr_audio_node',
name='asr_audio_node',
output='screen'
),
# TTS + 音频输出设备节点
Node(
package='robot_speaker',
executable='tts_audio_node',
name='tts_audio_node',
output='screen'
),
# 主业务逻辑节点
Node(
package='robot_speaker',
executable='robot_speaker_node',

View File

@@ -9,6 +9,12 @@
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>ament_index_python</depend>
<depend>interfaces</depend>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_python</buildtool_depend>
<exec_depend>python3-pyaudio</exec_depend>
<exec_depend>python3-requests</exec_depend>
@@ -23,6 +29,6 @@
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@@ -1,17 +1,9 @@
dashscope>=1.20.0
openai>=1.0.0
pyaudio>=0.2.11
webrtcvad>=2.0.10
pypinyin>=0.49.0
rclpy>=3.0.0
pyrealsense2>=2.54.0
Pillow>=10.0.0
numpy>=1.24.0
numpy>=1.24.0,<2.0.0 # cv_bridge需要NumPy 1.xNumPy 2.x会导致段错误
PyYAML>=6.0
aec-audio-processing
modelscope>=1.33.0
funasr>=1.0.0
datasets==3.6.0

View File

@@ -0,0 +1,24 @@
# Bridge package for connecting LLM outputs to brain execution.

View File

@@ -0,0 +1,239 @@
#!/usr/bin/env python3
"""
桥接LLM技能序列到小脑ExecuteBtAction并转发反馈/结果。
"""
import json
import os
import re
import rclpy
from rclpy.node import Node
from rclpy.action import ActionClient
from std_msgs.msg import String
from ament_index_python.packages import get_package_share_directory
from interfaces.action import ExecuteBtAction
from interfaces.srv import BtRebuild
class SkillBridgeNode(Node):
def __init__(self):
super().__init__('skill_bridge_node')
self._action_client = ActionClient(self, ExecuteBtAction, '/execute_bt_action')
self._current_epoch = 1
self.run_trigger_ = self.create_client(BtRebuild, '/cerebrum/rebuild_now')
self.rebuild_requests = 0
self._allowed_skills = self._load_allowed_skills()
self.skill_seq_sub = self.create_subscription(
String, '/llm_skill_sequence', self._on_skill_sequence_received, 10
)
self.feedback_pub = self.create_publisher(String, '/skill_execution_feedback', 10)
self.result_pub = self.create_publisher(String, '/skill_execution_result', 10)
self.get_logger().info('SkillBridgeNode started')
def _on_skill_sequence_received(self, msg: String):
raw = (msg.data or "").strip()
if not raw:
return
if not self._allowed_skills:
self.get_logger().warning("No skill whitelist loaded; reject all sequences")
return
# 尝试解析JSON格式
sequence_list = None
try:
data = json.loads(raw)
sequence_list = self._parse_json_sequence(data)
if sequence_list is None:
self.get_logger().error("Invalid skill sequence format; must be JSON or plain text")
return
except (json.JSONDecodeError, ValueError) as e:
self.get_logger().debug(f"JSON解析失败尝试文本解析: {e}")
# JSON格式处理
try:
skill_names = [item["skill"] for item in sequence_list]
if any(skill in skill_names for skill in ["VisionObjectRecognition", "Arm", "GripperCmd0"]):
self.get_logger().info(f"Skill sequence contains special skills, triggering rebuild: {skill_names}")
self.rebuild_now("Trigger", "bt_vision_grasp_dual_arm", "")
else:
skill_params = []
for item in sequence_list:
p = item.get("parameters")
params = ""
if isinstance(p, dict):
lines = []
for k, v in p.items():
lines.append(f"{k}: {v}")
if lines:
params = "\n".join(lines) + "\n"
skill_params.append(params)
self.get_logger().info(f"Sending skill sequence: {skill_names}")
self.get_logger().info(f"Sending skill parameters: {skill_params}")
# 将技能名和参数列表分别用单引号包括,并用逗号隔开
# names_str = ", ".join([f"'{name}'" for name in skill_names])
# params_str = ", ".join([f"'{param}'" for param in skill_params])
names_str = ", ".join(skill_names)
params_str = ", ".join(skill_params)
self.rebuild_now("Remote", names_str, params_str)
except Exception as e:
self.get_logger().error(f"Error processing skill sequence: {e}")
def _load_allowed_skills(self) -> set[str]:
try:
brain_share = get_package_share_directory("brain")
skill_path = os.path.join(brain_share, "config", "robot_skills.yaml")
if not os.path.exists(skill_path):
return set()
import yaml
with open(skill_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or []
return {str(entry["name"]) for entry in data if isinstance(entry, dict) and entry.get("name")}
except Exception as e:
self.get_logger().warning(f"Load skills failed: {e}")
return set()
def _extract_skill_sequence(self, text: str) -> tuple[str, list[str]]:
# Accept CSV/space/semicolon and filter by CamelCase tokens
tokens = re.split(r'[,\s;]+', text.strip())
skills = [t for t in tokens if re.match(r'^[A-Z][A-Za-z0-9]*$', t)]
if not skills:
return "", []
invalid = [s for s in skills if s not in self._allowed_skills]
return ",".join(skills), invalid
def _parse_json_sequence(self, data: dict) -> list[dict] | None:
"""解析JSON格式的技能序列"""
if not isinstance(data, dict) or "sequence" not in data:
return None
sequence = data["sequence"]
if not isinstance(sequence, list):
return None
validated = []
for item in sequence:
if not isinstance(item, dict):
continue
skill = item.get("skill")
if not skill or skill not in self._allowed_skills:
continue
execution = item.get("execution", "serial")
if execution not in ["serial", "parallel"]:
execution = "serial"
body_id = item.get("body_id")
# 只支持数字格式(0,1,2)和null与意图路由对齐
if body_id not in [0, 1, 2, None]:
body_id = None
validated.append({
"skill": skill,
"execution": execution,
"body_id": body_id,
"parameters": item.get("parameters")
})
return validated if validated else None
def _send_skill_sequence(self, skill_sequence: str):
if not self._action_client.wait_for_server(timeout_sec=2.0):
self.get_logger().error('ExecuteBtAction server unavailable')
return
goal = ExecuteBtAction.Goal()
goal.epoch = self._current_epoch
self._current_epoch += 1
goal.action_name = skill_sequence
goal.calls = []
self.get_logger().info(f"Dispatch skill sequence: {skill_sequence}")
send_future = self._action_client.send_goal_async(goal, feedback_callback=self._feedback_callback)
rclpy.spin_until_future_complete(self, send_future, timeout_sec=5.0)
if not send_future.done():
self.get_logger().warning("Send goal timed out")
return
goal_handle = send_future.result()
if not goal_handle or not goal_handle.accepted:
self.get_logger().error("Goal rejected")
return
result_future = goal_handle.get_result_async()
rclpy.spin_until_future_complete(self, result_future)
if result_future.done():
self._handle_result(result_future.result())
def _feedback_callback(self, feedback_msg):
fb = feedback_msg.feedback
payload = {
"stage": fb.stage,
"current_skill": fb.current_skill,
"progress": float(fb.progress),
"detail": fb.detail,
"epoch": int(fb.epoch),
}
msg = String()
msg.data = json.dumps(payload, ensure_ascii=True)
self.feedback_pub.publish(msg)
def _handle_result(self, result_wrapper):
result = result_wrapper.result
if not result:
return
payload = {
"success": bool(result.success),
"message": result.message,
"total_skills": int(result.total_skills),
"succeeded_skills": int(result.succeeded_skills),
}
msg = String()
msg.data = json.dumps(payload, ensure_ascii=True)
self.result_pub.publish(msg)
def rebuild_now(self, type: str, config: str, param: str) -> None:
if not self.run_trigger_.service_is_ready():
self.get_logger().error('Rebuild service not ready')
return
self.rebuild_requests += 1
self.get_logger().info(f'Rebuild BehaviorTree now. Total requests: {self.rebuild_requests}')
request = BtRebuild.Request()
request.type = type
request.config = config
request.param = param
self.get_logger().info(f'Calling rebuild service... request info: {request}')
future = self.run_trigger_.call_async(request)
future.add_done_callback(self._rebuild_done_callback)
def _rebuild_done_callback(self, future):
try:
response = future.result()
if response.success:
self.get_logger().info('Rebuild request successful')
else:
self.get_logger().warning(f'Rebuild request failed: {response.message}')
except Exception as e:
self.get_logger().error(f'Rebuild request exception: {str(e)}')
self.get_logger().info(f"Rebuild requested. Total rebuild requests: {str(self.rebuild_requests)}")
def main(args=None):
rclpy.init(args=args)
node = SkillBridgeNode()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@@ -2,3 +2,27 @@

View File

@@ -1,10 +1,17 @@
"""
对话历史管理模块
"""
from robot_speaker.core.types import LLMMessage
from dataclasses import dataclass
import threading
@dataclass
class LLMMessage:
"""LLM消息"""
role: str # "user", "assistant", "system"
content: str
class ConversationHistory:
"""对话历史管理器 - 实时语音"""
@@ -109,3 +116,15 @@ class ConversationHistory:
self.summary = None
self._pending_user_message = None

View File

@@ -1,10 +0,0 @@
from enum import Enum
class ConversationState(Enum):
"""会话状态机"""
IDLE = "idle" # 等待用户唤醒或声音
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
AUTHORIZED = "authorized" # 已注册用户

View File

@@ -1,7 +1,12 @@
from dataclasses import dataclass
from typing import Optional
import os
import yaml
import json
from ament_index_python.packages import get_package_share_directory
from pypinyin import pinyin, Style
from robot_speaker.core.skill_interface_parser import SkillInterfaceParser
@dataclass
@@ -9,7 +14,7 @@ class IntentResult:
intent: str # "skill_sequence" | "kb_qa" | "chat_text" | "chat_camera"
text: str
need_camera: bool
camera_mode: Optional[str] # "head" | "left_hand" | "right_hand" | None
camera_mode: Optional[str] # "top" | "left" | "right" | "hand_r" | None
system_prompt: Optional[str]
@@ -18,12 +23,69 @@ class IntentRouter:
self.camera_capture_keywords = [
"pai zhao", "pai ge zhao", "pai zhang zhao"
]
self.skill_keywords = [
"ban xiang zi"
# 动作词列表(拼音)- 用于检测技能序列意图
self.action_verbs = [
"zou", "zou liang bu", "zou ji bu", # 走、走两步、走几步
"na", "na qi", "na zhu", # 拿、拿起、拿住
"ban", "ban yun", # 搬、搬运
"zhua", "zhua qu", # 抓、抓取
"tui", "tui dong", # 推、推动
"la", "la dong", # 拉、拉动
"yi dong", "qian jin", "hou tui", # 移动、前进、后退
"kong zhi", "cao zuo", # 控制、操作
"fang xia", "fang zhi", # 放下、放置
"ju qi", "sheng qi", # 举起、升起
"jia zhua", "jia qi", "jia", # 夹爪、夹起、夹
"shen you bi", "shen zuo bi", "shen chu", "shen shou", # 伸右臂、伸左臂、伸出、伸手
"zhuan quan", "zhuan yi quan", "zhuan", # 转个圈、转一圈、转
]
self.kb_keywords = [
"ni shi shei", "ni de ming zi"
"ni shi shui", "ni de ming zi", "tiao ge wu", "ni jiao sha", "ni hui gan", "ni neng gan"
]
self._cached_skill_names: list[str] | None = None
self._cached_kb_data: list[dict] | None = None
interfaces_root = self._get_interfaces_root()
self.interface_parser = SkillInterfaceParser(interfaces_root)
def _get_interfaces_root(self) -> str:
"""从配置文件读取接口文件根目录"""
try:
robot_speaker_share = get_package_share_directory("robot_speaker")
config_path = os.path.join(robot_speaker_share, "config", "voice.yaml")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
interfaces_config = config.get("interfaces", {})
root_path = interfaces_config.get("root_path", "")
if not root_path:
raise ValueError("interfaces.root_path 未在配置文件中配置")
if root_path.startswith("~"):
root_path = os.path.expanduser(root_path)
if not os.path.isabs(root_path):
config_dir = os.path.dirname(robot_speaker_share)
root_path = os.path.join(config_dir, root_path)
abs_path = os.path.abspath(root_path)
if not os.path.exists(abs_path):
raise ValueError(f"接口文件根目录不存在: {abs_path}")
return abs_path
except Exception as e:
raise ValueError(f"读取接口文件根目录失败: {e}")
def _load_brain_skill_names(self) -> list[str]:
"""加载技能名称(使用接口解析器,避免重复读取)"""
if self._cached_skill_names is not None:
return self._cached_skill_names
skill_names = self.interface_parser.get_skill_names()
self._cached_skill_names = skill_names
return skill_names
def to_pinyin(self, text: str) -> str:
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
@@ -32,89 +94,172 @@ class IntentRouter:
py_list = pinyin(''.join(chars), style=Style.NORMAL)
return ' '.join([item[0] for item in py_list]).lower().strip()
def is_skill_sequence_intent(self, text: str) -> bool:
text_pinyin = self.to_pinyin(text)
return any(k in text_pinyin for k in self.skill_keywords)
def is_skill_sequence_intent(self, text: str, text_pinyin: str | None = None) -> bool:
if text_pinyin is None:
text_pinyin = self.to_pinyin(text)
# 检查动作词(精确匹配:动作词必须是完整的单词序列)
text_words = text_pinyin.split()
for action in self.action_verbs:
action_words = action.split()
# 检查动作词的单词序列是否是文本单词序列的连续子序列
for i in range(len(text_words) - len(action_words) + 1):
if text_words[i:i+len(action_words)] == action_words:
return True
return False
def check_camera_command(self, text: str) -> tuple[bool, Optional[str]]:
def check_camera_command(self, text: str, text_pinyin: str | None = None) -> tuple[bool, Optional[str]]:
"""检查是否包含拍照指令,返回(是否需要相机, 相机模式)"""
if not text:
return False, None
text_pinyin = self.to_pinyin(text)
for keyword in self.camera_capture_keywords:
if keyword in text_pinyin:
return True, self.detect_camera_mode(text)
if text_pinyin is None:
text_pinyin = self.to_pinyin(text)
# 精确匹配:关键词必须作为完整短语出现在文本拼音中
if any(keyword in text_pinyin for keyword in self.camera_capture_keywords):
return True, self.detect_camera_mode(text, text_pinyin)
return False, None
def detect_camera_mode(self, text: str) -> str:
text_pinyin = self.to_pinyin(text)
left_keys = ["zuo shou", "zuo bi", "zuo bian"]
right_keys = ["you shou", "you bi", "you bian"]
head_keys = ["tou", "nao dai"]
for kw in left_keys:
if kw in text_pinyin:
return "left_hand"
for kw in right_keys:
if kw in text_pinyin:
return "right_hand"
for kw in head_keys:
if kw in text_pinyin:
return "head"
return "head"
def detect_camera_mode(self, text: str, text_pinyin: str | None = None) -> str:
"""检测相机模式返回与相机驱动匹配的position值left/right/top/hand_r"""
if text_pinyin is None:
text_pinyin = self.to_pinyin(text)
if any(kw in text_pinyin for kw in ["zuo shou", "zuo bi", "zuo bian", "zuo shou bi"]):
return "left"
if any(kw in text_pinyin for kw in ["you shou", "you bi", "you bian", "you shou bi"]):
return "right"
if any(kw in text_pinyin for kw in ["shou bu", "shou", "shou xiang ji", "shou bi xiang ji"]):
return "hand_r"
if any(kw in text_pinyin for kw in ["tou", "nao dai", "ding bu", "shang fang"]):
return "top"
return "top"
def build_skill_prompt(self) -> str:
def build_skill_prompt(self, execution_status: Optional[str] = None) -> str:
skills = self._load_brain_skill_names()
skills_text = ", ".join(skills) if skills else ""
skill_guard = (
"【技能限制】只能使用以下技能名称:" + skills_text
if skills_text
else "【技能限制】技能列表不可用,请不要输出任何技能名称。"
)
execution_hint = ""
if execution_status:
execution_hint = f"【上一轮执行状态】{execution_status}\n请参考上述执行状态,根据成功/失败信息调整本次技能序列。\n"
else:
execution_hint = "【注意】这是首次执行或没有上一轮执行状态,请根据当前图片和用户请求规划技能序列。\n"
skill_params_doc = self.interface_parser.generate_params_documentation()
return (
"你是机器人任务规划器。\n"
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照(默认头部相机),并结合当前环境信息生成简洁、可执行的技能序列。"
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照,并结合当前环境信息生成简洁、可执行的技能序列。\n"
"如果用户明确要求或者任务明显需要双手/双臂协作(如扶稳+操作、抓取大体积的物体),必须规划双手技能。\n"
+ execution_hint
+ "\n"
"【规划要求】\n"
"1. execution规划判断技能之间的执行关系\n"
" - serial串行技能必须按顺序执行前一个完成后再执行下一个\n"
" - parallel并行技能可以同时执行\n"
"2. parameters规划根据目标物距离和任务需求规划具体参数值\n"
" - parameters字典必须包含该技能接口文件目标字段的所有字段\n"
"【输出格式要求】\n"
"必须输出JSON格式包含sequence数组。每个技能对象包含3个一级字段\n"
"1. skill: 技能名称(字符串)\n"
"2. execution: 执行方式serial串行或 parallel并行\n"
"3. parameters: 参数字典包含该技能接口文件目标字段的所有字段并填入合理的预测值。如果技能无参数使用null。\n"
"\n"
"注意一级字段skill, execution, parameters是固定结构。\n"
"\n"
"【技能参数说明】\n"
+ skill_params_doc +
"\n"
"示例格式:\n"
"{\n"
' "sequence": [\n'
' {"skill": "MoveWheel", "execution": "serial", "parameters": {"move_distance": 1.5, "move_angle": 0.0}},\n'
' {"skill": "Arm", "execution": "serial", "parameters": {"body_id": 0, "data_type": 1, "data_length": 6, "command_id": 0, "frame_time_stamp": 0, "data_array": [0.1, 0.2, 0.3, 0.0, 0.0, 0.0]}},\n'
' {"skill": "GripperCmd0", "execution": "parallel", "parameters": {"loc": 128, "speed": 100, "torque": 80, "mode": 1}}\n'
" ]\n"
"}\n"
+ skill_guard
)
def build_chat_prompt(self, need_camera: bool) -> str:
if need_camera:
return (
"你是一个智能语音助手\n"
"请结合图片内容简短回答。"
"你是一个机器人视觉助理,擅长分析图片中物体的相对位置和空间关系\n"
"请结合图片内容,重点描述物体之间的相对位置(如左右、前后、上下、远近),仅基于可观察信息回答。\n"
"回答应简短、客观不要超过100个token。"
)
return (
"你是一个智能语音助手\n"
"自然、简短地与用户对话。"
)
def build_kb_prompt(self) -> str:
return (
"你是蜂核科技的员工。\n"
"请基于知识库信息回答用户问题,回答要准确简洁。"
"你是一个表达清晰、语气自然的真人助理\n"
"请简短地与用户对话不要超过100个token"
)
def _load_kb_data(self) -> list[dict]:
"""加载知识库数据"""
if self._cached_kb_data is not None:
return self._cached_kb_data
kb_data = []
try:
robot_speaker_share = get_package_share_directory("robot_speaker")
kb_path = os.path.join(robot_speaker_share, "config", "knowledge.json")
with open(kb_path, "r", encoding="utf-8") as f:
data = json.load(f)
kb_data = data["entries"]
except Exception as e:
kb_data = []
self._cached_kb_data = kb_data
return kb_data
def search_kb(self, text: str) -> Optional[str]:
"""检索知识库,返回匹配的答案"""
if not text:
return None
text_pinyin = self.to_pinyin(text)
kb_data = self._load_kb_data()
for entry in kb_data:
patterns = entry["patterns"]
for pattern in patterns:
if pattern in text_pinyin:
answer = entry["answer"]
if answer:
return answer
return None
def build_default_system_prompt(self) -> str:
return (
"你是一个智能语音助手。\n"
"你是一个工厂专业的助手。\n"
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
"请根据对话模式调整你的回答风格。"
)
def route(self, text: str) -> IntentResult:
need_camera, camera_mode = self.check_camera_command(text)
text_pinyin = self.to_pinyin(text)
need_camera, camera_mode = self.check_camera_command(text, text_pinyin)
if self.is_skill_sequence_intent(text):
if camera_mode is None:
camera_mode = "head"
if self.is_skill_sequence_intent(text, text_pinyin):
# 技能序列意图总是需要相机,复用 detect_camera_mode用户指定了相机就用指定的否则默认 "top"
skill_camera_mode = self.detect_camera_mode(text, text_pinyin)
return IntentResult(
intent="skill_sequence",
text=text,
need_camera=True,
camera_mode=camera_mode,
camera_mode=skill_camera_mode,
system_prompt=self.build_skill_prompt()
)
if any(k in text_pinyin for k in self.kb_keywords):
# 精确匹配:关键词必须作为完整短语出现在文本拼音中
if any(keyword in text_pinyin for keyword in self.kb_keywords):
return IntentResult(
intent="kb_qa",
text=text,
need_camera=False,
camera_mode=None,
system_prompt=self.build_kb_prompt()
system_prompt=None # kb_qa不走LLM不需要system_prompt
)
return IntentResult(

View File

@@ -1,246 +0,0 @@
import threading
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeCallbacks:
# ==================== 初始化与内部工具 ====================
def __init__(self, node):
self.node = node
def _mark_utterance_processed(self) -> bool:
node = self.node
with node.utterance_lock:
if node.current_utterance_id == node.last_processed_utterance_id:
return False
node.last_processed_utterance_id = node.current_utterance_id
return True
def _trigger_sv_for_check_voice(self, source: str):
node = self.node
if not (node.sv_enabled and node.sv_client):
return
if not self._mark_utterance_processed():
return
if node._handle_empty_speaker_db():
node.get_logger().info(f"[声纹] CHECK_VOICE状态数据库为空跳过声纹验证来源: {source}")
return
if not node.sv_speech_end_event.is_set():
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if buffer_size > 0:
node.sv_speech_end_event.set()
else:
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source}")
# ==================== 业务逻辑代理 ====================
def handle_interrupt_command(self, msg):
return self.node._handle_interrupt_command(msg)
def check_interrupt_and_cancel_turn(self) -> bool:
return self.node._check_interrupt_and_cancel_turn()
def handle_wake_word(self, text: str) -> str:
return self.node._handle_wake_word(text)
def check_shutup_command(self, text: str) -> bool:
return self.node._check_shutup_command(text)
def check_camera_command(self, text: str):
return self.node.intent_router.check_camera_command(text)
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
return self.node._llm_process_stream_with_camera(user_text, need_camera)
def put_tts_text(self, text: str):
return self.node._put_tts_text(text)
def force_stop_tts(self):
return self.node._force_stop_tts()
def drain_queue(self, q):
return self.node._drain_queue(q)
# ==================== 录音/VAD回调 ====================
def get_silence_threshold(self) -> int:
"""获取动态静音阈值(毫秒)"""
node = self.node
return node.silence_duration_ms
def should_put_audio_to_queue(self) -> bool:
"""
检查是否应该将音频放入队列用于ASR,根据状态机决定是否允许ASR
"""
node = self.node
state = node._get_state()
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
ConversationState.AUTHORIZED]:
return True
return False
def on_speech_start(self):
"""录音线程检测到人声开始"""
node = self.node
node.get_logger().info("[录音线程] 检测到人声,开始录音")
with node.utterance_lock:
node.current_utterance_id += 1
state = node._get_state()
if state == ConversationState.IDLE:
# Idle -> CheckVoice
if node.sv_enabled and node.sv_client:
# 开始录音用于声纹验证
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
elif state == ConversationState.CHECK_VOICE:
# CheckVoice状态继续录音用于声纹验证
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
elif state == ConversationState.AUTHORIZED:
# Authorized状态开始录音用于声纹验证验证当前用户
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
node = self.node
state = node._get_state()
# 声纹验证录音CHECK_VOICE, AUTHORIZED状态
if node.sv_enabled and node.sv_recording:
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with node.sv_lock:
node.sv_audio_buffer.extend(audio_array)
except Exception as e:
node.get_logger().debug(f"[声纹] 录音失败: {e}")
def on_speech_end(self):
"""录音线程检测到说话结束(静音一段时间)"""
node = self.node
node.get_logger().info("[录音线程] 检测到说话结束")
state = node._get_state()
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
if state == ConversationState.CHECK_VOICE:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
self._trigger_sv_for_check_voice("VAD")
return
elif state == ConversationState.AUTHORIZED:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
node.sv_speech_end_event.set()
# 如果TTS正在播放异步等待声纹验证结果如果通过才中断TTS
# 使用独立线程避免阻塞录音线程影响TTS播放
if node.tts_playing_event.is_set():
node.get_logger().info("[打断] TTS播放中用户说话结束异步等待声纹验证结果...")
def _check_sv_and_interrupt():
# 等待声纹验证结果最多等待2秒
with node.sv_result_cv:
current_seq = node.sv_result_seq
if node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=2.0
):
# 声纹验证完成,检查结果
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id})中断TTS播放")
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
else:
node.get_logger().debug(f"[打断] 声纹验证未通过不中断TTS状态: {speaker_state.value}")
else:
node.get_logger().warning("[打断] 声纹验证超时不中断TTS")
# 在独立线程中等待,避免阻塞录音线程
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
return
def on_new_segment(self):
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
node = self.node
state = node._get_state()
if state == ConversationState.AUTHORIZED:
# TTS播放期间检测到人声时不立即中断而是开始录音用于声纹验证
# 等待用户说话结束speech_end如果声纹验证通过才中断TTS
# 这样可以避免TTS回声误触发但支持真正的用户打断
if node.tts_playing_event.is_set():
node.get_logger().debug("[打断] TTS播放中检测到人声开始录音用于声纹验证等待说话结束后验证")
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
else:
# TTS未播放时检查声纹验证结果并立即中断
if node.sv_enabled and node.sv_client:
with node.sv_lock:
current_speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
node._interrupt_tts("检测到人声(已授权用户)")
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声但声纹未验证或未匹配不中断TTS当前状态: {speaker_state.value}")
else:
# 未启用声纹,直接中断(保持原有行为)
node._interrupt_tts("检测到人声(未启用声纹)")
node.get_logger().info("[打断] 检测到人声中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value}非已授权用户不允许打断TTS")
def on_heartbeat(self):
"""录音线程静音心跳回调"""
self.node.get_logger().info("[录音线程] 静音中")
# ==================== ASR回调 ====================
def on_asr_sentence_end(self, text: str):
"""ASR sentence_end回调 - 将文本放入队列"""
node = self.node
if not text or not text.strip():
return
text_clean = text.strip()
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
state = node._get_state()
# 规则2CHECK_VOICE状态下如果ASR识别完成但VAD还没有触发speech_end主动触发声纹验证
if state == ConversationState.CHECK_VOICE:
if node.sv_enabled and node.sv_client:
node.get_logger().info("[ASR] CHECK_VOICE状态ASR识别完成主动触发声纹验证")
self._trigger_sv_for_check_voice("ASR")
# 其他状态,将文本放入队列
node.text_queue.put(text_clean, timeout=1.0)
def on_asr_text_update(self, text: str):
"""ASR 实时文本更新回调 - 用于多轮提示"""
if not text or not text.strip():
return
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")

View File

@@ -1,188 +0,0 @@
import queue
import time
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeWorkers:
def __init__(self, node):
self.node = node
def recording_worker(self):
"""线程1: 录音线程 - 唯一实时线程"""
node = self.node
node.get_logger().info("[录音线程] 启动")
node.audio_recorder.record_with_vad()
def asr_worker(self):
"""线程2: ASR推理线程 - 只做 audio → text"""
node = self.node
node.get_logger().info("[ASR推理线程] 启动")
while not node.stop_event.is_set():
try:
audio_chunk = node.audio_queue.get(timeout=0.1)
except queue.Empty:
continue
if node.interrupt_event.is_set():
continue
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
node.asr_client.send_audio(audio_chunk)
def process_worker(self):
"""线程3: 主线程 - 处理业务逻辑"""
node = self.node
node.get_logger().info("[主线程] 启动")
while not node.stop_event.is_set():
try:
text = node.text_queue.get(timeout=0.1)
except queue.Empty:
continue
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
current_state = node._get_state()
if current_state == ConversationState.CHECK_VOICE:
if node.use_wake_word:
node.get_logger().info(f"[主线程] CHECK_VOICE状态检查唤醒词文本: {text}")
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'回到Idle状态")
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
text = processed_text
if node.sv_enabled and node.sv_client:
node.get_logger().info("[主线程] CHECK_VOICE状态等待声纹验证结果...")
with node.sv_result_cv:
current_seq = node.sv_result_seq
if not node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=15.0
):
node.get_logger().warning("[主线程] CHECK_VOICE状态声纹结果未ready超时15秒拒绝本轮")
with node.sv_lock:
node.sv_audio_buffer.clear()
node._change_state(ConversationState.IDLE, "声纹结果未ready")
continue
node.get_logger().info("[主线程] CHECK_VOICE状态声纹结果ready继续处理")
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
score = node.current_speaker_score
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
else:
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
node.callbacks.put_tts_text("声纹验证失败")
node._change_state(ConversationState.IDLE, "声纹验证失败")
continue
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
elif current_state == ConversationState.AUTHORIZED:
if node.tts_playing_event.is_set():
node.get_logger().debug("[主线程] AUTHORIZED状态TTS播放中忽略ASR识别结果只有VAD检测到已授权用户人声才能中断")
continue
elif current_state == ConversationState.IDLE:
node.get_logger().warning("[主线程] Idle状态收到文本忽略")
continue
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
text = processed_text
if node.callbacks.check_shutup_command(text):
node.get_logger().info("[主线程] 检测到闭嘴指令")
node.interrupt_event.set()
node.callbacks.force_stop_tts()
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
continue
intent_payload = node.intent_router.route(text)
node._handle_intent(intent_payload)
if current_state == ConversationState.AUTHORIZED:
node.session_start_time = time.time()
def sv_worker(self):
"""线程5: 声纹识别线程 - 非实时、低频CAM++"""
node = self.node
node.get_logger().info("[声纹识别线程] 启动")
# 动态计算最小音频样本数确保降采样到16kHz后≥0.5秒
target_sr = 16000 # CAM++模型目标采样率
min_duration_seconds = 0.5
min_samples_at_target_sr = int(target_sr * min_duration_seconds) # 8000样本@16kHz
if node.sample_rate >= target_sr:
downsample_step = int(node.sample_rate / target_sr)
min_audio_samples = min_samples_at_target_sr * downsample_step
else:
min_audio_samples = int(node.sample_rate * min_duration_seconds)
while not node.stop_event.is_set():
try:
if node.sv_speech_end_event.wait(timeout=0.1):
node.sv_speech_end_event.clear()
with node.sv_lock:
audio_list = list(node.sv_audio_buffer)
buffer_size = len(audio_list)
node.sv_audio_buffer.clear()
node.get_logger().info(f"[声纹识别] 收到speech_end事件录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if node._handle_empty_speaker_db():
node.get_logger().info("[声纹识别] 数据库为空跳过验证直接设置UNKNOWN状态")
continue
if buffer_size >= min_audio_samples:
audio_array = np.array(audio_list, dtype=np.int16)
embedding, success = node.sv_client.extract_embedding(
audio_array,
sample_rate=node.sample_rate
)
if not success or embedding is None:
node.get_logger().debug("[声纹识别] 提取embedding失败")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.ERROR
node.current_speaker_score = 0.0
else:
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
with node.sv_lock:
node.current_speaker_id = speaker_id
node.current_speaker_state = match_state
node.current_speaker_score = score
if match_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
elif match_state == SpeakerState.REJECTED:
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
else:
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
else:
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.UNKNOWN
node.current_speaker_score = 0.0
with node.sv_result_cv:
node.sv_result_seq += 1
node.sv_result_cv.notify_all()
except Exception as e:
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
time.sleep(0.1)

View File

@@ -1,24 +1,16 @@
"""
声纹注册独立节点:运行完成后退出
"""
import collections
"""声纹注册独立节点:运行完成后退出"""
import os
import queue
import threading
import time
import yaml
import numpy as np
import threading
import queue
import rclpy
from rclpy.node import Node
from ament_index_python.packages import get_package_share_directory
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
from robot_speaker.perception.echo_cancellation import ReferenceSignalBuffer
from robot_speaker.models.asr.dashscope import DashScopeASR
from robot_speaker.models.tts.dashscope import DashScopeTTSClient
from robot_speaker.core.types import TTSRequest
from interfaces.srv import ASRRecognize, AudioData, VADEvent
from robot_speaker.core.speaker_verifier import SpeakerVerificationClient
from pypinyin import pinyin, Style
@@ -27,81 +19,15 @@ class RegisterSpeakerNode(Node):
super().__init__('register_speaker_node')
self._load_config()
self.stop_event = threading.Event()
self.processing = False
self.buffer_lock = threading.Lock()
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
self.asr_client = self.create_client(ASRRecognize, '/asr/recognize')
self.audio_data_client = self.create_client(AudioData, '/asr/audio_data')
self.vad_client = self.create_client(VADEvent, '/vad/event')
# 状态:等待唤醒词 -> 等待声纹语音
self.waiting_for_wake_word = True
self.waiting_for_voiceprint = False
# 音频队列和文本队列用于ASR
self.audio_queue = queue.Queue()
self.text_queue = queue.Queue()
self.vad_detector = VADDetector(
mode=self.vad_mode,
sample_rate=self.sample_rate
)
# 创建参考信号缓冲区(用于回声消除)
self.reference_signal_buffer = ReferenceSignalBuffer(
max_duration_ms=self.audio_echo_cancellation_max_duration_ms,
sample_rate=self.sample_rate,
channels=self.output_channels
) if self.audio_echo_cancellation_enabled else None
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
vad_detector=self.vad_detector,
audio_queue=self.audio_queue, # 送ASR用于唤醒词检测
silence_duration_ms=self.silence_duration_ms,
min_energy_threshold=self.min_energy_threshold,
heartbeat_interval=self.audio_microphone_heartbeat_interval,
on_heartbeat=self._on_heartbeat,
is_playing=lambda: False,
on_new_segment=None,
on_speech_start=self._on_speech_start,
on_speech_end=self._on_speech_end,
stop_flag=self.stop_event.is_set,
on_audio_chunk=self._on_audio_chunk,
should_put_to_queue=self._should_put_to_queue,
get_silence_threshold=lambda: self.silence_duration_ms,
enable_echo_cancellation=self.audio_echo_cancellation_enabled, # 启用回声消除,保持与主程序一致
reference_signal_buffer=self.reference_signal_buffer,
logger=self.get_logger()
)
# ASR客户端 - 用于唤醒词检测
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self._on_asr_sentence_end
self.asr_client.start()
# ASR处理线程
self.asr_thread = threading.Thread(
target=self._asr_worker,
name="RegisterASRThread",
daemon=True
)
self.asr_thread.start()
# 文本处理线程
self.text_thread = threading.Thread(
target=self._text_worker,
name="RegisterTextThread",
daemon=True
)
self.text_thread.start()
self.get_logger().info('等待服务启动...')
self.asr_client.wait_for_service(timeout_sec=10.0)
self.audio_data_client.wait_for_service(timeout_sec=10.0)
self.vad_client.wait_for_service(timeout_sec=10.0)
self.get_logger().info('所有服务已就绪')
self.sv_client = SpeakerVerificationClient(
model_path=self.sv_model_path,
@@ -110,31 +36,20 @@ class RegisterSpeakerNode(Node):
logger=self.get_logger()
)
self.tts_client = DashScopeTTSClient(
api_key=self.dashscope_api_key,
model=self.tts_model,
voice=self.tts_voice,
card_index=self.output_card_index,
device_index=self.output_device_index,
output_sample_rate=self.output_sample_rate,
output_channels=self.output_channels,
output_volume=self.output_volume,
tts_source_sample_rate=self.audio_tts_source_sample_rate,
tts_source_channels=self.audio_tts_source_channels,
tts_ffmpeg_thread_queue_size=self.audio_tts_ffmpeg_thread_queue_size,
reference_signal_buffer=self.reference_signal_buffer,
logger=self.get_logger()
)
self.registered = False
self.shutting_down = False
self.get_logger().info("声纹注册节点启动,请说唤醒词开始注册(例如:'二狗我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去'")
self.get_logger().info("声纹注册节点启动,请说'er gou......'唤醒注册")
self.recording_thread = threading.Thread(
target=self.audio_recorder.record_with_vad,
name="RegisterRecordingThread",
daemon=True
)
self.recording_thread.start()
# 使用队列在线程间传递 VAD 事件,避免在子线程中调用 spin_until_future_complete
self.vad_event_queue = queue.Queue()
self.recording = False # 录音状态标志
self.pending_asr_future = None # 待处理的 ASR future
self.pending_audio_future = None # 待处理的 AudioData future
self.state = "waiting_speech" # 状态机waiting_speech, waiting_asr, waiting_audio
self.timer = self.create_timer(0.2, self._check_done)
self.vad_thread = threading.Thread(target=self._vad_event_worker, daemon=True)
self.vad_thread.start()
self.timer = self.create_timer(0.1, self._main_loop)
def _load_config(self):
config_file = os.path.join(
@@ -145,267 +60,48 @@ class RegisterSpeakerNode(Node):
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
dashscope = config['dashscope']
audio = config['audio']
mic = audio['microphone']
soundcard = audio['soundcard']
vad = config['vad']
system = config['system']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
self.tts_model = dashscope['tts']['model']
self.tts_voice = dashscope['tts']['voice']
self.input_device_index = mic['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
self.output_card_index = soundcard['card_index']
self.output_device_index = soundcard['device_index']
self.output_sample_rate = soundcard['sample_rate']
self.output_channels = soundcard['channels']
self.output_volume = soundcard['volume']
echo = audio.get('echo_cancellation', {})
self.audio_echo_cancellation_enabled = echo['enable']
self.audio_echo_cancellation_max_duration_ms = echo.get('max_duration_ms', 200)
tts_audio = audio.get('tts', {})
self.audio_tts_source_sample_rate = tts_audio.get('source_sample_rate', 22050)
self.audio_tts_source_channels = tts_audio.get('source_channels', 1)
self.audio_tts_ffmpeg_thread_queue_size = tts_audio.get('ffmpeg_thread_queue_size', 5)
self.vad_mode = vad['vad_mode']
self.silence_duration_ms = vad['silence_duration_ms']
self.min_energy_threshold = vad['min_energy_threshold']
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
self.sv_threshold = system['sv_threshold']
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
self.sv_buffer_size = system['sv_buffer_size']
self.wake_word = system['wake_word']
def _should_put_to_queue(self) -> bool:
"""判断是否应该将音频放入ASR队列仅在等待唤醒词时"""
return self.waiting_for_wake_word
def _on_heartbeat(self):
if self.waiting_for_wake_word:
self.get_logger().info("[注册录音] 等待唤醒词'er gou'...")
elif self.waiting_for_voiceprint:
self.get_logger().info("[注册录音] 等待声纹语音...")
def _on_speech_start(self):
if self.waiting_for_wake_word:
# 等待唤醒词时,开始录音(可能包含唤醒词)
self.get_logger().info("[注册录音] 检测到人声,开始录音")
elif self.waiting_for_voiceprint:
self.get_logger().info("[注册录音] 检测到人声,继续录音(用于声纹注册)")
# 注意:不清空缓冲区,保留包含唤醒词的音频
def _on_audio_chunk(self, audio_chunk: bytes):
# 记录所有音频(包括唤醒词),用于声纹注册
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.buffer_lock:
self.audio_buffer.extend(audio_array)
except Exception as e:
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
def _on_speech_end(self):
# 如果还在等待唤醒词,不处理
if self.waiting_for_wake_word:
return
# 如果已经在处理,不重复处理
if self.processing:
return
# 等待声纹语音时用户说话结束使用当前音频即使不足3秒
if self.waiting_for_voiceprint:
self._process_voiceprint_audio(use_current_audio_if_short=True)
return # 处理完毕后直接返回,防止重复调用
def _process_voiceprint_audio(self, use_current_audio_if_short: bool = False):
"""处理声纹音频:使用用户完整的第一段语音进行注册
Args:
use_current_audio_if_short: 如果音频不足3秒是否使用当前音频用于用户已说完的情况
"""
if self.processing:
return
self.processing = True
with self.buffer_lock:
audio_list = list(self.audio_buffer)
buffer_size = len(audio_list)
buffer_sec = buffer_size / self.sample_rate
self.get_logger().info(f"[注册录音] 当前音频长度: {buffer_sec:.2f}")
required_samples = int(self.sample_rate * 3)
# 如果音频不足3秒
if buffer_size < required_samples:
if use_current_audio_if_short:
# 用户已经说完了使用当前音频即使不足3秒
self.get_logger().info(f"[注册录音] 音频不足3秒当前{buffer_sec:.2f}秒),但用户已说完,使用当前音频进行注册")
audio_to_use = audio_list
else:
# 等待继续录音
self.get_logger().info(f"[注册录音] 音频不足3秒当前{buffer_sec:.2f}秒),等待继续录音...")
self.processing = False
return
else:
# 策略优化不再强行截取最后3秒因为唤醒词检测有延迟
# "er gou" 可能在缓冲区的中间偏后位置。
# 为了防止截取到尾部的静音,并在包含完整唤醒词,
# 我们截取最近的 3.0 秒或者全部如果不足3秒
# 这样能最大程度包含有效语音 "二狗"。
target_samples = int(self.sample_rate * 3.0)
if buffer_size > target_samples:
audio_to_use = audio_list[-target_samples:]
else:
audio_to_use = audio_list
duration = len(audio_to_use) / self.sample_rate
self.get_logger().info(f"[注册录音] 使用最近 {duration:.2f} 秒音频用于注册(覆盖唤醒词)")
# 清空缓冲区
with self.buffer_lock:
self.audio_buffer.clear()
try:
audio_array = np.array(audio_to_use, dtype=np.int16)
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=self.sample_rate
)
if not success or embedding is None:
self.get_logger().error("[注册录音] 提取embedding失败")
self.processing = False
return
speaker_id = f"user_{int(time.time())}"
if self.sv_client.register_speaker(speaker_id, embedding):
self.get_logger().info(f"[注册录音] 注册成功用户ID: {speaker_id},准备退出")
# 播放成功提示
try:
self.get_logger().info("[注册录音] 播放注册成功提示")
request = TTSRequest(text="声纹注册成功", voice=self.tts_voice)
self.tts_client.synthesize(request)
time.sleep(5)
except Exception as e:
self.get_logger().error(f"[注册录音] 播放提示失败: {e}")
self.stop_event.set()
else:
self.get_logger().error("[注册录音] 注册失败")
self.processing = False
except Exception as e:
self.get_logger().error(f"[注册录音] 注册异常: {e}")
self.processing = False
def _extract_speech_segments(self, audio_array: np.ndarray, frame_size: int = 1024) -> list:
"""使用能量检测提取人声片段(过滤静音)"""
speech_segments = []
frame_samples = frame_size
total_frames = 0
speech_frames = 0
for i in range(0, len(audio_array), frame_samples):
frame = audio_array[i:i + frame_samples]
if len(frame) < frame_samples:
break
total_frames += 1
# 计算帧的能量RMS对于int16音频
frame_float = frame.astype(np.float32)
energy = np.sqrt(np.mean(frame_float ** 2))
# 使用更低的阈值来检测人声(降低阈值,避免误判静音)
# 阈值可以动态调整,或者使用自适应阈值
threshold = self.min_energy_threshold * 0.50 # 降低阈值到原来的50%
# 如果能量超过阈值,认为是人声
if energy >= threshold:
speech_segments.append((i, i + frame_samples))
speech_frames += 1
# 调试信息
if total_frames > 0:
speech_ratio = speech_frames / total_frames
self.get_logger().debug(f"[注册录音] 能量检测: 总帧数={total_frames}, 人声帧数={speech_frames}, 人声比例={speech_ratio:.2%}, 阈值={self.min_energy_threshold}")
return speech_segments
def _merge_speech_segments(self, audio_array: np.ndarray, segments: list, min_samples: int) -> np.ndarray:
"""合并人声片段,返回连续的人声音频"""
if not segments:
return np.array([], dtype=np.int16)
# 合并相邻的片段
merged_segments = []
current_start, current_end = segments[0]
for start, end in segments[1:]:
if start <= current_end + 1024: # 允许小间隙1帧
current_end = end
else:
merged_segments.append((current_start, current_end))
current_start, current_end = start, end
merged_segments.append((current_start, current_end))
# 从后往前选择片段直到达到3秒
selected_audio = []
total_samples = 0
for start, end in reversed(merged_segments):
segment_audio = audio_array[start:end]
selected_audio.insert(0, segment_audio)
total_samples += len(segment_audio)
if total_samples >= min_samples:
break
if not selected_audio:
return np.array([], dtype=np.int16)
return np.concatenate(selected_audio)
def _asr_worker(self):
"""ASR处理线程"""
while not self.stop_event.is_set():
def _vad_event_worker(self):
"""VAD 事件监听线程,只负责接收事件并放入队列,不调用 spin_until_future_complete"""
while not self.registered and not self.shutting_down:
try:
audio_chunk = self.audio_queue.get(timeout=0.1)
if self.asr_client and self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
except queue.Empty:
continue
request = VADEvent.Request()
request.command = "wait"
request.timeout_ms = 1000
future = self.vad_client.call_async(request)
# 简单等待 future 完成,不使用 spin_until_future_complete
start_time = time.time()
while not future.done() and (time.time() - start_time) < 1.5:
time.sleep(0.01)
if not future.done() or self.registered or self.shutting_down:
continue
response = future.result()
if response.success and response.event in ["speech_started", "speech_stopped"]:
# 将事件放入队列,由主线程处理
try:
self.vad_event_queue.put(response.event, timeout=0.1)
except queue.Full:
self.get_logger().warn(f"[VAD] 事件队列已满,丢弃事件: {response.event}")
except Exception as e:
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
if not self.shutting_down:
self.get_logger().error(f"[VAD] 线程异常: {e}")
break
def _on_asr_sentence_end(self, text: str):
"""ASR识别完成回调"""
if text and text.strip():
self.text_queue.put(text.strip())
def _text_worker(self):
"""文本处理线程:检测唤醒词"""
while not self.stop_event.is_set():
try:
text = self.text_queue.get(timeout=0.1)
if self.waiting_for_wake_word:
self._check_wake_word(text)
except queue.Empty:
continue
except Exception as e:
self.get_logger().error(f"[注册文本] 处理异常: {e}")
def _start_recording(self):
"""启动录音,返回 future 供主线程处理"""
request = AudioData.Request()
request.command = "start"
return self.audio_data_client.call_async(request)
def _to_pinyin(self, text: str) -> str:
"""将中文文本转换为拼音"""
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
if not chars:
return ""
@@ -413,10 +109,8 @@ class RegisterSpeakerNode(Node):
return ' '.join([item[0] for item in py_list]).lower().strip()
def _check_wake_word(self, text: str):
"""检查是否包含唤醒词"""
text_pinyin = self._to_pinyin(text)
wake_word_pinyin = self.wake_word.lower().strip()
self.get_logger().info(f"[注册唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
if not wake_word_pinyin:
return
@@ -424,40 +118,119 @@ class RegisterSpeakerNode(Node):
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
wake_word_parts = wake_word_pinyin.split()
# 检查是否包含唤醒词
has_wake_word = False
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'")
self.get_logger().info("=" * 50)
self.get_logger().info("[声纹注册] 开始注册声纹将截取3秒音频用于注册")
self.get_logger().info("=" * 50)
self.waiting_for_wake_word = False
self.waiting_for_voiceprint = True
# 停止ASR不再需要识别
if self.asr_client:
self.asr_client.stop_current_recognition()
# 立即处理当前音频缓冲区中的完整音频
# 用户可能已经说完了(包含唤醒词的整段语音)
self._process_voiceprint_audio()
return
def _check_done(self):
if self.stop_event.is_set():
has_wake_word = True
break
if has_wake_word:
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}',停止录音并获取音频")
request = AudioData.Request()
request.command = "stop"
future = self.audio_data_client.call_async(request)
future._future_type = "stop"
self.pending_audio_future = future
def _process_voiceprint_audio(self, response):
"""处理声纹音频数据 - 直接使用 AudioData 返回的音频,不再过滤"""
if not response or not response.success or response.samples == 0:
self.get_logger().error(f"[注册录音] 获取音频数据失败: {response.message if response else '无响应'}")
return
audio_array = np.frombuffer(response.audio_data, dtype=np.int16)
buffer_sec = response.samples / response.sample_rate
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}")
# 直接使用音频,不再进行 VAD 过滤
# 因为 AudioData 服务基于 DashScope VAD已经是语音活动片段
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=response.sample_rate
)
if not success or embedding is None:
self.get_logger().error("[注册录音] 提取embedding失败")
return
speaker_id = f"user_{int(time.time())}"
if self.sv_client.register_speaker(speaker_id, embedding):
# 注册成功后立即保存到文件
self.sv_client.save_speakers()
self.get_logger().info(f"[注册录音] 注册成功用户ID: {speaker_id},已保存到文件,准备退出")
self.registered = True
else:
self.get_logger().error("[注册录音] 注册失败")
def _main_loop(self):
"""主循环,在主线程中处理所有异步操作"""
# 检查是否完成注册
if self.registered:
self.get_logger().info("注册完成,节点退出")
# 清理资源
if self.asr_client:
self.asr_client.stop()
self.destroy_node()
self.shutting_down = True
self.timer.cancel()
rclpy.shutdown()
return
# 处理待处理的 ASR future
if self.pending_asr_future and self.pending_asr_future.done():
response = self.pending_asr_future.result()
self.pending_asr_future = None
if response.success and response.text:
text = response.text.strip()
if text:
self._check_wake_word(text)
self.state = "waiting_speech"
# 处理待处理的 AudioData future
if self.pending_audio_future and self.pending_audio_future.done():
response = self.pending_audio_future.result()
future_type = getattr(self.pending_audio_future, '_future_type', None)
self.pending_audio_future = None
if future_type == "start":
if response.success:
self.get_logger().info("[注册录音] 已开始录音")
self.recording = True
else:
self.get_logger().warn(f"[注册录音] 启动录音失败: {response.message}")
self.state = "waiting_speech"
elif future_type == "stop":
self.recording = False
self._process_voiceprint_audio(response)
# 处理 VAD 事件队列
try:
event = self.vad_event_queue.get_nowait()
if event == "speech_started" and self.state == "waiting_speech" and not self.recording:
self.get_logger().info("[VAD] 检测到语音开始,启动录音")
future = self._start_recording()
future._future_type = "start"
self.pending_audio_future = future
elif event == "speech_stopped" and self.recording and self.state == "waiting_speech":
self.get_logger().info("[VAD] 检测到语音结束,请求 ASR 识别")
self.state = "waiting_asr"
request = ASRRecognize.Request()
request.command = "start"
self.pending_asr_future = self.asr_client.call_async(request)
except queue.Empty:
pass
def main(args=None):
rclpy.init(args=args)
node = RegisterSpeakerNode()
rclpy.spin(node)
node.destroy_node()
try:
rclpy.shutdown()
except Exception:
pass
if __name__ == '__main__':
main()
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,185 @@
"""技能接口文件解析器"""
import os
import yaml
import json
from typing import Optional
from ament_index_python.packages import get_package_share_directory
class SkillInterfaceParser:
def __init__(self, interfaces_root: str):
"""初始化解析器"""
self.interfaces_root = interfaces_root
self._cached_skill_config: list[dict] | None = None
self._cached_skill_interfaces: dict[str, dict] | None = None
def get_skill_names(self) -> list[str]:
"""获取所有技能名称(统一读取 robot_skills.yaml避免重复"""
skill_config = self._load_skill_config()
return [entry["name"] for entry in skill_config if isinstance(entry, dict) and entry.get("name")]
def _load_skill_config(self) -> list[dict]:
"""加载 robot_skills.yaml带缓存避免重复读取"""
if self._cached_skill_config is not None:
return self._cached_skill_config
try:
brain_share = get_package_share_directory("brain")
skill_path = os.path.join(brain_share, "config", "robot_skills.yaml")
with open(skill_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or []
self._cached_skill_config = data if isinstance(data, list) else []
return self._cached_skill_config
except Exception:
self._cached_skill_config = []
return []
def parse_skill_interfaces(self) -> dict[str, dict]:
"""解析所有技能接口文件的目标字段(带缓存)"""
if self._cached_skill_interfaces is not None:
return self._cached_skill_interfaces
result = {}
skill_config = self._load_skill_config()
for skill_entry in skill_config:
skill_name = skill_entry.get("name")
if not skill_name:
continue
interfaces = skill_entry.get("interfaces", [])
for iface in interfaces:
if isinstance(iface, dict):
iface_name = iface.get("name", "")
else:
iface_name = str(iface)
if ".action" in iface_name:
iface_type = "action"
file_path = os.path.join(self.interfaces_root, "action", iface_name)
elif ".srv" in iface_name:
iface_type = "srv"
file_path = os.path.join(self.interfaces_root, "srv", iface_name)
else:
continue
if os.path.exists(file_path):
goal_fields = self._parse_goal_fields(file_path)
result[skill_name] = {
"type": iface_type,
"goal_fields": goal_fields
}
break
self._cached_skill_interfaces = result
return result
def _parse_goal_fields(self, file_path: str) -> list[dict]:
"""解析接口文件的目标字段(第一个---之前的所有字段)"""
goal_fields = []
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line.startswith("---"):
break
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
field_type = parts[0]
field_name = parts[1]
comment = ""
if "#" in line:
comment = line.split("#", 1)[1].strip()
goal_fields.append({
"name": field_name,
"type": field_type,
"comment": comment
})
except Exception:
return []
return goal_fields
def generate_params_documentation(self) -> str:
"""生成技能参数说明文档"""
skill_interfaces = self.parse_skill_interfaces()
doc_lines = []
for skill_name, skill_info in skill_interfaces.items():
doc_lines.append(f"{skill_name}技能的parameters字段")
goal_fields = skill_info.get("goal_fields", [])
if not goal_fields:
doc_lines.append(" - 无参数,使用 null")
else:
doc_lines.append(" parameters字典必须包含以下字段")
for field in goal_fields:
field_name = field["name"]
field_type = field["type"]
comment = field.get("comment", "")
if field_name == "body_id":
doc_lines.append(
f" - {field_name} ({field_type}): 身体部位ID0=左臂1=右臂2=头部。"
f"根据目标物在图片中的方位选择左侧用0右侧用1中央用2。"
)
else:
type_desc = self._get_type_description(field_type)
doc_lines.append(f" - {field_name} ({field_type}): {type_desc} {comment}")
example_params = {}
for field in goal_fields:
field_name = field["name"]
field_type = field["type"]
example_params[field_name] = self._get_example_value(field_name, field_type)
doc_lines.append(f" 示例:{json.dumps(example_params, ensure_ascii=False)}")
doc_lines.append("")
return "\n".join(doc_lines)
def _get_type_description(self, field_type: str) -> str:
"""根据字段类型返回描述"""
type_map = {
"int8": "整数,范围-128到127",
"int16": "整数,范围-32768到32767",
"int32": "整数",
"int64": "整数",
"uint8": "无符号整数范围0到255",
"float32": "浮点数",
"float64": "浮点数",
"string": "字符串",
}
base_type = field_type.replace("[]", "").replace("_", "")
return type_map.get(base_type, field_type)
def _get_example_value(self, field_name: str, field_type: str) -> any:
"""根据字段名和类型生成示例值"""
if field_name == "body_id":
return 0
elif field_name == "data_array" and "float64[]" in field_type:
return [0.1, 0.2, 0.3, 0.0, 0.0, 0.0]
elif "int" in field_type:
return 0
elif "float" in field_type:
return 0.0
elif "string" in field_type:
return ""
elif "[]" in field_type:
if "int" in field_type:
return [0, 0, 0]
elif "float" in field_type:
return [0.0, 0.0, 0.0]
return []
else:
return None

View File

@@ -0,0 +1,199 @@
"""
声纹识别模块
"""
import numpy as np
import threading
import os
import time
import json
from enum import Enum
class SpeakerState(Enum):
"""说话人识别状态"""
UNKNOWN = "unknown"
VERIFIED = "verified"
REJECTED = "rejected"
ERROR = "error"
class SpeakerVerificationClient:
"""声纹识别客户端 - 非实时、低频处理"""
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
self.model_path = model_path
self.threshold = threshold
self.speaker_db_path = speaker_db_path
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
self._lock = threading.Lock()
# # 优化CPU性能限制Torch使用的线程数防止多线程竞争导致性能骤降
import torch
torch.set_num_threads(1)
from funasr import AutoModel
model_path = os.path.expanduser(self.model_path)
# 禁用自动更新检查,防止每次初始化都联网检查
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
if self.logger:
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
if self.speaker_db_path:
self.load_speakers()
def _log(self, level: str, msg: str):
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
if level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warning(msg)
elif level == "error":
self.logger.error(msg)
elif level == "debug":
self.logger.debug(msg)
except Exception:
pass
def load_speakers(self):
if not self.speaker_db_path:
return
db_path = os.path.expanduser(self.speaker_db_path)
if not os.path.exists(db_path):
self._log("info", f"声纹数据库文件不存在: {db_path},将创建新文件")
return
try:
with open(db_path, 'rb') as f:
data = json.load(f)
with self._lock:
self.speaker_db = {}
for speaker_id, info in data.items():
embedding_array = np.array(info["embedding"], dtype=np.float32)
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
self.speaker_db[speaker_id] = {
"embedding": embedding_array,
"env": info.get("env", ""),
"registered_at": info.get("registered_at", 0.0)
}
self._log("info", f"已加载 {len(self.speaker_db)} 个已注册说话人")
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
def save_speakers(self):
if not self.speaker_db_path:
return
db_path = os.path.expanduser(self.speaker_db_path)
try:
os.makedirs(os.path.dirname(db_path), exist_ok=True)
with self._lock:
data = {}
for speaker_id, info in self.speaker_db.items():
data[speaker_id] = {
"embedding": info["embedding"].tolist(),
"env": info.get("env", ""),
"registered_at": info.get("registered_at", 0.0)
}
with open(db_path, 'w') as f:
json.dump(data, f, indent=2)
self._log("info", f"已保存 {len(data)} 个已注册说话人到: {db_path}")
except Exception as e:
self._log("error", f"保存声纹数据库失败: {e}")
def extract_embedding(self, audio_array: np.ndarray, sample_rate: int = 16000) -> tuple[np.ndarray | None, bool]:
try:
if len(audio_array) == 0:
return None, False
# 确保是int16格式
if audio_array.dtype != np.int16:
audio_array = audio_array.astype(np.int16)
# 转换为float32并归一化到[-1, 1]
audio_float = audio_array.astype(np.float32) / 32768.0
# 调用模型提取embedding
result = self.model.generate(input=audio_float, cache={})
if result and len(result) > 0 and "spk_embedding" in result[0]:
embedding = result[0]["spk_embedding"]
if embedding is not None and len(embedding) > 0:
embedding_array = np.array(embedding, dtype=np.float32)
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
return embedding_array, True
return None, False
except Exception as e:
self._log("error", f"提取声纹特征失败: {e}")
return None, False
def match_speaker(self, embedding: np.ndarray) -> tuple[str | None, SpeakerState, float, float]:
if embedding is None or len(embedding) == 0:
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
with self._lock:
if len(self.speaker_db) == 0:
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
try:
best_speaker_id = None
best_score = 0.0
with self._lock:
for speaker_id, info in self.speaker_db.items():
stored_embedding = info["embedding"]
# 计算余弦相似度
dot_product = np.dot(embedding, stored_embedding)
norm_embedding = np.linalg.norm(embedding)
norm_stored = np.linalg.norm(stored_embedding)
if norm_embedding > 0 and norm_stored > 0:
score = dot_product / (norm_embedding * norm_stored)
if score > best_score:
best_score = score
best_speaker_id = speaker_id
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
return best_speaker_id, state, float(best_score), float(self.threshold)
except Exception as e:
self._log("error", f"匹配说话人失败: {e}")
return None, SpeakerState.ERROR, 0.0, float(self.threshold)
def register_speaker(self, speaker_id: str, embedding: np.ndarray, env: str = "") -> bool:
if embedding is None or len(embedding) == 0:
return False
try:
with self._lock:
self.speaker_db[speaker_id] = {
"embedding": np.array(embedding, dtype=np.float32),
"env": env,
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}")
return True
except Exception as e:
self._log("error", f"注册说话人失败: {e}")
return False
def get_speaker_count(self) -> int:
with self._lock:
return len(self.speaker_db)
def get_speaker_list(self) -> list[str]:
with self._lock:
return list(self.speaker_db.keys())
def remove_speaker(self, speaker_id: str) -> bool:
with self._lock:
if speaker_id in self.speaker_db:
del self.speaker_db[speaker_id]
self._log("info", f"已删除说话人: {speaker_id}")
return True
return False
def cleanup(self):
try:
self.save_speakers()
if hasattr(self, 'model') and self.model:
del self.model
except Exception as e:
self._log("error", f"清理资源失败: {e}")

View File

@@ -1,36 +0,0 @@
"""
统一数据结构定义
"""
from dataclasses import dataclass
@dataclass
class ASRResult:
"""ASR识别结果"""
text: str
confidence: float | None = None
language: str | None = None
@dataclass
class LLMMessage:
"""LLM消息"""
role: str # "user", "assistant", "system"
content: str
@dataclass
class TTSRequest:
"""TTS请求"""
text: str
voice: str | None = None # 如果为None使用控制台配置的默认音色
speed: float | None = None
pitch: float | None = None
@dataclass
class ImageMessage:
"""图像消息 - 用于多模态LLM"""
image_data: bytes # base64编码的图像数据
image_format: str = "jpeg"

View File

@@ -1,4 +0,0 @@
"""模型层"""

View File

@@ -1,4 +0,0 @@
"""ASR模型"""

View File

@@ -1,12 +0,0 @@
class ASRClient:
def start(self) -> bool:
raise NotImplementedError
def stop(self) -> bool:
raise NotImplementedError
def send_audio(self, audio_data: bytes) -> bool:
raise NotImplementedError

View File

@@ -1,229 +0,0 @@
"""
ASR语音识别模块
"""
import base64
import time
import threading
import dashscope
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
from robot_speaker.models.asr.base import ASRClient
class DashScopeASR(ASRClient):
"""DashScope实时ASR识别器封装"""
def __init__(self, api_key: str,
sample_rate: int,
model: str,
url: str,
logger=None):
dashscope.api_key = api_key
self.sample_rate = sample_rate
self.model = model
self.url = url
self.logger = logger
self.conversation = None
self.running = False
self.on_sentence_end = None
self.on_text_update = None # 实时文本更新回调
# 线程同步机制
self._stop_lock = threading.Lock() # 防止并发调用 stop_current_recognition
self._final_result_event = threading.Event() # 等待 final 回调完成
self._pending_commit = False # 标记是否有待处理的 commit
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
else:
print(f"[ASR] {msg}")
def start(self):
"""启动ASR识别器"""
if self.running:
return False
try:
callback = _ASRCallback(self)
self.conversation = OmniRealtimeConversation(
model=self.model,
url=self.url,
callback=callback
)
callback.conversation = self.conversation
self.conversation.connect()
transcription_params = TranscriptionParams(
language='zh',
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
# 本地 VAD → 只控制 TTS 打断
# 服务端 turn detection → 只控制 ASR 输出、LLM 生成轮次
self.conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
enable_turn_detection=True,
# 保留服务端 turn detection
turn_detection_type='server_vad', # 服务端VAD
turn_detection_threshold=0.2, # 可调
turn_detection_silence_duration_ms=800
)
self.running = True
self._log("info", "ASR已启动")
return True
except Exception as e:
self.running = False
self._log("error", f"ASR启动失败: {e}")
if self.conversation:
try:
self.conversation.close()
except:
pass
self.conversation = None
return False
def send_audio(self, audio_chunk: bytes):
"""发送音频chunk到ASR"""
if not self.running or not self.conversation:
return False
try:
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
self.conversation.append_audio(audio_b64)
return True
except Exception as e:
# 连接已关闭或其他错误,静默处理(避免日志过多)
# running状态会在stop_current_recognition中正确设置
return False
def stop_current_recognition(self):
"""
停止当前识别触发final结果然后重新启动
优化:
1. 使用事件代替 sleep等待 final 回调完成
2. 使用锁防止并发调用
3. 处理 start() 失败的情况,确保 running 状态正确
4. 添加超时机制,避免无限等待
"""
# 使用锁防止并发调用
if not self._stop_lock.acquire(blocking=False):
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
return False
try:
if not self.running or not self.conversation:
return False
# 重置事件,准备等待 final 回调
self._final_result_event.clear()
self._pending_commit = True
# 触发 commit等待 final 结果
self.conversation.commit()
# 等待 final 回调完成最多等待1秒
if self._final_result_event.wait(timeout=1.0):
self._log("debug", "已收到 final 回调,准备关闭连接")
else:
self._log("warning", "等待 final 回调超时,继续执行")
# 先设置running=False防止ASR线程继续发送音频
self.running = False
# 关闭当前连接
old_conversation = self.conversation
self.conversation = None # 立即清空防止send_audio继续使用
try:
old_conversation.close()
except Exception as e:
self._log("warning", f"关闭连接时出错: {e}")
# 短暂等待,确保连接完全关闭
time.sleep(0.1)
# 重新启动,如果失败则保持 running=False
if not self.start():
self._log("error", "ASR重启失败running状态已重置")
return False
# 启动成功running已在start()中设置为True
return True
finally:
self._pending_commit = False
self._stop_lock.release()
def stop(self):
"""停止ASR识别器"""
# 等待正在执行的 stop_current_recognition 完成
with self._stop_lock:
self.running = False
self._final_result_event.set() # 唤醒可能正在等待的线程
if self.conversation:
try:
self.conversation.close()
except Exception as e:
self._log("warning", f"停止时关闭连接出错: {e}")
self.conversation = None
self._log("info", "ASR已停止")
class _ASRCallback(OmniRealtimeCallback):
"""ASR回调处理"""
def __init__(self, asr_client: DashScopeASR):
self.asr_client = asr_client
self.conversation = None
def on_open(self):
self.asr_client._log("info", "ASR WebSocket已连接")
def on_close(self, code, msg):
self.asr_client._log("info", f"ASR WebSocket已关闭: code={code}, msg={msg}")
def on_event(self, response):
event_type = response.get('type', '')
if event_type == 'session.created':
session_id = response.get('session', {}).get('id', '')
self.asr_client._log("info", f"ASR会话已创建: {session_id}")
elif event_type == 'conversation.item.input_audio_transcription.completed':
# 最终识别结果
transcript = response.get('transcript', '')
if transcript and transcript.strip() and self.asr_client.on_sentence_end:
self.asr_client.on_sentence_end(transcript.strip())
# 如果有待处理的 commit通知等待的线程
if self.asr_client._pending_commit:
self.asr_client._final_result_event.set()
elif event_type == 'conversation.item.input_audio_transcription.text':
# 实时识别文本更新(多轮提示)
transcript = response.get('transcript', '') or response.get('text', '')
if transcript and transcript.strip() and self.asr_client.on_text_update:
self.asr_client.on_text_update(transcript.strip())
elif event_type == 'input_audio_buffer.speech_started':
self.asr_client._log("info", "ASR检测到说话开始")
elif event_type == 'input_audio_buffer.speech_stopped':
self.asr_client._log("info", "ASR检测到说话结束")

View File

@@ -1,4 +0,0 @@
"""LLM模型"""

View File

@@ -1,14 +0,0 @@
from robot_speaker.core.types import LLMMessage
class LLMClient:
def chat(self, messages: list[LLMMessage]) -> str | None:
raise NotImplementedError
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
interrupt_check=None) -> str | None:
raise NotImplementedError

View File

@@ -1,149 +0,0 @@
"""
LLM大语言模型模块
支持多模态(文本+图像)
"""
from openai import OpenAI
from typing import Optional, List
from robot_speaker.core.types import LLMMessage
from robot_speaker.models.llm.base import LLMClient
class DashScopeLLM(LLMClient):
"""DashScope LLM客户端封装"""
def __init__(self, api_key: str,
model: str,
base_url: str,
temperature: float,
max_tokens: int,
name: str = "LLM",
logger=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.name = name
self.logger = logger
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
msg = f"[{self.name}] {msg}"
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
def chat(self, messages: list[LLMMessage]) -> str | None:
"""非流式聊天:任务规划"""
payload_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=payload_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=False
)
reply = response.choices[0].message.content.strip()
return reply if reply else None
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
images: Optional[List[str]] = None,
interrupt_check=None) -> str | None:
"""
流式聊天:语音系统
支持多模态(文本+图像)
支持中断检查interrupt_check: 返回True表示需要中断
"""
# 转换消息格式,支持多模态
# 图像只添加到最后一个user消息中
payload_messages = []
last_user_idx = -1
for i, msg in enumerate(messages):
if msg.role == "user":
last_user_idx = i
has_images_in_message = False
for i, msg in enumerate(messages):
msg_dict = {"role": msg.role}
# 如果当前消息是最后一个user消息且有图像构建多模态content
if i == last_user_idx and msg.role == "user" and images and len(images) > 0:
content_list = [{"type": "text", "text": msg.content}]
# 添加所有图像
for img_idx, img_base64 in enumerate(images):
image_url = f"data:image/jpeg;base64,{img_base64[:50]}..." if len(img_base64) > 50 else f"data:image/jpeg;base64,{img_base64}"
content_list.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
})
self._log("info", f"[多模态] 添加图像 #{img_idx+1} 到user消息base64长度: {len(img_base64)}")
msg_dict["content"] = content_list
has_images_in_message = True
else:
msg_dict["content"] = msg.content
payload_messages.append(msg_dict)
# 记录多模态信息
if images and len(images) > 0:
if has_images_in_message:
# 找到最后一个user消息记录其content结构
last_user_msg = payload_messages[last_user_idx] if last_user_idx >= 0 else None
if last_user_msg and isinstance(last_user_msg.get("content"), list):
content_items = last_user_msg["content"]
text_items = [item for item in content_items if item.get("type") == "text"]
image_items = [item for item in content_items if item.get("type") == "image_url"]
self._log("info", f"[多模态] 已发送多模态请求: {len(text_items)}个文本 + {len(image_items)}张图片")
self._log("debug", f"[多模态] 用户文本: {text_items[0].get('text', '')[:50] if text_items else 'N/A'}")
else:
self._log("warning", "[多模态] 消息格式异常,无法确认图片是否添加")
else:
self._log("warning", f"[多模态] 有{len(images)}张图片但未找到user消息图片未被添加")
else:
self._log("debug", "[多模态] 纯文本请求(无图片)")
full_reply = ""
interrupted = False
stream = self.client.chat.completions.create(
model=self.model,
messages=payload_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True
)
for chunk in stream:
# 检查中断标志
if interrupt_check and interrupt_check():
self._log("info", "LLM流式处理被中断")
interrupted = True
break
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_reply += content
if on_token:
on_token(content)
# 在on_token回调后再次检查中断on_token可能设置中断标志
if interrupt_check and interrupt_check():
self._log("info", "LLM流式处理在on_token回调后被中断")
interrupted = True
break
if interrupted:
return None # 被中断时返回None表示未完成
return full_reply.strip() if full_reply else None

View File

@@ -1,4 +0,0 @@
"""TTS模型"""

View File

@@ -1,13 +0,0 @@
from robot_speaker.core.types import TTSRequest
class TTSClient:
"""TTS客户端抽象基类"""
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
raise NotImplementedError

View File

@@ -1,244 +0,0 @@
"""
TTS语音合成模块
"""
import subprocess
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
from robot_speaker.core.types import TTSRequest
from robot_speaker.models.tts.base import TTSClient
class DashScopeTTSClient(TTSClient):
"""DashScope流式TTS客户端封装"""
def __init__(self, api_key: str,
model: str,
voice: str,
card_index: int,
device_index: int,
output_sample_rate: int = 44100,
output_channels: int = 2,
output_volume: float = 1.0,
tts_source_sample_rate: int = 22050, # TTS服务固定输出采样率
tts_source_channels: int = 1, # TTS服务固定输出声道数
tts_ffmpeg_thread_queue_size: int = 1024, # ffmpeg输入线程队列大小
reference_signal_buffer=None, # 参考信号缓冲区(用于回声消除)
logger=None):
dashscope.api_key = api_key
self.model = model
self.voice = voice
self.card_index = card_index
self.device_index = device_index
self.output_sample_rate = output_sample_rate
self.output_channels = output_channels
self.output_volume = output_volume
self.tts_source_sample_rate = tts_source_sample_rate
self.tts_source_channels = tts_source_channels
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
self.logger = logger
self.current_ffmpeg_pid = None # 当前ffmpeg进程的PID
# 构建ALSA设备, 允许 ffmpeg 自动重采样 / 重声道
self.alsa_device = f"plughw:{card_index},{device_index}" if (
card_index >= 0 and device_index >= 0
) else "default"
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
else:
print(f"[TTS] {msg}")
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
"""主流程:流式合成并播放"""
callback = _TTSCallback(self, interrupt_check, on_chunk, self.reference_signal_buffer)
# 使用配置的voicerequest.voice为None或空时使用self.voice
voice_to_use = request.voice if request.voice and request.voice.strip() else self.voice
if not voice_to_use or not voice_to_use.strip():
self._log("error", f"Voice参数无效: '{voice_to_use}'")
return False
self._log("info", f"TTS开始: 文本='{request.text[:50]}...', voice='{voice_to_use}'")
synthesizer = SpeechSynthesizer(
model=self.model,
voice=voice_to_use,
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
callback=callback,
)
try:
synthesizer.streaming_call(request.text)
synthesizer.streaming_complete()
finally:
callback.cleanup()
return not callback._interrupted
class _TTSCallback(ResultCallback):
"""TTS回调处理 - 使用ffmpeg播放自动处理采样率转换"""
def __init__(self, tts_client: DashScopeTTSClient,
interrupt_check=None,
on_chunk=None,
reference_signal_buffer=None):
self.tts_client = tts_client
self.interrupt_check = interrupt_check
self.on_chunk = on_chunk
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
self._proc = None
self._interrupted = False
self._cleaned_up = False
def on_open(self):
# 使用ffmpeg播放自动处理采样率转换TTS源采样率 -> 设备采样率)
# TTS服务输出固定采样率和声道数ffmpeg会自动转换为设备采样率和声道数
ffmpeg_cmd = [
'ffmpeg',
'-f', 's16le', # 原始 PCM
'-ar', str(self.tts_client.tts_source_sample_rate), # TTS输出采样率从配置文件读取
'-ac', str(self.tts_client.tts_source_channels), # TTS输出声道数从配置文件读取
'-i', 'pipe:0', # stdin
'-f', 'alsa', # 输出到 ALSA
'-ar', str(self.tts_client.output_sample_rate), # 输出设备采样率(从配置文件读取)
'-ac', str(self.tts_client.output_channels), # 输出设备声道数(从配置文件读取)
'-acodec', 'pcm_s16le', # 输出编码
'-fflags', 'nobuffer', # 减少缓冲
'-flags', 'low_delay', # 低延迟
'-avioflags', 'direct', # 尝试直通写入 ALSA减少延迟
self.tts_client.alsa_device
]
# 将 -thread_queue_size 放到输入文件之前
insert_pos = ffmpeg_cmd.index('-i')
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
# 添加音量调节filter如果音量不是1.0
if self.tts_client.output_volume != 1.0:
# 在输出编码前插入音量filter
# volume filter放在输入之后、输出编码之前
acodec_idx = ffmpeg_cmd.index('-acodec')
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
ffmpeg_cmd.insert(acodec_idx, '-af')
self.tts_client._log("info", f"启动ffmpeg播放: ALSA设备={self.tts_client.alsa_device}, "
f"输出采样率={self.tts_client.output_sample_rate}Hz, "
f"输出声道数={self.tts_client.output_channels}, "
f"音量={self.tts_client.output_volume * 100:.0f}%")
self._proc = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE # 改为PIPE以便捕获错误
)
# 记录ffmpeg进程PID
self.tts_client.current_ffmpeg_pid = self._proc.pid
self.tts_client._log("debug", f"ffmpeg进程已启动PID={self._proc.pid}")
def on_complete(self):
pass
def on_error(self, message: str):
self.tts_client._log("error", f"TTS错误: {message}")
def on_close(self):
self.cleanup()
def on_event(self, message):
pass
def on_data(self, data: bytes) -> None:
"""接收音频数据并播放"""
if self._interrupted:
return
if self.interrupt_check and self.interrupt_check():
# 停止播放,不停止 TTS
self._interrupted = True
if self._proc:
self._proc.terminate()
return
# 优先写入ffmpeg避免阻塞播放
# 优先写入ffmpeg避免阻塞播放
if self._proc and self._proc.stdin and not self._interrupted:
try:
self._proc.stdin.write(data)
self._proc.stdin.flush()
except BrokenPipeError:
# ffmpeg进程可能已退出检查错误
if self._proc.stderr:
error_msg = self._proc.stderr.read().decode('utf-8', errors='ignore')
self.tts_client._log("error", f"ffmpeg错误: {error_msg}")
self._interrupted = True
# 将音频数据添加到参考信号缓冲区(用于回声消除)
# 在写入ffmpeg之后处理避免阻塞播放
if self.reference_signal_buffer and data:
try:
self.reference_signal_buffer.add_reference(
data,
source_sample_rate=self.tts_client.tts_source_sample_rate,
source_channels=self.tts_client.tts_source_channels
)
except Exception as e:
# 参考信号处理失败不应影响播放
self.tts_client._log("warning", f"参考信号处理失败: {e}")
if self.on_chunk:
self.on_chunk(data)
def cleanup(self):
"""清理资源"""
if self._cleaned_up or not self._proc:
return
self._cleaned_up = True
# 关闭stdin让ffmpeg处理完剩余数据
if self._proc.stdin and not self._proc.stdin.closed:
try:
self._proc.stdin.close()
except:
pass
# 等待进程自然结束根据文本长度估算最少10秒最多30秒
# 假设平均语速3-4字/秒,加上缓冲时间
if self._proc.poll() is None:
try:
# 增加等待时间确保ffmpeg播放完成
# 对于长文本,可能需要更长时间
self._proc.wait(timeout=30.0)
except:
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
if self._proc.poll() is None:
self.tts_client._log("warning", "ffmpeg播放超时强制终止")
try:
self._proc.terminate()
self._proc.wait(timeout=1.0)
except:
try:
self._proc.kill()
self._proc.wait(timeout=0.1)
except:
pass
# 清空PID记录
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
self.tts_client.current_ffmpeg_pid = None

View File

@@ -1,4 +0,0 @@
"""感知层"""

View File

@@ -1,12 +1,11 @@
"""
音频处理模块:录音 + VAD + 回声消除
音频处理模块:录音 + VAD
"""
import time
import pyaudio
import webrtcvad
import struct
import queue
from .echo_cancellation import EchoCanceller, ReferenceSignalBuffer
class VADDetector:
@@ -35,8 +34,6 @@ class AudioRecorder:
on_audio_chunk=None, # 音频chunk回调用于声纹录音等可选
should_put_to_queue=None, # 检查是否应该将音频放入队列用于阻止ASR可选
get_silence_threshold=None, # 获取动态静音阈值(毫秒,可选)
enable_echo_cancellation: bool = True, # 是否启用回声消除
reference_signal_buffer: ReferenceSignalBuffer = None, # 参考信号缓冲区(可选)
logger=None):
self.device_index = device_index
self.sample_rate = sample_rate
@@ -89,7 +86,7 @@ class AudioRecorder:
self.device_index = found_index
else:
if self.logger:
self.logger.warning(f"未自动检测到 iFLYTEK 设备,将继续使用配置的索引: {self.device_index}")
self.logger.warning(f"未自动检测到 iFLYTEK 设备,请检查USB连接或执行 'arecord -l' 确认系统是否识别到录音设备,将继续使用配置的索引: {self.device_index}")
except Exception as e:
if self.logger:
@@ -97,39 +94,6 @@ class AudioRecorder:
self.format = pyaudio.paInt16
self._debug_counter = 0
# 回声消除相关
self.enable_echo_cancellation = enable_echo_cancellation
self.reference_signal_buffer = reference_signal_buffer
if enable_echo_cancellation:
# 初始化回声消除器(在录音线程中同步处理,不是单独线程)
# frame_size设置为chunk大小确保每次处理一个chunk
frame_size = chunk
try:
# 获取参考信号声道数从reference_signal_buffer获取因为它是根据播放声道数创建的
ref_channels = self.reference_signal_buffer.channels if self.reference_signal_buffer else 1
self.echo_canceller = EchoCanceller(
sample_rate=sample_rate,
frame_size=frame_size,
channels=self.channels, # 麦克风输入1声道
ref_channels=ref_channels, # 参考信号播放声道数2声道
logger=logger
)
if self.echo_canceller.aec is not None:
if logger:
logger.info(f"回声消除器已启用: sample_rate={sample_rate}, frame_size={frame_size}")
else:
if logger:
logger.warning("回声消除器初始化失败,将禁用回声消除功能")
self.enable_echo_cancellation = False
self.echo_canceller = None
except Exception as e:
if logger:
logger.warning(f"回声消除器初始化失败: {e},将禁用回声消除功能")
self.enable_echo_cancellation = False
self.echo_canceller = None
else:
self.echo_canceller = None
def record_with_vad(self):
"""录音线程VAD + 能量检测"""
@@ -163,19 +127,7 @@ class AudioRecorder:
while not self.stop_flag():
# exception_on_overflow=False, 宁可丢帧,也不阻塞
data = stream.read(self.chunk, exception_on_overflow=False)
# 回声消除处理
processed_data = data
if self.enable_echo_cancellation and self.echo_canceller and self.reference_signal_buffer:
try:
# 获取参考信号(长度与麦克风信号匹配)
ref_signal = self.reference_signal_buffer.get_reference(num_samples=self.chunk)
# 执行回声消除
processed_data = self.echo_canceller.process(data, ref_signal)
except Exception as e:
if self.logger:
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
processed_data = data
# 检查是否应该将音频放入队列用于阻止ASR例如无声纹文件时需要注册
if self.should_put_to_queue():

View File

@@ -1,131 +0,0 @@
"""
相机模块 - RealSense相机封装
"""
import numpy as np
import contextlib
class CameraClient:
def __init__(self,
serial_number: str | None,
width: int,
height: int,
fps: int,
format: str,
logger=None):
self.serial_number = serial_number
self.width = width
self.height = height
self.fps = fps
self.format = format
self.logger = logger
self.pipeline = None
self.config = None
self._is_initialized = False
self._rs = None
def _log(self, level: str, msg: str):
if self.logger:
getattr(self.logger, level, self.logger.info)(msg)
else:
print(f"[相机] {msg}")
def initialize(self) -> bool:
"""
初始化并启动相机管道
"""
if self._is_initialized:
return True
try:
import pyrealsense2 as rs
self._rs = rs
self.pipeline = rs.pipeline()
self.config = rs.config()
if self.serial_number:
self.config.enable_device(self.serial_number)
self.config.enable_stream(
rs.stream.color,
self.width,
self.height,
rs.format.rgb8 if self.format == 'RGB8' else rs.format.bgr8,
self.fps
)
self.pipeline.start(self.config)
self._is_initialized = True
self._log("info", f"相机已启动并保持运行: {self.width}x{self.height}@{self.fps}fps")
return True
except Exception as e:
self._log("error", f"相机初始化失败: {e}")
self.cleanup()
return False
def cleanup(self):
"""停止相机管道,释放资源"""
if self.pipeline:
self.pipeline.stop()
self._log("info", "相机已停止")
self.pipeline = None
self.config = None
self._is_initialized = False
def capture_rgb(self) -> np.ndarray | None:
"""
从运行中的相机管道捕获一帧RGB图像
"""
if not self._is_initialized:
self._log("error", "相机未初始化,无法捕获图像")
return None
try:
frames = self.pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
return np.asanyarray(color_frame.get_data())
except Exception as e:
self._log("error", f"捕获图像失败: {e}")
return None
@contextlib.contextmanager
def capture_context(self):
"""
上下文管理器:拍照并自动清理资源
"""
image_data = self.capture_rgb()
try:
yield image_data
finally:
if image_data is not None:
del image_data
def capture_multiple(self, count: int = 1) -> list[np.ndarray]:
"""
捕获多张图像(为未来扩展准备)
"""
images = []
for i in range(count):
img = self.capture_rgb()
if img is not None:
images.append(img)
else:
self._log("warning", f"{i+1}张图像捕获失败")
return images
@contextlib.contextmanager
def capture_multiple_context(self, count: int = 1):
"""
上下文管理器:捕获多张图像并自动清理资源
"""
images = self.capture_multiple(count)
try:
yield images
finally:
for img in images:
del img
images.clear()

View File

@@ -1,98 +0,0 @@
import collections
import numpy as np
class ReferenceSignalBuffer:
"""参考信号缓冲区"""
def __init__(self, sample_rate: int, channels: int, max_duration_ms: int | None = None,
buffer_seconds: float = 5.0):
self.sample_rate = int(sample_rate)
self.channels = int(channels)
if max_duration_ms is not None:
buffer_seconds = max(float(max_duration_ms) / 1000.0, 0.1)
self.max_samples = int(self.sample_rate * buffer_seconds)
self._buffer = collections.deque(maxlen=self.max_samples * self.channels)
def add_reference(self, data: bytes, source_sample_rate: int, source_channels: int):
if source_sample_rate != self.sample_rate or source_channels != self.channels:
return
samples = np.frombuffer(data, dtype=np.int16)
self._buffer.extend(samples.tolist())
def get_reference(self, num_samples: int) -> bytes:
needed = int(num_samples) * self.channels
if needed <= 0:
return b""
if len(self._buffer) < needed:
data = list(self._buffer) + [0] * (needed - len(self._buffer))
else:
data = list(self._buffer)[-needed:]
return np.array(data, dtype=np.int16).tobytes()
class EchoCanceller:
"""回声消除器(基于 aec-audio-processing"""
def __init__(self, sample_rate: int, frame_size: int, channels: int, ref_channels: int, logger=None):
self.sample_rate = int(sample_rate)
self.frame_size = int(frame_size)
self.channels = int(channels)
self.ref_channels = int(ref_channels)
self.logger = logger
self.aec = None
self._process_reverse = None
self._frame_bytes = int(self.sample_rate / 100) * self.channels * 2 # 10ms, int16
self._ref_frame_bytes = int(self.sample_rate / 100) * self.ref_channels * 2
try:
from aec_audio_processing import AudioProcessor
self.aec = AudioProcessor(enable_aec=True, enable_ns=False, enable_agc=False)
self.aec.set_stream_format(self.sample_rate, self.channels)
if hasattr(self.aec, "set_reverse_stream_format"):
self.aec.set_reverse_stream_format(self.sample_rate, self.ref_channels)
if hasattr(self.aec, "set_stream_delay"):
self.aec.set_stream_delay(0)
if hasattr(self.aec, "process_reverse_stream"):
self._process_reverse = self.aec.process_reverse_stream
elif hasattr(self.aec, "process_reverse"):
self._process_reverse = self.aec.process_reverse
except Exception:
self.aec = None
def process(self, mic_data: bytes, ref_data: bytes) -> bytes:
if not self.aec:
return mic_data
if not mic_data:
return mic_data
try:
out_chunks = []
total_len = len(mic_data)
frame_bytes = self._frame_bytes
ref_frame_bytes = self._ref_frame_bytes
frame_count = (total_len + frame_bytes - 1) // frame_bytes
for i in range(frame_count):
m_start = i * frame_bytes
m_end = m_start + frame_bytes
mic_frame = mic_data[m_start:m_end]
if len(mic_frame) < frame_bytes:
mic_frame = mic_frame + b"\x00" * (frame_bytes - len(mic_frame))
if ref_data:
r_start = i * ref_frame_bytes
r_end = r_start + ref_frame_bytes
ref_frame = ref_data[r_start:r_end]
if len(ref_frame) < ref_frame_bytes:
ref_frame = ref_frame + b"\x00" * (ref_frame_bytes - len(ref_frame))
if self._process_reverse:
self._process_reverse(ref_frame)
processed = self.aec.process_stream(mic_frame)
out_chunks.append(processed if processed is not None else mic_frame)
return b"".join(out_chunks)[:total_len]
except Exception as e:
if self.logger:
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
return mic_data

View File

@@ -1,296 +0,0 @@
"""
声纹识别模块
"""
import numpy as np
import threading
import tempfile
import os
import wave
import time
import json
from enum import Enum
class SpeakerState(Enum):
"""说话人识别状态"""
UNKNOWN = "unknown"
VERIFIED = "verified"
REJECTED = "rejected"
ERROR = "error"
class SpeakerVerificationClient:
"""声纹识别客户端 - 非实时、低频处理"""
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
self.model_path = model_path
self.threshold = threshold
self.speaker_db_path = speaker_db_path
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
self._lock = threading.Lock()
# 优化CPU性能限制Torch使用的线程数防止多线程竞争导致性能骤降
import torch
torch.set_num_threads(1)
from funasr import AutoModel
model_path = os.path.expanduser(self.model_path)
# 禁用自动更新检查,防止每次初始化都联网检查
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
if self.logger:
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
if self.speaker_db_path:
self.load_speakers()
def _log(self, level: str, msg: str):
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
log_methods = {
"debug": self.logger.debug,
"info": self.logger.info,
"warning": self.logger.warning,
"error": self.logger.error,
"fatal": self.logger.fatal
}
log_method = log_methods.get(level.lower(), self.logger.info)
log_method(msg)
except ValueError as e:
if "severity cannot be changed" in str(e):
try:
self.logger.info(f"[声纹-{level.upper()}] {msg}")
except:
print(f"[声纹-{level.upper()}] {msg}")
else:
raise
else:
print(f"[声纹] {msg}")
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""将numpy音频数组写入临时wav文件"""
audio_int16 = audio_data.astype(np.int16)
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
os.close(fd)
with wave.open(temp_path, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
return temp_path
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""
提取说话人embedding低频调用一句话只调用一次
"""
# 降采样到 16000Hz (如果需要)
# Cam++ 等模型通常只支持 16k如果传入 48k 会导致内部重采样极慢或计算量剧增
target_sr = 16000
if sample_rate > target_sr:
if sample_rate % target_sr == 0:
step = sample_rate // target_sr
audio_data = audio_data[::step]
sample_rate = target_sr
else:
# 简单的非整数倍降采样可能导致问题,但对于语音验证通常 48k->16k 是整数倍
# 如果不是,此处暂不处理,依赖 funasr 内部处理,或者简单的步长取整
step = int(sample_rate / target_sr)
audio_data = audio_data[::step]
sample_rate = target_sr
if len(audio_data) < int(sample_rate * 0.5):
return None, False
temp_wav_path = None
try:
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
import torch
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, False
return embedding, True
except Exception as e:
self._log("error", f"提取embedding失败: {e}")
return None, False
finally:
if temp_wav_path and os.path.exists(temp_wav_path):
try:
os.unlink(temp_wav_path)
except:
pass
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
env: str = "near", threshold: float = None) -> bool:
"""
注册说话人
"""
embedding_dim = len(embedding)
if embedding_dim == 0:
return False
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
self._log("error", f"注册失败embedding范数为0")
return False
embedding_normalized = embedding / embedding_norm
speaker_threshold = threshold if threshold is not None else self.threshold
with self._lock:
self.speaker_db[speaker_id] = {
"embedding": embedding_normalized,
"env": env, # 添加 env 字段
"threshold": speaker_threshold,
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
save_result = self.save_speakers()
if not save_result:
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
return True
def match_speaker(self, embedding: np.ndarray):
"""
匹配说话人(一句话只调用一次)
"""
if not self.speaker_db:
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
embedding_normalized = embedding / embedding_norm
best_match = None
best_score = -1.0
best_threshold = self.threshold
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
ref_embedding = speaker_data["embedding"]
score = np.dot(embedding_normalized, ref_embedding)
if score > best_score:
best_score = score
best_match = speaker_id
best_threshold = speaker_data["threshold"]
state = SpeakerState.VERIFIED if best_score >= best_threshold else SpeakerState.REJECTED
return (best_match, state, best_score, best_threshold)
def is_available(self) -> bool:
return self.model is not None
def cleanup(self):
"""清理资源"""
pass
def get_speaker_count(self) -> int:
with self._lock:
return len(self.speaker_db)
def remove_speaker(self, speaker_id: str) -> bool:
with self._lock:
if speaker_id not in self.speaker_db:
return False
del self.speaker_db[speaker_id]
self.save_speakers()
return True
def load_speakers(self) -> bool:
"""
从文件加载已注册的声纹
"""
if not self.speaker_db_path:
return False
if not os.path.exists(self.speaker_db_path):
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
return False
try:
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
data = json.load(f)
with self._lock:
for speaker_id, speaker_data in data.items():
embedding_list = speaker_data["embedding"]
embedding_array = np.array(embedding_list, dtype=np.float32)
embedding_dim = len(embedding_array)
if embedding_dim == 0:
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
continue
embedding_norm = np.linalg.norm(embedding_array)
if embedding_norm > 0:
embedding_array = embedding_array / embedding_norm
self.speaker_db[speaker_id] = {
"embedding": embedding_array,
"env": speaker_data["env"],
"threshold": speaker_data["threshold"],
"registered_at": speaker_data["registered_at"]
}
count = len(self.speaker_db)
self._log("info", f"已加载 {count} 个已注册说话人")
return True
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
return False
def save_speakers(self) -> bool:
"""
保存已注册的声纹到文件
"""
if not self.speaker_db_path:
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
return False
try:
db_dir = os.path.dirname(self.speaker_db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
json_data = {}
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
json_data[speaker_id] = {
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
"threshold": speaker_data["threshold"],
"registered_at": speaker_data["registered_at"]
}
temp_path = self.speaker_db_path + ".tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
os.replace(temp_path, self.speaker_db_path)
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
return True
except Exception as e:
import traceback
self._log("error", f"保存声纹数据库失败: {e}")
self._log("error", f"保存路径: {self.speaker_db_path}")
self._log("error", f"错误详情: {traceback.format_exc()}")
temp_path = self.speaker_db_path + ".tmp"
if os.path.exists(temp_path):
try:
os.unlink(temp_path)
except:
pass
return False

View File

@@ -0,0 +1,22 @@
"""
Service节点模块
"""

View File

@@ -0,0 +1,703 @@
import rclpy
from rclpy.node import Node
from interfaces.srv import ASRRecognize, AudioData, VADEvent
import threading
import queue
import time
import pyaudio
import yaml
import os
import collections
import numpy as np
import base64
import dashscope
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
from ament_index_python.packages import get_package_share_directory
class AudioRecorder:
def __init__(self, device_index: int, sample_rate: int, channels: int,
chunk: int, audio_queue: queue.Queue, stop_event, logger=None):
self.device_index = device_index
self.sample_rate = sample_rate
self.channels = channels
self.chunk = chunk
self.audio_queue = audio_queue
self.stop_event = stop_event
self.logger = logger
self.audio = pyaudio.PyAudio()
original_index = self.device_index
try:
for i in range(self.audio.get_device_count()):
device_info = self.audio.get_device_info_by_index(i)
if 'iFLYTEK' in device_info['name'] and device_info['maxInputChannels'] > 0:
self.device_index = i
if self.logger:
self.logger.info(f"[ASR-Recorder] 已自动定位到麦克风设备: {device_info['name']} (Index: {i})")
break
except Exception as e:
if self.logger:
self.logger.error(f"[ASR-Recorder] 设备自动检测过程出错: {e}")
if self.device_index == original_index and original_index == -1:
self.device_index = 0
if self.logger:
self.logger.info("[ASR-Recorder] 未找到 iFLYTEK 设备,使用系统默认输入设备")
self.format = pyaudio.paInt16
def record(self):
if self.logger:
self.logger.info(f"[ASR-Recorder] 录音线程启动,设备索引: {self.device_index}")
stream = None
try:
stream = self.audio.open(
format=self.format,
channels=self.channels,
rate=self.sample_rate,
input=True,
input_device_index=self.device_index if self.device_index >= 0 else None,
frames_per_buffer=self.chunk
)
if self.logger:
self.logger.info("[ASR-Recorder] 音频输入设备已打开")
except Exception as e:
if self.logger:
self.logger.error(f"[ASR-Recorder] 无法打开音频输入设备: {e}")
return
try:
while not self.stop_event.is_set():
try:
data = stream.read(self.chunk, exception_on_overflow=False)
if self.audio_queue.full():
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(data)
except OSError as e:
if self.logger:
self.logger.debug(f"[ASR-Recorder] 录音设备错误: {e}")
break
except KeyboardInterrupt:
if self.logger:
self.logger.info("[ASR-Recorder] 录音线程收到中断信号")
finally:
if stream is not None:
try:
if stream.is_active():
stream.stop_stream()
stream.close()
except Exception as e:
pass
if self.logger:
self.logger.info("[ASR-Recorder] 录音线程已退出")
class DashScopeASR:
def __init__(self, api_key: str, sample_rate: int, model: str, url: str, logger=None):
dashscope.api_key = api_key
self.sample_rate = sample_rate
self.model = model
self.url = url
self.logger = logger
self.conversation = None
self.running = False
self.on_sentence_end = None
self.on_speech_started = None
self.on_speech_stopped = None
self._stop_lock = threading.Lock()
self._final_result_event = threading.Event()
self._pending_commit = False
# ========== 连接生命周期管理: 解决 DashScope ASR WebSocket 连接超时导致的识别不稳定 ==========
self._connection_start_time = None # 连接创建时间
self._last_audio_time = None # 最后一次发送音频的时间
self._recognition_count = 0 # 识别次数计数
self._audio_send_count = 0 # 音频发送次数计数
self._last_audio_send_success = True # 最后一次音频发送是否成功
self._consecutive_send_failures = 0 # 连续发送失败次数
# 配置参数
self.MAX_CONNECTION_AGE = 300 # 连接最大存活时间5分钟
self.MAX_IDLE_TIME = 180 # 最大空闲时间3分钟
self.MAX_RECOGNITIONS = 30 # 最大识别次数30次后重建连接
self.MAX_CONSECUTIVE_FAILURES = 3 # 最大连续失败次数
def _log(self, level: str, msg: str):
if not self.logger:
return
try:
if level == "debug":
self.logger.debug(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
elif level == "info":
self.logger.info(msg)
except Exception:
pass
def _should_reconnect(self) -> tuple[bool, str]:
if not self.running or not self.conversation:
return False, ""
current_time = time.time()
# 检查1连接时间
if self._connection_start_time:
connection_age = current_time - self._connection_start_time
if connection_age > self.MAX_CONNECTION_AGE:
return True, f"连接已存活{connection_age:.0f}秒,超过{self.MAX_CONNECTION_AGE}秒阈值"
# 检查2空闲时间
if self._last_audio_time:
idle_time = current_time - self._last_audio_time
if idle_time > self.MAX_IDLE_TIME:
return True, f"连接已空闲{idle_time:.0f}秒,超过{self.MAX_IDLE_TIME}秒阈值"
# 检查3识别次数
if self._recognition_count >= self.MAX_RECOGNITIONS:
return True, f"已完成{self._recognition_count}次识别,达到重连阈值"
# 检查4连续发送失败
if self._consecutive_send_failures >= self.MAX_CONSECUTIVE_FAILURES:
return True, f"连续{self._consecutive_send_failures}次音频发送失败"
return False, ""
def _reset_connection_stats(self):
self._connection_start_time = time.time()
self._last_audio_time = time.time()
self._recognition_count = 0
self._audio_send_count = 0
self._last_audio_send_success = True
self._consecutive_send_failures = 0
def start(self):
if self.running:
return False
try:
callback = _ASRCallback(self)
self.conversation = OmniRealtimeConversation(
model=self.model,
url=self.url,
callback=callback
)
callback.conversation = self.conversation
self.conversation.connect()
transcription_params = TranscriptionParams(
language='zh',
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
self.conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
enable_turn_detection=True,
turn_detection_type='server_vad',
prefix_padding_ms=1000,
turn_detection_threshold=0.3,
turn_detection_silence_duration_ms=800,
)
self.running = True
self._reset_connection_stats()
self._log("info", f"[ASR] 已启动 | 连接ID:{id(self.conversation)}")
return True
except Exception as e:
self.running = False
self._log("error", f"[ASR] 启动失败: {e}")
if self.conversation:
try:
self.conversation.close()
except Exception:
pass
self.conversation = None
return False
def send_audio(self, audio_chunk: bytes):
should_reconnect, reason = self._should_reconnect()
if should_reconnect:
self._log("warning", f"[ASR] 检测到需要重连: {reason}")
self.running = False
try:
if self.conversation:
self.conversation.close()
except:
pass
self.conversation = None
time.sleep(1.0)
if not self.start():
self._log("error", "[ASR] 自动重连失败")
return False
self._log("info", "[ASR] 自动重连成功")
import threading
self._log("debug", f"[ASR] send_audio 被调用 | 线程:{threading.current_thread().name} | running:{self.running} | conversation:{self.conversation is not None}")
if not self.running or not self.conversation:
self._log("debug", f"[ASR] send_audio 跳过 | running:{self.running} | conversation:{self.conversation is not None}")
return False
try:
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
self.conversation.append_audio(audio_b64)
self._last_audio_time = time.time()
self._audio_send_count += 1
self._last_audio_send_success = True
self._consecutive_send_failures = 0
self._log("debug", f"[ASR] 音频发送成功 | 总计:{self._audio_send_count} | 连接年龄:{time.time() - self._connection_start_time:.1f}")
return True
except Exception as e:
self._last_audio_send_success = False
self._consecutive_send_failures += 1
error_msg = str(e)
error_type = type(e).__name__
if "Connection is already closed" in error_msg or "WebSocketConnectionClosedException" in error_type or "ConnectionClosed" in error_type or "websocket" in error_msg.lower():
self._log("warning", f"[ASR] WebSocket 连接已断开 | 错误:{error_msg} | 连续失败:{self._consecutive_send_failures}")
self.running = False
try:
if self.conversation:
self.conversation.close()
except:
pass
self.conversation = None
else:
self._log("error", f"[ASR] send_audio 异常 | 错误:{error_msg} | 类型:{error_type} | 连续失败:{self._consecutive_send_failures}")
return False
def stop_current_recognition(self):
import threading
self._log("debug", f"[ASR] stop_current_recognition 被调用 | 线程:{threading.current_thread().name} | running:{self.running}")
if not self._stop_lock.acquire(blocking=False):
self._log("debug", f"[ASR] 锁获取失败,有其他线程正在执行 stop_current_recognition")
return False
self._final_result_event.clear()
self._pending_commit = True
try:
self._log("debug", f"[ASR] 获得锁,开始停止识别 | conversation:{self.conversation is not None}")
if not self.running or not self.conversation:
self._log("debug", f"[ASR] 无法停止 | running:{self.running} | conversation:{self.conversation is not None}")
return False
self._recognition_count += 1
should_reconnect, reason = self._should_reconnect()
if should_reconnect:
self._log("info", f"[ASR] 识别完成后检测到需要重连: {reason}")
self._final_result_event.clear()
self._pending_commit = True
try:
self.conversation.commit()
self._final_result_event.wait(timeout=3.0)
except Exception as e:
self._log("debug", f"[ASR] commit 异常: {e}")
self._log("debug", f"[ASR] 准备关闭旧连接 | conversation_id:{id(self.conversation)}")
self.running = False
old_conversation = self.conversation
self.conversation = None
self._log("debug", f"[ASR] conversation已设为None准备关闭旧连接")
try:
old_conversation.close()
self._log("debug", f"[ASR] 旧连接已关闭")
except Exception as e:
self._log("warning", f"[ASR] 关闭连接异常: {e}")
self._log("debug", f"[ASR] 连接已关闭,等待下次语音活动时重连")
return True
finally:
self._pending_commit = False
self._stop_lock.release()
self._log("debug", f"[ASR] stop_current_recognition 完成,锁已释放")
def stop(self):
with self._stop_lock:
self.running = False
self._final_result_event.set()
if self.conversation:
try:
self.conversation.close()
except Exception:
pass
self.conversation = None
self._log("info", "[ASR] 已完全停止")
class _ASRCallback(OmniRealtimeCallback):
def __init__(self, asr_client: DashScopeASR):
self.asr_client = asr_client
self.conversation = None
def on_event(self, response):
try:
event_type = response['type']
if event_type == 'conversation.item.input_audio_transcription.completed':
transcript = response['transcript']
if transcript.strip() and self.asr_client.on_sentence_end:
self.asr_client.on_sentence_end(transcript.strip())
if self.asr_client._pending_commit:
self.asr_client._final_result_event.set()
elif event_type == 'input_audio_buffer.speech_started':
if self.asr_client.logger:
self.asr_client.logger.info("[ASR] 检测到语音开始")
if self.asr_client.on_speech_started:
self.asr_client.on_speech_started()
elif event_type == 'input_audio_buffer.speech_stopped':
if self.asr_client.logger:
self.asr_client.logger.info("[ASR] 检测到语音结束")
if self.asr_client.on_speech_stopped:
self.asr_client.on_speech_stopped()
except Exception:
pass
class ASRAudioNode(Node):
def __init__(self):
super().__init__('asr_audio_node')
self._load_config()
self.audio_queue = queue.Queue(maxsize=100)
self.stop_event = threading.Event()
self._shutdown_in_progress = False
self._init_components()
self.recognize_service = self.create_service(
ASRRecognize, '/asr/recognize', self._recognize_callback
)
self.audio_data_service = self.create_service(
AudioData, '/asr/audio_data', self._audio_data_callback
)
self.vad_event_service = self.create_service(
VADEvent, '/vad/event', self._vad_event_callback
)
self._last_result = None
self._result_event = threading.Event()
self._last_result_time = None
self.vad_event_queue = queue.Queue()
self.audio_buffer = collections.deque(maxlen=240000)
self.audio_recording = False
self.audio_lock = threading.Lock()
# ========== 异常识别检测 ==========
self._abnormal_results = ["嗯。", "", "啊。", "哦。"] # 异常识别结果列表
self._consecutive_abnormal_count = 0 # 连续异常识别次数
self.MAX_CONSECUTIVE_ABNORMAL = 5 # 最大连续异常次数
self.recording_thread = threading.Thread(
target=self.audio_recorder.record, name="RecordingThread", daemon=True
)
self.recording_thread.start()
self.asr_thread = threading.Thread(
target=self._asr_worker, name="ASRThread", daemon=True
)
self.asr_thread.start()
self.get_logger().info("ASR Audio节点已启动")
def _load_config(self):
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
mic = config['audio']['microphone']
self.input_device_index = mic['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
dashscope = config['dashscope']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
def _init_components(self):
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
audio_queue=self.audio_queue,
stop_event=self.stop_event,
logger=self.get_logger()
)
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self._on_asr_result
self.asr_client.on_speech_started = lambda: self._put_vad_event("speech_started")
self.asr_client.on_speech_stopped = lambda: (self._clear_result(), self._put_vad_event("speech_stopped"))
self.asr_client.start()
def _on_asr_result(self, text: str):
if not text or not text.strip():
return
self._last_result = text.strip()
self._last_result_time = time.time()
self._result_event.set()
is_abnormal = self._last_result in self._abnormal_results and len(self._last_result) <= 2
if is_abnormal:
self._consecutive_abnormal_count += 1
self.get_logger().warn(f"[ASR] 检测到异常识别结果: '{self._last_result}' | 连续异常:{self._consecutive_abnormal_count}")
# 如果连续多次异常,强制重置 ASR 连接
if self._consecutive_abnormal_count >= self.MAX_CONSECUTIVE_ABNORMAL:
self.get_logger().error(f"[ASR] 连续{self._consecutive_abnormal_count}次异常识别,标记需要重连")
self.asr_client._consecutive_send_failures = self.asr_client.MAX_CONSECUTIVE_FAILURES
self._consecutive_abnormal_count = 0
else:
# 正常识别,重置异常计数
self._consecutive_abnormal_count = 0
try:
self.get_logger().info(f"[ASR] 识别结果: {self._last_result}")
except Exception:
pass
def _put_vad_event(self, event_type):
try:
self.vad_event_queue.put(event_type, timeout=0.1)
except queue.Full:
try:
self.get_logger().warn(f"[ASR] VAD事件队列已满丢弃{event_type}事件")
except Exception:
pass
def _audio_data_callback(self, request, response):
import threading
self.get_logger().debug(f"[ASR-AudioData] 回调触发 | command:{request.command} | 线程:{threading.current_thread().name}")
response.sample_rate = self.sample_rate
response.channels = self.channels
if request.command == "start":
with self.audio_lock:
self.get_logger().debug(f"[ASR-AudioData] start命令 | 旧buffer大小:{len(self.audio_buffer)} | recording:{self.audio_recording}")
self.audio_buffer.clear()
self.audio_recording = True
self.get_logger().debug(f"[ASR-AudioData] buffer已清空recording=True")
response.success = True
response.message = "开始录音"
response.samples = 0
return response
if request.command == "stop":
self.get_logger().debug(f"[ASR-AudioData] stop命令 | recording:{self.audio_recording}")
with self.audio_lock:
self.audio_recording = False
audio_list = list(self.audio_buffer)
self.get_logger().debug(f"[ASR-AudioData] 读取buffer | 大小:{len(audio_list)}")
self.audio_buffer.clear()
if len(audio_list) > 0:
audio_array = np.array(audio_list, dtype=np.int16)
response.success = True
response.audio_data = audio_array.tobytes()
response.samples = len(audio_list)
response.message = f"录音完成{len(audio_list)}样本"
self.get_logger().debug(f"[ASR-AudioData] 返回音频 | samples:{len(audio_list)}")
else:
response.success = False
response.message = "缓冲区为空"
response.samples = 0
self.get_logger().debug(f"[ASR-AudioData] buffer为空")
return response
if request.command == "get":
with self.audio_lock:
audio_list = list(self.audio_buffer)
if len(audio_list) > 0:
audio_array = np.array(audio_list, dtype=np.int16)
response.success = True
response.audio_data = audio_array.tobytes()
response.samples = len(audio_list)
response.message = f"获取到{len(audio_list)}样本"
else:
response.success = False
response.message = "缓冲区为空"
response.samples = 0
return response
def _vad_event_callback(self, request, response):
timeout = request.timeout_ms / 1000.0 if request.timeout_ms > 0 else None
try:
event = self.vad_event_queue.get(timeout=timeout)
response.success = True
response.event = event
response.message = "收到VAD事件"
except queue.Empty:
response.success = False
response.event = "none"
response.message = "等待超时"
except KeyboardInterrupt:
try:
self.get_logger().info("[ASR-VAD] 收到中断信号,正在关闭")
except Exception:
pass
response.success = False
response.event = "none"
response.message = "节点正在关闭"
self.stop_event.set()
return response
def _clear_result(self):
self._last_result = None
self._last_result_time = None
self._result_event.clear()
def _return_result(self, response, text, message):
response.success = True
response.text = text
response.message = message
self._clear_result()
return response
def _recognize_callback(self, request, response):
if request.command == "stop":
if self.asr_client.running:
self.asr_client.stop_current_recognition()
response.success = True
response.text = ""
response.message = "识别已停止"
return response
if request.command == "reset":
self.asr_client.stop_current_recognition()
time.sleep(0.1)
self.asr_client.start()
response.success = True
response.text = ""
response.message = "识别器已重置"
return response
if self.asr_client.running:
current_time = time.time()
if (self._last_result and self._last_result_time and
(current_time - self._last_result_time) < 0.3) or (self._result_event.is_set() and self._last_result):
return self._return_result(response, self._last_result, "返回最近识别结果")
if self._result_event.wait(timeout=2.0) and self._last_result:
return self._return_result(response, self._last_result, "识别成功(等待中)")
self.asr_client.stop_current_recognition()
time.sleep(0.2)
self._clear_result()
if not self.asr_client.running and not self.asr_client.start():
response.success = False
response.text = ""
response.message = "ASR启动失败"
return response
if self._result_event.wait(timeout=5.0) and self._last_result:
response.success = True
response.text = self._last_result
response.message = "识别成功"
else:
response.success = False
response.text = ""
response.message = "识别超时" if not self._result_event.is_set() else "识别结果为空"
self._clear_result()
return response
def _asr_worker(self):
while not self.stop_event.is_set():
try:
audio_chunk = self.audio_queue.get(timeout=0.1)
except queue.Empty:
continue
except KeyboardInterrupt:
try:
self.get_logger().info("[ASR-Worker] 收到中断信号")
except Exception:
pass
break
if self.audio_recording:
self.get_logger().debug(f"[ASR-Worker] 收到音频chunk | recording:{self.audio_recording} | buffer_size:{len(self.audio_buffer)}")
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.audio_lock:
self.audio_buffer.extend(audio_array)
except Exception as e:
self.get_logger().error(f"[ASR-Worker] buffer写入异常 | 错误:{e}")
pass
if self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
else:
if not self.asr_client.start():
time.sleep(1.0)
def destroy_node(self):
if self._shutdown_in_progress:
return
self._shutdown_in_progress = True
try:
self.get_logger().info("ASR Audio节点正在关闭...")
except Exception:
pass
self.stop_event.set()
if hasattr(self, 'recording_thread') and self.recording_thread.is_alive():
self.recording_thread.join(timeout=1.0)
if hasattr(self, 'asr_thread') and self.asr_thread.is_alive():
self.asr_thread.join(timeout=1.0)
try:
if hasattr(self, 'audio_recorder'):
self.audio_recorder.audio.terminate()
except Exception:
pass
try:
if hasattr(self, 'asr_client'):
self.asr_client.stop()
except Exception:
pass
try:
super().destroy_node()
except Exception:
pass
def main(args=None):
rclpy.init(args=args)
node = ASRAudioNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
try:
node.get_logger().info("收到中断信号,正在关闭节点")
except Exception:
pass
finally:
try:
node.destroy_node()
except Exception:
pass
try:
rclpy.shutdown()
except Exception:
pass
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,341 @@
import rclpy
from rclpy.node import Node
from rclpy.callback_groups import ReentrantCallbackGroup
from interfaces.srv import TTSSynthesize
import threading
import yaml
import os
import signal
import subprocess
import time
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
from ament_index_python.packages import get_package_share_directory
class DashScopeTTSClient:
def __init__(self, api_key: str,
model: str,
voice: str,
card_index: int,
device_index: int,
output_sample_rate: int,
output_channels: int,
output_volume: float,
tts_source_sample_rate: int,
tts_source_channels: int,
tts_ffmpeg_thread_queue_size: int,
force_stop_delay: float,
cleanup_timeout: float,
terminate_timeout: float,
logger):
dashscope.api_key = api_key
self.model = model
self.voice = voice
self.card_index = card_index
self.device_index = device_index
self.output_sample_rate = output_sample_rate
self.output_channels = output_channels
self.output_volume = output_volume
self.tts_source_sample_rate = tts_source_sample_rate
self.tts_source_channels = tts_source_channels
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
self.force_stop_delay = force_stop_delay
self.cleanup_timeout = cleanup_timeout
self.terminate_timeout = terminate_timeout
self.logger = logger
self.current_ffmpeg_pid = None
self._current_callback = None
self.alsa_device = f"plughw:{card_index},{device_index}" if (
card_index >= 0 and device_index >= 0
) else "default"
def force_stop(self):
if self._current_callback:
self._current_callback._interrupted = True
if not self.current_ffmpeg_pid:
if self.logger:
self.logger.warn("[TTS] force_stop: current_ffmpeg_pid is None")
return
pid = self.current_ffmpeg_pid
try:
if self.logger:
self.logger.info(f"[TTS] force_stop: 正在kill进程 {pid}")
os.kill(pid, signal.SIGTERM)
time.sleep(self.force_stop_delay)
try:
os.kill(pid, 0)
os.kill(pid, signal.SIGKILL)
if self.logger:
self.logger.info(f"[TTS] force_stop: 已发送SIGKILL到进程 {pid}")
except ProcessLookupError:
if self.logger:
self.logger.info(f"[TTS] force_stop: 进程 {pid} 已退出")
except (ProcessLookupError, OSError) as e:
if self.logger:
self.logger.warn(f"[TTS] force_stop: kill进程失败 {pid}: {e}")
finally:
self.current_ffmpeg_pid = None
self._current_callback = None
def synthesize(self, text: str, voice: str = None,
on_chunk=None,
interrupt_check=None) -> bool:
callback = _TTSCallback(self, interrupt_check, on_chunk)
self._current_callback = callback
voice_to_use = voice if voice and voice.strip() else self.voice
if not voice_to_use or not voice_to_use.strip():
if self.logger:
self.logger.error(f"[TTS] Voice参数无效: '{voice_to_use}'")
self._current_callback = None
return False
synthesizer = SpeechSynthesizer(
model=self.model,
voice=voice_to_use,
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
callback=callback,
)
try:
synthesizer.streaming_call(text)
synthesizer.streaming_complete()
finally:
callback.cleanup()
self._current_callback = None
return not callback._interrupted
class _TTSCallback(ResultCallback):
def __init__(self, tts_client: DashScopeTTSClient,
interrupt_check=None,
on_chunk=None):
self.tts_client = tts_client
self.interrupt_check = interrupt_check
self.on_chunk = on_chunk
self._proc = None
self._interrupted = False
self._cleaned_up = False
def on_open(self):
ffmpeg_cmd = [
'ffmpeg',
'-f', 's16le',
'-ar', str(self.tts_client.tts_source_sample_rate),
'-ac', str(self.tts_client.tts_source_channels),
'-i', 'pipe:0',
'-f', 'alsa',
'-ar', str(self.tts_client.output_sample_rate),
'-ac', str(self.tts_client.output_channels),
'-acodec', 'pcm_s16le',
'-fflags', 'nobuffer',
'-flags', 'low_delay',
'-avioflags', 'direct',
self.tts_client.alsa_device
]
insert_pos = ffmpeg_cmd.index('-i')
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
if self.tts_client.output_volume != 1.0:
acodec_idx = ffmpeg_cmd.index('-acodec')
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
ffmpeg_cmd.insert(acodec_idx, '-af')
self._proc = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
self.tts_client.current_ffmpeg_pid = self._proc.pid
def on_data(self, data: bytes) -> None:
if self._interrupted:
return
if self.interrupt_check and self.interrupt_check():
self._interrupted = True
if self._proc:
self._proc.terminate()
return
if self._proc and self._proc.stdin and not self._interrupted:
try:
self._proc.stdin.write(data)
self._proc.stdin.flush()
except BrokenPipeError:
self._interrupted = True
except OSError:
self._interrupted = True
if self.on_chunk and not self._interrupted:
self.on_chunk(data)
def cleanup(self):
if self._cleaned_up or not self._proc:
return
self._cleaned_up = True
if self._proc.stdin and not self._proc.stdin.closed:
self._proc.stdin.close()
if self._proc.poll() is None:
self._proc.wait(timeout=self.tts_client.cleanup_timeout)
if self._proc.poll() is None:
self._proc.terminate()
self._proc.wait(timeout=self.tts_client.terminate_timeout)
if self._proc.poll() is None:
self._proc.kill()
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
self.tts_client.current_ffmpeg_pid = None
class TTSAudioNode(Node):
def __init__(self):
super().__init__('tts_audio_node')
self._load_config()
self._init_tts_client()
self.callback_group = ReentrantCallbackGroup()
self.synthesize_service = self.create_service(
TTSSynthesize, '/tts/synthesize', self._synthesize_callback,
callback_group=self.callback_group
)
self.interrupt_event = threading.Event()
self.playing_lock = threading.Lock()
self.is_playing = False
self.get_logger().info("[TTS] TTS Audio节点已启动")
def _load_config(self):
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
audio = config['audio']
soundcard = audio['soundcard']
tts_audio = audio['tts']
dashscope = config['dashscope']
self.output_card_index = soundcard['card_index']
self.output_device_index = soundcard['device_index']
self.output_sample_rate = soundcard['sample_rate']
self.output_channels = soundcard['channels']
self.output_volume = soundcard['volume']
self.tts_source_sample_rate = tts_audio['source_sample_rate']
self.tts_source_channels = tts_audio['source_channels']
self.tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
self.force_stop_delay = tts_audio['force_stop_delay']
self.cleanup_timeout = tts_audio['cleanup_timeout']
self.terminate_timeout = tts_audio['terminate_timeout']
self.interrupt_wait = tts_audio['interrupt_wait']
self.dashscope_api_key = dashscope['api_key']
self.tts_model = dashscope['tts']['model']
self.tts_voice = dashscope['tts']['voice']
def _init_tts_client(self):
self.tts_client = DashScopeTTSClient(
api_key=self.dashscope_api_key,
model=self.tts_model,
voice=self.tts_voice,
card_index=self.output_card_index,
device_index=self.output_device_index,
output_sample_rate=self.output_sample_rate,
output_channels=self.output_channels,
output_volume=self.output_volume,
tts_source_sample_rate=self.tts_source_sample_rate,
tts_source_channels=self.tts_source_channels,
tts_ffmpeg_thread_queue_size=self.tts_ffmpeg_thread_queue_size,
force_stop_delay=self.force_stop_delay,
cleanup_timeout=self.cleanup_timeout,
terminate_timeout=self.terminate_timeout,
logger=self.get_logger()
)
def _synthesize_callback(self, request, response):
command = request.command if request.command else "synthesize"
if command == "interrupt":
with self.playing_lock:
was_playing = self.is_playing
has_pid = self.tts_client.current_ffmpeg_pid is not None
if was_playing or has_pid:
self.interrupt_event.set()
self.tts_client.force_stop()
self.is_playing = False
response.success = True
response.message = "已中断播放"
response.status = "interrupted"
else:
response.success = False
response.message = "没有正在播放的内容"
response.status = "none"
return response
if not request.text or not request.text.strip():
response.success = False
response.message = "文本为空"
response.status = "error"
return response
with self.playing_lock:
if self.is_playing:
self.tts_client.force_stop()
time.sleep(self.interrupt_wait)
self.is_playing = True
self.interrupt_event.clear()
def synthesize_worker():
try:
success = self.tts_client.synthesize(
request.text.strip(),
voice=request.voice if request.voice else None,
interrupt_check=lambda: self.interrupt_event.is_set()
)
with self.playing_lock:
self.is_playing = False
if self.get_logger():
if success:
self.get_logger().info("[TTS] 合成并播放成功")
else:
self.get_logger().info("[TTS] 播放被中断")
except Exception as e:
with self.playing_lock:
self.is_playing = False
if self.get_logger():
self.get_logger().error(f"[TTS] 合成失败: {e}")
thread = threading.Thread(target=synthesize_worker, daemon=True)
thread.start()
response.success = True
response.message = "合成任务已启动"
response.status = "playing"
return response
def main(args=None):
rclpy.init(args=args)
node = TTSAudioNode()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@@ -1,4 +0,0 @@
"""理解层"""

View File

@@ -15,6 +15,7 @@ setup(
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
(os.path.join('share', package_name, 'srv'), glob('srv/*.srv')),
],
install_requires=[
'setuptools',
@@ -25,11 +26,13 @@ setup(
maintainer_email='mzebra@foxmail.com',
description='语音识别和合成ROS2包',
license='Apache-2.0',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
'skill_bridge_node = robot_speaker.bridge.skill_bridge_node:main',
'asr_audio_node = robot_speaker.services.asr_audio_node:main',
'tts_audio_node = robot_speaker.services.tts_audio_node:main',
],
},
)

10
srv/ASRRecognize.srv Normal file
View File

@@ -0,0 +1,10 @@
# 请求:启动识别
string command # "start" (默认), "stop", "reset"
---
# 响应:识别结果
bool success
string text # 识别文本(空字符串表示未识别到)
string message # 状态消息

27
srv/AudioData.srv Normal file
View File

@@ -0,0 +1,27 @@
# 请求:获取音频数据
string command # "start" (开始录音), "stop" (停止并返回), "get" (获取当前缓冲区)
int32 duration_ms # 录音时长毫秒仅用于start命令
---
# 响应:音频数据
bool success
uint8[] audio_data # PCM音频数据int16格式
int32 sample_rate
int32 channels
int32 samples # 样本数
string message

14
srv/TTSSynthesize.srv Normal file
View File

@@ -0,0 +1,14 @@
# 请求:合成文本或中断命令
string command # "synthesize" (默认), "interrupt"
string text
string voice # 可选,默认使用配置
---
# 响应:合成状态
bool success
string message
string status # "playing", "completed", "interrupted"

11
srv/VADEvent.srv Normal file
View File

@@ -0,0 +1,11 @@
# 请求等待VAD事件
string command # "wait" (等待下一个事件)
int32 timeout_ms # 超时时间毫秒0表示无限等待
---
# 响应VAD事件
bool success
string event # "speech_started", "speech_stopped", "none"
string message

View File

@@ -1,68 +0,0 @@
#!/usr/bin/env python3
"""
查看相机画面的简单脚本
按空格键保存当前帧,按'q'键退出
"""
import sys
import cv2
import numpy as np
try:
import pyrealsense2 as rs
except ImportError:
print("错误: 未安装pyrealsense2请运行: pip install pyrealsense2")
sys.exit(1)
def main():
# 配置相机
pipeline = rs.pipeline()
config = rs.config()
# 启用彩色流
config.enable_stream(rs.stream.color, 640, 480, rs.format.rgb8, 30)
# 启动管道
pipeline.start(config)
print("相机已启动,按空格键保存图片,按'q'键退出")
frame_count = 0
try:
while True:
# 等待一帧
frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
if not color_frame:
continue
# 转换为numpy数组 (RGB格式)
color_image = np.asanyarray(color_frame.get_data())
# OpenCV使用BGR格式需要转换
bgr_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
# 显示图像
cv2.imshow('Camera View', bgr_image)
# 等待按键
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
print("退出...")
break
elif key == ord(' '): # 空格键保存
frame_count += 1
filename = f'camera_frame_{frame_count:04d}.jpg'
cv2.imwrite(filename, bgr_image)
print(f"已保存: {filename}")
except KeyboardInterrupt:
print("\n中断...")
finally:
pipeline.stop()
cv2.destroyAllWindows()
print("相机已关闭")
if __name__ == '__main__':
main()