Compare commits
14 Commits
feature-de
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
| a0ceb934ce | |||
|
|
ed861a9fb1 | ||
| aaa17c10f2 | |||
|
|
c65395c50f | ||
| 9c8bd017e1 | |||
|
|
856c07715c | ||
| e8a9821ce4 | |||
| ab1fb4f3f8 | |||
| dd6ccf77bb | |||
| 7324630458 | |||
|
|
04ca80c3f9 | ||
| 98c0eb5ca5 | |||
| 71062701e1 | |||
| 6d101b9d9e |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,3 +4,6 @@ log/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
dist/
|
||||
lib/
|
||||
installed_files.txt
|
||||
|
||||
116
CMakeLists.txt
Normal file
116
CMakeLists.txt
Normal file
@@ -0,0 +1,116 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(robot_speaker)
|
||||
|
||||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
find_package(ament_cmake REQUIRED)
|
||||
find_package(ament_cmake_python REQUIRED)
|
||||
find_package(interfaces REQUIRED)
|
||||
|
||||
# 确保使用系统 Python(而不是 conda/miniconda 的 Python)
|
||||
find_program(PYTHON3_CMD python3 PATHS /usr/bin /usr/local/bin NO_DEFAULT_PATH)
|
||||
if(NOT PYTHON3_CMD)
|
||||
find_program(PYTHON3_CMD python3)
|
||||
endif()
|
||||
if(PYTHON3_CMD)
|
||||
set(Python3_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python 3 executable" FORCE)
|
||||
set(PYTHON_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python executable" FORCE)
|
||||
endif()
|
||||
|
||||
install(CODE "
|
||||
execute_process(
|
||||
COMMAND ${PYTHON3_CMD} -m pip install --prefix=${CMAKE_INSTALL_PREFIX} --no-deps ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE install_result
|
||||
OUTPUT_VARIABLE install_output
|
||||
ERROR_VARIABLE install_error
|
||||
)
|
||||
if(NOT install_result EQUAL 0)
|
||||
message(FATAL_ERROR \"Failed to install Python package. Output: ${install_output} Error: ${install_error}\")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${PYTHON3_CMD} -c \"
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
import sysconfig
|
||||
|
||||
install_prefix = '${CMAKE_INSTALL_PREFIX}'
|
||||
build_dir = '${CMAKE_CURRENT_BINARY_DIR}'
|
||||
python_version = f'{sysconfig.get_python_version()}'
|
||||
|
||||
# ROS2 期望的 Python 包位置
|
||||
ros2_site_packages = os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages')
|
||||
os.makedirs(ros2_site_packages, exist_ok=True)
|
||||
|
||||
# pip install --prefix 可能将包安装到不同位置(系统环境通常是 local/lib/pythonX/dist-packages)
|
||||
pip_locations = [
|
||||
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'dist-packages'),
|
||||
os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages'),
|
||||
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'site-packages'),
|
||||
]
|
||||
|
||||
# 查找并复制 robot_speaker 包到 ROS2 期望的位置
|
||||
robot_speaker_src = None
|
||||
for location in pip_locations:
|
||||
candidate = os.path.join(location, 'robot_speaker')
|
||||
if os.path.exists(candidate) and os.path.isdir(candidate):
|
||||
robot_speaker_src = candidate
|
||||
break
|
||||
|
||||
if robot_speaker_src:
|
||||
robot_speaker_dest = os.path.join(ros2_site_packages, 'robot_speaker')
|
||||
if os.path.exists(robot_speaker_dest):
|
||||
shutil.rmtree(robot_speaker_dest)
|
||||
if robot_speaker_src != robot_speaker_dest:
|
||||
shutil.copytree(robot_speaker_src, robot_speaker_dest)
|
||||
print(f'Copied robot_speaker from {robot_speaker_src} to {ros2_site_packages}')
|
||||
else:
|
||||
print(f'robot_speaker already in correct location')
|
||||
|
||||
# 处理 entry_points 脚本
|
||||
lib_dir = os.path.join(install_prefix, 'lib', 'robot_speaker')
|
||||
os.makedirs(lib_dir, exist_ok=True)
|
||||
|
||||
# 脚本可能在 local/bin 或 bin 中
|
||||
for bin_dir in [os.path.join(install_prefix, 'local', 'bin'), os.path.join(install_prefix, 'bin')]:
|
||||
if os.path.exists(bin_dir):
|
||||
scripts = glob.glob(os.path.join(bin_dir, '*_node'))
|
||||
for script in scripts:
|
||||
script_name = os.path.basename(script)
|
||||
dest = os.path.join(lib_dir, script_name)
|
||||
if script != dest:
|
||||
shutil.copy2(script, dest)
|
||||
os.chmod(dest, 0o755)
|
||||
print(f'Copied {script_name} to {lib_dir}')
|
||||
\"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE python_result
|
||||
OUTPUT_VARIABLE python_output
|
||||
)
|
||||
if(python_result EQUAL 0)
|
||||
message(STATUS \"${python_output}\")
|
||||
else()
|
||||
message(WARNING \"Failed to setup Python package: ${python_output}\")
|
||||
endif()
|
||||
")
|
||||
|
||||
install(DIRECTORY launch/
|
||||
DESTINATION share/${PROJECT_NAME}/launch
|
||||
FILES_MATCHING PATTERN "*.launch.py"
|
||||
)
|
||||
|
||||
install(DIRECTORY config/
|
||||
DESTINATION share/${PROJECT_NAME}/config
|
||||
FILES_MATCHING PATTERN "*.yaml" PATTERN "*.json"
|
||||
)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
ament_package()
|
||||
72
README.md
72
README.md
@@ -22,19 +22,21 @@ pip3 install -r requirements.txt --break-system-packages
|
||||
|
||||
## 编译启动
|
||||
1. 注册声纹
|
||||
- 启动节点后可以说:二狗今天天气真好开始注册声纹
|
||||
- 正确的注册姿势:
|
||||
方法A(推荐):唤醒后停顿一下,然后说一段长句子。
|
||||
用户:"二狗"
|
||||
机器:(日志提示等待声纹语音)
|
||||
用户:"我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去。"(持续说 3-5 秒)
|
||||
方法B(连贯说):一口气说很长的一句话。
|
||||
用户:"二狗你好,我是你的主人,请记住我的声音,这是一段用来注册的长语音。"
|
||||
- 注意:要包含唤醒词,语句不要停顿,尽量大于1.5秒
|
||||
- 启动节点后可以说:er gou我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去。
|
||||
- 正确的注册姿势:包含唤醒词二狗,不要停顿的尽量说完3秒
|
||||
|
||||
- 现在的逻辑只要识别到二狗就注册,然后退出节点,识别不到二狗继续等待
|
||||
- 多注册几段,换方向距离注册,可以提高识别相似度,注册方向对声纹相似性影响很大
|
||||
```bash
|
||||
cd ~/ros_learn/hivecore_robot_voice
|
||||
colcon build
|
||||
source install/setup.bash
|
||||
```
|
||||
|
||||
```bash
|
||||
# 终端1: 启动ASR节点
|
||||
ros2 run robot_speaker asr_audio_node
|
||||
# 终端2: 注册声纹
|
||||
ros2 run robot_speaker register_speaker_node
|
||||
```
|
||||
|
||||
@@ -49,38 +51,34 @@ source install/setup.bash
|
||||
ros2 launch robot_speaker voice.launch.py
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
[录音线程] - 唯一实时线程
|
||||
├─ 麦克风采集 PCM
|
||||
├─ VAD + 能量检测
|
||||
├─ 检测到人声 → 立即中断TTS
|
||||
├─ 语音 PCM → ASR 音频队列
|
||||
└─ 语音 PCM → 声纹音频队列(旁路,不阻塞)
|
||||
3. ASR节点
|
||||
```bash
|
||||
ros2 run robot_speaker asr_audio_node
|
||||
```
|
||||
|
||||
[ASR推理线程] - 只做 audio → text
|
||||
└─ 从 ASR 音频队列取音频→ 实时 / 流式 ASR → text → 文本队列
|
||||
4. TTS节点
|
||||
```bash
|
||||
# 终端1: 启动TTS节点
|
||||
ros2 run robot_speaker tts_audio_node
|
||||
|
||||
[声纹识别线程] - 非实时、低频(CAM++)
|
||||
├─ 通过回调函数接收音频chunk,写入缓冲区,等待 speech_end 事件触发处理
|
||||
├─ 累积 1~2 秒有效人声(VAD 后)
|
||||
├─ CAM++ 提取 speaker embedding
|
||||
├─ 声纹匹配 / 注册
|
||||
└─ 更新 current_speaker_id(共享状态,只写不控)
|
||||
声纹线程要求:不影响录音,不影响ASR,不控制TTS,只更新当前说话人是谁
|
||||
# 终端2: 启动播放
|
||||
source install/setup.bash
|
||||
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
|
||||
"{command: 'synthesize', text: '这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。让我继续说下去,这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。让我继续说下去,这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。', voice: ''}"
|
||||
|
||||
[主线程/处理线程] - 处理业务逻辑
|
||||
├─ 从 文本队列 取 ASR 文本
|
||||
├─ 读取 current_speaker_id(只读)
|
||||
├─ 唤醒词处理(结合 speaker_id)
|
||||
├─ 权限 / 身份判断(是否允许继续)
|
||||
├─ VLM处理(文本 / 多模态)
|
||||
└─ TTS播放(启动TTS线程,不等待)
|
||||
|
||||
[TTS播放线程] - 只播放(可被中断)
|
||||
├─ 接收 TTS 音频流
|
||||
├─ 播放到输出设备
|
||||
└─ 响应中断标志(由录音线程触发)
|
||||
# 终端3: 立即执行中断
|
||||
source install/setup.bash
|
||||
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
|
||||
"{command: 'interrupt', text: '', voice: ''}"
|
||||
```
|
||||
|
||||
5. 完整运行
|
||||
```bash
|
||||
# 终端1:启动 brain 节点
|
||||
# 终端2:启动 voice 节点
|
||||
# 终端3:启动 bridge 节点
|
||||
# 终端4:订阅相机
|
||||
```
|
||||
|
||||
## 用到的命令
|
||||
1. 音频设备
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
{
|
||||
"entries": [
|
||||
{
|
||||
"id": "robot_identity",
|
||||
"id": "robot_identity_1",
|
||||
"patterns": [
|
||||
"ni shi shei"
|
||||
"ni shi shui"
|
||||
],
|
||||
"answer": "我叫二狗,是蜂核科技的机器人,很高兴为你服务"
|
||||
},
|
||||
{
|
||||
"id": "robot_identity_2",
|
||||
"patterns": [
|
||||
"ni jiao sha"
|
||||
],
|
||||
"answer": "我叫二狗呀,我是你的好帮手"
|
||||
},
|
||||
{
|
||||
"id": "wake_word",
|
||||
@@ -13,6 +20,27 @@
|
||||
"ni de ming zi"
|
||||
],
|
||||
"answer": "我的名字是二狗"
|
||||
},
|
||||
{
|
||||
"id": "skill_1",
|
||||
"patterns": [
|
||||
"tiao ge wu"
|
||||
],
|
||||
"answer": "这个我真不会,我怕跳起来吓到你"
|
||||
},
|
||||
{
|
||||
"id": "skill_2",
|
||||
"patterns": [
|
||||
"ni neng gan"
|
||||
],
|
||||
"answer": "我可以陪你聊天,也能帮你干活"
|
||||
},
|
||||
{
|
||||
"id": "skill_3",
|
||||
"patterns": [
|
||||
"ni hui gan"
|
||||
],
|
||||
"answer": "我可以陪你聊天,你也可以发布具体的指令让我干活"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
1173
config/speakers.json
1173
config/speakers.json
File diff suppressed because it is too large
Load Diff
@@ -18,26 +18,25 @@ dashscope:
|
||||
|
||||
audio:
|
||||
microphone:
|
||||
device_index: 3 # 指向 iFLYTEK-M2 (hw:1,0)
|
||||
sample_rate: 48000 # 尝试使用硬件原生采样率 48kHz,避免重采样可能导致的问题
|
||||
device_index: -1 # 使用系统默认输入设备
|
||||
sample_rate: 48000 # 尝试使用硬件原生采样率 48kHz,避免重采样可能导致的问题
|
||||
channels: 1 # 输入声道数:单声道(MONO,适合语音采集)
|
||||
chunk: 1024
|
||||
heartbeat_interval: 2.0 # 心跳间隔(秒),用于定期输出录音状态
|
||||
soundcard:
|
||||
card_index: 1 # USB Audio Device (card 1)
|
||||
device_index: 0 # USB Audio [USB Audio] (device 0)
|
||||
# card_index: -1 # 使用默认声卡
|
||||
# device_index: -1 # 使用默认输出设备
|
||||
sample_rate: 48000 # 输出采样率:48kHz(iFLYTEK 支持 48000)
|
||||
card_index: -1 # 使用默认声卡
|
||||
device_index: -1 # 使用默认输出设备
|
||||
sample_rate: 48000 # 输出采样率:默认 44100
|
||||
channels: 2 # 输出声道数:立体声(2声道,FL+FR)
|
||||
volume: 1.0 # 音量比例(0.0-1.0,0.2表示20%音量)
|
||||
echo_cancellation:
|
||||
enable: false # 是否启用回声消除
|
||||
max_duration_ms: 500 # 参考信号缓冲区最大时长(毫秒)
|
||||
tts:
|
||||
source_sample_rate: 22050 # TTS服务固定输出采样率(DashScope服务固定值,不可修改)
|
||||
source_channels: 1 # TTS服务固定输出声道数(DashScope服务固定值,不可修改)
|
||||
ffmpeg_thread_queue_size: 4096 # ffmpeg输入线程队列大小(增大以减少卡顿)
|
||||
force_stop_delay: 0.1 # 强制停止时的延迟(秒)
|
||||
cleanup_timeout: 30.0 # 清理超时(秒)
|
||||
terminate_timeout: 1.0 # 终止超时(秒)
|
||||
interrupt_wait: 0.1 # 中断等待时间(秒)
|
||||
|
||||
vad:
|
||||
vad_mode: 3 # VAD模式:0-3,3最严格
|
||||
@@ -45,26 +44,24 @@ vad:
|
||||
min_energy_threshold: 300 # 最小能量阈值
|
||||
|
||||
system:
|
||||
use_llm: true # 是否使用LLM
|
||||
use_wake_word: true # 是否启用唤醒词检测
|
||||
wake_word: "er gou" # 唤醒词(拼音)
|
||||
session_timeout: 3.0 # 会话超时时间(秒)
|
||||
shutup_keywords: "bi zui" # 闭嘴指令关键词(拼音,逗号分隔)
|
||||
interrupt_command_queue_depth: 10 # 中断命令订阅的队列深度(QoS)
|
||||
sv_enabled: true # 是否启用声纹识别
|
||||
sv_model_path: "~/hivecore_robot_os1/voice_model" # 声纹模型路径
|
||||
sv_threshold: 0.55 # 声纹识别阈值(0.0-1.0,值越小越宽松,值越大越严格)
|
||||
sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
sv_buffer_size: 240000 # 声纹验证录音缓冲区大小(样本数,48kHz下5秒=240000)
|
||||
sv_registration_silence_threshold_ms: 500 # 声纹注册状态下的静音阈值(毫秒)
|
||||
|
||||
sv_enabled: false # 是否启用声纹识别
|
||||
# sv_model_path: "~/hivecore_robot_os1/voice_model" # 声纹模型路径
|
||||
sv_model_path: "~/ros_learn/speech_campplus_sv_zh-cn_16k-common" # 声纹模型路径
|
||||
sv_threshold: 0.65 # 声纹识别阈值(0.0-1.0,值越小越宽松,值越大越严格)
|
||||
# sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
sv_speaker_db_path: "~/ros_learn/hivecore_robot_voice/config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
sv_buffer_size: 96000 # 声纹验证录音缓冲区大小(样本数,48kHz下2秒=96000)
|
||||
continue_without_image: true # 多模态意图(skill_sequence/chat_camera)未获取到图片时是否继续推理
|
||||
|
||||
camera:
|
||||
serial_number: "405622075404" # 相机序列号(Intel RealSense D435)
|
||||
rgb:
|
||||
width: 640 # 图像宽度
|
||||
height: 480 # 图像高度
|
||||
fps: 30 # 帧率(支持:6, 10, 15, 30, 60)
|
||||
format: "RGB8" # 图像格式:RGB8, BGR8
|
||||
image:
|
||||
jpeg_quality: 85 # JPEG压缩质量(0-100,85是质量和大小平衡点)
|
||||
max_size: "1280x720" # 最大尺寸
|
||||
|
||||
interfaces:
|
||||
# root_path: "~/hivecore_robot_os1/hivecore_robot_interfaces/src" # 接口文件根目录,支持 ~ 展开和相对路径
|
||||
root_path: "~/ros_learn/hivecore_robot_interfaces/src" # 接口文件根目录,支持 ~ 展开和相对路径
|
||||
|
||||
54
launch/register_speaker.launch.py
Normal file
54
launch/register_speaker.launch.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import SetEnvironmentVariable, RegisterEventHandler
|
||||
from launch.event_handlers import OnProcessExit
|
||||
from launch.actions import EmitEvent
|
||||
from launch.events import Shutdown
|
||||
import os
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""启动声纹注册节点(需要ASR服务)"""
|
||||
# 获取interfaces包的install路径
|
||||
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
|
||||
|
||||
# 设置AMENT_PREFIX_PATH,确保能找到interfaces包的消息类型
|
||||
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
|
||||
if interfaces_install_path not in ament_prefix_path:
|
||||
if ament_prefix_path:
|
||||
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
|
||||
else:
|
||||
ament_prefix_path = interfaces_install_path
|
||||
|
||||
# ASR + 音频输入设备节点(提供ASR和AudioData服务)
|
||||
asr_audio_node = Node(
|
||||
package='robot_speaker',
|
||||
executable='asr_audio_node',
|
||||
name='asr_audio_node',
|
||||
output='screen'
|
||||
)
|
||||
|
||||
# 声纹注册节点
|
||||
register_speaker_node = Node(
|
||||
package='robot_speaker',
|
||||
executable='register_speaker_node',
|
||||
name='register_speaker_node',
|
||||
output='screen'
|
||||
)
|
||||
|
||||
# 当注册节点退出时,关闭整个 launch
|
||||
register_exit_handler = RegisterEventHandler(
|
||||
OnProcessExit(
|
||||
target_action=register_speaker_node,
|
||||
on_exit=[
|
||||
EmitEvent(event=Shutdown(reason='注册完成,关闭所有节点'))
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return LaunchDescription([
|
||||
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
|
||||
asr_audio_node,
|
||||
register_speaker_node,
|
||||
register_exit_handler,
|
||||
])
|
||||
@@ -1,10 +1,39 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import SetEnvironmentVariable
|
||||
import os
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""启动语音交互节点,所有参数从 voice.yaml 读取"""
|
||||
# 获取interfaces包的install路径
|
||||
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
|
||||
|
||||
# 设置AMENT_PREFIX_PATH,确保能找到interfaces包的消息类型
|
||||
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
|
||||
if interfaces_install_path not in ament_prefix_path:
|
||||
if ament_prefix_path:
|
||||
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
|
||||
else:
|
||||
ament_prefix_path = interfaces_install_path
|
||||
|
||||
return LaunchDescription([
|
||||
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
|
||||
# ASR + 音频输入设备节点(同时提供VAD事件Service,利用云端ASR的VAD)
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='asr_audio_node',
|
||||
name='asr_audio_node',
|
||||
output='screen'
|
||||
),
|
||||
# TTS + 音频输出设备节点
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='tts_audio_node',
|
||||
name='tts_audio_node',
|
||||
output='screen'
|
||||
),
|
||||
# 主业务逻辑节点
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='robot_speaker_node',
|
||||
|
||||
@@ -9,6 +9,12 @@
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>ament_index_python</depend>
|
||||
<depend>interfaces</depend>
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<buildtool_depend>ament_cmake_python</buildtool_depend>
|
||||
|
||||
<exec_depend>python3-pyaudio</exec_depend>
|
||||
<exec_depend>python3-requests</exec_depend>
|
||||
@@ -23,6 +29,6 @@
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
</package>
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
dashscope>=1.20.0
|
||||
openai>=1.0.0
|
||||
pyaudio>=0.2.11
|
||||
webrtcvad>=2.0.10
|
||||
pypinyin>=0.49.0
|
||||
rclpy>=3.0.0
|
||||
pyrealsense2>=2.54.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
numpy>=1.24.0,<2.0.0 # cv_bridge需要NumPy 1.x,NumPy 2.x会导致段错误
|
||||
PyYAML>=6.0
|
||||
aec-audio-processing
|
||||
modelscope>=1.33.0
|
||||
funasr>=1.0.0
|
||||
datasets==3.6.0
|
||||
|
||||
|
||||
|
||||
|
||||
24
robot_speaker/bridge/__init__.py
Normal file
24
robot_speaker/bridge/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Bridge package for connecting LLM outputs to brain execution.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
239
robot_speaker/bridge/skill_bridge_node.py
Normal file
239
robot_speaker/bridge/skill_bridge_node.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
桥接LLM技能序列到小脑ExecuteBtAction,并转发反馈/结果。
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.action import ActionClient
|
||||
from std_msgs.msg import String
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from interfaces.action import ExecuteBtAction
|
||||
from interfaces.srv import BtRebuild
|
||||
|
||||
|
||||
class SkillBridgeNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('skill_bridge_node')
|
||||
self._action_client = ActionClient(self, ExecuteBtAction, '/execute_bt_action')
|
||||
self._current_epoch = 1
|
||||
self.run_trigger_ = self.create_client(BtRebuild, '/cerebrum/rebuild_now')
|
||||
self.rebuild_requests = 0
|
||||
self._allowed_skills = self._load_allowed_skills()
|
||||
|
||||
self.skill_seq_sub = self.create_subscription(
|
||||
String, '/llm_skill_sequence', self._on_skill_sequence_received, 10
|
||||
)
|
||||
self.feedback_pub = self.create_publisher(String, '/skill_execution_feedback', 10)
|
||||
self.result_pub = self.create_publisher(String, '/skill_execution_result', 10)
|
||||
|
||||
self.get_logger().info('SkillBridgeNode started')
|
||||
|
||||
def _on_skill_sequence_received(self, msg: String):
|
||||
raw = (msg.data or "").strip()
|
||||
if not raw:
|
||||
return
|
||||
if not self._allowed_skills:
|
||||
self.get_logger().warning("No skill whitelist loaded; reject all sequences")
|
||||
return
|
||||
|
||||
# 尝试解析JSON格式
|
||||
sequence_list = None
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
sequence_list = self._parse_json_sequence(data)
|
||||
if sequence_list is None:
|
||||
self.get_logger().error("Invalid skill sequence format; must be JSON or plain text")
|
||||
return
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
self.get_logger().debug(f"JSON解析失败,尝试文本解析: {e}")
|
||||
|
||||
# JSON格式处理
|
||||
try:
|
||||
skill_names = [item["skill"] for item in sequence_list]
|
||||
if any(skill in skill_names for skill in ["VisionObjectRecognition", "Arm", "GripperCmd0"]):
|
||||
self.get_logger().info(f"Skill sequence contains special skills, triggering rebuild: {skill_names}")
|
||||
self.rebuild_now("Trigger", "bt_vision_grasp_dual_arm", "")
|
||||
else:
|
||||
skill_params = []
|
||||
for item in sequence_list:
|
||||
p = item.get("parameters")
|
||||
params = ""
|
||||
if isinstance(p, dict):
|
||||
lines = []
|
||||
for k, v in p.items():
|
||||
lines.append(f"{k}: {v}")
|
||||
if lines:
|
||||
params = "\n".join(lines) + "\n"
|
||||
skill_params.append(params)
|
||||
|
||||
self.get_logger().info(f"Sending skill sequence: {skill_names}")
|
||||
self.get_logger().info(f"Sending skill parameters: {skill_params}")
|
||||
|
||||
# 将技能名和参数列表分别用单引号包括,并用逗号隔开
|
||||
# names_str = ", ".join([f"'{name}'" for name in skill_names])
|
||||
# params_str = ", ".join([f"'{param}'" for param in skill_params])
|
||||
names_str = ", ".join(skill_names)
|
||||
params_str = ", ".join(skill_params)
|
||||
|
||||
self.rebuild_now("Remote", names_str, params_str)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Error processing skill sequence: {e}")
|
||||
|
||||
def _load_allowed_skills(self) -> set[str]:
|
||||
try:
|
||||
brain_share = get_package_share_directory("brain")
|
||||
skill_path = os.path.join(brain_share, "config", "robot_skills.yaml")
|
||||
if not os.path.exists(skill_path):
|
||||
return set()
|
||||
import yaml
|
||||
with open(skill_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or []
|
||||
return {str(entry["name"]) for entry in data if isinstance(entry, dict) and entry.get("name")}
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"Load skills failed: {e}")
|
||||
return set()
|
||||
|
||||
def _extract_skill_sequence(self, text: str) -> tuple[str, list[str]]:
|
||||
# Accept CSV/space/semicolon and filter by CamelCase tokens
|
||||
tokens = re.split(r'[,\s;]+', text.strip())
|
||||
skills = [t for t in tokens if re.match(r'^[A-Z][A-Za-z0-9]*$', t)]
|
||||
if not skills:
|
||||
return "", []
|
||||
invalid = [s for s in skills if s not in self._allowed_skills]
|
||||
return ",".join(skills), invalid
|
||||
|
||||
def _parse_json_sequence(self, data: dict) -> list[dict] | None:
|
||||
"""解析JSON格式的技能序列"""
|
||||
if not isinstance(data, dict) or "sequence" not in data:
|
||||
return None
|
||||
|
||||
sequence = data["sequence"]
|
||||
if not isinstance(sequence, list):
|
||||
return None
|
||||
|
||||
validated = []
|
||||
for item in sequence:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
skill = item.get("skill")
|
||||
if not skill or skill not in self._allowed_skills:
|
||||
continue
|
||||
|
||||
execution = item.get("execution", "serial")
|
||||
if execution not in ["serial", "parallel"]:
|
||||
execution = "serial"
|
||||
|
||||
body_id = item.get("body_id")
|
||||
# 只支持数字格式(0,1,2)和null,与意图路由对齐
|
||||
if body_id not in [0, 1, 2, None]:
|
||||
body_id = None
|
||||
|
||||
validated.append({
|
||||
"skill": skill,
|
||||
"execution": execution,
|
||||
"body_id": body_id,
|
||||
"parameters": item.get("parameters")
|
||||
})
|
||||
|
||||
return validated if validated else None
|
||||
|
||||
def _send_skill_sequence(self, skill_sequence: str):
|
||||
if not self._action_client.wait_for_server(timeout_sec=2.0):
|
||||
self.get_logger().error('ExecuteBtAction server unavailable')
|
||||
return
|
||||
goal = ExecuteBtAction.Goal()
|
||||
goal.epoch = self._current_epoch
|
||||
self._current_epoch += 1
|
||||
goal.action_name = skill_sequence
|
||||
goal.calls = []
|
||||
|
||||
self.get_logger().info(f"Dispatch skill sequence: {skill_sequence}")
|
||||
send_future = self._action_client.send_goal_async(goal, feedback_callback=self._feedback_callback)
|
||||
rclpy.spin_until_future_complete(self, send_future, timeout_sec=5.0)
|
||||
if not send_future.done():
|
||||
self.get_logger().warning("Send goal timed out")
|
||||
return
|
||||
goal_handle = send_future.result()
|
||||
if not goal_handle or not goal_handle.accepted:
|
||||
self.get_logger().error("Goal rejected")
|
||||
return
|
||||
result_future = goal_handle.get_result_async()
|
||||
rclpy.spin_until_future_complete(self, result_future)
|
||||
if result_future.done():
|
||||
self._handle_result(result_future.result())
|
||||
|
||||
def _feedback_callback(self, feedback_msg):
|
||||
fb = feedback_msg.feedback
|
||||
payload = {
|
||||
"stage": fb.stage,
|
||||
"current_skill": fb.current_skill,
|
||||
"progress": float(fb.progress),
|
||||
"detail": fb.detail,
|
||||
"epoch": int(fb.epoch),
|
||||
}
|
||||
msg = String()
|
||||
msg.data = json.dumps(payload, ensure_ascii=True)
|
||||
self.feedback_pub.publish(msg)
|
||||
|
||||
def _handle_result(self, result_wrapper):
|
||||
result = result_wrapper.result
|
||||
if not result:
|
||||
return
|
||||
payload = {
|
||||
"success": bool(result.success),
|
||||
"message": result.message,
|
||||
"total_skills": int(result.total_skills),
|
||||
"succeeded_skills": int(result.succeeded_skills),
|
||||
}
|
||||
msg = String()
|
||||
msg.data = json.dumps(payload, ensure_ascii=True)
|
||||
self.result_pub.publish(msg)
|
||||
|
||||
def rebuild_now(self, type: str, config: str, param: str) -> None:
|
||||
if not self.run_trigger_.service_is_ready():
|
||||
self.get_logger().error('Rebuild service not ready')
|
||||
return
|
||||
|
||||
self.rebuild_requests += 1
|
||||
self.get_logger().info(f'Rebuild BehaviorTree now. Total requests: {self.rebuild_requests}')
|
||||
|
||||
request = BtRebuild.Request()
|
||||
request.type = type
|
||||
request.config = config
|
||||
request.param = param
|
||||
|
||||
self.get_logger().info(f'Calling rebuild service... request info: {request}')
|
||||
|
||||
future = self.run_trigger_.call_async(request)
|
||||
future.add_done_callback(self._rebuild_done_callback)
|
||||
|
||||
def _rebuild_done_callback(self, future):
|
||||
try:
|
||||
response = future.result()
|
||||
if response.success:
|
||||
self.get_logger().info('Rebuild request successful')
|
||||
else:
|
||||
self.get_logger().warning(f'Rebuild request failed: {response.message}')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Rebuild request exception: {str(e)}')
|
||||
|
||||
self.get_logger().info(f"Rebuild requested. Total rebuild requests: {str(self.rebuild_requests)}")
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = SkillBridgeNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -2,3 +2,27 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""
|
||||
对话历史管理模块
|
||||
"""
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM消息"""
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
|
||||
|
||||
class ConversationHistory:
|
||||
"""对话历史管理器 - 实时语音"""
|
||||
|
||||
@@ -109,3 +116,15 @@ class ConversationHistory:
|
||||
self.summary = None
|
||||
self._pending_user_message = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""会话状态机"""
|
||||
IDLE = "idle" # 等待用户唤醒或声音
|
||||
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
|
||||
AUTHORIZED = "authorized" # 已注册用户
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from pypinyin import pinyin, Style
|
||||
from robot_speaker.core.skill_interface_parser import SkillInterfaceParser
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -9,7 +14,7 @@ class IntentResult:
|
||||
intent: str # "skill_sequence" | "kb_qa" | "chat_text" | "chat_camera"
|
||||
text: str
|
||||
need_camera: bool
|
||||
camera_mode: Optional[str] # "head" | "left_hand" | "right_hand" | None
|
||||
camera_mode: Optional[str] # "top" | "left" | "right" | "hand_r" | None
|
||||
system_prompt: Optional[str]
|
||||
|
||||
|
||||
@@ -18,12 +23,69 @@ class IntentRouter:
|
||||
self.camera_capture_keywords = [
|
||||
"pai zhao", "pai ge zhao", "pai zhang zhao"
|
||||
]
|
||||
self.skill_keywords = [
|
||||
"ban xiang zi"
|
||||
# 动作词列表(拼音)- 用于检测技能序列意图
|
||||
self.action_verbs = [
|
||||
"zou", "zou liang bu", "zou ji bu", # 走、走两步、走几步
|
||||
"na", "na qi", "na zhu", # 拿、拿起、拿住
|
||||
"ban", "ban yun", # 搬、搬运
|
||||
"zhua", "zhua qu", # 抓、抓取
|
||||
"tui", "tui dong", # 推、推动
|
||||
"la", "la dong", # 拉、拉动
|
||||
"yi dong", "qian jin", "hou tui", # 移动、前进、后退
|
||||
"kong zhi", "cao zuo", # 控制、操作
|
||||
"fang xia", "fang zhi", # 放下、放置
|
||||
"ju qi", "sheng qi", # 举起、升起
|
||||
"jia zhua", "jia qi", "jia", # 夹爪、夹起、夹
|
||||
"shen you bi", "shen zuo bi", "shen chu", "shen shou", # 伸右臂、伸左臂、伸出、伸手
|
||||
"zhuan quan", "zhuan yi quan", "zhuan", # 转个圈、转一圈、转
|
||||
]
|
||||
self.kb_keywords = [
|
||||
"ni shi shei", "ni de ming zi"
|
||||
"ni shi shui", "ni de ming zi", "tiao ge wu", "ni jiao sha", "ni hui gan", "ni neng gan"
|
||||
]
|
||||
self._cached_skill_names: list[str] | None = None
|
||||
self._cached_kb_data: list[dict] | None = None
|
||||
|
||||
interfaces_root = self._get_interfaces_root()
|
||||
self.interface_parser = SkillInterfaceParser(interfaces_root)
|
||||
|
||||
def _get_interfaces_root(self) -> str:
|
||||
"""从配置文件读取接口文件根目录"""
|
||||
try:
|
||||
robot_speaker_share = get_package_share_directory("robot_speaker")
|
||||
config_path = os.path.join(robot_speaker_share, "config", "voice.yaml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
|
||||
interfaces_config = config.get("interfaces", {})
|
||||
root_path = interfaces_config.get("root_path", "")
|
||||
|
||||
if not root_path:
|
||||
raise ValueError("interfaces.root_path 未在配置文件中配置")
|
||||
|
||||
if root_path.startswith("~"):
|
||||
root_path = os.path.expanduser(root_path)
|
||||
|
||||
if not os.path.isabs(root_path):
|
||||
config_dir = os.path.dirname(robot_speaker_share)
|
||||
root_path = os.path.join(config_dir, root_path)
|
||||
|
||||
abs_path = os.path.abspath(root_path)
|
||||
|
||||
if not os.path.exists(abs_path):
|
||||
raise ValueError(f"接口文件根目录不存在: {abs_path}")
|
||||
|
||||
return abs_path
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取接口文件根目录失败: {e}")
|
||||
|
||||
def _load_brain_skill_names(self) -> list[str]:
|
||||
"""加载技能名称(使用接口解析器,避免重复读取)"""
|
||||
if self._cached_skill_names is not None:
|
||||
return self._cached_skill_names
|
||||
|
||||
skill_names = self.interface_parser.get_skill_names()
|
||||
self._cached_skill_names = skill_names
|
||||
return skill_names
|
||||
|
||||
def to_pinyin(self, text: str) -> str:
|
||||
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
|
||||
@@ -32,89 +94,172 @@ class IntentRouter:
|
||||
py_list = pinyin(''.join(chars), style=Style.NORMAL)
|
||||
return ' '.join([item[0] for item in py_list]).lower().strip()
|
||||
|
||||
def is_skill_sequence_intent(self, text: str) -> bool:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
return any(k in text_pinyin for k in self.skill_keywords)
|
||||
def is_skill_sequence_intent(self, text: str, text_pinyin: str | None = None) -> bool:
|
||||
if text_pinyin is None:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
|
||||
# 检查动作词(精确匹配:动作词必须是完整的单词序列)
|
||||
text_words = text_pinyin.split()
|
||||
for action in self.action_verbs:
|
||||
action_words = action.split()
|
||||
# 检查动作词的单词序列是否是文本单词序列的连续子序列
|
||||
for i in range(len(text_words) - len(action_words) + 1):
|
||||
if text_words[i:i+len(action_words)] == action_words:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_camera_command(self, text: str) -> tuple[bool, Optional[str]]:
|
||||
def check_camera_command(self, text: str, text_pinyin: str | None = None) -> tuple[bool, Optional[str]]:
|
||||
"""检查是否包含拍照指令,返回(是否需要相机, 相机模式)"""
|
||||
if not text:
|
||||
return False, None
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
for keyword in self.camera_capture_keywords:
|
||||
if keyword in text_pinyin:
|
||||
return True, self.detect_camera_mode(text)
|
||||
if text_pinyin is None:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
# 精确匹配:关键词必须作为完整短语出现在文本拼音中
|
||||
if any(keyword in text_pinyin for keyword in self.camera_capture_keywords):
|
||||
return True, self.detect_camera_mode(text, text_pinyin)
|
||||
return False, None
|
||||
|
||||
def detect_camera_mode(self, text: str) -> str:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
left_keys = ["zuo shou", "zuo bi", "zuo bian"]
|
||||
right_keys = ["you shou", "you bi", "you bian"]
|
||||
head_keys = ["tou", "nao dai"]
|
||||
for kw in left_keys:
|
||||
if kw in text_pinyin:
|
||||
return "left_hand"
|
||||
for kw in right_keys:
|
||||
if kw in text_pinyin:
|
||||
return "right_hand"
|
||||
for kw in head_keys:
|
||||
if kw in text_pinyin:
|
||||
return "head"
|
||||
return "head"
|
||||
def detect_camera_mode(self, text: str, text_pinyin: str | None = None) -> str:
|
||||
"""检测相机模式,返回与相机驱动匹配的position值:left/right/top/hand_r"""
|
||||
if text_pinyin is None:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
if any(kw in text_pinyin for kw in ["zuo shou", "zuo bi", "zuo bian", "zuo shou bi"]):
|
||||
return "left"
|
||||
if any(kw in text_pinyin for kw in ["you shou", "you bi", "you bian", "you shou bi"]):
|
||||
return "right"
|
||||
if any(kw in text_pinyin for kw in ["shou bu", "shou", "shou xiang ji", "shou bi xiang ji"]):
|
||||
return "hand_r"
|
||||
if any(kw in text_pinyin for kw in ["tou", "nao dai", "ding bu", "shang fang"]):
|
||||
return "top"
|
||||
return "top"
|
||||
|
||||
def build_skill_prompt(self) -> str:
|
||||
def build_skill_prompt(self, execution_status: Optional[str] = None) -> str:
|
||||
skills = self._load_brain_skill_names()
|
||||
skills_text = ", ".join(skills) if skills else ""
|
||||
skill_guard = (
|
||||
"【技能限制】只能使用以下技能名称:" + skills_text
|
||||
if skills_text
|
||||
else "【技能限制】技能列表不可用,请不要输出任何技能名称。"
|
||||
)
|
||||
|
||||
execution_hint = ""
|
||||
if execution_status:
|
||||
execution_hint = f"【上一轮执行状态】{execution_status}\n请参考上述执行状态,根据成功/失败信息调整本次技能序列。\n"
|
||||
else:
|
||||
execution_hint = "【注意】这是首次执行或没有上一轮执行状态,请根据当前图片和用户请求规划技能序列。\n"
|
||||
|
||||
skill_params_doc = self.interface_parser.generate_params_documentation()
|
||||
|
||||
return (
|
||||
"你是机器人任务规划器。\n"
|
||||
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照(默认头部相机),并结合当前环境信息生成简洁、可执行的技能序列。"
|
||||
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照,并结合当前环境信息生成简洁、可执行的技能序列。\n"
|
||||
"如果用户明确要求或者任务明显需要双手/双臂协作(如扶稳+操作、抓取大体积的物体),必须规划双手技能。\n"
|
||||
+ execution_hint
|
||||
+ "\n"
|
||||
"【规划要求】\n"
|
||||
"1. execution规划:判断技能之间的执行关系\n"
|
||||
" - serial(串行):技能必须按顺序执行,前一个完成后再执行下一个\n"
|
||||
" - parallel(并行):技能可以同时执行\n"
|
||||
"2. parameters规划:根据目标物距离和任务需求,规划具体参数值\n"
|
||||
" - parameters字典必须包含该技能接口文件目标字段的所有字段\n"
|
||||
"【输出格式要求】\n"
|
||||
"必须输出JSON格式,包含sequence数组。每个技能对象包含3个一级字段:\n"
|
||||
"1. skill: 技能名称(字符串)\n"
|
||||
"2. execution: 执行方式,serial(串行)或 parallel(并行)\n"
|
||||
"3. parameters: 参数字典,包含该技能接口文件目标字段的所有字段,并填入合理的预测值。如果技能无参数,使用null。\n"
|
||||
"\n"
|
||||
"注意:一级字段(skill, execution, parameters)是固定结构。\n"
|
||||
"\n"
|
||||
"【技能参数说明】\n"
|
||||
+ skill_params_doc +
|
||||
"\n"
|
||||
"示例格式:\n"
|
||||
"{\n"
|
||||
' "sequence": [\n'
|
||||
' {"skill": "MoveWheel", "execution": "serial", "parameters": {"move_distance": 1.5, "move_angle": 0.0}},\n'
|
||||
' {"skill": "Arm", "execution": "serial", "parameters": {"body_id": 0, "data_type": 1, "data_length": 6, "command_id": 0, "frame_time_stamp": 0, "data_array": [0.1, 0.2, 0.3, 0.0, 0.0, 0.0]}},\n'
|
||||
' {"skill": "GripperCmd0", "execution": "parallel", "parameters": {"loc": 128, "speed": 100, "torque": 80, "mode": 1}}\n'
|
||||
" ]\n"
|
||||
"}\n"
|
||||
+ skill_guard
|
||||
)
|
||||
|
||||
def build_chat_prompt(self, need_camera: bool) -> str:
|
||||
if need_camera:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请结合图片内容简短回答。"
|
||||
"你是一个机器人视觉助理,擅长分析图片中物体的相对位置和空间关系。\n"
|
||||
"请结合图片内容,重点描述物体之间的相对位置(如左右、前后、上下、远近),仅基于可观察信息回答。\n"
|
||||
"回答应简短、客观,不要超过100个token。"
|
||||
)
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请自然、简短地与用户对话。"
|
||||
)
|
||||
|
||||
def build_kb_prompt(self) -> str:
|
||||
return (
|
||||
"你是蜂核科技的员工。\n"
|
||||
"请基于知识库信息回答用户问题,回答要准确简洁。"
|
||||
"你是一个表达清晰、语气自然的真人助理。\n"
|
||||
"请简短地与用户对话,不要超过100个token。"
|
||||
)
|
||||
|
||||
def _load_kb_data(self) -> list[dict]:
|
||||
"""加载知识库数据"""
|
||||
if self._cached_kb_data is not None:
|
||||
return self._cached_kb_data
|
||||
kb_data = []
|
||||
try:
|
||||
robot_speaker_share = get_package_share_directory("robot_speaker")
|
||||
kb_path = os.path.join(robot_speaker_share, "config", "knowledge.json")
|
||||
with open(kb_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
kb_data = data["entries"]
|
||||
except Exception as e:
|
||||
kb_data = []
|
||||
self._cached_kb_data = kb_data
|
||||
return kb_data
|
||||
|
||||
def search_kb(self, text: str) -> Optional[str]:
|
||||
"""检索知识库,返回匹配的答案"""
|
||||
if not text:
|
||||
return None
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
kb_data = self._load_kb_data()
|
||||
|
||||
for entry in kb_data:
|
||||
patterns = entry["patterns"]
|
||||
for pattern in patterns:
|
||||
if pattern in text_pinyin:
|
||||
answer = entry["answer"]
|
||||
if answer:
|
||||
return answer
|
||||
return None
|
||||
|
||||
def build_default_system_prompt(self) -> str:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"你是一个工厂专业的助手。\n"
|
||||
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
|
||||
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
|
||||
"请根据对话模式调整你的回答风格。"
|
||||
)
|
||||
|
||||
def route(self, text: str) -> IntentResult:
|
||||
need_camera, camera_mode = self.check_camera_command(text)
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
need_camera, camera_mode = self.check_camera_command(text, text_pinyin)
|
||||
|
||||
if self.is_skill_sequence_intent(text):
|
||||
if camera_mode is None:
|
||||
camera_mode = "head"
|
||||
if self.is_skill_sequence_intent(text, text_pinyin):
|
||||
# 技能序列意图总是需要相机,复用 detect_camera_mode:用户指定了相机就用指定的,否则默认 "top"
|
||||
skill_camera_mode = self.detect_camera_mode(text, text_pinyin)
|
||||
return IntentResult(
|
||||
intent="skill_sequence",
|
||||
text=text,
|
||||
need_camera=True,
|
||||
camera_mode=camera_mode,
|
||||
camera_mode=skill_camera_mode,
|
||||
system_prompt=self.build_skill_prompt()
|
||||
)
|
||||
|
||||
if any(k in text_pinyin for k in self.kb_keywords):
|
||||
# 精确匹配:关键词必须作为完整短语出现在文本拼音中
|
||||
if any(keyword in text_pinyin for keyword in self.kb_keywords):
|
||||
return IntentResult(
|
||||
intent="kb_qa",
|
||||
text=text,
|
||||
need_camera=False,
|
||||
camera_mode=None,
|
||||
system_prompt=self.build_kb_prompt()
|
||||
system_prompt=None # kb_qa不走LLM,不需要system_prompt
|
||||
)
|
||||
|
||||
return IntentResult(
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeCallbacks:
|
||||
# ==================== 初始化与内部工具 ====================
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def _mark_utterance_processed(self) -> bool:
|
||||
node = self.node
|
||||
with node.utterance_lock:
|
||||
if node.current_utterance_id == node.last_processed_utterance_id:
|
||||
return False
|
||||
node.last_processed_utterance_id = node.current_utterance_id
|
||||
return True
|
||||
|
||||
def _trigger_sv_for_check_voice(self, source: str):
|
||||
node = self.node
|
||||
if not (node.sv_enabled and node.sv_client):
|
||||
return
|
||||
if not self._mark_utterance_processed():
|
||||
return
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info(f"[声纹] CHECK_VOICE状态,数据库为空,跳过声纹验证(来源: {source})")
|
||||
return
|
||||
if not node.sv_speech_end_event.is_set():
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
if buffer_size > 0:
|
||||
node.sv_speech_end_event.set()
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source})")
|
||||
|
||||
# ==================== 业务逻辑代理 ====================
|
||||
def handle_interrupt_command(self, msg):
|
||||
return self.node._handle_interrupt_command(msg)
|
||||
|
||||
def check_interrupt_and_cancel_turn(self) -> bool:
|
||||
return self.node._check_interrupt_and_cancel_turn()
|
||||
|
||||
def handle_wake_word(self, text: str) -> str:
|
||||
return self.node._handle_wake_word(text)
|
||||
|
||||
def check_shutup_command(self, text: str) -> bool:
|
||||
return self.node._check_shutup_command(text)
|
||||
|
||||
def check_camera_command(self, text: str):
|
||||
return self.node.intent_router.check_camera_command(text)
|
||||
|
||||
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
|
||||
return self.node._llm_process_stream_with_camera(user_text, need_camera)
|
||||
|
||||
def put_tts_text(self, text: str):
|
||||
return self.node._put_tts_text(text)
|
||||
|
||||
def force_stop_tts(self):
|
||||
return self.node._force_stop_tts()
|
||||
|
||||
def drain_queue(self, q):
|
||||
return self.node._drain_queue(q)
|
||||
|
||||
# ==================== 录音/VAD回调 ====================
|
||||
def get_silence_threshold(self) -> int:
|
||||
"""获取动态静音阈值(毫秒)"""
|
||||
node = self.node
|
||||
return node.silence_duration_ms
|
||||
|
||||
def should_put_audio_to_queue(self) -> bool:
|
||||
"""
|
||||
检查是否应该将音频放入队列(用于ASR),根据状态机决定是否允许ASR
|
||||
"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
|
||||
ConversationState.AUTHORIZED]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_speech_start(self):
|
||||
"""录音线程检测到人声开始"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到人声,开始录音")
|
||||
|
||||
with node.utterance_lock:
|
||||
node.current_utterance_id += 1
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
if state == ConversationState.IDLE:
|
||||
# Idle -> CheckVoice
|
||||
if node.sv_enabled and node.sv_client:
|
||||
# 开始录音用于声纹验证
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
|
||||
|
||||
elif state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态,继续录音用于声纹验证
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态,开始录音用于声纹验证(验证当前用户)
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
|
||||
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
|
||||
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
|
||||
# 声纹验证录音(CHECK_VOICE, AUTHORIZED状态)
|
||||
if node.sv_enabled and node.sv_recording:
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
node.get_logger().debug(f"[声纹] 录音失败: {e}")
|
||||
|
||||
def on_speech_end(self):
|
||||
"""录音线程检测到说话结束(静音一段时间)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到说话结束")
|
||||
|
||||
state = node._get_state()
|
||||
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
|
||||
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
self._trigger_sv_for_check_voice("VAD")
|
||||
return
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
|
||||
node.sv_speech_end_event.set()
|
||||
|
||||
# 如果TTS正在播放,异步等待声纹验证结果,如果通过才中断TTS
|
||||
# 使用独立线程避免阻塞录音线程,影响TTS播放
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().info("[打断] TTS播放中,用户说话结束,异步等待声纹验证结果...")
|
||||
def _check_sv_and_interrupt():
|
||||
# 等待声纹验证结果(最多等待2秒)
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=2.0
|
||||
):
|
||||
# 声纹验证完成,检查结果
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id}),中断TTS播放")
|
||||
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 声纹验证未通过,不中断TTS(状态: {speaker_state.value})")
|
||||
else:
|
||||
node.get_logger().warning("[打断] 声纹验证超时,不中断TTS")
|
||||
# 在独立线程中等待,避免阻塞录音线程
|
||||
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
|
||||
return
|
||||
|
||||
def on_new_segment(self):
|
||||
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state == ConversationState.AUTHORIZED:
|
||||
# TTS播放期间,检测到人声时不立即中断,而是开始录音用于声纹验证
|
||||
# 等待用户说话结束(speech_end)后,如果声纹验证通过,才中断TTS
|
||||
# 这样可以避免TTS回声误触发,但支持真正的用户打断
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[打断] TTS播放中,检测到人声,开始录音用于声纹验证(等待说话结束后验证)")
|
||||
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
|
||||
else:
|
||||
# TTS未播放时,检查声纹验证结果并立即中断
|
||||
if node.sv_enabled and node.sv_client:
|
||||
with node.sv_lock:
|
||||
current_speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
|
||||
node._interrupt_tts("检测到人声(已授权用户)")
|
||||
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但声纹未验证或未匹配,不中断TTS(当前状态: {speaker_state.value})")
|
||||
else:
|
||||
# 未启用声纹,直接中断(保持原有行为)
|
||||
node._interrupt_tts("检测到人声(未启用声纹)")
|
||||
node.get_logger().info("[打断] 检测到人声,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value},非已授权用户,不允许打断TTS")
|
||||
|
||||
def on_heartbeat(self):
|
||||
"""录音线程静音心跳回调"""
|
||||
self.node.get_logger().info("[录音线程] 静音中")
|
||||
|
||||
# ==================== ASR回调 ====================
|
||||
def on_asr_sentence_end(self, text: str):
|
||||
"""ASR sentence_end回调 - 将文本放入队列"""
|
||||
node = self.node
|
||||
if not text or not text.strip():
|
||||
return
|
||||
text_clean = text.strip()
|
||||
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
# 规则2:CHECK_VOICE状态下,如果ASR识别完成但VAD还没有触发speech_end,主动触发声纹验证
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[ASR] CHECK_VOICE状态,ASR识别完成,主动触发声纹验证")
|
||||
self._trigger_sv_for_check_voice("ASR")
|
||||
|
||||
# 其他状态,将文本放入队列
|
||||
node.text_queue.put(text_clean, timeout=1.0)
|
||||
|
||||
def on_asr_text_update(self, text: str):
|
||||
"""ASR 实时文本更新回调 - 用于多轮提示"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")
|
||||
@@ -1,188 +0,0 @@
|
||||
import queue
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeWorkers:
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def recording_worker(self):
|
||||
"""线程1: 录音线程 - 唯一实时线程"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 启动")
|
||||
node.audio_recorder.record_with_vad()
|
||||
|
||||
def asr_worker(self):
|
||||
"""线程2: ASR推理线程 - 只做 audio → text"""
|
||||
node = self.node
|
||||
node.get_logger().info("[ASR推理线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = node.audio_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if node.interrupt_event.is_set():
|
||||
continue
|
||||
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
|
||||
node.asr_client.send_audio(audio_chunk)
|
||||
|
||||
def process_worker(self):
|
||||
"""线程3: 主线程 - 处理业务逻辑"""
|
||||
node = self.node
|
||||
node.get_logger().info("[主线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
text = node.text_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
|
||||
current_state = node._get_state()
|
||||
|
||||
if current_state == ConversationState.CHECK_VOICE:
|
||||
if node.use_wake_word:
|
||||
node.get_logger().info(f"[主线程] CHECK_VOICE状态,检查唤醒词,文本: {text}")
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'),回到Idle状态")
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
|
||||
text = processed_text
|
||||
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:等待声纹验证结果...")
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if not node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=15.0
|
||||
):
|
||||
node.get_logger().warning("[主线程] CHECK_VOICE状态:声纹结果未ready(超时15秒),拒绝本轮")
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.clear()
|
||||
node._change_state(ConversationState.IDLE, "声纹结果未ready")
|
||||
continue
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:声纹结果ready,继续处理")
|
||||
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
score = node.current_speaker_score
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
|
||||
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
|
||||
else:
|
||||
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
|
||||
node.callbacks.put_tts_text("声纹验证失败")
|
||||
node._change_state(ConversationState.IDLE, "声纹验证失败")
|
||||
continue
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
|
||||
|
||||
elif current_state == ConversationState.AUTHORIZED:
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[主线程] AUTHORIZED状态,TTS播放中,忽略ASR识别结果(只有VAD检测到已授权用户人声才能中断)")
|
||||
continue
|
||||
|
||||
elif current_state == ConversationState.IDLE:
|
||||
node.get_logger().warning("[主线程] Idle状态收到文本,忽略")
|
||||
continue
|
||||
|
||||
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
text = processed_text
|
||||
|
||||
if node.callbacks.check_shutup_command(text):
|
||||
node.get_logger().info("[主线程] 检测到闭嘴指令")
|
||||
node.interrupt_event.set()
|
||||
node.callbacks.force_stop_tts()
|
||||
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
|
||||
continue
|
||||
|
||||
intent_payload = node.intent_router.route(text)
|
||||
node._handle_intent(intent_payload)
|
||||
|
||||
if current_state == ConversationState.AUTHORIZED:
|
||||
node.session_start_time = time.time()
|
||||
|
||||
def sv_worker(self):
|
||||
"""线程5: 声纹识别线程 - 非实时、低频(CAM++)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[声纹识别线程] 启动")
|
||||
|
||||
# 动态计算最小音频样本数,确保降采样到16kHz后≥0.5秒
|
||||
target_sr = 16000 # CAM++模型目标采样率
|
||||
min_duration_seconds = 0.5
|
||||
min_samples_at_target_sr = int(target_sr * min_duration_seconds) # 8000样本@16kHz
|
||||
|
||||
if node.sample_rate >= target_sr:
|
||||
downsample_step = int(node.sample_rate / target_sr)
|
||||
min_audio_samples = min_samples_at_target_sr * downsample_step
|
||||
else:
|
||||
min_audio_samples = int(node.sample_rate * min_duration_seconds)
|
||||
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
if node.sv_speech_end_event.wait(timeout=0.1):
|
||||
node.sv_speech_end_event.clear()
|
||||
with node.sv_lock:
|
||||
audio_list = list(node.sv_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
node.sv_audio_buffer.clear()
|
||||
|
||||
node.get_logger().info(f"[声纹识别] 收到speech_end事件,录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info("[声纹识别] 数据库为空,跳过验证,直接设置UNKNOWN状态")
|
||||
continue
|
||||
|
||||
if buffer_size >= min_audio_samples:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
embedding, success = node.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=node.sample_rate
|
||||
)
|
||||
|
||||
if not success or embedding is None:
|
||||
node.get_logger().debug("[声纹识别] 提取embedding失败")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.ERROR
|
||||
node.current_speaker_score = 0.0
|
||||
else:
|
||||
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = speaker_id
|
||||
node.current_speaker_state = match_state
|
||||
node.current_speaker_score = score
|
||||
|
||||
if match_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
|
||||
elif match_state == SpeakerState.REJECTED:
|
||||
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.UNKNOWN
|
||||
node.current_speaker_score = 0.0
|
||||
|
||||
with node.sv_result_cv:
|
||||
node.sv_result_seq += 1
|
||||
node.sv_result_cv.notify_all()
|
||||
|
||||
except Exception as e:
|
||||
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -1,24 +1,16 @@
|
||||
"""
|
||||
声纹注册独立节点:运行完成后退出
|
||||
"""
|
||||
import collections
|
||||
"""声纹注册独立节点:运行完成后退出"""
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import threading
|
||||
import queue
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
|
||||
from robot_speaker.perception.echo_cancellation import ReferenceSignalBuffer
|
||||
from robot_speaker.models.asr.dashscope import DashScopeASR
|
||||
from robot_speaker.models.tts.dashscope import DashScopeTTSClient
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
from interfaces.srv import ASRRecognize, AudioData, VADEvent
|
||||
from robot_speaker.core.speaker_verifier import SpeakerVerificationClient
|
||||
from pypinyin import pinyin, Style
|
||||
|
||||
|
||||
@@ -27,81 +19,15 @@ class RegisterSpeakerNode(Node):
|
||||
super().__init__('register_speaker_node')
|
||||
self._load_config()
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
self.processing = False
|
||||
self.buffer_lock = threading.Lock()
|
||||
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
|
||||
self.asr_client = self.create_client(ASRRecognize, '/asr/recognize')
|
||||
self.audio_data_client = self.create_client(AudioData, '/asr/audio_data')
|
||||
self.vad_client = self.create_client(VADEvent, '/vad/event')
|
||||
|
||||
# 状态:等待唤醒词 -> 等待声纹语音
|
||||
self.waiting_for_wake_word = True
|
||||
self.waiting_for_voiceprint = False
|
||||
|
||||
# 音频队列和文本队列(用于ASR)
|
||||
self.audio_queue = queue.Queue()
|
||||
self.text_queue = queue.Queue()
|
||||
|
||||
self.vad_detector = VADDetector(
|
||||
mode=self.vad_mode,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
# 创建参考信号缓冲区(用于回声消除)
|
||||
self.reference_signal_buffer = ReferenceSignalBuffer(
|
||||
max_duration_ms=self.audio_echo_cancellation_max_duration_ms,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.output_channels
|
||||
) if self.audio_echo_cancellation_enabled else None
|
||||
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
vad_detector=self.vad_detector,
|
||||
audio_queue=self.audio_queue, # 送ASR用于唤醒词检测
|
||||
silence_duration_ms=self.silence_duration_ms,
|
||||
min_energy_threshold=self.min_energy_threshold,
|
||||
heartbeat_interval=self.audio_microphone_heartbeat_interval,
|
||||
on_heartbeat=self._on_heartbeat,
|
||||
is_playing=lambda: False,
|
||||
on_new_segment=None,
|
||||
on_speech_start=self._on_speech_start,
|
||||
on_speech_end=self._on_speech_end,
|
||||
stop_flag=self.stop_event.is_set,
|
||||
on_audio_chunk=self._on_audio_chunk,
|
||||
should_put_to_queue=self._should_put_to_queue,
|
||||
get_silence_threshold=lambda: self.silence_duration_ms,
|
||||
enable_echo_cancellation=self.audio_echo_cancellation_enabled, # 启用回声消除,保持与主程序一致
|
||||
reference_signal_buffer=self.reference_signal_buffer,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
# ASR客户端 - 用于唤醒词检测
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.asr_client.on_sentence_end = self._on_asr_sentence_end
|
||||
self.asr_client.start()
|
||||
|
||||
# ASR处理线程
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self._asr_worker,
|
||||
name="RegisterASRThread",
|
||||
daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
# 文本处理线程
|
||||
self.text_thread = threading.Thread(
|
||||
target=self._text_worker,
|
||||
name="RegisterTextThread",
|
||||
daemon=True
|
||||
)
|
||||
self.text_thread.start()
|
||||
self.get_logger().info('等待服务启动...')
|
||||
self.asr_client.wait_for_service(timeout_sec=10.0)
|
||||
self.audio_data_client.wait_for_service(timeout_sec=10.0)
|
||||
self.vad_client.wait_for_service(timeout_sec=10.0)
|
||||
self.get_logger().info('所有服务已就绪')
|
||||
|
||||
self.sv_client = SpeakerVerificationClient(
|
||||
model_path=self.sv_model_path,
|
||||
@@ -110,31 +36,20 @@ class RegisterSpeakerNode(Node):
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.tts_client = DashScopeTTSClient(
|
||||
api_key=self.dashscope_api_key,
|
||||
model=self.tts_model,
|
||||
voice=self.tts_voice,
|
||||
card_index=self.output_card_index,
|
||||
device_index=self.output_device_index,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_channels=self.output_channels,
|
||||
output_volume=self.output_volume,
|
||||
tts_source_sample_rate=self.audio_tts_source_sample_rate,
|
||||
tts_source_channels=self.audio_tts_source_channels,
|
||||
tts_ffmpeg_thread_queue_size=self.audio_tts_ffmpeg_thread_queue_size,
|
||||
reference_signal_buffer=self.reference_signal_buffer,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.registered = False
|
||||
self.shutting_down = False
|
||||
self.get_logger().info("声纹注册节点启动,请说唤醒词开始注册(例如:'二狗我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去')")
|
||||
|
||||
self.get_logger().info("声纹注册节点启动,请说'er gou......'唤醒注册")
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.audio_recorder.record_with_vad,
|
||||
name="RegisterRecordingThread",
|
||||
daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
# 使用队列在线程间传递 VAD 事件,避免在子线程中调用 spin_until_future_complete
|
||||
self.vad_event_queue = queue.Queue()
|
||||
self.recording = False # 录音状态标志
|
||||
self.pending_asr_future = None # 待处理的 ASR future
|
||||
self.pending_audio_future = None # 待处理的 AudioData future
|
||||
self.state = "waiting_speech" # 状态机:waiting_speech, waiting_asr, waiting_audio
|
||||
|
||||
self.timer = self.create_timer(0.2, self._check_done)
|
||||
self.vad_thread = threading.Thread(target=self._vad_event_worker, daemon=True)
|
||||
self.vad_thread.start()
|
||||
self.timer = self.create_timer(0.1, self._main_loop)
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
@@ -145,267 +60,48 @@ class RegisterSpeakerNode(Node):
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
dashscope = config['dashscope']
|
||||
audio = config['audio']
|
||||
mic = audio['microphone']
|
||||
soundcard = audio['soundcard']
|
||||
vad = config['vad']
|
||||
system = config['system']
|
||||
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
self.tts_model = dashscope['tts']['model']
|
||||
self.tts_voice = dashscope['tts']['voice']
|
||||
|
||||
self.input_device_index = mic['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
|
||||
|
||||
self.output_card_index = soundcard['card_index']
|
||||
self.output_device_index = soundcard['device_index']
|
||||
self.output_sample_rate = soundcard['sample_rate']
|
||||
self.output_channels = soundcard['channels']
|
||||
self.output_volume = soundcard['volume']
|
||||
|
||||
echo = audio.get('echo_cancellation', {})
|
||||
self.audio_echo_cancellation_enabled = echo['enable']
|
||||
self.audio_echo_cancellation_max_duration_ms = echo.get('max_duration_ms', 200)
|
||||
|
||||
tts_audio = audio.get('tts', {})
|
||||
self.audio_tts_source_sample_rate = tts_audio.get('source_sample_rate', 22050)
|
||||
self.audio_tts_source_channels = tts_audio.get('source_channels', 1)
|
||||
self.audio_tts_ffmpeg_thread_queue_size = tts_audio.get('ffmpeg_thread_queue_size', 5)
|
||||
|
||||
self.vad_mode = vad['vad_mode']
|
||||
self.silence_duration_ms = vad['silence_duration_ms']
|
||||
self.min_energy_threshold = vad['min_energy_threshold']
|
||||
|
||||
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
|
||||
self.sv_threshold = system['sv_threshold']
|
||||
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
|
||||
self.sv_buffer_size = system['sv_buffer_size']
|
||||
self.wake_word = system['wake_word']
|
||||
|
||||
def _should_put_to_queue(self) -> bool:
|
||||
"""判断是否应该将音频放入ASR队列(仅在等待唤醒词时)"""
|
||||
return self.waiting_for_wake_word
|
||||
|
||||
def _on_heartbeat(self):
|
||||
if self.waiting_for_wake_word:
|
||||
self.get_logger().info("[注册录音] 等待唤醒词'er gou'...")
|
||||
elif self.waiting_for_voiceprint:
|
||||
self.get_logger().info("[注册录音] 等待声纹语音...")
|
||||
|
||||
def _on_speech_start(self):
|
||||
if self.waiting_for_wake_word:
|
||||
# 等待唤醒词时,开始录音(可能包含唤醒词)
|
||||
self.get_logger().info("[注册录音] 检测到人声,开始录音")
|
||||
elif self.waiting_for_voiceprint:
|
||||
self.get_logger().info("[注册录音] 检测到人声,继续录音(用于声纹注册)")
|
||||
# 注意:不清空缓冲区,保留包含唤醒词的音频
|
||||
|
||||
def _on_audio_chunk(self, audio_chunk: bytes):
|
||||
# 记录所有音频(包括唤醒词),用于声纹注册
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.buffer_lock:
|
||||
self.audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
|
||||
|
||||
def _on_speech_end(self):
|
||||
# 如果还在等待唤醒词,不处理
|
||||
if self.waiting_for_wake_word:
|
||||
return
|
||||
# 如果已经在处理,不重复处理
|
||||
if self.processing:
|
||||
return
|
||||
|
||||
# 等待声纹语音时,用户说话结束,使用当前音频(即使不足3秒)
|
||||
if self.waiting_for_voiceprint:
|
||||
self._process_voiceprint_audio(use_current_audio_if_short=True)
|
||||
return # 处理完毕后直接返回,防止重复调用
|
||||
|
||||
def _process_voiceprint_audio(self, use_current_audio_if_short: bool = False):
|
||||
"""处理声纹音频:使用用户完整的第一段语音进行注册
|
||||
|
||||
Args:
|
||||
use_current_audio_if_short: 如果音频不足3秒,是否使用当前音频(用于用户已说完的情况)
|
||||
"""
|
||||
if self.processing:
|
||||
return
|
||||
self.processing = True
|
||||
with self.buffer_lock:
|
||||
audio_list = list(self.audio_buffer)
|
||||
|
||||
buffer_size = len(audio_list)
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 当前音频长度: {buffer_sec:.2f}秒")
|
||||
|
||||
required_samples = int(self.sample_rate * 3)
|
||||
|
||||
# 如果音频不足3秒
|
||||
if buffer_size < required_samples:
|
||||
if use_current_audio_if_short:
|
||||
# 用户已经说完了,使用当前音频(即使不足3秒)
|
||||
self.get_logger().info(f"[注册录音] 音频不足3秒(当前{buffer_sec:.2f}秒),但用户已说完,使用当前音频进行注册")
|
||||
audio_to_use = audio_list
|
||||
else:
|
||||
# 等待继续录音
|
||||
self.get_logger().info(f"[注册录音] 音频不足3秒(当前{buffer_sec:.2f}秒),等待继续录音...")
|
||||
self.processing = False
|
||||
return
|
||||
else:
|
||||
# 策略优化:不再强行截取最后3秒,因为唤醒词检测有延迟,
|
||||
# "er gou" 可能在缓冲区的中间偏后位置。
|
||||
# 为了防止截取到尾部的静音,并在包含完整唤醒词,
|
||||
# 我们截取最近的 3.0 秒(或者全部,如果不足3秒),
|
||||
# 这样能最大程度包含有效语音 "二狗"。
|
||||
target_samples = int(self.sample_rate * 3.0)
|
||||
if buffer_size > target_samples:
|
||||
audio_to_use = audio_list[-target_samples:]
|
||||
else:
|
||||
audio_to_use = audio_list
|
||||
|
||||
duration = len(audio_to_use) / self.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 使用最近 {duration:.2f} 秒音频用于注册(覆盖唤醒词)")
|
||||
|
||||
# 清空缓冲区
|
||||
with self.buffer_lock:
|
||||
self.audio_buffer.clear()
|
||||
|
||||
try:
|
||||
audio_array = np.array(audio_to_use, dtype=np.int16)
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
if not success or embedding is None:
|
||||
self.get_logger().error("[注册录音] 提取embedding失败")
|
||||
self.processing = False
|
||||
return
|
||||
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
if self.sv_client.register_speaker(speaker_id, embedding):
|
||||
self.get_logger().info(f"[注册录音] 注册成功,用户ID: {speaker_id},准备退出")
|
||||
|
||||
# 播放成功提示
|
||||
try:
|
||||
self.get_logger().info("[注册录音] 播放注册成功提示")
|
||||
request = TTSRequest(text="声纹注册成功", voice=self.tts_voice)
|
||||
self.tts_client.synthesize(request)
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册录音] 播放提示失败: {e}")
|
||||
|
||||
self.stop_event.set()
|
||||
else:
|
||||
self.get_logger().error("[注册录音] 注册失败")
|
||||
self.processing = False
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册录音] 注册异常: {e}")
|
||||
self.processing = False
|
||||
|
||||
def _extract_speech_segments(self, audio_array: np.ndarray, frame_size: int = 1024) -> list:
|
||||
"""使用能量检测提取人声片段(过滤静音)"""
|
||||
speech_segments = []
|
||||
frame_samples = frame_size
|
||||
total_frames = 0
|
||||
speech_frames = 0
|
||||
|
||||
for i in range(0, len(audio_array), frame_samples):
|
||||
frame = audio_array[i:i + frame_samples]
|
||||
if len(frame) < frame_samples:
|
||||
break
|
||||
|
||||
total_frames += 1
|
||||
# 计算帧的能量(RMS,对于int16音频)
|
||||
frame_float = frame.astype(np.float32)
|
||||
energy = np.sqrt(np.mean(frame_float ** 2))
|
||||
|
||||
# 使用更低的阈值来检测人声(降低阈值,避免误判静音)
|
||||
# 阈值可以动态调整,或者使用自适应阈值
|
||||
threshold = self.min_energy_threshold * 0.50 # 降低阈值到原来的50%
|
||||
|
||||
# 如果能量超过阈值,认为是人声
|
||||
if energy >= threshold:
|
||||
speech_segments.append((i, i + frame_samples))
|
||||
speech_frames += 1
|
||||
|
||||
# 调试信息
|
||||
if total_frames > 0:
|
||||
speech_ratio = speech_frames / total_frames
|
||||
self.get_logger().debug(f"[注册录音] 能量检测: 总帧数={total_frames}, 人声帧数={speech_frames}, 人声比例={speech_ratio:.2%}, 阈值={self.min_energy_threshold}")
|
||||
|
||||
return speech_segments
|
||||
|
||||
def _merge_speech_segments(self, audio_array: np.ndarray, segments: list, min_samples: int) -> np.ndarray:
|
||||
"""合并人声片段,返回连续的人声音频"""
|
||||
if not segments:
|
||||
return np.array([], dtype=np.int16)
|
||||
|
||||
# 合并相邻的片段
|
||||
merged_segments = []
|
||||
current_start, current_end = segments[0]
|
||||
|
||||
for start, end in segments[1:]:
|
||||
if start <= current_end + 1024: # 允许小间隙(1帧)
|
||||
current_end = end
|
||||
else:
|
||||
merged_segments.append((current_start, current_end))
|
||||
current_start, current_end = start, end
|
||||
merged_segments.append((current_start, current_end))
|
||||
|
||||
# 从后往前选择片段,直到达到3秒
|
||||
selected_audio = []
|
||||
total_samples = 0
|
||||
|
||||
for start, end in reversed(merged_segments):
|
||||
segment_audio = audio_array[start:end]
|
||||
selected_audio.insert(0, segment_audio)
|
||||
total_samples += len(segment_audio)
|
||||
if total_samples >= min_samples:
|
||||
break
|
||||
|
||||
if not selected_audio:
|
||||
return np.array([], dtype=np.int16)
|
||||
|
||||
return np.concatenate(selected_audio)
|
||||
|
||||
def _asr_worker(self):
|
||||
"""ASR处理线程"""
|
||||
while not self.stop_event.is_set():
|
||||
def _vad_event_worker(self):
|
||||
"""VAD 事件监听线程,只负责接收事件并放入队列,不调用 spin_until_future_complete"""
|
||||
while not self.registered and not self.shutting_down:
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
except queue.Empty:
|
||||
continue
|
||||
request = VADEvent.Request()
|
||||
request.command = "wait"
|
||||
request.timeout_ms = 1000
|
||||
future = self.vad_client.call_async(request)
|
||||
|
||||
# 简单等待 future 完成,不使用 spin_until_future_complete
|
||||
start_time = time.time()
|
||||
while not future.done() and (time.time() - start_time) < 1.5:
|
||||
time.sleep(0.01)
|
||||
|
||||
if not future.done() or self.registered or self.shutting_down:
|
||||
continue
|
||||
|
||||
response = future.result()
|
||||
if response.success and response.event in ["speech_started", "speech_stopped"]:
|
||||
# 将事件放入队列,由主线程处理
|
||||
try:
|
||||
self.vad_event_queue.put(response.event, timeout=0.1)
|
||||
except queue.Full:
|
||||
self.get_logger().warn(f"[VAD] 事件队列已满,丢弃事件: {response.event}")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
|
||||
if not self.shutting_down:
|
||||
self.get_logger().error(f"[VAD] 线程异常: {e}")
|
||||
break
|
||||
|
||||
def _on_asr_sentence_end(self, text: str):
|
||||
"""ASR识别完成回调"""
|
||||
if text and text.strip():
|
||||
self.text_queue.put(text.strip())
|
||||
|
||||
def _text_worker(self):
|
||||
"""文本处理线程:检测唤醒词"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
text = self.text_queue.get(timeout=0.1)
|
||||
if self.waiting_for_wake_word:
|
||||
self._check_wake_word(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册文本] 处理异常: {e}")
|
||||
def _start_recording(self):
|
||||
"""启动录音,返回 future 供主线程处理"""
|
||||
request = AudioData.Request()
|
||||
request.command = "start"
|
||||
return self.audio_data_client.call_async(request)
|
||||
|
||||
def _to_pinyin(self, text: str) -> str:
|
||||
"""将中文文本转换为拼音"""
|
||||
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
|
||||
if not chars:
|
||||
return ""
|
||||
@@ -413,10 +109,8 @@ class RegisterSpeakerNode(Node):
|
||||
return ' '.join([item[0] for item in py_list]).lower().strip()
|
||||
|
||||
def _check_wake_word(self, text: str):
|
||||
"""检查是否包含唤醒词"""
|
||||
text_pinyin = self._to_pinyin(text)
|
||||
wake_word_pinyin = self.wake_word.lower().strip()
|
||||
self.get_logger().info(f"[注册唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
|
||||
|
||||
if not wake_word_pinyin:
|
||||
return
|
||||
@@ -424,40 +118,119 @@ class RegisterSpeakerNode(Node):
|
||||
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
|
||||
wake_word_parts = wake_word_pinyin.split()
|
||||
|
||||
# 检查是否包含唤醒词
|
||||
has_wake_word = False
|
||||
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
|
||||
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
|
||||
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'")
|
||||
self.get_logger().info("=" * 50)
|
||||
self.get_logger().info("[声纹注册] 开始注册声纹,将截取3秒音频用于注册")
|
||||
self.get_logger().info("=" * 50)
|
||||
self.waiting_for_wake_word = False
|
||||
self.waiting_for_voiceprint = True
|
||||
# 停止ASR,不再需要识别
|
||||
if self.asr_client:
|
||||
self.asr_client.stop_current_recognition()
|
||||
# 立即处理当前音频缓冲区中的完整音频
|
||||
# 用户可能已经说完了(包含唤醒词的整段语音)
|
||||
self._process_voiceprint_audio()
|
||||
return
|
||||
|
||||
def _check_done(self):
|
||||
if self.stop_event.is_set():
|
||||
has_wake_word = True
|
||||
break
|
||||
|
||||
if has_wake_word:
|
||||
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}',停止录音并获取音频")
|
||||
request = AudioData.Request()
|
||||
request.command = "stop"
|
||||
future = self.audio_data_client.call_async(request)
|
||||
future._future_type = "stop"
|
||||
self.pending_audio_future = future
|
||||
|
||||
def _process_voiceprint_audio(self, response):
|
||||
"""处理声纹音频数据 - 直接使用 AudioData 返回的音频,不再过滤"""
|
||||
if not response or not response.success or response.samples == 0:
|
||||
self.get_logger().error(f"[注册录音] 获取音频数据失败: {response.message if response else '无响应'}")
|
||||
return
|
||||
|
||||
audio_array = np.frombuffer(response.audio_data, dtype=np.int16)
|
||||
buffer_sec = response.samples / response.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}秒")
|
||||
|
||||
# 直接使用音频,不再进行 VAD 过滤
|
||||
# 因为 AudioData 服务基于 DashScope VAD,已经是语音活动片段
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=response.sample_rate
|
||||
)
|
||||
if not success or embedding is None:
|
||||
self.get_logger().error("[注册录音] 提取embedding失败")
|
||||
return
|
||||
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
if self.sv_client.register_speaker(speaker_id, embedding):
|
||||
# 注册成功后立即保存到文件
|
||||
self.sv_client.save_speakers()
|
||||
self.get_logger().info(f"[注册录音] 注册成功,用户ID: {speaker_id},已保存到文件,准备退出")
|
||||
self.registered = True
|
||||
else:
|
||||
self.get_logger().error("[注册录音] 注册失败")
|
||||
|
||||
def _main_loop(self):
|
||||
"""主循环,在主线程中处理所有异步操作"""
|
||||
# 检查是否完成注册
|
||||
if self.registered:
|
||||
self.get_logger().info("注册完成,节点退出")
|
||||
# 清理资源
|
||||
if self.asr_client:
|
||||
self.asr_client.stop()
|
||||
self.destroy_node()
|
||||
self.shutting_down = True
|
||||
self.timer.cancel()
|
||||
rclpy.shutdown()
|
||||
return
|
||||
|
||||
# 处理待处理的 ASR future
|
||||
if self.pending_asr_future and self.pending_asr_future.done():
|
||||
response = self.pending_asr_future.result()
|
||||
self.pending_asr_future = None
|
||||
|
||||
if response.success and response.text:
|
||||
text = response.text.strip()
|
||||
if text:
|
||||
self._check_wake_word(text)
|
||||
|
||||
self.state = "waiting_speech"
|
||||
|
||||
# 处理待处理的 AudioData future
|
||||
if self.pending_audio_future and self.pending_audio_future.done():
|
||||
response = self.pending_audio_future.result()
|
||||
future_type = getattr(self.pending_audio_future, '_future_type', None)
|
||||
self.pending_audio_future = None
|
||||
|
||||
if future_type == "start":
|
||||
if response.success:
|
||||
self.get_logger().info("[注册录音] 已开始录音")
|
||||
self.recording = True
|
||||
else:
|
||||
self.get_logger().warn(f"[注册录音] 启动录音失败: {response.message}")
|
||||
self.state = "waiting_speech"
|
||||
elif future_type == "stop":
|
||||
self.recording = False
|
||||
self._process_voiceprint_audio(response)
|
||||
|
||||
# 处理 VAD 事件队列
|
||||
try:
|
||||
event = self.vad_event_queue.get_nowait()
|
||||
|
||||
if event == "speech_started" and self.state == "waiting_speech" and not self.recording:
|
||||
self.get_logger().info("[VAD] 检测到语音开始,启动录音")
|
||||
future = self._start_recording()
|
||||
future._future_type = "start"
|
||||
self.pending_audio_future = future
|
||||
|
||||
elif event == "speech_stopped" and self.recording and self.state == "waiting_speech":
|
||||
self.get_logger().info("[VAD] 检测到语音结束,请求 ASR 识别")
|
||||
self.state = "waiting_asr"
|
||||
request = ASRRecognize.Request()
|
||||
request.command = "start"
|
||||
self.pending_asr_future = self.asr_client.call_async(request)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = RegisterSpeakerNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
try:
|
||||
rclpy.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
185
robot_speaker/core/skill_interface_parser.py
Normal file
185
robot_speaker/core/skill_interface_parser.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""技能接口文件解析器"""
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
from typing import Optional
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
class SkillInterfaceParser:
|
||||
def __init__(self, interfaces_root: str):
|
||||
"""初始化解析器"""
|
||||
self.interfaces_root = interfaces_root
|
||||
self._cached_skill_config: list[dict] | None = None
|
||||
self._cached_skill_interfaces: dict[str, dict] | None = None
|
||||
|
||||
def get_skill_names(self) -> list[str]:
|
||||
"""获取所有技能名称(统一读取 robot_skills.yaml,避免重复)"""
|
||||
skill_config = self._load_skill_config()
|
||||
return [entry["name"] for entry in skill_config if isinstance(entry, dict) and entry.get("name")]
|
||||
|
||||
def _load_skill_config(self) -> list[dict]:
|
||||
"""加载 robot_skills.yaml(带缓存,避免重复读取)"""
|
||||
if self._cached_skill_config is not None:
|
||||
return self._cached_skill_config
|
||||
|
||||
try:
|
||||
brain_share = get_package_share_directory("brain")
|
||||
skill_path = os.path.join(brain_share, "config", "robot_skills.yaml")
|
||||
with open(skill_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or []
|
||||
self._cached_skill_config = data if isinstance(data, list) else []
|
||||
return self._cached_skill_config
|
||||
except Exception:
|
||||
self._cached_skill_config = []
|
||||
return []
|
||||
|
||||
def parse_skill_interfaces(self) -> dict[str, dict]:
|
||||
"""解析所有技能接口文件的目标字段(带缓存)"""
|
||||
if self._cached_skill_interfaces is not None:
|
||||
return self._cached_skill_interfaces
|
||||
|
||||
result = {}
|
||||
skill_config = self._load_skill_config()
|
||||
|
||||
for skill_entry in skill_config:
|
||||
skill_name = skill_entry.get("name")
|
||||
if not skill_name:
|
||||
continue
|
||||
|
||||
interfaces = skill_entry.get("interfaces", [])
|
||||
for iface in interfaces:
|
||||
if isinstance(iface, dict):
|
||||
iface_name = iface.get("name", "")
|
||||
else:
|
||||
iface_name = str(iface)
|
||||
|
||||
if ".action" in iface_name:
|
||||
iface_type = "action"
|
||||
file_path = os.path.join(self.interfaces_root, "action", iface_name)
|
||||
elif ".srv" in iface_name:
|
||||
iface_type = "srv"
|
||||
file_path = os.path.join(self.interfaces_root, "srv", iface_name)
|
||||
else:
|
||||
continue
|
||||
|
||||
if os.path.exists(file_path):
|
||||
goal_fields = self._parse_goal_fields(file_path)
|
||||
result[skill_name] = {
|
||||
"type": iface_type,
|
||||
"goal_fields": goal_fields
|
||||
}
|
||||
break
|
||||
|
||||
self._cached_skill_interfaces = result
|
||||
return result
|
||||
|
||||
def _parse_goal_fields(self, file_path: str) -> list[dict]:
|
||||
"""解析接口文件的目标字段(第一个---之前的所有字段)"""
|
||||
goal_fields = []
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("---"):
|
||||
break
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
field_type = parts[0]
|
||||
field_name = parts[1]
|
||||
|
||||
comment = ""
|
||||
if "#" in line:
|
||||
comment = line.split("#", 1)[1].strip()
|
||||
|
||||
goal_fields.append({
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"comment": comment
|
||||
})
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return goal_fields
|
||||
|
||||
def generate_params_documentation(self) -> str:
|
||||
"""生成技能参数说明文档"""
|
||||
skill_interfaces = self.parse_skill_interfaces()
|
||||
doc_lines = []
|
||||
|
||||
for skill_name, skill_info in skill_interfaces.items():
|
||||
doc_lines.append(f"{skill_name}技能的parameters字段:")
|
||||
goal_fields = skill_info.get("goal_fields", [])
|
||||
|
||||
if not goal_fields:
|
||||
doc_lines.append(" - 无参数,使用 null")
|
||||
else:
|
||||
doc_lines.append(" parameters字典必须包含以下字段:")
|
||||
for field in goal_fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
comment = field.get("comment", "")
|
||||
|
||||
if field_name == "body_id":
|
||||
doc_lines.append(
|
||||
f" - {field_name} ({field_type}): 身体部位ID,0=左臂,1=右臂,2=头部。"
|
||||
f"根据目标物在图片中的方位选择:左侧用0,右侧用1,中央用2。"
|
||||
)
|
||||
else:
|
||||
type_desc = self._get_type_description(field_type)
|
||||
doc_lines.append(f" - {field_name} ({field_type}): {type_desc} {comment}")
|
||||
|
||||
example_params = {}
|
||||
for field in goal_fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
example_params[field_name] = self._get_example_value(field_name, field_type)
|
||||
|
||||
doc_lines.append(f" 示例:{json.dumps(example_params, ensure_ascii=False)}")
|
||||
|
||||
doc_lines.append("")
|
||||
|
||||
return "\n".join(doc_lines)
|
||||
|
||||
def _get_type_description(self, field_type: str) -> str:
|
||||
"""根据字段类型返回描述"""
|
||||
type_map = {
|
||||
"int8": "整数,范围-128到127",
|
||||
"int16": "整数,范围-32768到32767",
|
||||
"int32": "整数",
|
||||
"int64": "整数",
|
||||
"uint8": "无符号整数,范围0到255",
|
||||
"float32": "浮点数",
|
||||
"float64": "浮点数",
|
||||
"string": "字符串",
|
||||
}
|
||||
|
||||
base_type = field_type.replace("[]", "").replace("_", "")
|
||||
return type_map.get(base_type, field_type)
|
||||
|
||||
def _get_example_value(self, field_name: str, field_type: str) -> any:
|
||||
"""根据字段名和类型生成示例值"""
|
||||
if field_name == "body_id":
|
||||
return 0
|
||||
elif field_name == "data_array" and "float64[]" in field_type:
|
||||
return [0.1, 0.2, 0.3, 0.0, 0.0, 0.0]
|
||||
elif "int" in field_type:
|
||||
return 0
|
||||
elif "float" in field_type:
|
||||
return 0.0
|
||||
elif "string" in field_type:
|
||||
return ""
|
||||
elif "[]" in field_type:
|
||||
if "int" in field_type:
|
||||
return [0, 0, 0]
|
||||
elif "float" in field_type:
|
||||
return [0.0, 0.0, 0.0]
|
||||
return []
|
||||
else:
|
||||
return None
|
||||
|
||||
199
robot_speaker/core/speaker_verifier.py
Normal file
199
robot_speaker/core/speaker_verifier.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
声纹识别模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpeakerState(Enum):
|
||||
"""说话人识别状态"""
|
||||
UNKNOWN = "unknown"
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SpeakerVerificationClient:
|
||||
"""声纹识别客户端 - 非实时、低频处理"""
|
||||
|
||||
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
|
||||
self.model_path = model_path
|
||||
self.threshold = threshold
|
||||
self.speaker_db_path = speaker_db_path
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# # 优化CPU性能:限制Torch使用的线程数,防止多线程竞争导致性能骤降
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
from funasr import AutoModel
|
||||
model_path = os.path.expanduser(self.model_path)
|
||||
# 禁用自动更新检查,防止每次初始化都联网检查
|
||||
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
|
||||
if self.logger:
|
||||
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
|
||||
|
||||
if self.speaker_db_path:
|
||||
self.load_speakers()
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
if level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warning(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
elif level == "debug":
|
||||
self.logger.debug(msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_speakers(self):
|
||||
if not self.speaker_db_path:
|
||||
return
|
||||
|
||||
db_path = os.path.expanduser(self.speaker_db_path)
|
||||
if not os.path.exists(db_path):
|
||||
self._log("info", f"声纹数据库文件不存在: {db_path},将创建新文件")
|
||||
return
|
||||
try:
|
||||
with open(db_path, 'rb') as f:
|
||||
data = json.load(f)
|
||||
with self._lock:
|
||||
self.speaker_db = {}
|
||||
for speaker_id, info in data.items():
|
||||
embedding_array = np.array(info["embedding"], dtype=np.float32)
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_array,
|
||||
"env": info.get("env", ""),
|
||||
"registered_at": info.get("registered_at", 0.0)
|
||||
}
|
||||
self._log("info", f"已加载 {len(self.speaker_db)} 个已注册说话人")
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
|
||||
def save_speakers(self):
|
||||
if not self.speaker_db_path:
|
||||
return
|
||||
db_path = os.path.expanduser(self.speaker_db_path)
|
||||
try:
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
with self._lock:
|
||||
data = {}
|
||||
for speaker_id, info in self.speaker_db.items():
|
||||
data[speaker_id] = {
|
||||
"embedding": info["embedding"].tolist(),
|
||||
"env": info.get("env", ""),
|
||||
"registered_at": info.get("registered_at", 0.0)
|
||||
}
|
||||
with open(db_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self._log("info", f"已保存 {len(data)} 个已注册说话人到: {db_path}")
|
||||
except Exception as e:
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
|
||||
def extract_embedding(self, audio_array: np.ndarray, sample_rate: int = 16000) -> tuple[np.ndarray | None, bool]:
|
||||
try:
|
||||
if len(audio_array) == 0:
|
||||
return None, False
|
||||
# 确保是int16格式
|
||||
if audio_array.dtype != np.int16:
|
||||
audio_array = audio_array.astype(np.int16)
|
||||
# 转换为float32并归一化到[-1, 1]
|
||||
audio_float = audio_array.astype(np.float32) / 32768.0
|
||||
# 调用模型提取embedding
|
||||
result = self.model.generate(input=audio_float, cache={})
|
||||
if result and len(result) > 0 and "spk_embedding" in result[0]:
|
||||
embedding = result[0]["spk_embedding"]
|
||||
if embedding is not None and len(embedding) > 0:
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
return embedding_array, True
|
||||
return None, False
|
||||
except Exception as e:
|
||||
self._log("error", f"提取声纹特征失败: {e}")
|
||||
return None, False
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray) -> tuple[str | None, SpeakerState, float, float]:
|
||||
if embedding is None or len(embedding) == 0:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
|
||||
|
||||
with self._lock:
|
||||
if len(self.speaker_db) == 0:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
|
||||
try:
|
||||
best_speaker_id = None
|
||||
best_score = 0.0
|
||||
with self._lock:
|
||||
for speaker_id, info in self.speaker_db.items():
|
||||
stored_embedding = info["embedding"]
|
||||
# 计算余弦相似度
|
||||
dot_product = np.dot(embedding, stored_embedding)
|
||||
norm_embedding = np.linalg.norm(embedding)
|
||||
norm_stored = np.linalg.norm(stored_embedding)
|
||||
|
||||
if norm_embedding > 0 and norm_stored > 0:
|
||||
score = dot_product / (norm_embedding * norm_stored)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_speaker_id = speaker_id
|
||||
|
||||
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
|
||||
return best_speaker_id, state, float(best_score), float(self.threshold)
|
||||
except Exception as e:
|
||||
self._log("error", f"匹配说话人失败: {e}")
|
||||
return None, SpeakerState.ERROR, 0.0, float(self.threshold)
|
||||
|
||||
def register_speaker(self, speaker_id: str, embedding: np.ndarray, env: str = "") -> bool:
|
||||
if embedding is None or len(embedding) == 0:
|
||||
return False
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": np.array(embedding, dtype=np.float32),
|
||||
"env": env,
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"注册说话人失败: {e}")
|
||||
return False
|
||||
|
||||
def get_speaker_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self.speaker_db)
|
||||
|
||||
def get_speaker_list(self) -> list[str]:
|
||||
with self._lock:
|
||||
return list(self.speaker_db.keys())
|
||||
|
||||
def remove_speaker(self, speaker_id: str) -> bool:
|
||||
with self._lock:
|
||||
if speaker_id in self.speaker_db:
|
||||
del self.speaker_db[speaker_id]
|
||||
self._log("info", f"已删除说话人: {speaker_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
try:
|
||||
self.save_speakers()
|
||||
if hasattr(self, 'model') and self.model:
|
||||
del self.model
|
||||
except Exception as e:
|
||||
self._log("error", f"清理资源失败: {e}")
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
统一数据结构定义
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRResult:
|
||||
"""ASR识别结果"""
|
||||
text: str
|
||||
confidence: float | None = None
|
||||
language: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM消息"""
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSRequest:
|
||||
"""TTS请求"""
|
||||
text: str
|
||||
voice: str | None = None # 如果为None,使用控制台配置的默认音色
|
||||
speed: float | None = None
|
||||
pitch: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageMessage:
|
||||
"""图像消息 - 用于多模态LLM"""
|
||||
image_data: bytes # base64编码的图像数据
|
||||
image_format: str = "jpeg"
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""模型层"""
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""ASR模型"""
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
class ASRClient:
|
||||
def start(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def send_audio(self, audio_data: bytes) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
ASR语音识别模块
|
||||
"""
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
|
||||
from robot_speaker.models.asr.base import ASRClient
|
||||
|
||||
|
||||
class DashScopeASR(ASRClient):
|
||||
"""DashScope实时ASR识别器封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
sample_rate: int,
|
||||
model: str,
|
||||
url: str,
|
||||
logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.sample_rate = sample_rate
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.logger = logger
|
||||
|
||||
self.conversation = None
|
||||
self.running = False
|
||||
self.on_sentence_end = None
|
||||
self.on_text_update = None # 实时文本更新回调
|
||||
|
||||
# 线程同步机制
|
||||
self._stop_lock = threading.Lock() # 防止并发调用 stop_current_recognition
|
||||
self._final_result_event = threading.Event() # 等待 final 回调完成
|
||||
self._pending_commit = False # 标记是否有待处理的 commit
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
else:
|
||||
print(f"[ASR] {msg}")
|
||||
|
||||
def start(self):
|
||||
"""启动ASR识别器"""
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
callback = _ASRCallback(self)
|
||||
self.conversation = OmniRealtimeConversation(
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
callback=callback
|
||||
)
|
||||
callback.conversation = self.conversation
|
||||
|
||||
self.conversation.connect()
|
||||
|
||||
transcription_params = TranscriptionParams(
|
||||
language='zh',
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
|
||||
# 本地 VAD → 只控制 TTS 打断
|
||||
# 服务端 turn detection → 只控制 ASR 输出、LLM 生成轮次
|
||||
|
||||
self.conversation.update_session(
|
||||
output_modalities=[MultiModality.TEXT],
|
||||
enable_input_audio_transcription=True,
|
||||
transcription_params=transcription_params,
|
||||
enable_turn_detection=True,
|
||||
# 保留服务端 turn detection
|
||||
turn_detection_type='server_vad', # 服务端VAD
|
||||
turn_detection_threshold=0.2, # 可调
|
||||
turn_detection_silence_duration_ms=800
|
||||
)
|
||||
|
||||
self.running = True
|
||||
self._log("info", "ASR已启动")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.running = False
|
||||
self._log("error", f"ASR启动失败: {e}")
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except:
|
||||
pass
|
||||
self.conversation = None
|
||||
return False
|
||||
|
||||
def send_audio(self, audio_chunk: bytes):
|
||||
"""发送音频chunk到ASR"""
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
try:
|
||||
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
|
||||
self.conversation.append_audio(audio_b64)
|
||||
return True
|
||||
except Exception as e:
|
||||
# 连接已关闭或其他错误,静默处理(避免日志过多)
|
||||
# running状态会在stop_current_recognition中正确设置
|
||||
return False
|
||||
|
||||
def stop_current_recognition(self):
|
||||
"""
|
||||
停止当前识别,触发final结果,然后重新启动
|
||||
优化:
|
||||
1. 使用事件代替 sleep,等待 final 回调完成
|
||||
2. 使用锁防止并发调用
|
||||
3. 处理 start() 失败的情况,确保 running 状态正确
|
||||
4. 添加超时机制,避免无限等待
|
||||
"""
|
||||
# 使用锁防止并发调用
|
||||
if not self._stop_lock.acquire(blocking=False):
|
||||
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
|
||||
# 重置事件,准备等待 final 回调
|
||||
self._final_result_event.clear()
|
||||
self._pending_commit = True
|
||||
|
||||
# 触发 commit,等待 final 结果
|
||||
self.conversation.commit()
|
||||
|
||||
# 等待 final 回调完成(最多等待1秒)
|
||||
if self._final_result_event.wait(timeout=1.0):
|
||||
self._log("debug", "已收到 final 回调,准备关闭连接")
|
||||
else:
|
||||
self._log("warning", "等待 final 回调超时,继续执行")
|
||||
|
||||
# 先设置running=False,防止ASR线程继续发送音频
|
||||
self.running = False
|
||||
|
||||
# 关闭当前连接
|
||||
old_conversation = self.conversation
|
||||
self.conversation = None # 立即清空,防止send_audio继续使用
|
||||
try:
|
||||
old_conversation.close()
|
||||
except Exception as e:
|
||||
self._log("warning", f"关闭连接时出错: {e}")
|
||||
|
||||
# 短暂等待,确保连接完全关闭
|
||||
time.sleep(0.1)
|
||||
|
||||
# 重新启动,如果失败则保持 running=False
|
||||
if not self.start():
|
||||
self._log("error", "ASR重启失败,running状态已重置")
|
||||
return False
|
||||
|
||||
# 启动成功,running已在start()中设置为True
|
||||
return True
|
||||
finally:
|
||||
self._pending_commit = False
|
||||
self._stop_lock.release()
|
||||
|
||||
def stop(self):
|
||||
"""停止ASR识别器"""
|
||||
# 等待正在执行的 stop_current_recognition 完成
|
||||
with self._stop_lock:
|
||||
self.running = False
|
||||
self._final_result_event.set() # 唤醒可能正在等待的线程
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception as e:
|
||||
self._log("warning", f"停止时关闭连接出错: {e}")
|
||||
self.conversation = None
|
||||
self._log("info", "ASR已停止")
|
||||
|
||||
|
||||
class _ASRCallback(OmniRealtimeCallback):
|
||||
"""ASR回调处理"""
|
||||
|
||||
def __init__(self, asr_client: DashScopeASR):
|
||||
self.asr_client = asr_client
|
||||
self.conversation = None
|
||||
|
||||
def on_open(self):
|
||||
self.asr_client._log("info", "ASR WebSocket已连接")
|
||||
|
||||
def on_close(self, code, msg):
|
||||
self.asr_client._log("info", f"ASR WebSocket已关闭: code={code}, msg={msg}")
|
||||
|
||||
def on_event(self, response):
|
||||
event_type = response.get('type', '')
|
||||
|
||||
if event_type == 'session.created':
|
||||
session_id = response.get('session', {}).get('id', '')
|
||||
self.asr_client._log("info", f"ASR会话已创建: {session_id}")
|
||||
|
||||
elif event_type == 'conversation.item.input_audio_transcription.completed':
|
||||
# 最终识别结果
|
||||
transcript = response.get('transcript', '')
|
||||
if transcript and transcript.strip() and self.asr_client.on_sentence_end:
|
||||
self.asr_client.on_sentence_end(transcript.strip())
|
||||
|
||||
# 如果有待处理的 commit,通知等待的线程
|
||||
if self.asr_client._pending_commit:
|
||||
self.asr_client._final_result_event.set()
|
||||
|
||||
elif event_type == 'conversation.item.input_audio_transcription.text':
|
||||
# 实时识别文本更新(多轮提示)
|
||||
transcript = response.get('transcript', '') or response.get('text', '')
|
||||
if transcript and transcript.strip() and self.asr_client.on_text_update:
|
||||
self.asr_client.on_text_update(transcript.strip())
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_started':
|
||||
self.asr_client._log("info", "ASR检测到说话开始")
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_stopped':
|
||||
self.asr_client._log("info", "ASR检测到说话结束")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""LLM模型"""
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
interrupt_check=None) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
LLM大语言模型模块
|
||||
支持多模态(文本+图像)
|
||||
"""
|
||||
from openai import OpenAI
|
||||
from typing import Optional, List
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
from robot_speaker.models.llm.base import LLMClient
|
||||
|
||||
|
||||
class DashScopeLLM(LLMClient):
|
||||
"""DashScope LLM客户端封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
base_url: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
name: str = "LLM",
|
||||
logger=None):
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
msg = f"[{self.name}] {msg}"
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
"""非流式聊天:任务规划"""
|
||||
payload_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=payload_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
stream=False
|
||||
)
|
||||
reply = response.choices[0].message.content.strip()
|
||||
return reply if reply else None
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
images: Optional[List[str]] = None,
|
||||
interrupt_check=None) -> str | None:
|
||||
"""
|
||||
流式聊天:语音系统
|
||||
支持多模态(文本+图像)
|
||||
支持中断检查(interrupt_check: 返回True表示需要中断)
|
||||
"""
|
||||
# 转换消息格式,支持多模态
|
||||
# 图像只添加到最后一个user消息中
|
||||
payload_messages = []
|
||||
last_user_idx = -1
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == "user":
|
||||
last_user_idx = i
|
||||
|
||||
has_images_in_message = False
|
||||
for i, msg in enumerate(messages):
|
||||
msg_dict = {"role": msg.role}
|
||||
|
||||
# 如果当前消息是最后一个user消息且有图像,构建多模态content
|
||||
if i == last_user_idx and msg.role == "user" and images and len(images) > 0:
|
||||
content_list = [{"type": "text", "text": msg.content}]
|
||||
# 添加所有图像
|
||||
for img_idx, img_base64 in enumerate(images):
|
||||
image_url = f"data:image/jpeg;base64,{img_base64[:50]}..." if len(img_base64) > 50 else f"data:image/jpeg;base64,{img_base64}"
|
||||
content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||
}
|
||||
})
|
||||
self._log("info", f"[多模态] 添加图像 #{img_idx+1} 到user消息,base64长度: {len(img_base64)}")
|
||||
msg_dict["content"] = content_list
|
||||
has_images_in_message = True
|
||||
else:
|
||||
msg_dict["content"] = msg.content
|
||||
|
||||
payload_messages.append(msg_dict)
|
||||
|
||||
# 记录多模态信息
|
||||
if images and len(images) > 0:
|
||||
if has_images_in_message:
|
||||
# 找到最后一个user消息,记录其content结构
|
||||
last_user_msg = payload_messages[last_user_idx] if last_user_idx >= 0 else None
|
||||
if last_user_msg and isinstance(last_user_msg.get("content"), list):
|
||||
content_items = last_user_msg["content"]
|
||||
text_items = [item for item in content_items if item.get("type") == "text"]
|
||||
image_items = [item for item in content_items if item.get("type") == "image_url"]
|
||||
self._log("info", f"[多模态] 已发送多模态请求: {len(text_items)}个文本 + {len(image_items)}张图片")
|
||||
self._log("debug", f"[多模态] 用户文本: {text_items[0].get('text', '')[:50] if text_items else 'N/A'}")
|
||||
else:
|
||||
self._log("warning", "[多模态] 消息格式异常,无法确认图片是否添加")
|
||||
else:
|
||||
self._log("warning", f"[多模态] 有{len(images)}张图片,但未找到user消息,图片未被添加")
|
||||
else:
|
||||
self._log("debug", "[多模态] 纯文本请求(无图片)")
|
||||
|
||||
full_reply = ""
|
||||
interrupted = False
|
||||
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=payload_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
# 检查中断标志
|
||||
if interrupt_check and interrupt_check():
|
||||
self._log("info", "LLM流式处理被中断")
|
||||
interrupted = True
|
||||
break
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_reply += content
|
||||
if on_token:
|
||||
on_token(content)
|
||||
# 在on_token回调后再次检查中断(on_token可能设置中断标志)
|
||||
if interrupt_check and interrupt_check():
|
||||
self._log("info", "LLM流式处理在on_token回调后被中断")
|
||||
interrupted = True
|
||||
break
|
||||
|
||||
if interrupted:
|
||||
return None # 被中断时返回None,表示未完成
|
||||
|
||||
return full_reply.strip() if full_reply else None
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""TTS模型"""
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
|
||||
|
||||
class TTSClient:
|
||||
"""TTS客户端抽象基类"""
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
"""
|
||||
TTS语音合成模块
|
||||
"""
|
||||
import subprocess
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
from robot_speaker.models.tts.base import TTSClient
|
||||
|
||||
|
||||
class DashScopeTTSClient(TTSClient):
|
||||
"""DashScope流式TTS客户端封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
voice: str,
|
||||
card_index: int,
|
||||
device_index: int,
|
||||
output_sample_rate: int = 44100,
|
||||
output_channels: int = 2,
|
||||
output_volume: float = 1.0,
|
||||
tts_source_sample_rate: int = 22050, # TTS服务固定输出采样率
|
||||
tts_source_channels: int = 1, # TTS服务固定输出声道数
|
||||
tts_ffmpeg_thread_queue_size: int = 1024, # ffmpeg输入线程队列大小
|
||||
reference_signal_buffer=None, # 参考信号缓冲区(用于回声消除)
|
||||
logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.card_index = card_index
|
||||
self.device_index = device_index
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_channels = output_channels
|
||||
self.output_volume = output_volume
|
||||
self.tts_source_sample_rate = tts_source_sample_rate
|
||||
self.tts_source_channels = tts_source_channels
|
||||
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
|
||||
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
|
||||
self.logger = logger
|
||||
self.current_ffmpeg_pid = None # 当前ffmpeg进程的PID
|
||||
|
||||
# 构建ALSA设备, 允许 ffmpeg 自动重采样 / 重声道
|
||||
self.alsa_device = f"plughw:{card_index},{device_index}" if (
|
||||
card_index >= 0 and device_index >= 0
|
||||
) else "default"
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
else:
|
||||
print(f"[TTS] {msg}")
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
"""主流程:流式合成并播放"""
|
||||
callback = _TTSCallback(self, interrupt_check, on_chunk, self.reference_signal_buffer)
|
||||
# 使用配置的voice,request.voice为None或空时使用self.voice
|
||||
voice_to_use = request.voice if request.voice and request.voice.strip() else self.voice
|
||||
|
||||
if not voice_to_use or not voice_to_use.strip():
|
||||
self._log("error", f"Voice参数无效: '{voice_to_use}'")
|
||||
return False
|
||||
|
||||
self._log("info", f"TTS开始: 文本='{request.text[:50]}...', voice='{voice_to_use}'")
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=self.model,
|
||||
voice=voice_to_use,
|
||||
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
try:
|
||||
synthesizer.streaming_call(request.text)
|
||||
synthesizer.streaming_complete()
|
||||
finally:
|
||||
callback.cleanup()
|
||||
|
||||
return not callback._interrupted
|
||||
|
||||
|
||||
class _TTSCallback(ResultCallback):
|
||||
"""TTS回调处理 - 使用ffmpeg播放,自动处理采样率转换"""
|
||||
|
||||
def __init__(self, tts_client: DashScopeTTSClient,
|
||||
interrupt_check=None,
|
||||
on_chunk=None,
|
||||
reference_signal_buffer=None):
|
||||
self.tts_client = tts_client
|
||||
self.interrupt_check = interrupt_check
|
||||
self.on_chunk = on_chunk
|
||||
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
|
||||
self._proc = None
|
||||
self._interrupted = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def on_open(self):
|
||||
# 使用ffmpeg播放,自动处理采样率转换(TTS源采样率 -> 设备采样率)
|
||||
# TTS服务输出固定采样率和声道数,ffmpeg会自动转换为设备采样率和声道数
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # 原始 PCM
|
||||
'-ar', str(self.tts_client.tts_source_sample_rate), # TTS输出采样率(从配置文件读取)
|
||||
'-ac', str(self.tts_client.tts_source_channels), # TTS输出声道数(从配置文件读取)
|
||||
'-i', 'pipe:0', # stdin
|
||||
'-f', 'alsa', # 输出到 ALSA
|
||||
'-ar', str(self.tts_client.output_sample_rate), # 输出设备采样率(从配置文件读取)
|
||||
'-ac', str(self.tts_client.output_channels), # 输出设备声道数(从配置文件读取)
|
||||
'-acodec', 'pcm_s16le', # 输出编码
|
||||
'-fflags', 'nobuffer', # 减少缓冲
|
||||
'-flags', 'low_delay', # 低延迟
|
||||
'-avioflags', 'direct', # 尝试直通写入 ALSA,减少延迟
|
||||
self.tts_client.alsa_device
|
||||
]
|
||||
|
||||
# 将 -thread_queue_size 放到输入文件之前
|
||||
insert_pos = ffmpeg_cmd.index('-i')
|
||||
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
|
||||
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
|
||||
|
||||
# 添加音量调节filter(如果音量不是1.0)
|
||||
if self.tts_client.output_volume != 1.0:
|
||||
# 在输出编码前插入音量filter
|
||||
# volume filter放在输入之后、输出编码之前
|
||||
acodec_idx = ffmpeg_cmd.index('-acodec')
|
||||
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
|
||||
ffmpeg_cmd.insert(acodec_idx, '-af')
|
||||
|
||||
self.tts_client._log("info", f"启动ffmpeg播放: ALSA设备={self.tts_client.alsa_device}, "
|
||||
f"输出采样率={self.tts_client.output_sample_rate}Hz, "
|
||||
f"输出声道数={self.tts_client.output_channels}, "
|
||||
f"音量={self.tts_client.output_volume * 100:.0f}%")
|
||||
self._proc = subprocess.Popen(
|
||||
ffmpeg_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE # 改为PIPE以便捕获错误
|
||||
)
|
||||
# 记录ffmpeg进程PID
|
||||
self.tts_client.current_ffmpeg_pid = self._proc.pid
|
||||
self.tts_client._log("debug", f"ffmpeg进程已启动,PID={self._proc.pid}")
|
||||
|
||||
def on_complete(self):
|
||||
pass
|
||||
|
||||
def on_error(self, message: str):
|
||||
self.tts_client._log("error", f"TTS错误: {message}")
|
||||
|
||||
def on_close(self):
|
||||
self.cleanup()
|
||||
|
||||
def on_event(self, message):
|
||||
pass
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
"""接收音频数据并播放"""
|
||||
if self._interrupted:
|
||||
return
|
||||
|
||||
if self.interrupt_check and self.interrupt_check():
|
||||
# 停止播放,不停止 TTS
|
||||
self._interrupted = True
|
||||
if self._proc:
|
||||
self._proc.terminate()
|
||||
return
|
||||
|
||||
# 优先写入ffmpeg,避免阻塞播放
|
||||
# 优先写入ffmpeg,避免阻塞播放
|
||||
if self._proc and self._proc.stdin and not self._interrupted:
|
||||
try:
|
||||
self._proc.stdin.write(data)
|
||||
self._proc.stdin.flush()
|
||||
except BrokenPipeError:
|
||||
# ffmpeg进程可能已退出,检查错误
|
||||
if self._proc.stderr:
|
||||
error_msg = self._proc.stderr.read().decode('utf-8', errors='ignore')
|
||||
self.tts_client._log("error", f"ffmpeg错误: {error_msg}")
|
||||
self._interrupted = True
|
||||
|
||||
# 将音频数据添加到参考信号缓冲区(用于回声消除)
|
||||
# 在写入ffmpeg之后处理,避免阻塞播放
|
||||
if self.reference_signal_buffer and data:
|
||||
try:
|
||||
self.reference_signal_buffer.add_reference(
|
||||
data,
|
||||
source_sample_rate=self.tts_client.tts_source_sample_rate,
|
||||
source_channels=self.tts_client.tts_source_channels
|
||||
)
|
||||
except Exception as e:
|
||||
# 参考信号处理失败不应影响播放
|
||||
self.tts_client._log("warning", f"参考信号处理失败: {e}")
|
||||
|
||||
if self.on_chunk:
|
||||
self.on_chunk(data)
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self._cleaned_up or not self._proc:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
|
||||
# 关闭stdin,让ffmpeg处理完剩余数据
|
||||
if self._proc.stdin and not self._proc.stdin.closed:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 等待进程自然结束(根据文本长度估算,最少10秒,最多30秒)
|
||||
# 假设平均语速:3-4字/秒,加上缓冲时间
|
||||
if self._proc.poll() is None:
|
||||
try:
|
||||
# 增加等待时间,确保ffmpeg播放完成
|
||||
# 对于长文本,可能需要更长时间
|
||||
self._proc.wait(timeout=30.0)
|
||||
except:
|
||||
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
|
||||
if self._proc.poll() is None:
|
||||
self.tts_client._log("warning", "ffmpeg播放超时,强制终止")
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=1.0)
|
||||
except:
|
||||
try:
|
||||
self._proc.kill()
|
||||
self._proc.wait(timeout=0.1)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 清空PID记录
|
||||
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""感知层"""
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"""
|
||||
音频处理模块:录音 + VAD + 回声消除
|
||||
音频处理模块:录音 + VAD
|
||||
"""
|
||||
import time
|
||||
import pyaudio
|
||||
import webrtcvad
|
||||
import struct
|
||||
import queue
|
||||
from .echo_cancellation import EchoCanceller, ReferenceSignalBuffer
|
||||
|
||||
|
||||
class VADDetector:
|
||||
@@ -35,8 +34,6 @@ class AudioRecorder:
|
||||
on_audio_chunk=None, # 音频chunk回调(用于声纹录音等,可选)
|
||||
should_put_to_queue=None, # 检查是否应该将音频放入队列(用于阻止ASR,可选)
|
||||
get_silence_threshold=None, # 获取动态静音阈值(毫秒,可选)
|
||||
enable_echo_cancellation: bool = True, # 是否启用回声消除
|
||||
reference_signal_buffer: ReferenceSignalBuffer = None, # 参考信号缓冲区(可选)
|
||||
logger=None):
|
||||
self.device_index = device_index
|
||||
self.sample_rate = sample_rate
|
||||
@@ -89,7 +86,7 @@ class AudioRecorder:
|
||||
self.device_index = found_index
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.warning(f"未自动检测到 iFLYTEK 设备,将继续使用配置的索引: {self.device_index}")
|
||||
self.logger.warning(f"未自动检测到 iFLYTEK 设备,请检查USB连接,或执行 'arecord -l' 确认系统是否识别到录音设备,将继续使用配置的索引: {self.device_index}")
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
@@ -97,39 +94,6 @@ class AudioRecorder:
|
||||
|
||||
self.format = pyaudio.paInt16
|
||||
self._debug_counter = 0
|
||||
|
||||
# 回声消除相关
|
||||
self.enable_echo_cancellation = enable_echo_cancellation
|
||||
self.reference_signal_buffer = reference_signal_buffer
|
||||
if enable_echo_cancellation:
|
||||
# 初始化回声消除器(在录音线程中同步处理,不是单独线程)
|
||||
# frame_size设置为chunk大小,确保每次处理一个chunk
|
||||
frame_size = chunk
|
||||
try:
|
||||
# 获取参考信号声道数(从reference_signal_buffer获取,因为它是根据播放声道数创建的)
|
||||
ref_channels = self.reference_signal_buffer.channels if self.reference_signal_buffer else 1
|
||||
self.echo_canceller = EchoCanceller(
|
||||
sample_rate=sample_rate,
|
||||
frame_size=frame_size,
|
||||
channels=self.channels, # 麦克风输入:1声道
|
||||
ref_channels=ref_channels, # 参考信号:播放声道数(2声道)
|
||||
logger=logger
|
||||
)
|
||||
if self.echo_canceller.aec is not None:
|
||||
if logger:
|
||||
logger.info(f"回声消除器已启用: sample_rate={sample_rate}, frame_size={frame_size}")
|
||||
else:
|
||||
if logger:
|
||||
logger.warning("回声消除器初始化失败,将禁用回声消除功能")
|
||||
self.enable_echo_cancellation = False
|
||||
self.echo_canceller = None
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"回声消除器初始化失败: {e},将禁用回声消除功能")
|
||||
self.enable_echo_cancellation = False
|
||||
self.echo_canceller = None
|
||||
else:
|
||||
self.echo_canceller = None
|
||||
|
||||
def record_with_vad(self):
|
||||
"""录音线程:VAD + 能量检测"""
|
||||
@@ -163,19 +127,7 @@ class AudioRecorder:
|
||||
while not self.stop_flag():
|
||||
# exception_on_overflow=False, 宁可丢帧,也不阻塞
|
||||
data = stream.read(self.chunk, exception_on_overflow=False)
|
||||
|
||||
# 回声消除处理
|
||||
processed_data = data
|
||||
if self.enable_echo_cancellation and self.echo_canceller and self.reference_signal_buffer:
|
||||
try:
|
||||
# 获取参考信号(长度与麦克风信号匹配)
|
||||
ref_signal = self.reference_signal_buffer.get_reference(num_samples=self.chunk)
|
||||
# 执行回声消除
|
||||
processed_data = self.echo_canceller.process(data, ref_signal)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
|
||||
processed_data = data
|
||||
|
||||
# 检查是否应该将音频放入队列(用于阻止ASR,例如无声纹文件时需要注册)
|
||||
if self.should_put_to_queue():
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
"""
|
||||
相机模块 - RealSense相机封装
|
||||
"""
|
||||
import numpy as np
|
||||
import contextlib
|
||||
|
||||
|
||||
class CameraClient:
|
||||
def __init__(self,
|
||||
serial_number: str | None,
|
||||
width: int,
|
||||
height: int,
|
||||
fps: int,
|
||||
format: str,
|
||||
logger=None):
|
||||
self.serial_number = serial_number
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.fps = fps
|
||||
self.format = format
|
||||
self.logger = logger
|
||||
|
||||
self.pipeline = None
|
||||
self.config = None
|
||||
self._is_initialized = False
|
||||
self._rs = None
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
if self.logger:
|
||||
getattr(self.logger, level, self.logger.info)(msg)
|
||||
else:
|
||||
print(f"[相机] {msg}")
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
初始化并启动相机管道
|
||||
"""
|
||||
if self._is_initialized:
|
||||
return True
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs
|
||||
self._rs = rs
|
||||
|
||||
self.pipeline = rs.pipeline()
|
||||
self.config = rs.config()
|
||||
|
||||
if self.serial_number:
|
||||
self.config.enable_device(self.serial_number)
|
||||
|
||||
self.config.enable_stream(
|
||||
rs.stream.color,
|
||||
self.width,
|
||||
self.height,
|
||||
rs.format.rgb8 if self.format == 'RGB8' else rs.format.bgr8,
|
||||
self.fps
|
||||
)
|
||||
|
||||
self.pipeline.start(self.config)
|
||||
self._is_initialized = True
|
||||
self._log("info", f"相机已启动并保持运行: {self.width}x{self.height}@{self.fps}fps")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"相机初始化失败: {e}")
|
||||
self.cleanup()
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
"""停止相机管道,释放资源"""
|
||||
if self.pipeline:
|
||||
self.pipeline.stop()
|
||||
self._log("info", "相机已停止")
|
||||
self.pipeline = None
|
||||
self.config = None
|
||||
self._is_initialized = False
|
||||
|
||||
def capture_rgb(self) -> np.ndarray | None:
|
||||
"""
|
||||
从运行中的相机管道捕获一帧RGB图像
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self._log("error", "相机未初始化,无法捕获图像")
|
||||
return None
|
||||
|
||||
try:
|
||||
frames = self.pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
|
||||
return np.asanyarray(color_frame.get_data())
|
||||
except Exception as e:
|
||||
self._log("error", f"捕获图像失败: {e}")
|
||||
return None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_context(self):
|
||||
"""
|
||||
上下文管理器:拍照并自动清理资源
|
||||
"""
|
||||
image_data = self.capture_rgb()
|
||||
try:
|
||||
yield image_data
|
||||
finally:
|
||||
if image_data is not None:
|
||||
del image_data
|
||||
|
||||
def capture_multiple(self, count: int = 1) -> list[np.ndarray]:
|
||||
"""
|
||||
捕获多张图像(为未来扩展准备)
|
||||
"""
|
||||
images = []
|
||||
for i in range(count):
|
||||
img = self.capture_rgb()
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
else:
|
||||
self._log("warning", f"第{i+1}张图像捕获失败")
|
||||
return images
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_multiple_context(self, count: int = 1):
|
||||
"""
|
||||
上下文管理器:捕获多张图像并自动清理资源
|
||||
"""
|
||||
images = self.capture_multiple(count)
|
||||
try:
|
||||
yield images
|
||||
finally:
|
||||
for img in images:
|
||||
del img
|
||||
images.clear()
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ReferenceSignalBuffer:
|
||||
"""参考信号缓冲区"""
|
||||
|
||||
def __init__(self, sample_rate: int, channels: int, max_duration_ms: int | None = None,
|
||||
buffer_seconds: float = 5.0):
|
||||
self.sample_rate = int(sample_rate)
|
||||
self.channels = int(channels)
|
||||
if max_duration_ms is not None:
|
||||
buffer_seconds = max(float(max_duration_ms) / 1000.0, 0.1)
|
||||
self.max_samples = int(self.sample_rate * buffer_seconds)
|
||||
self._buffer = collections.deque(maxlen=self.max_samples * self.channels)
|
||||
|
||||
def add_reference(self, data: bytes, source_sample_rate: int, source_channels: int):
|
||||
if source_sample_rate != self.sample_rate or source_channels != self.channels:
|
||||
return
|
||||
samples = np.frombuffer(data, dtype=np.int16)
|
||||
self._buffer.extend(samples.tolist())
|
||||
|
||||
def get_reference(self, num_samples: int) -> bytes:
|
||||
needed = int(num_samples) * self.channels
|
||||
if needed <= 0:
|
||||
return b""
|
||||
if len(self._buffer) < needed:
|
||||
data = list(self._buffer) + [0] * (needed - len(self._buffer))
|
||||
else:
|
||||
data = list(self._buffer)[-needed:]
|
||||
return np.array(data, dtype=np.int16).tobytes()
|
||||
|
||||
|
||||
class EchoCanceller:
|
||||
"""回声消除器(基于 aec-audio-processing)"""
|
||||
|
||||
def __init__(self, sample_rate: int, frame_size: int, channels: int, ref_channels: int, logger=None):
|
||||
self.sample_rate = int(sample_rate)
|
||||
self.frame_size = int(frame_size)
|
||||
self.channels = int(channels)
|
||||
self.ref_channels = int(ref_channels)
|
||||
self.logger = logger
|
||||
self.aec = None
|
||||
self._process_reverse = None
|
||||
self._frame_bytes = int(self.sample_rate / 100) * self.channels * 2 # 10ms, int16
|
||||
self._ref_frame_bytes = int(self.sample_rate / 100) * self.ref_channels * 2
|
||||
try:
|
||||
from aec_audio_processing import AudioProcessor
|
||||
self.aec = AudioProcessor(enable_aec=True, enable_ns=False, enable_agc=False)
|
||||
self.aec.set_stream_format(self.sample_rate, self.channels)
|
||||
if hasattr(self.aec, "set_reverse_stream_format"):
|
||||
self.aec.set_reverse_stream_format(self.sample_rate, self.ref_channels)
|
||||
if hasattr(self.aec, "set_stream_delay"):
|
||||
self.aec.set_stream_delay(0)
|
||||
if hasattr(self.aec, "process_reverse_stream"):
|
||||
self._process_reverse = self.aec.process_reverse_stream
|
||||
elif hasattr(self.aec, "process_reverse"):
|
||||
self._process_reverse = self.aec.process_reverse
|
||||
except Exception:
|
||||
self.aec = None
|
||||
|
||||
def process(self, mic_data: bytes, ref_data: bytes) -> bytes:
|
||||
if not self.aec:
|
||||
return mic_data
|
||||
if not mic_data:
|
||||
return mic_data
|
||||
|
||||
try:
|
||||
out_chunks = []
|
||||
total_len = len(mic_data)
|
||||
frame_bytes = self._frame_bytes
|
||||
ref_frame_bytes = self._ref_frame_bytes
|
||||
|
||||
frame_count = (total_len + frame_bytes - 1) // frame_bytes
|
||||
for i in range(frame_count):
|
||||
m_start = i * frame_bytes
|
||||
m_end = m_start + frame_bytes
|
||||
mic_frame = mic_data[m_start:m_end]
|
||||
if len(mic_frame) < frame_bytes:
|
||||
mic_frame = mic_frame + b"\x00" * (frame_bytes - len(mic_frame))
|
||||
|
||||
if ref_data:
|
||||
r_start = i * ref_frame_bytes
|
||||
r_end = r_start + ref_frame_bytes
|
||||
ref_frame = ref_data[r_start:r_end]
|
||||
if len(ref_frame) < ref_frame_bytes:
|
||||
ref_frame = ref_frame + b"\x00" * (ref_frame_bytes - len(ref_frame))
|
||||
if self._process_reverse:
|
||||
self._process_reverse(ref_frame)
|
||||
|
||||
processed = self.aec.process_stream(mic_frame)
|
||||
out_chunks.append(processed if processed is not None else mic_frame)
|
||||
|
||||
return b"".join(out_chunks)[:total_len]
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
|
||||
return mic_data
|
||||
@@ -1,296 +0,0 @@
|
||||
"""
|
||||
声纹识别模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
import tempfile
|
||||
import os
|
||||
import wave
|
||||
import time
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpeakerState(Enum):
|
||||
"""说话人识别状态"""
|
||||
UNKNOWN = "unknown"
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SpeakerVerificationClient:
|
||||
"""声纹识别客户端 - 非实时、低频处理"""
|
||||
|
||||
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
|
||||
self.model_path = model_path
|
||||
self.threshold = threshold
|
||||
self.speaker_db_path = speaker_db_path
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 优化CPU性能:限制Torch使用的线程数,防止多线程竞争导致性能骤降
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
from funasr import AutoModel
|
||||
model_path = os.path.expanduser(self.model_path)
|
||||
# 禁用自动更新检查,防止每次初始化都联网检查
|
||||
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
|
||||
if self.logger:
|
||||
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
|
||||
|
||||
if self.speaker_db_path:
|
||||
self.load_speakers()
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
log_methods = {
|
||||
"debug": self.logger.debug,
|
||||
"info": self.logger.info,
|
||||
"warning": self.logger.warning,
|
||||
"error": self.logger.error,
|
||||
"fatal": self.logger.fatal
|
||||
}
|
||||
log_method = log_methods.get(level.lower(), self.logger.info)
|
||||
log_method(msg)
|
||||
except ValueError as e:
|
||||
if "severity cannot be changed" in str(e):
|
||||
try:
|
||||
self.logger.info(f"[声纹-{level.upper()}] {msg}")
|
||||
except:
|
||||
print(f"[声纹-{level.upper()}] {msg}")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
print(f"[声纹] {msg}")
|
||||
|
||||
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""将numpy音频数组写入临时wav文件"""
|
||||
audio_int16 = audio_data.astype(np.int16)
|
||||
|
||||
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
|
||||
os.close(fd)
|
||||
|
||||
with wave.open(temp_path, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
return temp_path
|
||||
|
||||
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""
|
||||
提取说话人embedding(低频调用,一句话只调用一次)
|
||||
"""
|
||||
# 降采样到 16000Hz (如果需要)
|
||||
# Cam++ 等模型通常只支持 16k,如果传入 48k 会导致内部重采样极慢或计算量剧增
|
||||
target_sr = 16000
|
||||
if sample_rate > target_sr:
|
||||
if sample_rate % target_sr == 0:
|
||||
step = sample_rate // target_sr
|
||||
audio_data = audio_data[::step]
|
||||
sample_rate = target_sr
|
||||
else:
|
||||
# 简单的非整数倍降采样可能导致问题,但对于语音验证通常 48k->16k 是整数倍
|
||||
# 如果不是,此处暂不处理,依赖 funasr 内部处理,或者简单的步长取整
|
||||
step = int(sample_rate / target_sr)
|
||||
audio_data = audio_data[::step]
|
||||
sample_rate = target_sr
|
||||
|
||||
if len(audio_data) < int(sample_rate * 0.5):
|
||||
return None, False
|
||||
|
||||
temp_wav_path = None
|
||||
try:
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
import torch
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, False
|
||||
|
||||
return embedding, True
|
||||
except Exception as e:
|
||||
self._log("error", f"提取embedding失败: {e}")
|
||||
return None, False
|
||||
finally:
|
||||
if temp_wav_path and os.path.exists(temp_wav_path):
|
||||
try:
|
||||
os.unlink(temp_wav_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
|
||||
env: str = "near", threshold: float = None) -> bool:
|
||||
"""
|
||||
注册说话人
|
||||
"""
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return False
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
self._log("error", f"注册失败:embedding范数为0")
|
||||
return False
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
speaker_threshold = threshold if threshold is not None else self.threshold
|
||||
|
||||
with self._lock:
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_normalized,
|
||||
"env": env, # 添加 env 字段
|
||||
"threshold": speaker_threshold,
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
|
||||
save_result = self.save_speakers()
|
||||
if not save_result:
|
||||
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
|
||||
return True
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray):
|
||||
"""
|
||||
匹配说话人(一句话只调用一次)
|
||||
"""
|
||||
if not self.speaker_db:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
best_match = None
|
||||
best_score = -1.0
|
||||
best_threshold = self.threshold
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
ref_embedding = speaker_data["embedding"]
|
||||
score = np.dot(embedding_normalized, ref_embedding)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = speaker_id
|
||||
best_threshold = speaker_data["threshold"]
|
||||
|
||||
state = SpeakerState.VERIFIED if best_score >= best_threshold else SpeakerState.REJECTED
|
||||
return (best_match, state, best_score, best_threshold)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
pass
|
||||
|
||||
def get_speaker_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self.speaker_db)
|
||||
|
||||
def remove_speaker(self, speaker_id: str) -> bool:
|
||||
with self._lock:
|
||||
if speaker_id not in self.speaker_db:
|
||||
return False
|
||||
del self.speaker_db[speaker_id]
|
||||
self.save_speakers()
|
||||
return True
|
||||
|
||||
def load_speakers(self) -> bool:
|
||||
"""
|
||||
从文件加载已注册的声纹
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
return False
|
||||
|
||||
if not os.path.exists(self.speaker_db_path):
|
||||
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in data.items():
|
||||
embedding_list = speaker_data["embedding"]
|
||||
embedding_array = np.array(embedding_list, dtype=np.float32)
|
||||
|
||||
embedding_dim = len(embedding_array)
|
||||
if embedding_dim == 0:
|
||||
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
|
||||
continue
|
||||
embedding_norm = np.linalg.norm(embedding_array)
|
||||
if embedding_norm > 0:
|
||||
embedding_array = embedding_array / embedding_norm
|
||||
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_array,
|
||||
"env": speaker_data["env"],
|
||||
"threshold": speaker_data["threshold"],
|
||||
"registered_at": speaker_data["registered_at"]
|
||||
}
|
||||
|
||||
count = len(self.speaker_db)
|
||||
self._log("info", f"已加载 {count} 个已注册说话人")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
return False
|
||||
|
||||
def save_speakers(self) -> bool:
|
||||
"""
|
||||
保存已注册的声纹到文件
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
|
||||
return False
|
||||
|
||||
try:
|
||||
db_dir = os.path.dirname(self.speaker_db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
json_data = {}
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
json_data[speaker_id] = {
|
||||
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
|
||||
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
|
||||
"threshold": speaker_data["threshold"],
|
||||
"registered_at": speaker_data["registered_at"]
|
||||
}
|
||||
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
os.replace(temp_path, self.speaker_db_path)
|
||||
|
||||
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
self._log("error", f"保存路径: {self.speaker_db_path}")
|
||||
self._log("error", f"错误详情: {traceback.format_exc()}")
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
22
robot_speaker/services/__init__.py
Normal file
22
robot_speaker/services/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Service节点模块
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
703
robot_speaker/services/asr_audio_node.py
Normal file
703
robot_speaker/services/asr_audio_node.py
Normal file
@@ -0,0 +1,703 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from interfaces.srv import ASRRecognize, AudioData, VADEvent
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import pyaudio
|
||||
import yaml
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
import base64
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
class AudioRecorder:
|
||||
def __init__(self, device_index: int, sample_rate: int, channels: int,
|
||||
chunk: int, audio_queue: queue.Queue, stop_event, logger=None):
|
||||
self.device_index = device_index
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.chunk = chunk
|
||||
self.audio_queue = audio_queue
|
||||
self.stop_event = stop_event
|
||||
self.logger = logger
|
||||
self.audio = pyaudio.PyAudio()
|
||||
|
||||
original_index = self.device_index
|
||||
try:
|
||||
for i in range(self.audio.get_device_count()):
|
||||
device_info = self.audio.get_device_info_by_index(i)
|
||||
if 'iFLYTEK' in device_info['name'] and device_info['maxInputChannels'] > 0:
|
||||
self.device_index = i
|
||||
if self.logger:
|
||||
self.logger.info(f"[ASR-Recorder] 已自动定位到麦克风设备: {device_info['name']} (Index: {i})")
|
||||
break
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ASR-Recorder] 设备自动检测过程出错: {e}")
|
||||
|
||||
if self.device_index == original_index and original_index == -1:
|
||||
self.device_index = 0
|
||||
if self.logger:
|
||||
self.logger.info("[ASR-Recorder] 未找到 iFLYTEK 设备,使用系统默认输入设备")
|
||||
self.format = pyaudio.paInt16
|
||||
|
||||
def record(self):
|
||||
if self.logger:
|
||||
self.logger.info(f"[ASR-Recorder] 录音线程启动,设备索引: {self.device_index}")
|
||||
stream = None
|
||||
try:
|
||||
stream = self.audio.open(
|
||||
format=self.format,
|
||||
channels=self.channels,
|
||||
rate=self.sample_rate,
|
||||
input=True,
|
||||
input_device_index=self.device_index if self.device_index >= 0 else None,
|
||||
frames_per_buffer=self.chunk
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.info("[ASR-Recorder] 音频输入设备已打开")
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ASR-Recorder] 无法打开音频输入设备: {e}")
|
||||
return
|
||||
try:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
data = stream.read(self.chunk, exception_on_overflow=False)
|
||||
|
||||
if self.audio_queue.full():
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.put_nowait(data)
|
||||
except OSError as e:
|
||||
if self.logger:
|
||||
self.logger.debug(f"[ASR-Recorder] 录音设备错误: {e}")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
if self.logger:
|
||||
self.logger.info("[ASR-Recorder] 录音线程收到中断信号")
|
||||
finally:
|
||||
if stream is not None:
|
||||
try:
|
||||
if stream.is_active():
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.logger:
|
||||
self.logger.info("[ASR-Recorder] 录音线程已退出")
|
||||
|
||||
|
||||
class DashScopeASR:
|
||||
def __init__(self, api_key: str, sample_rate: int, model: str, url: str, logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.sample_rate = sample_rate
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.logger = logger
|
||||
|
||||
self.conversation = None
|
||||
self.running = False
|
||||
self.on_sentence_end = None
|
||||
self.on_speech_started = None
|
||||
self.on_speech_stopped = None
|
||||
|
||||
self._stop_lock = threading.Lock()
|
||||
self._final_result_event = threading.Event()
|
||||
self._pending_commit = False
|
||||
|
||||
# ========== 连接生命周期管理: 解决 DashScope ASR WebSocket 连接超时导致的识别不稳定 ==========
|
||||
self._connection_start_time = None # 连接创建时间
|
||||
self._last_audio_time = None # 最后一次发送音频的时间
|
||||
self._recognition_count = 0 # 识别次数计数
|
||||
self._audio_send_count = 0 # 音频发送次数计数
|
||||
self._last_audio_send_success = True # 最后一次音频发送是否成功
|
||||
self._consecutive_send_failures = 0 # 连续发送失败次数
|
||||
|
||||
# 配置参数
|
||||
self.MAX_CONNECTION_AGE = 300 # 连接最大存活时间:5分钟
|
||||
self.MAX_IDLE_TIME = 180 # 最大空闲时间:3分钟
|
||||
self.MAX_RECOGNITIONS = 30 # 最大识别次数:30次后重建连接
|
||||
self.MAX_CONSECUTIVE_FAILURES = 3 # 最大连续失败次数
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
if not self.logger:
|
||||
return
|
||||
try:
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _should_reconnect(self) -> tuple[bool, str]:
|
||||
if not self.running or not self.conversation:
|
||||
return False, ""
|
||||
current_time = time.time()
|
||||
# 检查1:连接时间
|
||||
if self._connection_start_time:
|
||||
connection_age = current_time - self._connection_start_time
|
||||
if connection_age > self.MAX_CONNECTION_AGE:
|
||||
return True, f"连接已存活{connection_age:.0f}秒,超过{self.MAX_CONNECTION_AGE}秒阈值"
|
||||
# 检查2:空闲时间
|
||||
if self._last_audio_time:
|
||||
idle_time = current_time - self._last_audio_time
|
||||
if idle_time > self.MAX_IDLE_TIME:
|
||||
return True, f"连接已空闲{idle_time:.0f}秒,超过{self.MAX_IDLE_TIME}秒阈值"
|
||||
# 检查3:识别次数
|
||||
if self._recognition_count >= self.MAX_RECOGNITIONS:
|
||||
return True, f"已完成{self._recognition_count}次识别,达到重连阈值"
|
||||
# 检查4:连续发送失败
|
||||
if self._consecutive_send_failures >= self.MAX_CONSECUTIVE_FAILURES:
|
||||
return True, f"连续{self._consecutive_send_failures}次音频发送失败"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _reset_connection_stats(self):
|
||||
self._connection_start_time = time.time()
|
||||
self._last_audio_time = time.time()
|
||||
self._recognition_count = 0
|
||||
self._audio_send_count = 0
|
||||
self._last_audio_send_success = True
|
||||
self._consecutive_send_failures = 0
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
callback = _ASRCallback(self)
|
||||
self.conversation = OmniRealtimeConversation(
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
callback=callback
|
||||
)
|
||||
callback.conversation = self.conversation
|
||||
|
||||
self.conversation.connect()
|
||||
|
||||
transcription_params = TranscriptionParams(
|
||||
language='zh',
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
|
||||
self.conversation.update_session(
|
||||
output_modalities=[MultiModality.TEXT],
|
||||
enable_input_audio_transcription=True,
|
||||
transcription_params=transcription_params,
|
||||
enable_turn_detection=True,
|
||||
turn_detection_type='server_vad',
|
||||
prefix_padding_ms=1000,
|
||||
turn_detection_threshold=0.3,
|
||||
turn_detection_silence_duration_ms=800,
|
||||
)
|
||||
|
||||
self.running = True
|
||||
self._reset_connection_stats()
|
||||
self._log("info", f"[ASR] 已启动 | 连接ID:{id(self.conversation)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.running = False
|
||||
self._log("error", f"[ASR] 启动失败: {e}")
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.conversation = None
|
||||
return False
|
||||
|
||||
def send_audio(self, audio_chunk: bytes):
|
||||
should_reconnect, reason = self._should_reconnect()
|
||||
if should_reconnect:
|
||||
self._log("warning", f"[ASR] 检测到需要重连: {reason}")
|
||||
self.running = False
|
||||
try:
|
||||
if self.conversation:
|
||||
self.conversation.close()
|
||||
except:
|
||||
pass
|
||||
self.conversation = None
|
||||
time.sleep(1.0)
|
||||
if not self.start():
|
||||
self._log("error", "[ASR] 自动重连失败")
|
||||
return False
|
||||
self._log("info", "[ASR] 自动重连成功")
|
||||
import threading
|
||||
self._log("debug", f"[ASR] send_audio 被调用 | 线程:{threading.current_thread().name} | running:{self.running} | conversation:{self.conversation is not None}")
|
||||
if not self.running or not self.conversation:
|
||||
self._log("debug", f"[ASR] send_audio 跳过 | running:{self.running} | conversation:{self.conversation is not None}")
|
||||
return False
|
||||
try:
|
||||
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
|
||||
self.conversation.append_audio(audio_b64)
|
||||
self._last_audio_time = time.time()
|
||||
self._audio_send_count += 1
|
||||
self._last_audio_send_success = True
|
||||
self._consecutive_send_failures = 0
|
||||
self._log("debug", f"[ASR] 音频发送成功 | 总计:{self._audio_send_count} | 连接年龄:{time.time() - self._connection_start_time:.1f}秒")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self._last_audio_send_success = False
|
||||
self._consecutive_send_failures += 1
|
||||
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
if "Connection is already closed" in error_msg or "WebSocketConnectionClosedException" in error_type or "ConnectionClosed" in error_type or "websocket" in error_msg.lower():
|
||||
self._log("warning", f"[ASR] WebSocket 连接已断开 | 错误:{error_msg} | 连续失败:{self._consecutive_send_failures}次")
|
||||
self.running = False
|
||||
try:
|
||||
if self.conversation:
|
||||
self.conversation.close()
|
||||
except:
|
||||
pass
|
||||
self.conversation = None
|
||||
else:
|
||||
self._log("error", f"[ASR] send_audio 异常 | 错误:{error_msg} | 类型:{error_type} | 连续失败:{self._consecutive_send_failures}次")
|
||||
|
||||
return False
|
||||
|
||||
def stop_current_recognition(self):
|
||||
import threading
|
||||
self._log("debug", f"[ASR] stop_current_recognition 被调用 | 线程:{threading.current_thread().name} | running:{self.running}")
|
||||
if not self._stop_lock.acquire(blocking=False):
|
||||
self._log("debug", f"[ASR] 锁获取失败,有其他线程正在执行 stop_current_recognition")
|
||||
return False
|
||||
|
||||
self._final_result_event.clear()
|
||||
self._pending_commit = True
|
||||
|
||||
try:
|
||||
self._log("debug", f"[ASR] 获得锁,开始停止识别 | conversation:{self.conversation is not None}")
|
||||
if not self.running or not self.conversation:
|
||||
self._log("debug", f"[ASR] 无法停止 | running:{self.running} | conversation:{self.conversation is not None}")
|
||||
return False
|
||||
|
||||
self._recognition_count += 1
|
||||
should_reconnect, reason = self._should_reconnect()
|
||||
if should_reconnect:
|
||||
self._log("info", f"[ASR] 识别完成后检测到需要重连: {reason}")
|
||||
|
||||
self._final_result_event.clear()
|
||||
self._pending_commit = True
|
||||
|
||||
try:
|
||||
self.conversation.commit()
|
||||
self._final_result_event.wait(timeout=3.0)
|
||||
except Exception as e:
|
||||
self._log("debug", f"[ASR] commit 异常: {e}")
|
||||
|
||||
self._log("debug", f"[ASR] 准备关闭旧连接 | conversation_id:{id(self.conversation)}")
|
||||
self.running = False
|
||||
|
||||
old_conversation = self.conversation
|
||||
self.conversation = None
|
||||
|
||||
self._log("debug", f"[ASR] conversation已设为None,准备关闭旧连接")
|
||||
try:
|
||||
old_conversation.close()
|
||||
self._log("debug", f"[ASR] 旧连接已关闭")
|
||||
except Exception as e:
|
||||
self._log("warning", f"[ASR] 关闭连接异常: {e}")
|
||||
self._log("debug", f"[ASR] 连接已关闭,等待下次语音活动时重连")
|
||||
return True
|
||||
|
||||
finally:
|
||||
self._pending_commit = False
|
||||
self._stop_lock.release()
|
||||
self._log("debug", f"[ASR] stop_current_recognition 完成,锁已释放")
|
||||
|
||||
def stop(self):
|
||||
with self._stop_lock:
|
||||
self.running = False
|
||||
self._final_result_event.set()
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.conversation = None
|
||||
self._log("info", "[ASR] 已完全停止")
|
||||
|
||||
|
||||
class _ASRCallback(OmniRealtimeCallback):
|
||||
def __init__(self, asr_client: DashScopeASR):
|
||||
self.asr_client = asr_client
|
||||
self.conversation = None
|
||||
|
||||
|
||||
def on_event(self, response):
|
||||
try:
|
||||
event_type = response['type']
|
||||
if event_type == 'conversation.item.input_audio_transcription.completed':
|
||||
transcript = response['transcript']
|
||||
if transcript.strip() and self.asr_client.on_sentence_end:
|
||||
self.asr_client.on_sentence_end(transcript.strip())
|
||||
if self.asr_client._pending_commit:
|
||||
self.asr_client._final_result_event.set()
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_started':
|
||||
if self.asr_client.logger:
|
||||
self.asr_client.logger.info("[ASR] 检测到语音开始")
|
||||
if self.asr_client.on_speech_started:
|
||||
self.asr_client.on_speech_started()
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_stopped':
|
||||
if self.asr_client.logger:
|
||||
self.asr_client.logger.info("[ASR] 检测到语音结束")
|
||||
if self.asr_client.on_speech_stopped:
|
||||
self.asr_client.on_speech_stopped()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ASRAudioNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('asr_audio_node')
|
||||
self._load_config()
|
||||
|
||||
self.audio_queue = queue.Queue(maxsize=100)
|
||||
self.stop_event = threading.Event()
|
||||
self._shutdown_in_progress = False
|
||||
|
||||
self._init_components()
|
||||
|
||||
self.recognize_service = self.create_service(
|
||||
ASRRecognize, '/asr/recognize', self._recognize_callback
|
||||
)
|
||||
self.audio_data_service = self.create_service(
|
||||
AudioData, '/asr/audio_data', self._audio_data_callback
|
||||
)
|
||||
self.vad_event_service = self.create_service(
|
||||
VADEvent, '/vad/event', self._vad_event_callback
|
||||
)
|
||||
|
||||
self._last_result = None
|
||||
self._result_event = threading.Event()
|
||||
self._last_result_time = None
|
||||
self.vad_event_queue = queue.Queue()
|
||||
self.audio_buffer = collections.deque(maxlen=240000)
|
||||
self.audio_recording = False
|
||||
self.audio_lock = threading.Lock()
|
||||
|
||||
# ========== 异常识别检测 ==========
|
||||
self._abnormal_results = ["嗯。", "。", "啊。", "哦。"] # 异常识别结果列表
|
||||
self._consecutive_abnormal_count = 0 # 连续异常识别次数
|
||||
self.MAX_CONSECUTIVE_ABNORMAL = 5 # 最大连续异常次数
|
||||
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.audio_recorder.record, name="RecordingThread", daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self._asr_worker, name="ASRThread", daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
self.get_logger().info("ASR Audio节点已启动")
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
mic = config['audio']['microphone']
|
||||
self.input_device_index = mic['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
|
||||
dashscope = config['dashscope']
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
|
||||
def _init_components(self):
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
audio_queue=self.audio_queue,
|
||||
stop_event=self.stop_event,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.asr_client.on_sentence_end = self._on_asr_result
|
||||
self.asr_client.on_speech_started = lambda: self._put_vad_event("speech_started")
|
||||
self.asr_client.on_speech_stopped = lambda: (self._clear_result(), self._put_vad_event("speech_stopped"))
|
||||
self.asr_client.start()
|
||||
|
||||
def _on_asr_result(self, text: str):
|
||||
if not text or not text.strip():
|
||||
return
|
||||
|
||||
self._last_result = text.strip()
|
||||
self._last_result_time = time.time()
|
||||
self._result_event.set()
|
||||
|
||||
is_abnormal = self._last_result in self._abnormal_results and len(self._last_result) <= 2
|
||||
if is_abnormal:
|
||||
self._consecutive_abnormal_count += 1
|
||||
self.get_logger().warn(f"[ASR] 检测到异常识别结果: '{self._last_result}' | 连续异常:{self._consecutive_abnormal_count}次")
|
||||
# 如果连续多次异常,强制重置 ASR 连接
|
||||
if self._consecutive_abnormal_count >= self.MAX_CONSECUTIVE_ABNORMAL:
|
||||
self.get_logger().error(f"[ASR] 连续{self._consecutive_abnormal_count}次异常识别,标记需要重连")
|
||||
self.asr_client._consecutive_send_failures = self.asr_client.MAX_CONSECUTIVE_FAILURES
|
||||
self._consecutive_abnormal_count = 0
|
||||
else:
|
||||
# 正常识别,重置异常计数
|
||||
self._consecutive_abnormal_count = 0
|
||||
try:
|
||||
self.get_logger().info(f"[ASR] 识别结果: {self._last_result}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _put_vad_event(self, event_type):
|
||||
try:
|
||||
self.vad_event_queue.put(event_type, timeout=0.1)
|
||||
except queue.Full:
|
||||
try:
|
||||
self.get_logger().warn(f"[ASR] VAD事件队列已满,丢弃{event_type}事件")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _audio_data_callback(self, request, response):
|
||||
import threading
|
||||
self.get_logger().debug(f"[ASR-AudioData] 回调触发 | command:{request.command} | 线程:{threading.current_thread().name}")
|
||||
response.sample_rate = self.sample_rate
|
||||
response.channels = self.channels
|
||||
|
||||
if request.command == "start":
|
||||
with self.audio_lock:
|
||||
self.get_logger().debug(f"[ASR-AudioData] start命令 | 旧buffer大小:{len(self.audio_buffer)} | recording:{self.audio_recording}")
|
||||
self.audio_buffer.clear()
|
||||
self.audio_recording = True
|
||||
self.get_logger().debug(f"[ASR-AudioData] buffer已清空,recording=True")
|
||||
response.success = True
|
||||
response.message = "开始录音"
|
||||
response.samples = 0
|
||||
return response
|
||||
|
||||
if request.command == "stop":
|
||||
self.get_logger().debug(f"[ASR-AudioData] stop命令 | recording:{self.audio_recording}")
|
||||
with self.audio_lock:
|
||||
self.audio_recording = False
|
||||
audio_list = list(self.audio_buffer)
|
||||
self.get_logger().debug(f"[ASR-AudioData] 读取buffer | 大小:{len(audio_list)}")
|
||||
self.audio_buffer.clear()
|
||||
if len(audio_list) > 0:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
response.success = True
|
||||
response.audio_data = audio_array.tobytes()
|
||||
response.samples = len(audio_list)
|
||||
response.message = f"录音完成{len(audio_list)}样本"
|
||||
self.get_logger().debug(f"[ASR-AudioData] 返回音频 | samples:{len(audio_list)}")
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "缓冲区为空"
|
||||
response.samples = 0
|
||||
self.get_logger().debug(f"[ASR-AudioData] buffer为空!")
|
||||
return response
|
||||
|
||||
if request.command == "get":
|
||||
with self.audio_lock:
|
||||
audio_list = list(self.audio_buffer)
|
||||
if len(audio_list) > 0:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
response.success = True
|
||||
response.audio_data = audio_array.tobytes()
|
||||
response.samples = len(audio_list)
|
||||
response.message = f"获取到{len(audio_list)}样本"
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "缓冲区为空"
|
||||
response.samples = 0
|
||||
return response
|
||||
|
||||
def _vad_event_callback(self, request, response):
|
||||
timeout = request.timeout_ms / 1000.0 if request.timeout_ms > 0 else None
|
||||
try:
|
||||
event = self.vad_event_queue.get(timeout=timeout)
|
||||
response.success = True
|
||||
response.event = event
|
||||
response.message = "收到VAD事件"
|
||||
except queue.Empty:
|
||||
response.success = False
|
||||
response.event = "none"
|
||||
response.message = "等待超时"
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
self.get_logger().info("[ASR-VAD] 收到中断信号,正在关闭")
|
||||
except Exception:
|
||||
pass
|
||||
response.success = False
|
||||
response.event = "none"
|
||||
response.message = "节点正在关闭"
|
||||
self.stop_event.set()
|
||||
return response
|
||||
|
||||
def _clear_result(self):
|
||||
self._last_result = None
|
||||
self._last_result_time = None
|
||||
self._result_event.clear()
|
||||
|
||||
def _return_result(self, response, text, message):
|
||||
response.success = True
|
||||
response.text = text
|
||||
response.message = message
|
||||
self._clear_result()
|
||||
return response
|
||||
|
||||
def _recognize_callback(self, request, response):
|
||||
if request.command == "stop":
|
||||
if self.asr_client.running:
|
||||
self.asr_client.stop_current_recognition()
|
||||
response.success = True
|
||||
response.text = ""
|
||||
response.message = "识别已停止"
|
||||
return response
|
||||
|
||||
if request.command == "reset":
|
||||
self.asr_client.stop_current_recognition()
|
||||
time.sleep(0.1)
|
||||
self.asr_client.start()
|
||||
response.success = True
|
||||
response.text = ""
|
||||
response.message = "识别器已重置"
|
||||
return response
|
||||
|
||||
if self.asr_client.running:
|
||||
current_time = time.time()
|
||||
if (self._last_result and self._last_result_time and
|
||||
(current_time - self._last_result_time) < 0.3) or (self._result_event.is_set() and self._last_result):
|
||||
return self._return_result(response, self._last_result, "返回最近识别结果")
|
||||
if self._result_event.wait(timeout=2.0) and self._last_result:
|
||||
return self._return_result(response, self._last_result, "识别成功(等待中)")
|
||||
self.asr_client.stop_current_recognition()
|
||||
time.sleep(0.2)
|
||||
|
||||
self._clear_result()
|
||||
|
||||
if not self.asr_client.running and not self.asr_client.start():
|
||||
response.success = False
|
||||
response.text = ""
|
||||
response.message = "ASR启动失败"
|
||||
return response
|
||||
if self._result_event.wait(timeout=5.0) and self._last_result:
|
||||
response.success = True
|
||||
response.text = self._last_result
|
||||
response.message = "识别成功"
|
||||
else:
|
||||
response.success = False
|
||||
response.text = ""
|
||||
response.message = "识别超时" if not self._result_event.is_set() else "识别结果为空"
|
||||
self._clear_result()
|
||||
return response
|
||||
|
||||
def _asr_worker(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
self.get_logger().info("[ASR-Worker] 收到中断信号")
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if self.audio_recording:
|
||||
self.get_logger().debug(f"[ASR-Worker] 收到音频chunk | recording:{self.audio_recording} | buffer_size:{len(self.audio_buffer)}")
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.audio_lock:
|
||||
self.audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[ASR-Worker] buffer写入异常 | 错误:{e}")
|
||||
pass
|
||||
|
||||
if self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
else:
|
||||
if not self.asr_client.start():
|
||||
time.sleep(1.0)
|
||||
|
||||
def destroy_node(self):
|
||||
if self._shutdown_in_progress:
|
||||
return
|
||||
self._shutdown_in_progress = True
|
||||
try:
|
||||
self.get_logger().info("ASR Audio节点正在关闭...")
|
||||
except Exception:
|
||||
pass
|
||||
self.stop_event.set()
|
||||
if hasattr(self, 'recording_thread') and self.recording_thread.is_alive():
|
||||
self.recording_thread.join(timeout=1.0)
|
||||
if hasattr(self, 'asr_thread') and self.asr_thread.is_alive():
|
||||
self.asr_thread.join(timeout=1.0)
|
||||
try:
|
||||
if hasattr(self, 'audio_recorder'):
|
||||
self.audio_recorder.audio.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(self, 'asr_client'):
|
||||
self.asr_client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
super().destroy_node()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = ASRAudioNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
node.get_logger().info("收到中断信号,正在关闭节点")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
node.destroy_node()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
rclpy.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
341
robot_speaker/services/tts_audio_node.py
Normal file
341
robot_speaker/services/tts_audio_node.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
from interfaces.srv import TTSSynthesize
|
||||
import threading
|
||||
import yaml
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
class DashScopeTTSClient:
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
voice: str,
|
||||
card_index: int,
|
||||
device_index: int,
|
||||
output_sample_rate: int,
|
||||
output_channels: int,
|
||||
output_volume: float,
|
||||
tts_source_sample_rate: int,
|
||||
tts_source_channels: int,
|
||||
tts_ffmpeg_thread_queue_size: int,
|
||||
force_stop_delay: float,
|
||||
cleanup_timeout: float,
|
||||
terminate_timeout: float,
|
||||
logger):
|
||||
dashscope.api_key = api_key
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.card_index = card_index
|
||||
self.device_index = device_index
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_channels = output_channels
|
||||
self.output_volume = output_volume
|
||||
self.tts_source_sample_rate = tts_source_sample_rate
|
||||
self.tts_source_channels = tts_source_channels
|
||||
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
|
||||
self.force_stop_delay = force_stop_delay
|
||||
self.cleanup_timeout = cleanup_timeout
|
||||
self.terminate_timeout = terminate_timeout
|
||||
self.logger = logger
|
||||
self.current_ffmpeg_pid = None
|
||||
self._current_callback = None
|
||||
|
||||
self.alsa_device = f"plughw:{card_index},{device_index}" if (
|
||||
card_index >= 0 and device_index >= 0
|
||||
) else "default"
|
||||
|
||||
|
||||
def force_stop(self):
|
||||
if self._current_callback:
|
||||
self._current_callback._interrupted = True
|
||||
if not self.current_ffmpeg_pid:
|
||||
if self.logger:
|
||||
self.logger.warn("[TTS] force_stop: current_ffmpeg_pid is None")
|
||||
return
|
||||
pid = self.current_ffmpeg_pid
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 正在kill进程 {pid}")
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
time.sleep(self.force_stop_delay)
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 已发送SIGKILL到进程 {pid}")
|
||||
except ProcessLookupError:
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 进程 {pid} 已退出")
|
||||
except (ProcessLookupError, OSError) as e:
|
||||
if self.logger:
|
||||
self.logger.warn(f"[TTS] force_stop: kill进程失败 {pid}: {e}")
|
||||
finally:
|
||||
self.current_ffmpeg_pid = None
|
||||
self._current_callback = None
|
||||
|
||||
def synthesize(self, text: str, voice: str = None,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
callback = _TTSCallback(self, interrupt_check, on_chunk)
|
||||
self._current_callback = callback
|
||||
voice_to_use = voice if voice and voice.strip() else self.voice
|
||||
|
||||
if not voice_to_use or not voice_to_use.strip():
|
||||
if self.logger:
|
||||
self.logger.error(f"[TTS] Voice参数无效: '{voice_to_use}'")
|
||||
self._current_callback = None
|
||||
return False
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=self.model,
|
||||
voice=voice_to_use,
|
||||
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
try:
|
||||
synthesizer.streaming_call(text)
|
||||
synthesizer.streaming_complete()
|
||||
finally:
|
||||
callback.cleanup()
|
||||
self._current_callback = None
|
||||
|
||||
return not callback._interrupted
|
||||
|
||||
|
||||
class _TTSCallback(ResultCallback):
|
||||
def __init__(self, tts_client: DashScopeTTSClient,
|
||||
interrupt_check=None,
|
||||
on_chunk=None):
|
||||
self.tts_client = tts_client
|
||||
self.interrupt_check = interrupt_check
|
||||
self.on_chunk = on_chunk
|
||||
self._proc = None
|
||||
self._interrupted = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def on_open(self):
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg',
|
||||
'-f', 's16le',
|
||||
'-ar', str(self.tts_client.tts_source_sample_rate),
|
||||
'-ac', str(self.tts_client.tts_source_channels),
|
||||
'-i', 'pipe:0',
|
||||
'-f', 'alsa',
|
||||
'-ar', str(self.tts_client.output_sample_rate),
|
||||
'-ac', str(self.tts_client.output_channels),
|
||||
'-acodec', 'pcm_s16le',
|
||||
'-fflags', 'nobuffer',
|
||||
'-flags', 'low_delay',
|
||||
'-avioflags', 'direct',
|
||||
self.tts_client.alsa_device
|
||||
]
|
||||
|
||||
insert_pos = ffmpeg_cmd.index('-i')
|
||||
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
|
||||
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
|
||||
|
||||
if self.tts_client.output_volume != 1.0:
|
||||
acodec_idx = ffmpeg_cmd.index('-acodec')
|
||||
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
|
||||
ffmpeg_cmd.insert(acodec_idx, '-af')
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
ffmpeg_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
self.tts_client.current_ffmpeg_pid = self._proc.pid
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
if self._interrupted:
|
||||
return
|
||||
|
||||
if self.interrupt_check and self.interrupt_check():
|
||||
self._interrupted = True
|
||||
if self._proc:
|
||||
self._proc.terminate()
|
||||
return
|
||||
|
||||
if self._proc and self._proc.stdin and not self._interrupted:
|
||||
try:
|
||||
self._proc.stdin.write(data)
|
||||
self._proc.stdin.flush()
|
||||
except BrokenPipeError:
|
||||
self._interrupted = True
|
||||
except OSError:
|
||||
self._interrupted = True
|
||||
|
||||
if self.on_chunk and not self._interrupted:
|
||||
self.on_chunk(data)
|
||||
|
||||
def cleanup(self):
|
||||
if self._cleaned_up or not self._proc:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
|
||||
if self._proc.stdin and not self._proc.stdin.closed:
|
||||
self._proc.stdin.close()
|
||||
|
||||
if self._proc.poll() is None:
|
||||
self._proc.wait(timeout=self.tts_client.cleanup_timeout)
|
||||
if self._proc.poll() is None:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=self.tts_client.terminate_timeout)
|
||||
if self._proc.poll() is None:
|
||||
self._proc.kill()
|
||||
|
||||
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
|
||||
|
||||
class TTSAudioNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('tts_audio_node')
|
||||
self._load_config()
|
||||
self._init_tts_client()
|
||||
|
||||
self.callback_group = ReentrantCallbackGroup()
|
||||
self.synthesize_service = self.create_service(
|
||||
TTSSynthesize, '/tts/synthesize', self._synthesize_callback,
|
||||
callback_group=self.callback_group
|
||||
)
|
||||
|
||||
self.interrupt_event = threading.Event()
|
||||
self.playing_lock = threading.Lock()
|
||||
self.is_playing = False
|
||||
|
||||
self.get_logger().info("[TTS] TTS Audio节点已启动")
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
audio = config['audio']
|
||||
soundcard = audio['soundcard']
|
||||
tts_audio = audio['tts']
|
||||
dashscope = config['dashscope']
|
||||
|
||||
self.output_card_index = soundcard['card_index']
|
||||
self.output_device_index = soundcard['device_index']
|
||||
self.output_sample_rate = soundcard['sample_rate']
|
||||
self.output_channels = soundcard['channels']
|
||||
self.output_volume = soundcard['volume']
|
||||
|
||||
self.tts_source_sample_rate = tts_audio['source_sample_rate']
|
||||
self.tts_source_channels = tts_audio['source_channels']
|
||||
self.tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
|
||||
self.force_stop_delay = tts_audio['force_stop_delay']
|
||||
self.cleanup_timeout = tts_audio['cleanup_timeout']
|
||||
self.terminate_timeout = tts_audio['terminate_timeout']
|
||||
self.interrupt_wait = tts_audio['interrupt_wait']
|
||||
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.tts_model = dashscope['tts']['model']
|
||||
self.tts_voice = dashscope['tts']['voice']
|
||||
|
||||
def _init_tts_client(self):
|
||||
self.tts_client = DashScopeTTSClient(
|
||||
api_key=self.dashscope_api_key,
|
||||
model=self.tts_model,
|
||||
voice=self.tts_voice,
|
||||
card_index=self.output_card_index,
|
||||
device_index=self.output_device_index,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_channels=self.output_channels,
|
||||
output_volume=self.output_volume,
|
||||
tts_source_sample_rate=self.tts_source_sample_rate,
|
||||
tts_source_channels=self.tts_source_channels,
|
||||
tts_ffmpeg_thread_queue_size=self.tts_ffmpeg_thread_queue_size,
|
||||
force_stop_delay=self.force_stop_delay,
|
||||
cleanup_timeout=self.cleanup_timeout,
|
||||
terminate_timeout=self.terminate_timeout,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
def _synthesize_callback(self, request, response):
|
||||
command = request.command if request.command else "synthesize"
|
||||
|
||||
if command == "interrupt":
|
||||
with self.playing_lock:
|
||||
was_playing = self.is_playing
|
||||
has_pid = self.tts_client.current_ffmpeg_pid is not None
|
||||
if was_playing or has_pid:
|
||||
self.interrupt_event.set()
|
||||
self.tts_client.force_stop()
|
||||
self.is_playing = False
|
||||
response.success = True
|
||||
response.message = "已中断播放"
|
||||
response.status = "interrupted"
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "没有正在播放的内容"
|
||||
response.status = "none"
|
||||
return response
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
response.success = False
|
||||
response.message = "文本为空"
|
||||
response.status = "error"
|
||||
return response
|
||||
|
||||
with self.playing_lock:
|
||||
if self.is_playing:
|
||||
self.tts_client.force_stop()
|
||||
time.sleep(self.interrupt_wait)
|
||||
self.is_playing = True
|
||||
|
||||
self.interrupt_event.clear()
|
||||
|
||||
def synthesize_worker():
|
||||
try:
|
||||
success = self.tts_client.synthesize(
|
||||
request.text.strip(),
|
||||
voice=request.voice if request.voice else None,
|
||||
interrupt_check=lambda: self.interrupt_event.is_set()
|
||||
)
|
||||
with self.playing_lock:
|
||||
self.is_playing = False
|
||||
if self.get_logger():
|
||||
if success:
|
||||
self.get_logger().info("[TTS] 合成并播放成功")
|
||||
else:
|
||||
self.get_logger().info("[TTS] 播放被中断")
|
||||
except Exception as e:
|
||||
with self.playing_lock:
|
||||
self.is_playing = False
|
||||
if self.get_logger():
|
||||
self.get_logger().error(f"[TTS] 合成失败: {e}")
|
||||
|
||||
thread = threading.Thread(target=synthesize_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
response.success = True
|
||||
response.message = "合成任务已启动"
|
||||
response.status = "playing"
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = TTSAudioNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""理解层"""
|
||||
|
||||
|
||||
|
||||
5
setup.py
5
setup.py
@@ -15,6 +15,7 @@ setup(
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
|
||||
(os.path.join('share', package_name, 'srv'), glob('srv/*.srv')),
|
||||
],
|
||||
install_requires=[
|
||||
'setuptools',
|
||||
@@ -25,11 +26,13 @@ setup(
|
||||
maintainer_email='mzebra@foxmail.com',
|
||||
description='语音识别和合成ROS2包',
|
||||
license='Apache-2.0',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
|
||||
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
|
||||
'skill_bridge_node = robot_speaker.bridge.skill_bridge_node:main',
|
||||
'asr_audio_node = robot_speaker.services.asr_audio_node:main',
|
||||
'tts_audio_node = robot_speaker.services.tts_audio_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
10
srv/ASRRecognize.srv
Normal file
10
srv/ASRRecognize.srv
Normal file
@@ -0,0 +1,10 @@
|
||||
# 请求:启动识别
|
||||
string command # "start" (默认), "stop", "reset"
|
||||
---
|
||||
# 响应:识别结果
|
||||
bool success
|
||||
string text # 识别文本(空字符串表示未识别到)
|
||||
string message # 状态消息
|
||||
|
||||
|
||||
|
||||
27
srv/AudioData.srv
Normal file
27
srv/AudioData.srv
Normal file
@@ -0,0 +1,27 @@
|
||||
# 请求:获取音频数据
|
||||
string command # "start" (开始录音), "stop" (停止并返回), "get" (获取当前缓冲区)
|
||||
int32 duration_ms # 录音时长(毫秒),仅用于start命令
|
||||
---
|
||||
# 响应:音频数据
|
||||
bool success
|
||||
uint8[] audio_data # PCM音频数据(int16格式)
|
||||
int32 sample_rate
|
||||
int32 channels
|
||||
int32 samples # 样本数
|
||||
string message
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
14
srv/TTSSynthesize.srv
Normal file
14
srv/TTSSynthesize.srv
Normal file
@@ -0,0 +1,14 @@
|
||||
# 请求:合成文本或中断命令
|
||||
string command # "synthesize" (默认), "interrupt"
|
||||
string text
|
||||
string voice # 可选,默认使用配置
|
||||
---
|
||||
# 响应:合成状态
|
||||
bool success
|
||||
string message
|
||||
string status # "playing", "completed", "interrupted"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
11
srv/VADEvent.srv
Normal file
11
srv/VADEvent.srv
Normal file
@@ -0,0 +1,11 @@
|
||||
# 请求:等待VAD事件
|
||||
string command # "wait" (等待下一个事件)
|
||||
int32 timeout_ms # 超时时间(毫秒),0表示无限等待
|
||||
---
|
||||
# 响应:VAD事件
|
||||
bool success
|
||||
string event # "speech_started", "speech_stopped", "none"
|
||||
string message
|
||||
|
||||
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
查看相机画面的简单脚本
|
||||
按空格键保存当前帧,按'q'键退出
|
||||
"""
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
try:
|
||||
import pyrealsense2 as rs
|
||||
except ImportError:
|
||||
print("错误: 未安装pyrealsense2,请运行: pip install pyrealsense2")
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
# 配置相机
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
|
||||
# 启用彩色流
|
||||
config.enable_stream(rs.stream.color, 640, 480, rs.format.rgb8, 30)
|
||||
|
||||
# 启动管道
|
||||
pipeline.start(config)
|
||||
print("相机已启动,按空格键保存图片,按'q'键退出")
|
||||
|
||||
frame_count = 0
|
||||
try:
|
||||
while True:
|
||||
# 等待一帧
|
||||
frames = pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
continue
|
||||
|
||||
# 转换为numpy数组 (RGB格式)
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
# OpenCV使用BGR格式,需要转换
|
||||
bgr_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 显示图像
|
||||
cv2.imshow('Camera View', bgr_image)
|
||||
|
||||
# 等待按键
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord('q'):
|
||||
print("退出...")
|
||||
break
|
||||
elif key == ord(' '): # 空格键保存
|
||||
frame_count += 1
|
||||
filename = f'camera_frame_{frame_count:04d}.jpg'
|
||||
cv2.imwrite(filename, bgr_image)
|
||||
print(f"已保存: {filename}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n中断...")
|
||||
finally:
|
||||
pipeline.stop()
|
||||
cv2.destroyAllWindows()
|
||||
print("相机已关闭")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user