merge remote

This commit is contained in:
NuoDaJia02
2026-01-28 14:45:42 +08:00
41 changed files with 2297 additions and 2524 deletions

3
.gitignore vendored
View File

@@ -4,3 +4,6 @@ log/
__pycache__/
*.pyc
*.egg-info/
dist/
lib/
installed_files.txt

142
CMakeLists.txt Normal file
View File

@@ -0,0 +1,142 @@
cmake_minimum_required(VERSION 3.8)
project(robot_speaker)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()
find_package(ament_cmake REQUIRED)
find_package(ament_cmake_python REQUIRED)
find_package(rosidl_default_generators REQUIRED)
# 确保使用系统 Python而不是 conda/miniconda 的 Python
find_program(PYTHON3_CMD python3 PATHS /usr/bin /usr/local/bin NO_DEFAULT_PATH)
if(NOT PYTHON3_CMD)
find_program(PYTHON3_CMD python3)
endif()
if(PYTHON3_CMD)
set(Python3_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python 3 executable" FORCE)
set(PYTHON_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python executable" FORCE)
endif()
rosidl_generate_interfaces(${PROJECT_NAME}
"srv/ASRRecognize.srv"
"srv/TTSSynthesize.srv"
"srv/VADEvent.srv"
"srv/AudioData.srv"
)
install(CODE "
execute_process(
COMMAND ${PYTHON3_CMD} -m pip install --prefix=${CMAKE_INSTALL_PREFIX} --no-deps ${CMAKE_CURRENT_SOURCE_DIR}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE install_result
OUTPUT_VARIABLE install_output
ERROR_VARIABLE install_error
)
if(NOT install_result EQUAL 0)
message(FATAL_ERROR \"Failed to install Python package. Output: ${install_output} Error: ${install_error}\")
endif()
execute_process(
COMMAND ${PYTHON3_CMD} -c \"
import os
import shutil
import glob
import sysconfig
install_prefix = '${CMAKE_INSTALL_PREFIX}'
build_dir = '${CMAKE_CURRENT_BINARY_DIR}'
python_version = f'{sysconfig.get_python_version()}'
# ROS2 期望的 Python 包位置
ros2_site_packages = os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages')
os.makedirs(ros2_site_packages, exist_ok=True)
# pip install --prefix 可能将包安装到不同位置(系统环境通常是 local/lib/pythonX/dist-packages
pip_locations = [
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'dist-packages'),
os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages'),
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'site-packages'),
]
# 查找并复制 robot_speaker 包到 ROS2 期望的位置
robot_speaker_src = None
for location in pip_locations:
candidate = os.path.join(location, 'robot_speaker')
if os.path.exists(candidate) and os.path.isdir(candidate):
robot_speaker_src = candidate
break
if robot_speaker_src:
robot_speaker_dest = os.path.join(ros2_site_packages, 'robot_speaker')
if os.path.exists(robot_speaker_dest):
shutil.rmtree(robot_speaker_dest)
if robot_speaker_src != robot_speaker_dest:
shutil.copytree(robot_speaker_src, robot_speaker_dest)
print(f'Copied robot_speaker from {robot_speaker_src} to {ros2_site_packages}')
else:
print(f'robot_speaker already in correct location')
# 复制 ROS2 生成的 srv 模块rosidl_generate_interfaces 生成的)
rosidl_py_src = os.path.join(build_dir, 'rosidl_generator_py', 'robot_speaker')
if os.path.exists(rosidl_py_src):
# 复制 srv 目录
srv_src = os.path.join(rosidl_py_src, 'srv')
srv_dest = os.path.join(robot_speaker_dest, 'srv')
if os.path.exists(srv_src):
if os.path.exists(srv_dest):
shutil.rmtree(srv_dest)
shutil.copytree(srv_src, srv_dest)
print(f'Copied srv module to {srv_dest}')
# 复制生成的接口文件(.so 和 .c 文件)
for pattern in ['robot_speaker_s__rosidl_typesupport*.so', '_robot_speaker_s*.c']:
for file in glob.glob(os.path.join(rosidl_py_src, pattern)):
dest_file = os.path.join(robot_speaker_dest, os.path.basename(file))
shutil.copy2(file, dest_file)
print(f'Copied {os.path.basename(file)} to {robot_speaker_dest}')
# 处理 entry_points 脚本
lib_dir = os.path.join(install_prefix, 'lib', 'robot_speaker')
os.makedirs(lib_dir, exist_ok=True)
# 脚本可能在 local/bin 或 bin 中
for bin_dir in [os.path.join(install_prefix, 'local', 'bin'), os.path.join(install_prefix, 'bin')]:
if os.path.exists(bin_dir):
scripts = glob.glob(os.path.join(bin_dir, '*_node'))
for script in scripts:
script_name = os.path.basename(script)
dest = os.path.join(lib_dir, script_name)
if script != dest:
shutil.copy2(script, dest)
os.chmod(dest, 0o755)
print(f'Copied {script_name} to {lib_dir}')
\"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE python_result
OUTPUT_VARIABLE python_output
)
if(python_result EQUAL 0)
message(STATUS \"${python_output}\")
else()
message(WARNING \"Failed to setup Python package: ${python_output}\")
endif()
")
install(DIRECTORY launch/
DESTINATION share/${PROJECT_NAME}/launch
FILES_MATCHING PATTERN "*.launch.py"
)
install(DIRECTORY config/
DESTINATION share/${PROJECT_NAME}/config
FILES_MATCHING PATTERN "*.yaml" PATTERN "*.json"
)
if(BUILD_TESTING)
find_package(ament_lint_auto REQUIRED)
ament_lint_auto_find_test_dependencies()
endif()
ament_package()

View File

@@ -45,37 +45,26 @@ source install/setup.bash
ros2 launch robot_speaker voice.launch.py
```
## 架构说明
[录音线程] - 唯一实时线程
├─ 麦克风采集 PCM
├─ VAD + 能量检测
├─ 检测到人声 → 立即中断TTS
├─ 语音 PCM → ASR 音频队列
└─ 语音 PCM → 声纹音频队列(旁路,不阻塞)
3. ASR节点
```bash
ros2 run robot_speaker asr_audio_node
```
[ASR推理线程] - 只做 audio → text
└─ 从 ASR 音频队列取音频→ 实时 / 流式 ASR → text → 文本队列
4. TTS节点
```bash
# 终端1: 启动TTS节点
ros2 run robot_speaker tts_audio_node
[声纹识别线程] - 非实时、低频CAM++
├─ 通过回调函数接收音频chunk写入缓冲区等待 speech_end 事件触发处理
├─ 累积 1~2 秒有效人声VAD 后)
├─ CAM++ 提取 speaker embedding
├─ 声纹匹配 / 注册
└─ 更新 current_speaker_id共享状态只写不控
声纹线程要求不影响录音不影响ASR不控制TTS只更新当前说话人是谁
# 终端2: 启动播放
source install/setup.bash
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
"{command: 'synthesize', text: '这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。让我继续说下去这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。让我继续说下去这是一段很长的测试文本用于测试TTS中断功能。我需要说很多很多内容这样你才有足够的时间来测试中断命令。', voice: ''}"
[主线程/处理线程] - 处理业务逻辑
├─ 从 文本队列 取 ASR 文本
├─ 读取 current_speaker_id只读
├─ 唤醒词处理(结合 speaker_id
├─ 权限 / 身份判断(是否允许继续)
├─ VLM处理文本 / 多模态)
└─ TTS播放启动TTS线程不等待
[TTS播放线程] - 只播放(可被中断)
├─ 接收 TTS 音频流
├─ 播放到输出设备
└─ 响应中断标志(由录音线程触发)
# 终端3: 立即执行中断
source install/setup.bash
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
"{command: 'interrupt', text: '', voice: ''}"
```
## 用到的命令

View File

@@ -394,5 +394,205 @@
],
"env": "near",
"registered_at": 1768964438.9963026
},
"user_1769515089": {
"embedding": [
[
-1.5295110940933228,
0.5238341093063354,
0.08633111417293549,
0.11756575852632523,
1.44246244430542,
-1.6976442337036133,
0.2645050585269928,
1.5642119646072388,
1.4558132886886597,
-2.018132448196411,
0.30136486887931824,
1.5590322017669678,
0.3676050007343292,
2.096036434173584,
-1.203681468963623,
0.2745387852191925,
1.128976821899414,
-0.8042266368865967,
-0.04837780073285103,
-0.8245053291320801,
-0.6101562976837158,
0.08143205940723419,
-1.1198647022247314,
1.7753965854644775,
-0.5257269144058228,
-0.6572340726852417,
-0.08467039465904236,
0.08285830914974213,
0.49599483609199524,
-2.871098756790161,
-1.1618938446044922,
0.7318744659423828,
2.08620548248291,
0.18100303411483765,
-0.5528441071510315,
0.13717415928840637,
0.22606758773326874,
0.23349706828594208,
0.40789690613746643,
-0.23644576966762543,
-0.12830045819282532,
1.0583454370498657,
0.3954410254955292,
-1.0476133823394775,
0.6569878458976746,
0.43412935733795166,
0.7459996938705444,
0.25105446577072144,
0.40695688128471375,
0.41371095180511475,
-0.5081073045730591,
-0.15921951830387115,
0.6312111020088196,
2.678532123565674,
1.5355063676834106,
1.898784875869751,
1.257870078086853,
2.026048421859741,
1.1490176916122437,
0.742881178855896,
-1.206595540046692,
0.5405871272087097,
0.01001159567385912,
-0.7743952870368958,
-0.1243305653333664,
0.4287954568862915,
-1.1704397201538086,
2.057995557785034,
0.30912983417510986,
1.0761916637420654,
1.3979746103286743,
-1.070613145828247,
2.0996458530426025,
-0.16294217109680176,
-0.15417678654193878,
-0.6481220722198486,
0.9156526923179626,
0.7209145426750183,
-1.3280514478683472,
0.08632978051900864,
-0.09424483776092529,
1.8493571281433105,
0.917565107345581,
-0.0257036704570055,
-1.0192301273345947,
-0.8172388672828674,
0.37842708826065063,
0.20112906396389008,
-0.18812096118927002,
0.12312255054712296,
0.3173609673976898,
0.029730113223195076,
-0.662641704082489,
0.6436728239059448,
0.3574063181877136,
0.27612701058387756,
-0.6808024644851685,
-1.1454781293869019,
0.7457495927810669,
-1.8407135009765625,
-0.6051219701766968,
2.167180299758911,
0.181788831949234,
1.2942312955856323,
-2.2572178840637207,
-0.6572328209877014,
-0.44301870465278625,
0.5519763827323914,
-0.02834797278046608,
-1.118048906326294,
-0.44019994139671326,
1.2326226234436035,
-0.2865355312824249,
-1.9306018352508545,
0.4287217855453491,
-0.5471329092979431,
-1.8593220710754395,
-0.2029312551021576,
0.6949507594108582,
-0.2491024136543274,
-0.6223251819610596,
-0.5916008949279785,
1.3497960567474365,
-0.47974079847335815,
1.6955225467681885,
0.17834797501564026,
0.13161484897136688,
0.20850282907485962,
-0.04633784666657448,
-0.9113361835479736,
-1.1419169902801514,
-1.0826172828674316,
-0.2316463589668274,
0.45178237557411194,
0.18495112657546997,
0.535635232925415,
1.923178791999817,
-0.7357022762298584,
-0.5064287185668945,
0.5609160661697388,
1.1650713682174683,
-0.5384876728057861,
1.2522424459457397,
-1.309113621711731,
0.22394417226314545,
-0.14331775903701782,
0.7612791061401367,
-1.8949273824691772,
-0.8273413181304932,
0.15730154514312744,
0.5960761904716492,
-1.5179729461669922,
-1.3346058130264282,
-1.0774084329605103,
-0.960814356803894,
-0.14860300719738007,
-0.9822415113449097,
1.821016788482666,
-0.4035312235355377,
0.6270486116409302,
0.6994175910949707,
-0.8607892394065857,
0.7216717004776001,
-1.2650134563446045,
0.05397822707891464,
0.2296375185251236,
-0.40239569544792175,
-0.44462206959724426,
0.12279012054204941,
-0.3110475540161133,
1.0768173933029175,
-0.21416479349136353,
-0.44052380323410034,
0.743086040019989,
-1.3203964233398438,
0.47284168004989624,
0.16021426022052765,
1.2153557538986206,
0.7987464666366577,
-0.27521243691444397,
0.25042879581451416,
-0.36083176732063293,
1.5787007808685303,
1.2494744062423706,
0.16907380521297455,
0.01833455078303814,
-0.16504760086536407,
1.3832142353057861,
-0.331011027097702,
-0.28575095534324646,
-0.3638729751110077,
0.37575358152389526
]
],
"env": "",
"registered_at": 1769515089.7623787
}
}

View File

@@ -1,5 +1,11 @@
# ROS 语音包配置文件
asr:
mode: 'cloud' # 'cloud' | 'local' - ASR模式选择
local:
server_url: "ws://127.0.0.1:10095" # 本地FunASR服务地址
# 云端模式配置在dashscope中
dashscope:
api_key: "sk-7215a5ab7a00469db4072e1672a0661e"
asr:
@@ -33,6 +39,10 @@ audio:
source_sample_rate: 22050 # TTS服务固定输出采样率DashScope服务固定值不可修改
source_channels: 1 # TTS服务固定输出声道数DashScope服务固定值不可修改
ffmpeg_thread_queue_size: 4096 # ffmpeg输入线程队列大小增大以减少卡顿
force_stop_delay: 0.1 # 强制停止时的延迟(秒)
cleanup_timeout: 30.0 # 清理超时(秒)
terminate_timeout: 1.0 # 终止超时(秒)
interrupt_wait: 0.1 # 中断等待时间(秒)
vad:
vad_mode: 3 # VAD模式0-33最严格
@@ -40,7 +50,6 @@ vad:
min_energy_threshold: 300 # 最小能量阈值
system:
use_llm: true # 是否使用LLM
use_wake_word: true # 是否启用唤醒词检测
wake_word: "er gou" # 唤醒词(拼音)
session_timeout: 3.0 # 会话超时时间(秒)
@@ -53,7 +62,9 @@ system:
sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录
# sv_speaker_db_path: "~/ros_learn/hivecore_robot_voice/config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录
sv_buffer_size: 240000 # 声纹验证录音缓冲区大小样本数48kHz下5秒=240000
continue_without_image: false # 多模态意图skill_sequence/chat_camera未获取到图片时是否继续推理
continue_without_image: true # 多模态意图skill_sequence/chat_camera未获取到图片时是否继续推理
skill_auto_retry: true
skill_max_retries: 5
camera:
image:

View File

@@ -0,0 +1,54 @@
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.actions import SetEnvironmentVariable, RegisterEventHandler
from launch.event_handlers import OnProcessExit
from launch.actions import EmitEvent
from launch.events import Shutdown
import os
def generate_launch_description():
"""启动声纹注册节点需要ASR服务"""
# 获取interfaces包的install路径
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
# 设置AMENT_PREFIX_PATH确保能找到interfaces包的消息类型
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
if interfaces_install_path not in ament_prefix_path:
if ament_prefix_path:
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
else:
ament_prefix_path = interfaces_install_path
# ASR + 音频输入设备节点提供ASR和AudioData服务
asr_audio_node = Node(
package='robot_speaker',
executable='asr_audio_node',
name='asr_audio_node',
output='screen'
)
# 声纹注册节点
register_speaker_node = Node(
package='robot_speaker',
executable='register_speaker_node',
name='register_speaker_node',
output='screen'
)
# 当注册节点退出时,关闭整个 launch
register_exit_handler = RegisterEventHandler(
OnProcessExit(
target_action=register_speaker_node,
on_exit=[
EmitEvent(event=Shutdown(reason='注册完成,关闭所有节点'))
]
)
)
return LaunchDescription([
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
asr_audio_node,
register_speaker_node,
register_exit_handler,
])

View File

@@ -19,6 +19,21 @@ def generate_launch_description():
return LaunchDescription([
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
# ASR + 音频输入设备节点同时提供VAD事件Service利用云端ASR的VAD
Node(
package='robot_speaker',
executable='asr_audio_node',
name='asr_audio_node',
output='screen'
),
# TTS + 音频输出设备节点
Node(
package='robot_speaker',
executable='tts_audio_node',
name='tts_audio_node',
output='screen'
),
# 主业务逻辑节点
Node(
package='robot_speaker',
executable='robot_speaker_node',

View File

@@ -13,6 +13,11 @@
<depend>cv_bridge</depend>
<depend>ament_index_python</depend>
<depend>interfaces</depend>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_python</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<exec_depend>python3-pyaudio</exec_depend>
<exec_depend>python3-requests</exec_depend>
@@ -27,6 +32,6 @@
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@@ -12,6 +12,7 @@ aec-audio-processing
modelscope>=1.33.0
funasr>=1.0.0
datasets==3.6.0
websocket-client>=1.6.0

View File

@@ -4,3 +4,14 @@

View File

@@ -7,3 +7,15 @@

View File

@@ -1,10 +1,17 @@
"""
对话历史管理模块
"""
from robot_speaker.core.types import LLMMessage
from dataclasses import dataclass
import threading
@dataclass
class LLMMessage:
"""LLM消息"""
role: str # "user", "assistant", "system"
content: str
class ConversationHistory:
"""对话历史管理器 - 实时语音"""
@@ -109,3 +116,8 @@ class ConversationHistory:
self.summary = None
self._pending_user_message = None

View File

@@ -1,10 +0,0 @@
from enum import Enum
class ConversationState(Enum):
"""会话状态机"""
IDLE = "idle" # 等待用户唤醒或声音
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
AUTHORIZED = "authorized" # 已注册用户

View File

@@ -37,6 +37,7 @@ class IntentRouter:
"ju qi", "sheng qi", # 举起、升起
"jia zhua", "jia qi", "jia", # 夹爪、夹起、夹
"shen you bi", "shen zuo bi", "shen chu", "shen shou", # 伸右臂、伸左臂、伸出、伸手
"zhuan quan", "zhuan yi quan", "zhuan", # 转个圈、转一圈、转
]
self.kb_keywords = [
"ni shi shui", "ni de ming zi", "tiao ge wu", "ni jiao sha", "ni hui gan", "ni neng gan"
@@ -153,6 +154,7 @@ class IntentRouter:
return (
"你是机器人任务规划器。\n"
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照,并结合当前环境信息生成简洁、可执行的技能序列。\n"
"如果用户明确要求或者任务明显需要双手/双臂协作(如扶稳+操作、抓取大体积的物体),必须规划双手技能。\n"
+ execution_hint
+ "\n"
"【规划要求】\n"
@@ -161,18 +163,13 @@ class IntentRouter:
" - parallel并行技能可以同时执行\n"
"2. parameters规划根据目标物距离和任务需求规划具体参数值\n"
" - parameters字典必须包含该技能接口文件目标字段的所有字段\n"
" - 对于包含body_id字段的技能如Armbody_id值根据目标物在图片中的方位选择\n"
" * 目标物在图片左侧或机器人左侧使用body_id=0左臂\n"
" * 目标物在图片右侧或机器人右侧使用body_id=1右臂\n"
" * 目标物在图片中央或需要头部操作使用body_id=2头部\n"
"\n"
"【输出格式要求】\n"
"必须输出JSON格式包含sequence数组。每个技能对象包含3个一级字段\n"
"1. skill: 技能名称(字符串)\n"
"2. execution: 执行方式serial串行或 parallel并行\n"
"3. parameters: 参数字典包含该技能接口文件目标字段的所有字段并填入合理的预测值。如果技能无参数使用null。\n"
"\n"
"注意一级字段skill, execution, parameters是固定结构,直接使用即可,不需要预测\n"
"注意一级字段skill, execution, parameters是固定结构。\n"
"\n"
"【技能参数说明】\n"
+ skill_params_doc +
@@ -191,12 +188,13 @@ class IntentRouter:
def build_chat_prompt(self, need_camera: bool) -> str:
if need_camera:
return (
"你是一个智能语音助手\n"
"请结合图片内容简短回答。不要超过100个token。"
"你是一个机器人视觉助理,擅长分析图片中物体的相对位置和空间关系\n"
"请结合图片内容,重点描述物体之间的相对位置(如左右、前后、上下、远近),仅基于可观察信息回答。\n"
"回答应简短、客观不要超过100个token。"
)
return (
"你是一个智能语音助手\n"
"自然、简短地与用户对话不要超过100个token。"
"你是一个表达清晰、语气自然的真人助理\n"
"请简短地与用户对话不要超过100个token。"
)
def _load_kb_data(self) -> list[dict]:
@@ -233,7 +231,7 @@ class IntentRouter:
def build_default_system_prompt(self) -> str:
return (
"你是一个智能语音助手。\n"
"你是一个工厂专业的助手。\n"
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
"请根据对话模式调整你的回答风格。"

View File

@@ -1,246 +0,0 @@
import threading
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeCallbacks:
# ==================== 初始化与内部工具 ====================
def __init__(self, node):
self.node = node
def _mark_utterance_processed(self) -> bool:
node = self.node
with node.utterance_lock:
if node.current_utterance_id == node.last_processed_utterance_id:
return False
node.last_processed_utterance_id = node.current_utterance_id
return True
def _trigger_sv_for_check_voice(self, source: str):
node = self.node
if not (node.sv_enabled and node.sv_client):
return
if not self._mark_utterance_processed():
return
if node._handle_empty_speaker_db():
node.get_logger().info(f"[声纹] CHECK_VOICE状态数据库为空跳过声纹验证来源: {source}")
return
if not node.sv_speech_end_event.is_set():
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if buffer_size > 0:
node.sv_speech_end_event.set()
else:
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source}")
# ==================== 业务逻辑代理 ====================
def handle_interrupt_command(self, msg):
return self.node._handle_interrupt_command(msg)
def check_interrupt_and_cancel_turn(self) -> bool:
return self.node._check_interrupt_and_cancel_turn()
def handle_wake_word(self, text: str) -> str:
return self.node._handle_wake_word(text)
def check_shutup_command(self, text: str) -> bool:
return self.node._check_shutup_command(text)
def check_camera_command(self, text: str):
return self.node.intent_router.check_camera_command(text)
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
return self.node._llm_process_stream_with_camera(user_text, need_camera)
def put_tts_text(self, text: str):
return self.node._put_tts_text(text)
def force_stop_tts(self):
return self.node._force_stop_tts()
def drain_queue(self, q):
return self.node._drain_queue(q)
# ==================== 录音/VAD回调 ====================
def get_silence_threshold(self) -> int:
"""获取动态静音阈值(毫秒)"""
node = self.node
return node.silence_duration_ms
def should_put_audio_to_queue(self) -> bool:
"""
检查是否应该将音频放入队列用于ASR,根据状态机决定是否允许ASR
"""
node = self.node
state = node._get_state()
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
ConversationState.AUTHORIZED]:
return True
return False
def on_speech_start(self):
"""录音线程检测到人声开始"""
node = self.node
node.get_logger().info("[录音线程] 检测到人声,开始录音")
with node.utterance_lock:
node.current_utterance_id += 1
state = node._get_state()
if state == ConversationState.IDLE:
# Idle -> CheckVoice
if node.sv_enabled and node.sv_client:
# 开始录音用于声纹验证
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
elif state == ConversationState.CHECK_VOICE:
# CheckVoice状态继续录音用于声纹验证
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
elif state == ConversationState.AUTHORIZED:
# Authorized状态开始录音用于声纹验证验证当前用户
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
node = self.node
state = node._get_state()
# 声纹验证录音CHECK_VOICE, AUTHORIZED状态
if node.sv_enabled and node.sv_recording:
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with node.sv_lock:
node.sv_audio_buffer.extend(audio_array)
except Exception as e:
node.get_logger().debug(f"[声纹] 录音失败: {e}")
def on_speech_end(self):
"""录音线程检测到说话结束(静音一段时间)"""
node = self.node
node.get_logger().info("[录音线程] 检测到说话结束")
state = node._get_state()
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
if state == ConversationState.CHECK_VOICE:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
self._trigger_sv_for_check_voice("VAD")
return
elif state == ConversationState.AUTHORIZED:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
node.sv_speech_end_event.set()
# 如果TTS正在播放异步等待声纹验证结果如果通过才中断TTS
# 使用独立线程避免阻塞录音线程影响TTS播放
if node.tts_playing_event.is_set():
node.get_logger().info("[打断] TTS播放中用户说话结束异步等待声纹验证结果...")
def _check_sv_and_interrupt():
# 等待声纹验证结果最多等待2秒
with node.sv_result_cv:
current_seq = node.sv_result_seq
if node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=2.0
):
# 声纹验证完成,检查结果
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id})中断TTS播放")
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
else:
node.get_logger().debug(f"[打断] 声纹验证未通过不中断TTS状态: {speaker_state.value}")
else:
node.get_logger().warning("[打断] 声纹验证超时不中断TTS")
# 在独立线程中等待,避免阻塞录音线程
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
return
def on_new_segment(self):
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
node = self.node
state = node._get_state()
if state == ConversationState.AUTHORIZED:
# TTS播放期间检测到人声时不立即中断而是开始录音用于声纹验证
# 等待用户说话结束speech_end如果声纹验证通过才中断TTS
# 这样可以避免TTS回声误触发但支持真正的用户打断
if node.tts_playing_event.is_set():
node.get_logger().debug("[打断] TTS播放中检测到人声开始录音用于声纹验证等待说话结束后验证")
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
else:
# TTS未播放时检查声纹验证结果并立即中断
if node.sv_enabled and node.sv_client:
with node.sv_lock:
current_speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
node._interrupt_tts("检测到人声(已授权用户)")
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声但声纹未验证或未匹配不中断TTS当前状态: {speaker_state.value}")
else:
# 未启用声纹,直接中断(保持原有行为)
node._interrupt_tts("检测到人声(未启用声纹)")
node.get_logger().info("[打断] 检测到人声中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value}非已授权用户不允许打断TTS")
def on_heartbeat(self):
"""录音线程静音心跳回调"""
self.node.get_logger().info("[录音线程] 静音中")
# ==================== ASR回调 ====================
def on_asr_sentence_end(self, text: str):
"""ASR sentence_end回调 - 将文本放入队列"""
node = self.node
if not text or not text.strip():
return
text_clean = text.strip()
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
state = node._get_state()
# 规则2CHECK_VOICE状态下如果ASR识别完成但VAD还没有触发speech_end主动触发声纹验证
if state == ConversationState.CHECK_VOICE:
if node.sv_enabled and node.sv_client:
node.get_logger().info("[ASR] CHECK_VOICE状态ASR识别完成主动触发声纹验证")
self._trigger_sv_for_check_voice("ASR")
# 其他状态,将文本放入队列
node.text_queue.put(text_clean, timeout=1.0)
def on_asr_text_update(self, text: str):
"""ASR 实时文本更新回调 - 用于多轮提示"""
if not text or not text.strip():
return
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")

View File

@@ -1,188 +0,0 @@
import queue
import time
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeWorkers:
def __init__(self, node):
self.node = node
def recording_worker(self):
"""线程1: 录音线程 - 唯一实时线程"""
node = self.node
node.get_logger().info("[录音线程] 启动")
node.audio_recorder.record_with_vad()
def asr_worker(self):
"""线程2: ASR推理线程 - 只做 audio → text"""
node = self.node
node.get_logger().info("[ASR推理线程] 启动")
while not node.stop_event.is_set():
try:
audio_chunk = node.audio_queue.get(timeout=0.1)
except queue.Empty:
continue
if node.interrupt_event.is_set():
continue
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
node.asr_client.send_audio(audio_chunk)
def process_worker(self):
"""线程3: 主线程 - 处理业务逻辑"""
node = self.node
node.get_logger().info("[主线程] 启动")
while not node.stop_event.is_set():
try:
text = node.text_queue.get(timeout=0.1)
except queue.Empty:
continue
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
current_state = node._get_state()
if current_state == ConversationState.CHECK_VOICE:
if node.use_wake_word:
node.get_logger().info(f"[主线程] CHECK_VOICE状态检查唤醒词文本: {text}")
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'回到Idle状态")
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
text = processed_text
if node.sv_enabled and node.sv_client:
node.get_logger().info("[主线程] CHECK_VOICE状态等待声纹验证结果...")
with node.sv_result_cv:
current_seq = node.sv_result_seq
if not node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=15.0
):
node.get_logger().warning("[主线程] CHECK_VOICE状态声纹结果未ready超时15秒拒绝本轮")
with node.sv_lock:
node.sv_audio_buffer.clear()
node._change_state(ConversationState.IDLE, "声纹结果未ready")
continue
node.get_logger().info("[主线程] CHECK_VOICE状态声纹结果ready继续处理")
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
score = node.current_speaker_score
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
else:
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
node.callbacks.put_tts_text("声纹验证失败")
node._change_state(ConversationState.IDLE, "声纹验证失败")
continue
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
elif current_state == ConversationState.AUTHORIZED:
if node.tts_playing_event.is_set():
node.get_logger().debug("[主线程] AUTHORIZED状态TTS播放中忽略ASR识别结果只有VAD检测到已授权用户人声才能中断")
continue
elif current_state == ConversationState.IDLE:
node.get_logger().warning("[主线程] Idle状态收到文本忽略")
continue
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
text = processed_text
if node.callbacks.check_shutup_command(text):
node.get_logger().info("[主线程] 检测到闭嘴指令")
node.interrupt_event.set()
node.callbacks.force_stop_tts()
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
continue
intent_payload = node.intent_router.route(text)
node._handle_intent(intent_payload)
if current_state == ConversationState.AUTHORIZED:
node.session_start_time = time.time()
def sv_worker(self):
"""线程5: 声纹识别线程 - 非实时、低频CAM++"""
node = self.node
node.get_logger().info("[声纹识别线程] 启动")
# 动态计算最小音频样本数确保降采样到16kHz后≥0.5秒
target_sr = 16000 # CAM++模型目标采样率
min_duration_seconds = 0.5
min_samples_at_target_sr = int(target_sr * min_duration_seconds) # 8000样本@16kHz
if node.sample_rate >= target_sr:
downsample_step = int(node.sample_rate / target_sr)
min_audio_samples = min_samples_at_target_sr * downsample_step
else:
min_audio_samples = int(node.sample_rate * min_duration_seconds)
while not node.stop_event.is_set():
try:
if node.sv_speech_end_event.wait(timeout=0.1):
node.sv_speech_end_event.clear()
with node.sv_lock:
audio_list = list(node.sv_audio_buffer)
buffer_size = len(audio_list)
node.sv_audio_buffer.clear()
node.get_logger().info(f"[声纹识别] 收到speech_end事件录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if node._handle_empty_speaker_db():
node.get_logger().info("[声纹识别] 数据库为空跳过验证直接设置UNKNOWN状态")
continue
if buffer_size >= min_audio_samples:
audio_array = np.array(audio_list, dtype=np.int16)
embedding, success = node.sv_client.extract_embedding(
audio_array,
sample_rate=node.sample_rate
)
if not success or embedding is None:
node.get_logger().debug("[声纹识别] 提取embedding失败")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.ERROR
node.current_speaker_score = 0.0
else:
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
with node.sv_lock:
node.current_speaker_id = speaker_id
node.current_speaker_state = match_state
node.current_speaker_score = score
if match_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
elif match_state == SpeakerState.REJECTED:
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
else:
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
else:
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.UNKNOWN
node.current_speaker_score = 0.0
with node.sv_result_cv:
node.sv_result_seq += 1
node.sv_result_cv.notify_all()
except Exception as e:
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
time.sleep(0.1)

View File

@@ -1,21 +1,16 @@
"""
声纹注册独立节点:运行完成后退出
"""
import collections
"""声纹注册独立节点:运行完成后退出"""
import os
import queue
import threading
import time
import yaml
import numpy as np
import threading
import queue
import rclpy
from rclpy.node import Node
from ament_index_python.packages import get_package_share_directory
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
from robot_speaker.models.asr.dashscope import DashScopeASR
from robot_speaker.srv import ASRRecognize, AudioData, VADEvent
from robot_speaker.core.speaker_verifier import SpeakerVerificationClient
from pypinyin import pinyin, Style
@@ -24,66 +19,15 @@ class RegisterSpeakerNode(Node):
super().__init__('register_speaker_node')
self._load_config()
self.stop_event = threading.Event()
self.buffer_lock = threading.Lock()
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
self.asr_client = self.create_client(ASRRecognize, '/asr/recognize')
self.audio_data_client = self.create_client(AudioData, '/asr/audio_data')
self.vad_client = self.create_client(VADEvent, '/vad/event')
self.speech_start_idx = None
self.speech_end_idx = None
self.speech_start_time = None
self.audio_queue = queue.Queue()
self.text_queue = queue.Queue()
self.vad_detector = VADDetector(
mode=self.vad_mode,
sample_rate=self.sample_rate
)
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
vad_detector=self.vad_detector,
audio_queue=self.audio_queue,
silence_duration_ms=self.silence_duration_ms,
min_energy_threshold=self.min_energy_threshold,
heartbeat_interval=self.audio_microphone_heartbeat_interval,
on_heartbeat=None,
is_playing=lambda: False,
on_new_segment=None,
on_speech_start=self._on_speech_start,
on_speech_end=self._on_speech_end,
stop_flag=self.stop_event.is_set,
on_audio_chunk=self._on_audio_chunk,
get_silence_threshold=lambda: self.silence_duration_ms,
logger=self.get_logger()
)
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self._on_asr_sentence_end
self.asr_client.start()
self.asr_thread = threading.Thread(
target=self._asr_worker,
name="RegisterASRThread",
daemon=True
)
self.asr_thread.start()
self.text_thread = threading.Thread(
target=self._text_worker,
name="RegisterTextThread",
daemon=True
)
self.text_thread.start()
self.get_logger().info('等待服务启动...')
self.asr_client.wait_for_service(timeout_sec=10.0)
self.audio_data_client.wait_for_service(timeout_sec=10.0)
self.vad_client.wait_for_service(timeout_sec=10.0)
self.get_logger().info('所有服务已就绪')
self.sv_client = SpeakerVerificationClient(
model_path=self.sv_model_path,
@@ -92,15 +36,20 @@ class RegisterSpeakerNode(Node):
logger=self.get_logger()
)
self.get_logger().info("声纹注册节点启动,请说'er gou我现在正在注册声纹这是一段很长的测试语音请把我的声音录进去。'")
self.recording_thread = threading.Thread(
target=self.audio_recorder.record_with_vad,
name="RegisterRecordingThread",
daemon=True
)
self.recording_thread.start()
self.registered = False
self.shutting_down = False
self.get_logger().info("声纹注册节点启动,请说唤醒词开始注册(例如:'二狗我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去'")
self.timer = self.create_timer(0.2, self._check_done)
# 使用队列在线程间传递 VAD 事件,避免在子线程中调用 spin_until_future_complete
self.vad_event_queue = queue.Queue()
self.recording = False # 录音状态标志
self.pending_asr_future = None # 待处理的 ASR future
self.pending_audio_future = None # 待处理的 AudioData future
self.state = "waiting_speech" # 状态机waiting_speech, waiting_asr, waiting_audio
self.vad_thread = threading.Thread(target=self._vad_event_worker, daemon=True)
self.vad_thread.start()
self.timer = self.create_timer(0.1, self._main_loop)
def _load_config(self):
config_file = os.path.join(
@@ -111,129 +60,59 @@ class RegisterSpeakerNode(Node):
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
dashscope = config['dashscope']
audio = config['audio']
mic = audio['microphone']
soundcard = audio['soundcard']
vad = config['vad']
system = config['system']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
self.input_device_index = mic['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
self.vad_mode = vad['vad_mode']
self.silence_duration_ms = vad['silence_duration_ms']
self.min_energy_threshold = vad['min_energy_threshold']
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
self.sv_threshold = system['sv_threshold']
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
self.sv_buffer_size = system['sv_buffer_size']
self.wake_word = system['wake_word']
def _on_speech_start(self):
with self.buffer_lock:
if self.speech_start_idx is None:
self.speech_start_idx = len(self.audio_buffer)
self.speech_start_time = time.time()
def _on_audio_chunk(self, audio_chunk: bytes):
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.buffer_lock:
self.audio_buffer.extend(audio_array)
except Exception as e:
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
def _on_speech_end(self):
with self.buffer_lock:
self.speech_end_idx = len(self.audio_buffer)
def _process_voiceprint_audio(self):
"""处理声纹音频:使用用户完整的一句话进行注册"""
with self.buffer_lock:
audio_list = list(self.audio_buffer)
start_idx = self.speech_start_idx if self.speech_start_idx is not None else 0
end_idx = self.speech_end_idx if self.speech_end_idx is not None else len(audio_list)
self.audio_buffer.clear()
self.speech_start_idx = None
self.speech_end_idx = None
self.speech_start_time = None
audio_list = audio_list[start_idx:end_idx]
buffer_sec = len(audio_list) / self.sample_rate
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}")
try:
audio_array = np.array(audio_list, dtype=np.int16)
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=self.sample_rate
)
if not success:
self.get_logger().error("[注册录音] 提取embedding失败")
return
speaker_id = f"user_{int(time.time())}"
if self.sv_client.register_speaker(speaker_id, embedding):
self.get_logger().info(f"[注册录音] 注册成功用户ID: {speaker_id},准备退出")
self.stop_event.set()
else:
self.get_logger().error("[注册录音] 注册失败")
except Exception as e:
self.get_logger().error(f"[注册录音] 注册异常: {e}")
def _asr_worker(self):
"""ASR处理线程"""
while not self.stop_event.is_set():
def _vad_event_worker(self):
"""VAD 事件监听线程,只负责接收事件并放入队列,不调用 spin_until_future_complete"""
while not self.registered and not self.shutting_down:
try:
audio_chunk = self.audio_queue.get(timeout=0.1)
if self.asr_client and self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
except queue.Empty:
continue
request = VADEvent.Request()
request.command = "wait"
request.timeout_ms = 1000
future = self.vad_client.call_async(request)
# 简单等待 future 完成,不使用 spin_until_future_complete
start_time = time.time()
while not future.done() and (time.time() - start_time) < 1.5:
time.sleep(0.01)
if not future.done() or self.registered or self.shutting_down:
continue
response = future.result()
if response.success and response.event in ["speech_started", "speech_stopped"]:
# 将事件放入队列,由主线程处理
try:
self.vad_event_queue.put(response.event, timeout=0.1)
except queue.Full:
self.get_logger().warn(f"[VAD] 事件队列已满,丢弃事件: {response.event}")
except Exception as e:
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
if not self.shutting_down:
self.get_logger().error(f"[VAD] 线程异常: {e}")
break
def _on_asr_sentence_end(self, text: str):
"""ASR识别完成回调"""
if text and text.strip():
self.text_queue.put(text.strip())
def _text_worker(self):
"""文本处理线程:检测唤醒词"""
while not self.stop_event.is_set():
try:
text = self.text_queue.get(timeout=0.1)
self._check_wake_word(text)
except queue.Empty:
continue
except Exception as e:
self.get_logger().error(f"[注册文本] 处理异常: {e}")
def _start_recording(self):
"""启动录音,返回 future 供主线程处理"""
request = AudioData.Request()
request.command = "start"
return self.audio_data_client.call_async(request)
def _to_pinyin(self, text: str) -> str:
"""将中文文本转换为拼音"""
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
if not chars:
return ""
py_list = pinyin(chars, style=Style.NORMAL)
# 多音字处理:取第一个读音
return ' '.join([item[0] for item in py_list]).lower().strip()
def _check_wake_word(self, text: str):
"""检查是否包含唤醒词,如果有则注册,没有则继续等待"""
text_pinyin = self._to_pinyin(text)
wake_word_pinyin = self.wake_word.lower().strip()
if not wake_word_pinyin:
self.get_logger().info(f"[注册唤醒词] 唤醒词配置为空,继续等待")
return
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
@@ -246,27 +125,112 @@ class RegisterSpeakerNode(Node):
break
if has_wake_word:
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'使用完整音频注册")
self._process_voiceprint_audio()
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'停止录音并获取音频")
request = AudioData.Request()
request.command = "stop"
future = self.audio_data_client.call_async(request)
future._future_type = "stop"
self.pending_audio_future = future
def _process_voiceprint_audio(self, response):
"""处理声纹音频数据 - 直接使用 AudioData 返回的音频,不再过滤"""
if not response or not response.success or response.samples == 0:
self.get_logger().error(f"[注册录音] 获取音频数据失败: {response.message if response else '无响应'}")
return
audio_array = np.frombuffer(response.audio_data, dtype=np.int16)
buffer_sec = response.samples / response.sample_rate
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}")
# 直接使用音频,不再进行 VAD 过滤
# 因为 AudioData 服务基于 DashScope VAD已经是语音活动片段
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=response.sample_rate
)
if not success or embedding is None:
self.get_logger().error("[注册录音] 提取embedding失败")
return
speaker_id = f"user_{int(time.time())}"
if self.sv_client.register_speaker(speaker_id, embedding):
# 注册成功后立即保存到文件
self.sv_client.save_speakers()
self.get_logger().info(f"[注册录音] 注册成功用户ID: {speaker_id},已保存到文件,准备退出")
self.registered = True
else:
self.get_logger().info(f"[注册唤醒词] 未检测到唤醒词,继续等待用户说话")
def _check_done(self):
if self.stop_event.is_set():
self.get_logger().error("[注册录音] 注册失败")
def _main_loop(self):
"""主循环,在主线程中处理所有异步操作"""
# 检查是否完成注册
if self.registered:
self.get_logger().info("注册完成,节点退出")
if self.asr_client:
self.asr_client.stop()
self.destroy_node()
self.shutting_down = True
self.timer.cancel()
rclpy.shutdown()
return
# 处理待处理的 ASR future
if self.pending_asr_future and self.pending_asr_future.done():
response = self.pending_asr_future.result()
self.pending_asr_future = None
if response.success and response.text:
text = response.text.strip()
if text:
self._check_wake_word(text)
self.state = "waiting_speech"
# 处理待处理的 AudioData future
if self.pending_audio_future and self.pending_audio_future.done():
response = self.pending_audio_future.result()
future_type = getattr(self.pending_audio_future, '_future_type', None)
self.pending_audio_future = None
if future_type == "start":
if response.success:
self.get_logger().info("[注册录音] 已开始录音")
self.recording = True
else:
self.get_logger().warn(f"[注册录音] 启动录音失败: {response.message}")
self.state = "waiting_speech"
elif future_type == "stop":
self.recording = False
self._process_voiceprint_audio(response)
# 处理 VAD 事件队列
try:
event = self.vad_event_queue.get_nowait()
if event == "speech_started" and self.state == "waiting_speech" and not self.recording:
self.get_logger().info("[VAD] 检测到语音开始,启动录音")
future = self._start_recording()
future._future_type = "start"
self.pending_audio_future = future
elif event == "speech_stopped" and self.recording and self.state == "waiting_speech":
self.get_logger().info("[VAD] 检测到语音结束,请求 ASR 识别")
self.state = "waiting_asr"
request = ASRRecognize.Request()
request.command = "start"
self.pending_asr_future = self.asr_client.call_async(request)
except queue.Empty:
pass
def main(args=None):
rclpy.init(args=args)
node = RegisterSpeakerNode()
rclpy.spin(node)
node.destroy_node()
try:
rclpy.shutdown()
except Exception:
pass
if __name__ == '__main__':
main()
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,198 @@
"""
声纹识别模块
"""
import numpy as np
import threading
import os
import time
import json
from enum import Enum
class SpeakerState(Enum):
"""说话人识别状态"""
UNKNOWN = "unknown"
VERIFIED = "verified"
REJECTED = "rejected"
ERROR = "error"
class SpeakerVerificationClient:
"""声纹识别客户端 - 非实时、低频处理"""
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
self.model_path = model_path
self.threshold = threshold
self.speaker_db_path = speaker_db_path
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
self._lock = threading.Lock()
# 优化CPU性能限制Torch使用的线程数防止多线程竞争导致性能骤降
import torch
torch.set_num_threads(1)
from funasr import AutoModel
model_path = os.path.expanduser(self.model_path)
# 禁用自动更新检查,防止每次初始化都联网检查
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
if self.logger:
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
if self.speaker_db_path:
self.load_speakers()
def _log(self, level: str, msg: str):
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
if level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warning(msg)
elif level == "error":
self.logger.error(msg)
elif level == "debug":
self.logger.debug(msg)
except Exception:
pass
def load_speakers(self):
if not self.speaker_db_path:
return
db_path = os.path.expanduser(self.speaker_db_path)
if not os.path.exists(db_path):
self._log("info", f"声纹数据库文件不存在: {db_path},将创建新文件")
return
try:
with open(db_path, 'rb') as f:
data = json.load(f)
with self._lock:
self.speaker_db = {}
for speaker_id, info in data.items():
embedding_array = np.array(info["embedding"], dtype=np.float32)
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
self.speaker_db[speaker_id] = {
"embedding": embedding_array,
"env": info.get("env", ""),
"registered_at": info.get("registered_at", 0.0)
}
self._log("info", f"已加载 {len(self.speaker_db)} 个已注册说话人")
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
def save_speakers(self):
if not self.speaker_db_path:
return
db_path = os.path.expanduser(self.speaker_db_path)
try:
os.makedirs(os.path.dirname(db_path), exist_ok=True)
with self._lock:
data = {}
for speaker_id, info in self.speaker_db.items():
data[speaker_id] = {
"embedding": info["embedding"].tolist(),
"env": info.get("env", ""),
"registered_at": info.get("registered_at", 0.0)
}
with open(db_path, 'w') as f:
json.dump(data, f, indent=2)
self._log("info", f"已保存 {len(data)} 个已注册说话人到: {db_path}")
except Exception as e:
self._log("error", f"保存声纹数据库失败: {e}")
def extract_embedding(self, audio_array: np.ndarray, sample_rate: int = 16000) -> tuple[np.ndarray | None, bool]:
try:
if len(audio_array) == 0:
return None, False
# 确保是int16格式
if audio_array.dtype != np.int16:
audio_array = audio_array.astype(np.int16)
# 转换为float32并归一化到[-1, 1]
audio_float = audio_array.astype(np.float32) / 32768.0
# 调用模型提取embedding
result = self.model.generate(input=audio_float, cache={})
if result and len(result) > 0 and "spk_embedding" in result[0]:
embedding = result[0]["spk_embedding"]
if embedding is not None and len(embedding) > 0:
embedding_array = np.array(embedding, dtype=np.float32)
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
return embedding_array, True
return None, False
except Exception as e:
self._log("error", f"提取声纹特征失败: {e}")
return None, False
def match_speaker(self, embedding: np.ndarray) -> tuple[str | None, SpeakerState, float, float]:
if embedding is None or len(embedding) == 0:
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
with self._lock:
if len(self.speaker_db) == 0:
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
try:
best_speaker_id = None
best_score = 0.0
with self._lock:
for speaker_id, info in self.speaker_db.items():
stored_embedding = info["embedding"]
# 计算余弦相似度
dot_product = np.dot(embedding, stored_embedding)
norm_embedding = np.linalg.norm(embedding)
norm_stored = np.linalg.norm(stored_embedding)
if norm_embedding > 0 and norm_stored > 0:
score = dot_product / (norm_embedding * norm_stored)
if score > best_score:
best_score = score
best_speaker_id = speaker_id
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
return best_speaker_id, state, float(best_score), float(self.threshold)
except Exception as e:
self._log("error", f"匹配说话人失败: {e}")
return None, SpeakerState.ERROR, 0.0, float(self.threshold)
def register_speaker(self, speaker_id: str, embedding: np.ndarray, env: str = "") -> bool:
if embedding is None or len(embedding) == 0:
return False
try:
with self._lock:
self.speaker_db[speaker_id] = {
"embedding": np.array(embedding, dtype=np.float32),
"env": env,
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}")
return True
except Exception as e:
self._log("error", f"注册说话人失败: {e}")
return False
def get_speaker_count(self) -> int:
with self._lock:
return len(self.speaker_db)
def get_speaker_list(self) -> list[str]:
with self._lock:
return list(self.speaker_db.keys())
def remove_speaker(self, speaker_id: str) -> bool:
with self._lock:
if speaker_id in self.speaker_db:
del self.speaker_db[speaker_id]
self._log("info", f"已删除说话人: {speaker_id}")
return True
return False
def cleanup(self):
try:
self.save_speakers()
if hasattr(self, 'model') and self.model:
del self.model
except Exception as e:
self._log("error", f"清理资源失败: {e}")

View File

@@ -1,36 +0,0 @@
"""
统一数据结构定义
"""
from dataclasses import dataclass
@dataclass
class ASRResult:
"""ASR识别结果"""
text: str
confidence: float | None = None
language: str | None = None
@dataclass
class LLMMessage:
"""LLM消息"""
role: str # "user", "assistant", "system"
content: str
@dataclass
class TTSRequest:
"""TTS请求"""
text: str
voice: str | None = None # 如果为None使用控制台配置的默认音色
speed: float | None = None
pitch: float | None = None
@dataclass
class ImageMessage:
"""图像消息 - 用于多模态LLM"""
image_data: bytes # base64编码的图像数据
image_format: str = "jpeg"

View File

@@ -1,9 +0,0 @@
"""模型层"""

View File

@@ -1,9 +0,0 @@
"""ASR模型"""

View File

@@ -1,17 +0,0 @@
class ASRClient:
def start(self) -> bool:
raise NotImplementedError
def stop(self) -> bool:
raise NotImplementedError
def send_audio(self, audio_data: bytes) -> bool:
raise NotImplementedError

View File

@@ -1,229 +0,0 @@
"""
ASR语音识别模块
"""
import base64
import time
import threading
import dashscope
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
from robot_speaker.models.asr.base import ASRClient
class DashScopeASR(ASRClient):
"""DashScope实时ASR识别器封装"""
def __init__(self, api_key: str,
sample_rate: int,
model: str,
url: str,
logger=None):
dashscope.api_key = api_key
self.sample_rate = sample_rate
self.model = model
self.url = url
self.logger = logger
self.conversation = None
self.running = False
self.on_sentence_end = None
self.on_text_update = None # 实时文本更新回调
# 线程同步机制
self._stop_lock = threading.Lock() # 防止并发调用 stop_current_recognition
self._final_result_event = threading.Event() # 等待 final 回调完成
self._pending_commit = False # 标记是否有待处理的 commit
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
else:
print(f"[ASR] {msg}")
def start(self):
"""启动ASR识别器"""
if self.running:
return False
try:
callback = _ASRCallback(self)
self.conversation = OmniRealtimeConversation(
model=self.model,
url=self.url,
callback=callback
)
callback.conversation = self.conversation
self.conversation.connect()
transcription_params = TranscriptionParams(
language='zh',
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
# 本地 VAD → 只控制 TTS 打断
# 服务端 turn detection → 只控制 ASR 输出、LLM 生成轮次
self.conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
enable_turn_detection=True,
# 保留服务端 turn detection
turn_detection_type='server_vad', # 服务端VAD
turn_detection_threshold=0.2, # 可调
turn_detection_silence_duration_ms=800
)
self.running = True
self._log("info", "ASR已启动")
return True
except Exception as e:
self.running = False
self._log("error", f"ASR启动失败: {e}")
if self.conversation:
try:
self.conversation.close()
except:
pass
self.conversation = None
return False
def send_audio(self, audio_chunk: bytes):
"""发送音频chunk到ASR"""
if not self.running or not self.conversation:
return False
try:
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
self.conversation.append_audio(audio_b64)
return True
except Exception as e:
# 连接已关闭或其他错误,静默处理(避免日志过多)
# running状态会在stop_current_recognition中正确设置
return False
def stop_current_recognition(self):
"""
停止当前识别触发final结果然后重新启动
优化:
1. 使用事件代替 sleep等待 final 回调完成
2. 使用锁防止并发调用
3. 处理 start() 失败的情况,确保 running 状态正确
4. 添加超时机制,避免无限等待
"""
# 使用锁防止并发调用
if not self._stop_lock.acquire(blocking=False):
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
return False
try:
if not self.running or not self.conversation:
return False
# 重置事件,准备等待 final 回调
self._final_result_event.clear()
self._pending_commit = True
# 触发 commit等待 final 结果
self.conversation.commit()
# 等待 final 回调完成最多等待3秒
if self._final_result_event.wait(timeout=3.0):
self._log("debug", "已收到 final 回调,准备关闭连接")
else:
self._log("warning", "等待 final 回调超时,继续执行")
# 先设置running=False防止ASR线程继续发送音频
self.running = False
# 关闭当前连接
old_conversation = self.conversation
self.conversation = None # 立即清空防止send_audio继续使用
try:
old_conversation.close()
except Exception as e:
self._log("warning", f"关闭连接时出错: {e}")
# 短暂等待,确保连接完全关闭
time.sleep(0.1)
# 重新启动,如果失败则保持 running=False
if not self.start():
self._log("error", "ASR重启失败running状态已重置")
return False
# 启动成功running已在start()中设置为True
return True
finally:
self._pending_commit = False
self._stop_lock.release()
def stop(self):
"""停止ASR识别器"""
# 等待正在执行的 stop_current_recognition 完成
with self._stop_lock:
self.running = False
self._final_result_event.set() # 唤醒可能正在等待的线程
if self.conversation:
try:
self.conversation.close()
except Exception as e:
self._log("warning", f"停止时关闭连接出错: {e}")
self.conversation = None
self._log("info", "ASR已停止")
class _ASRCallback(OmniRealtimeCallback):
"""ASR回调处理"""
def __init__(self, asr_client: DashScopeASR):
self.asr_client = asr_client
self.conversation = None
def on_open(self):
self.asr_client._log("info", "ASR WebSocket已连接")
def on_close(self, code, msg):
self.asr_client._log("info", f"ASR WebSocket已关闭: code={code}, msg={msg}")
def on_event(self, response):
event_type = response.get('type', '')
if event_type == 'session.created':
session_id = response.get('session', {}).get('id', '')
self.asr_client._log("info", f"ASR会话已创建: {session_id}")
elif event_type == 'conversation.item.input_audio_transcription.completed':
# 最终识别结果
transcript = response.get('transcript', '')
if transcript and transcript.strip() and self.asr_client.on_sentence_end:
self.asr_client.on_sentence_end(transcript.strip())
# 如果有待处理的 commit通知等待的线程
if self.asr_client._pending_commit:
self.asr_client._final_result_event.set()
elif event_type == 'conversation.item.input_audio_transcription.text':
# 实时识别文本更新(多轮提示)
transcript = response.get('transcript', '') or response.get('text', '')
if transcript and transcript.strip() and self.asr_client.on_text_update:
self.asr_client.on_text_update(transcript.strip())
elif event_type == 'input_audio_buffer.speech_started':
self.asr_client._log("info", "ASR检测到说话开始")
elif event_type == 'input_audio_buffer.speech_stopped':
self.asr_client._log("info", "ASR检测到说话结束")

View File

@@ -1,9 +0,0 @@
"""LLM模型"""

View File

@@ -1,19 +0,0 @@
from robot_speaker.core.types import LLMMessage
class LLMClient:
def chat(self, messages: list[LLMMessage]) -> str | None:
raise NotImplementedError
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
interrupt_check=None) -> str | None:
raise NotImplementedError

View File

@@ -1,149 +0,0 @@
"""
LLM大语言模型模块
支持多模态(文本+图像)
"""
from openai import OpenAI
from typing import Optional, List
from robot_speaker.core.types import LLMMessage
from robot_speaker.models.llm.base import LLMClient
class DashScopeLLM(LLMClient):
"""DashScope LLM客户端封装"""
def __init__(self, api_key: str,
model: str,
base_url: str,
temperature: float,
max_tokens: int,
name: str = "LLM",
logger=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.name = name
self.logger = logger
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
msg = f"[{self.name}] {msg}"
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
def chat(self, messages: list[LLMMessage]) -> str | None:
"""非流式聊天:任务规划"""
payload_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=payload_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=False
)
reply = response.choices[0].message.content.strip()
return reply if reply else None
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
images: Optional[List[str]] = None,
interrupt_check=None) -> str | None:
"""
流式聊天:语音系统
支持多模态(文本+图像)
支持中断检查interrupt_check: 返回True表示需要中断
"""
# 转换消息格式,支持多模态
# 图像只添加到最后一个user消息中
payload_messages = []
last_user_idx = -1
for i, msg in enumerate(messages):
if msg.role == "user":
last_user_idx = i
has_images_in_message = False
for i, msg in enumerate(messages):
msg_dict = {"role": msg.role}
# 如果当前消息是最后一个user消息且有图像构建多模态content
if i == last_user_idx and msg.role == "user" and images and len(images) > 0:
content_list = [{"type": "text", "text": msg.content}]
# 添加所有图像
for img_idx, img_base64 in enumerate(images):
image_url = f"data:image/jpeg;base64,{img_base64[:50]}..." if len(img_base64) > 50 else f"data:image/jpeg;base64,{img_base64}"
content_list.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
})
self._log("info", f"[多模态] 添加图像 #{img_idx+1} 到user消息base64长度: {len(img_base64)}")
msg_dict["content"] = content_list
has_images_in_message = True
else:
msg_dict["content"] = msg.content
payload_messages.append(msg_dict)
# 记录多模态信息
if images and len(images) > 0:
if has_images_in_message:
# 找到最后一个user消息记录其content结构
last_user_msg = payload_messages[last_user_idx] if last_user_idx >= 0 else None
if last_user_msg and isinstance(last_user_msg.get("content"), list):
content_items = last_user_msg["content"]
text_items = [item for item in content_items if item.get("type") == "text"]
image_items = [item for item in content_items if item.get("type") == "image_url"]
self._log("info", f"[多模态] 已发送多模态请求: {len(text_items)}个文本 + {len(image_items)}张图片")
self._log("debug", f"[多模态] 用户文本: {text_items[0].get('text', '')[:50] if text_items else 'N/A'}")
else:
self._log("warning", "[多模态] 消息格式异常,无法确认图片是否添加")
else:
self._log("warning", f"[多模态] 有{len(images)}张图片但未找到user消息图片未被添加")
else:
self._log("debug", "[多模态] 纯文本请求(无图片)")
full_reply = ""
interrupted = False
stream = self.client.chat.completions.create(
model=self.model,
messages=payload_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True
)
for chunk in stream:
# 检查中断标志
if interrupt_check and interrupt_check():
self._log("info", "LLM流式处理被中断")
interrupted = True
break
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_reply += content
if on_token:
on_token(content)
# 在on_token回调后再次检查中断on_token可能设置中断标志
if interrupt_check and interrupt_check():
self._log("info", "LLM流式处理在on_token回调后被中断")
interrupted = True
break
if interrupted:
return None # 被中断时返回None表示未完成
return full_reply.strip() if full_reply else None

View File

@@ -1,9 +0,0 @@
"""TTS模型"""

View File

@@ -1,18 +0,0 @@
from robot_speaker.core.types import TTSRequest
class TTSClient:
"""TTS客户端抽象基类"""
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
raise NotImplementedError

View File

@@ -1,244 +0,0 @@
"""
TTS语音合成模块
"""
import subprocess
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
from robot_speaker.core.types import TTSRequest
from robot_speaker.models.tts.base import TTSClient
class DashScopeTTSClient(TTSClient):
"""DashScope流式TTS客户端封装"""
def __init__(self, api_key: str,
model: str,
voice: str,
card_index: int,
device_index: int,
output_sample_rate: int = 44100,
output_channels: int = 2,
output_volume: float = 1.0,
tts_source_sample_rate: int = 22050, # TTS服务固定输出采样率
tts_source_channels: int = 1, # TTS服务固定输出声道数
tts_ffmpeg_thread_queue_size: int = 1024, # ffmpeg输入线程队列大小
reference_signal_buffer=None, # 参考信号缓冲区(用于回声消除)
logger=None):
dashscope.api_key = api_key
self.model = model
self.voice = voice
self.card_index = card_index
self.device_index = device_index
self.output_sample_rate = output_sample_rate
self.output_channels = output_channels
self.output_volume = output_volume
self.tts_source_sample_rate = tts_source_sample_rate
self.tts_source_channels = tts_source_channels
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
self.logger = logger
self.current_ffmpeg_pid = None # 当前ffmpeg进程的PID
# 构建ALSA设备, 允许 ffmpeg 自动重采样 / 重声道
self.alsa_device = f"plughw:{card_index},{device_index}" if (
card_index >= 0 and device_index >= 0
) else "default"
def _log(self, level: str, msg: str):
"""记录日志根据级别调用对应的ROS2日志方法"""
if self.logger:
# ROS2 logger不能动态改变severity级别需要显式调用对应方法
if level == "debug":
self.logger.debug(msg)
elif level == "info":
self.logger.info(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
else:
self.logger.info(msg) # 默认使用info级别
else:
print(f"[TTS] {msg}")
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
"""主流程:流式合成并播放"""
callback = _TTSCallback(self, interrupt_check, on_chunk, self.reference_signal_buffer)
# 使用配置的voicerequest.voice为None或空时使用self.voice
voice_to_use = request.voice if request.voice and request.voice.strip() else self.voice
if not voice_to_use or not voice_to_use.strip():
self._log("error", f"Voice参数无效: '{voice_to_use}'")
return False
self._log("info", f"TTS开始: 文本='{request.text[:50]}...', voice='{voice_to_use}'")
synthesizer = SpeechSynthesizer(
model=self.model,
voice=voice_to_use,
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
callback=callback,
)
try:
synthesizer.streaming_call(request.text)
synthesizer.streaming_complete()
finally:
callback.cleanup()
return not callback._interrupted
class _TTSCallback(ResultCallback):
"""TTS回调处理 - 使用ffmpeg播放自动处理采样率转换"""
def __init__(self, tts_client: DashScopeTTSClient,
interrupt_check=None,
on_chunk=None,
reference_signal_buffer=None):
self.tts_client = tts_client
self.interrupt_check = interrupt_check
self.on_chunk = on_chunk
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
self._proc = None
self._interrupted = False
self._cleaned_up = False
def on_open(self):
# 使用ffmpeg播放自动处理采样率转换TTS源采样率 -> 设备采样率)
# TTS服务输出固定采样率和声道数ffmpeg会自动转换为设备采样率和声道数
ffmpeg_cmd = [
'ffmpeg',
'-f', 's16le', # 原始 PCM
'-ar', str(self.tts_client.tts_source_sample_rate), # TTS输出采样率从配置文件读取
'-ac', str(self.tts_client.tts_source_channels), # TTS输出声道数从配置文件读取
'-i', 'pipe:0', # stdin
'-f', 'alsa', # 输出到 ALSA
'-ar', str(self.tts_client.output_sample_rate), # 输出设备采样率(从配置文件读取)
'-ac', str(self.tts_client.output_channels), # 输出设备声道数(从配置文件读取)
'-acodec', 'pcm_s16le', # 输出编码
'-fflags', 'nobuffer', # 减少缓冲
'-flags', 'low_delay', # 低延迟
'-avioflags', 'direct', # 尝试直通写入 ALSA减少延迟
self.tts_client.alsa_device
]
# 将 -thread_queue_size 放到输入文件之前
insert_pos = ffmpeg_cmd.index('-i')
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
# 添加音量调节filter如果音量不是1.0
if self.tts_client.output_volume != 1.0:
# 在输出编码前插入音量filter
# volume filter放在输入之后、输出编码之前
acodec_idx = ffmpeg_cmd.index('-acodec')
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
ffmpeg_cmd.insert(acodec_idx, '-af')
self.tts_client._log("info", f"启动ffmpeg播放: ALSA设备={self.tts_client.alsa_device}, "
f"输出采样率={self.tts_client.output_sample_rate}Hz, "
f"输出声道数={self.tts_client.output_channels}, "
f"音量={self.tts_client.output_volume * 100:.0f}%")
self._proc = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE # 改为PIPE以便捕获错误
)
# 记录ffmpeg进程PID
self.tts_client.current_ffmpeg_pid = self._proc.pid
self.tts_client._log("debug", f"ffmpeg进程已启动PID={self._proc.pid}")
def on_complete(self):
pass
def on_error(self, message: str):
self.tts_client._log("error", f"TTS错误: {message}")
def on_close(self):
self.cleanup()
def on_event(self, message):
pass
def on_data(self, data: bytes) -> None:
"""接收音频数据并播放"""
if self._interrupted:
return
if self.interrupt_check and self.interrupt_check():
# 停止播放,不停止 TTS
self._interrupted = True
if self._proc:
self._proc.terminate()
return
# 优先写入ffmpeg避免阻塞播放
# 优先写入ffmpeg避免阻塞播放
if self._proc and self._proc.stdin and not self._interrupted:
try:
self._proc.stdin.write(data)
self._proc.stdin.flush()
except BrokenPipeError:
# ffmpeg进程可能已退出检查错误
if self._proc.stderr:
error_msg = self._proc.stderr.read().decode('utf-8', errors='ignore')
self.tts_client._log("error", f"ffmpeg错误: {error_msg}")
self._interrupted = True
# 将音频数据添加到参考信号缓冲区(用于回声消除)
# 在写入ffmpeg之后处理避免阻塞播放
if self.reference_signal_buffer and data:
try:
self.reference_signal_buffer.add_reference(
data,
source_sample_rate=self.tts_client.tts_source_sample_rate,
source_channels=self.tts_client.tts_source_channels
)
except Exception as e:
# 参考信号处理失败不应影响播放
self.tts_client._log("warning", f"参考信号处理失败: {e}")
if self.on_chunk:
self.on_chunk(data)
def cleanup(self):
"""清理资源"""
if self._cleaned_up or not self._proc:
return
self._cleaned_up = True
# 关闭stdin让ffmpeg处理完剩余数据
if self._proc.stdin and not self._proc.stdin.closed:
try:
self._proc.stdin.close()
except:
pass
# 等待进程自然结束根据文本长度估算最少10秒最多30秒
# 假设平均语速3-4字/秒,加上缓冲时间
if self._proc.poll() is None:
try:
# 增加等待时间确保ffmpeg播放完成
# 对于长文本,可能需要更长时间
self._proc.wait(timeout=30.0)
except:
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
if self._proc.poll() is None:
self.tts_client._log("warning", "ffmpeg播放超时强制终止")
try:
self._proc.terminate()
self._proc.wait(timeout=1.0)
except:
try:
self._proc.kill()
self._proc.wait(timeout=0.1)
except:
pass
# 清空PID记录
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
self.tts_client.current_ffmpeg_pid = None

View File

@@ -1,9 +0,0 @@
"""感知层"""

View File

@@ -1,297 +0,0 @@
"""
声纹识别模块
"""
import numpy as np
import threading
import tempfile
import os
import wave
import time
import json
from enum import Enum
class SpeakerState(Enum):
"""说话人识别状态"""
UNKNOWN = "unknown"
VERIFIED = "verified"
REJECTED = "rejected"
ERROR = "error"
class SpeakerVerificationClient:
"""声纹识别客户端 - 非实时、低频处理"""
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
self.model_path = model_path
self.threshold = threshold
self.speaker_db_path = speaker_db_path
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
self._lock = threading.Lock()
# 优化CPU性能限制Torch使用的线程数防止多线程竞争导致性能骤降
import torch
torch.set_num_threads(1)
from funasr import AutoModel
model_path = os.path.expanduser(self.model_path)
# 禁用自动更新检查,防止每次初始化都联网检查
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
if self.logger:
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
if self.speaker_db_path:
self.load_speakers()
def _log(self, level: str, msg: str):
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
log_methods = {
"debug": self.logger.debug,
"info": self.logger.info,
"warning": self.logger.warning,
"error": self.logger.error,
"fatal": self.logger.fatal
}
log_method = log_methods.get(level.lower(), self.logger.info)
log_method(msg)
except ValueError as e:
if "severity cannot be changed" in str(e):
try:
self.logger.info(f"[声纹-{level.upper()}] {msg}")
except:
print(f"[声纹-{level.upper()}] {msg}")
else:
raise
else:
print(f"[声纹] {msg}")
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""将numpy音频数组写入临时wav文件"""
audio_int16 = audio_data.astype(np.int16)
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
os.close(fd)
with wave.open(temp_path, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
return temp_path
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""
提取说话人embedding低频调用一句话只调用一次
"""
# 降采样到 16000Hz (如果需要)
# Cam++ 等模型通常只支持 16k如果传入 48k 会导致内部重采样极慢或计算量剧增
target_sr = 16000
if sample_rate > target_sr:
if sample_rate % target_sr == 0:
step = sample_rate // target_sr
audio_data = audio_data[::step]
sample_rate = target_sr
else:
# 简单的非整数倍降采样可能导致问题,但对于语音验证通常 48k->16k 是整数倍
# 如果不是,此处暂不处理,依赖 funasr 内部处理,或者简单的步长取整
step = int(sample_rate / target_sr)
audio_data = audio_data[::step]
sample_rate = target_sr
if len(audio_data) < int(sample_rate * 0.5):
return None, False
temp_wav_path = None
try:
# 限制Torch在推理时使用单线程避免在多任务环境下尤其是一边录音一边识别
# 出现的极端CPU竞争和上下文切换开销
import torch
with torch.inference_mode():
# 临时设置,虽然全局已经设置了,但在调用前再次确保
# 注意set_num_threads 是全局的,这里再次确认
if torch.get_num_threads() != 1:
torch.set_num_threads(1)
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, False
return embedding, True
except Exception as e:
self._log("error", f"提取embedding失败: {e}")
return None, False
finally:
if temp_wav_path and os.path.exists(temp_wav_path):
try:
os.unlink(temp_wav_path)
except:
pass
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
env: str = "near") -> bool:
"""
注册说话人
"""
embedding_dim = len(embedding)
if embedding_dim == 0:
return False
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
self._log("error", f"注册失败embedding范数为0")
return False
embedding_normalized = embedding / embedding_norm
with self._lock:
self.speaker_db[speaker_id] = {
"embedding": embedding_normalized,
"env": env, # 添加 env 字段
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}, 维度: {embedding_dim}")
save_result = self.save_speakers()
if not save_result:
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
return True
def match_speaker(self, embedding: np.ndarray):
"""
匹配说话人(一句话只调用一次)
"""
if not self.speaker_db:
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
embedding_normalized = embedding / embedding_norm
best_match = None
best_score = -float('inf')
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
ref_embedding = speaker_data["embedding"]
score = np.dot(embedding_normalized, ref_embedding)
if score > best_score:
best_score = score
best_match = speaker_id
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
return (best_match, state, best_score, self.threshold)
def is_available(self) -> bool:
return self.model is not None
def cleanup(self):
"""清理资源"""
pass
def get_speaker_count(self) -> int:
with self._lock:
return len(self.speaker_db)
def remove_speaker(self, speaker_id: str) -> bool:
with self._lock:
if speaker_id not in self.speaker_db:
return False
del self.speaker_db[speaker_id]
self.save_speakers()
return True
def load_speakers(self) -> bool:
"""
从文件加载已注册的声纹
"""
if not self.speaker_db_path:
return False
if not os.path.exists(self.speaker_db_path):
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
return False
try:
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
data = json.load(f)
with self._lock:
for speaker_id, speaker_data in data.items():
embedding_list = speaker_data["embedding"]
embedding_array = np.array(embedding_list, dtype=np.float32)
embedding_dim = len(embedding_array)
if embedding_dim == 0:
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
continue
embedding_norm = np.linalg.norm(embedding_array)
if embedding_norm > 0:
embedding_array = embedding_array / embedding_norm
self.speaker_db[speaker_id] = {
"embedding": embedding_array,
"env": speaker_data["env"],
"registered_at": speaker_data["registered_at"]
}
count = len(self.speaker_db)
self._log("info", f"已加载 {count} 个已注册说话人")
return True
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
return False
def save_speakers(self) -> bool:
"""
保存已注册的声纹到文件
"""
if not self.speaker_db_path:
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
return False
try:
db_dir = os.path.dirname(self.speaker_db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
json_data = {}
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
json_data[speaker_id] = {
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
"registered_at": speaker_data["registered_at"]
}
temp_path = self.speaker_db_path + ".tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
os.replace(temp_path, self.speaker_db_path)
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
return True
except Exception as e:
import traceback
self._log("error", f"保存声纹数据库失败: {e}")
self._log("error", f"保存路径: {self.speaker_db_path}")
self._log("error", f"错误详情: {traceback.format_exc()}")
temp_path = self.speaker_db_path + ".tmp"
if os.path.exists(temp_path):
try:
os.unlink(temp_path)
except:
pass
return False

View File

@@ -0,0 +1,15 @@
"""
Service节点模块
"""

View File

@@ -0,0 +1,566 @@
import rclpy
from rclpy.node import Node
from robot_speaker.srv import ASRRecognize, AudioData, VADEvent
import threading
import queue
import time
import pyaudio
import yaml
import os
import collections
import numpy as np
import base64
import dashscope
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
from ament_index_python.packages import get_package_share_directory
class AudioRecorder:
def __init__(self, device_index: int, sample_rate: int, channels: int,
chunk: int, audio_queue: queue.Queue, stop_event, logger=None):
self.device_index = device_index
self.sample_rate = sample_rate
self.channels = channels
self.chunk = chunk
self.audio_queue = audio_queue
self.stop_event = stop_event
self.logger = logger
self.audio = pyaudio.PyAudio()
original_index = self.device_index
try:
for i in range(self.audio.get_device_count()):
device_info = self.audio.get_device_info_by_index(i)
if 'iFLYTEK' in device_info['name'] and device_info['maxInputChannels'] > 0:
self.device_index = i
if self.logger:
self.logger.info(f"已自动定位到麦克风设备: {device_info['name']} (Index: {i})")
break
except Exception as e:
if self.logger:
self.logger.error(f"设备自动检测过程出错: {e}")
if self.device_index == original_index and original_index == -1:
self.device_index = 0
if self.logger:
self.logger.info("未找到 iFLYTEK 设备,使用系统默认输入设备")
self.format = pyaudio.paInt16
def record(self):
if self.logger:
self.logger.info(f"录音线程启动,设备索引: {self.device_index}")
stream = None
try:
stream = self.audio.open(
format=self.format,
channels=self.channels,
rate=self.sample_rate,
input=True,
input_device_index=self.device_index if self.device_index >= 0 else None,
frames_per_buffer=self.chunk
)
if self.logger:
self.logger.info("音频输入设备已打开")
except Exception as e:
if self.logger:
self.logger.error(f"无法打开音频输入设备: {e}")
return
try:
while not self.stop_event.is_set():
try:
data = stream.read(self.chunk, exception_on_overflow=False)
if self.audio_queue.full():
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(data)
except OSError as e:
if self.logger:
self.logger.debug(f"录音设备错误: {e}")
break
except KeyboardInterrupt:
if self.logger:
self.logger.info("录音线程收到中断信号")
finally:
if stream is not None:
try:
if stream.is_active():
stream.stop_stream()
stream.close()
except Exception as e:
pass
if self.logger:
self.logger.info("录音线程已退出")
class DashScopeASR:
def __init__(self, api_key: str, sample_rate: int, model: str, url: str, logger=None):
dashscope.api_key = api_key
self.sample_rate = sample_rate
self.model = model
self.url = url
self.logger = logger
self.conversation = None
self.running = False
self.on_sentence_end = None
self.on_speech_started = None
self.on_speech_stopped = None
self._stop_lock = threading.Lock()
self._final_result_event = threading.Event()
self._pending_commit = False
def _log(self, level: str, msg: str):
if not self.logger:
return
try:
if level == "debug":
self.logger.debug(msg)
elif level == "warning":
self.logger.warn(msg)
elif level == "error":
self.logger.error(msg)
elif level == "info":
self.logger.info(msg)
except Exception:
pass
def start(self):
if self.running:
return False
try:
callback = _ASRCallback(self)
self.conversation = OmniRealtimeConversation(
model=self.model,
url=self.url,
callback=callback
)
callback.conversation = self.conversation
self.conversation.connect()
transcription_params = TranscriptionParams(
language='zh',
sample_rate=self.sample_rate,
input_audio_format="pcm",
)
self.conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
enable_turn_detection=True,
turn_detection_type='server_vad',
turn_detection_threshold=0.2,
turn_detection_silence_duration_ms=800
)
self.running = True
self._log("info", "ASR已启动")
return True
except Exception as e:
self.running = False
self._log("error", f"ASR启动失败: {e}")
if self.conversation:
try:
self.conversation.close()
except Exception:
pass
self.conversation = None
return False
def send_audio(self, audio_chunk: bytes):
if not self.running or not self.conversation:
return False
try:
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
self.conversation.append_audio(audio_b64)
return True
except Exception:
return False
def stop_current_recognition(self):
if not self._stop_lock.acquire(blocking=False):
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
return False
try:
if not self.running or not self.conversation:
return False
self._final_result_event.clear()
self._pending_commit = True
self.conversation.commit()
self._final_result_event.wait(timeout=3.0)
self.running = False
old_conversation = self.conversation
self.conversation = None
try:
old_conversation.close()
except Exception:
pass
time.sleep(0.1)
if not self.start():
self._log("error", "ASR重启失败")
return False
return True
finally:
self._pending_commit = False
self._stop_lock.release()
def stop(self):
with self._stop_lock:
self.running = False
self._final_result_event.set()
if self.conversation:
try:
self.conversation.close()
except Exception:
pass
self.conversation = None
class _ASRCallback(OmniRealtimeCallback):
def __init__(self, asr_client: DashScopeASR):
self.asr_client = asr_client
self.conversation = None
def on_event(self, response):
try:
event_type = response['type']
if event_type == 'conversation.item.input_audio_transcription.completed':
transcript = response['transcript']
if transcript.strip() and self.asr_client.on_sentence_end:
self.asr_client.on_sentence_end(transcript.strip())
if self.asr_client._pending_commit:
self.asr_client._final_result_event.set()
elif event_type == 'input_audio_buffer.speech_started':
if self.asr_client.logger:
self.asr_client.logger.info("[ASR] 检测到语音开始")
if self.asr_client.on_speech_started:
self.asr_client.on_speech_started()
elif event_type == 'input_audio_buffer.speech_stopped':
if self.asr_client.logger:
self.asr_client.logger.info("[ASR] 检测到语音结束")
if self.asr_client.on_speech_stopped:
self.asr_client.on_speech_stopped()
except Exception:
pass
class ASRAudioNode(Node):
def __init__(self):
super().__init__('asr_audio_node')
self._load_config()
self.audio_queue = queue.Queue(maxsize=100)
self.stop_event = threading.Event()
self._shutdown_in_progress = False
self._init_components()
self.recognize_service = self.create_service(
ASRRecognize, '/asr/recognize', self._recognize_callback
)
self.audio_data_service = self.create_service(
AudioData, '/asr/audio_data', self._audio_data_callback
)
self.vad_event_service = self.create_service(
VADEvent, '/vad/event', self._vad_event_callback
)
self._last_result = None
self._result_event = threading.Event()
self._last_result_time = None
self.vad_event_queue = queue.Queue()
self.audio_buffer = collections.deque(maxlen=240000)
self.audio_recording = False
self.audio_lock = threading.Lock()
self.recording_thread = threading.Thread(
target=self.audio_recorder.record, name="RecordingThread", daemon=True
)
self.recording_thread.start()
self.asr_thread = threading.Thread(
target=self._asr_worker, name="ASRThread", daemon=True
)
self.asr_thread.start()
self.get_logger().info("ASR Audio节点已启动")
def _load_config(self):
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
mic = config['audio']['microphone']
self.input_device_index = mic['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
dashscope = config['dashscope']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
def _init_components(self):
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
audio_queue=self.audio_queue,
stop_event=self.stop_event,
logger=self.get_logger()
)
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self._on_asr_result
self.asr_client.on_speech_started = lambda: self._put_vad_event("speech_started")
self.asr_client.on_speech_stopped = lambda: self._put_vad_event("speech_stopped")
self.asr_client.start()
def _on_asr_result(self, text: str):
if not text or not text.strip():
return
self._last_result = text.strip()
self._last_result_time = time.time()
self._result_event.set()
try:
self.get_logger().info(f"[ASR] 识别结果: {self._last_result}")
except Exception:
pass
def _put_vad_event(self, event_type):
try:
self.vad_event_queue.put(event_type, timeout=0.1)
except queue.Full:
try:
self.get_logger().warn(f"[ASR] VAD事件队列已满丢弃{event_type}事件")
except Exception:
pass
def _audio_data_callback(self, request, response):
response.sample_rate = self.sample_rate
response.channels = self.channels
if request.command == "start":
with self.audio_lock:
self.audio_buffer.clear()
self.audio_recording = True
response.success = True
response.message = "开始录音"
response.samples = 0
return response
if request.command == "stop":
with self.audio_lock:
self.audio_recording = False
audio_list = list(self.audio_buffer)
self.audio_buffer.clear()
if len(audio_list) > 0:
audio_array = np.array(audio_list, dtype=np.int16)
response.success = True
response.audio_data = audio_array.tobytes()
response.samples = len(audio_list)
response.message = f"录音完成{len(audio_list)}样本"
else:
response.success = False
response.message = "缓冲区为空"
response.samples = 0
return response
if request.command == "get":
with self.audio_lock:
audio_list = list(self.audio_buffer)
if len(audio_list) > 0:
audio_array = np.array(audio_list, dtype=np.int16)
response.success = True
response.audio_data = audio_array.tobytes()
response.samples = len(audio_list)
response.message = f"获取到{len(audio_list)}样本"
else:
response.success = False
response.message = "缓冲区为空"
response.samples = 0
return response
def _vad_event_callback(self, request, response):
timeout = request.timeout_ms / 1000.0 if request.timeout_ms > 0 else None
try:
event = self.vad_event_queue.get(timeout=timeout)
response.success = True
response.event = event
response.message = "收到VAD事件"
except queue.Empty:
response.success = False
response.event = "none"
response.message = "等待超时"
except KeyboardInterrupt:
try:
self.get_logger().info("[VAD] 收到中断信号,正在关闭")
except Exception:
pass
response.success = False
response.event = "none"
response.message = "节点正在关闭"
self.stop_event.set()
return response
def _clear_result(self):
self._last_result = None
self._last_result_time = None
self._result_event.clear()
def _return_result(self, response, text, message):
response.success = True
response.text = text
response.message = message
self._clear_result()
return response
def _recognize_callback(self, request, response):
if request.command == "stop":
if self.asr_client.running:
self.asr_client.stop_current_recognition()
response.success = True
response.text = ""
response.message = "识别已停止"
return response
if request.command == "reset":
self.asr_client.stop_current_recognition()
time.sleep(0.1)
self.asr_client.start()
response.success = True
response.text = ""
response.message = "识别器已重置"
return response
if self.asr_client.running:
current_time = time.time()
if (self._last_result and self._last_result_time and
(current_time - self._last_result_time) < 5.0) or (self._result_event.is_set() and self._last_result):
return self._return_result(response, self._last_result, "返回最近识别结果")
if self._result_event.wait(timeout=2.0) and self._last_result:
return self._return_result(response, self._last_result, "识别成功(等待中)")
self.asr_client.stop_current_recognition()
time.sleep(0.2)
self._clear_result()
if not self.asr_client.running and not self.asr_client.start():
response.success = False
response.text = ""
response.message = "ASR启动失败"
return response
if self._result_event.wait(timeout=5.0) and self._last_result:
response.success = True
response.text = self._last_result
response.message = "识别成功"
else:
response.success = False
response.text = ""
response.message = "识别超时" if not self._result_event.is_set() else "识别结果为空"
self._clear_result()
return response
def _asr_worker(self):
while not self.stop_event.is_set():
try:
audio_chunk = self.audio_queue.get(timeout=0.1)
except queue.Empty:
continue
except KeyboardInterrupt:
try:
self.get_logger().info("[ASR Worker] 收到中断信号")
except Exception:
pass
break
if self.audio_recording:
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.audio_lock:
self.audio_buffer.extend(audio_array)
except Exception:
pass
if self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
def destroy_node(self):
if self._shutdown_in_progress:
return
self._shutdown_in_progress = True
try:
self.get_logger().info("ASR Audio节点正在关闭...")
except Exception:
pass
self.stop_event.set()
if hasattr(self, 'recording_thread') and self.recording_thread.is_alive():
self.recording_thread.join(timeout=1.0)
if hasattr(self, 'asr_thread') and self.asr_thread.is_alive():
self.asr_thread.join(timeout=1.0)
try:
if hasattr(self, 'audio_recorder'):
self.audio_recorder.audio.terminate()
except Exception:
pass
try:
if hasattr(self, 'asr_client'):
self.asr_client.stop()
except Exception:
pass
try:
super().destroy_node()
except Exception:
pass
def main(args=None):
rclpy.init(args=args)
node = ASRAudioNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
try:
node.get_logger().info("收到中断信号,正在关闭节点")
except Exception:
pass
finally:
try:
node.destroy_node()
except Exception:
pass
try:
rclpy.shutdown()
except Exception:
pass
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,341 @@
import rclpy
from rclpy.node import Node
from rclpy.callback_groups import ReentrantCallbackGroup
from robot_speaker.srv import TTSSynthesize
import threading
import yaml
import os
import signal
import subprocess
import time
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
from ament_index_python.packages import get_package_share_directory
class DashScopeTTSClient:
def __init__(self, api_key: str,
model: str,
voice: str,
card_index: int,
device_index: int,
output_sample_rate: int,
output_channels: int,
output_volume: float,
tts_source_sample_rate: int,
tts_source_channels: int,
tts_ffmpeg_thread_queue_size: int,
force_stop_delay: float,
cleanup_timeout: float,
terminate_timeout: float,
logger):
dashscope.api_key = api_key
self.model = model
self.voice = voice
self.card_index = card_index
self.device_index = device_index
self.output_sample_rate = output_sample_rate
self.output_channels = output_channels
self.output_volume = output_volume
self.tts_source_sample_rate = tts_source_sample_rate
self.tts_source_channels = tts_source_channels
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
self.force_stop_delay = force_stop_delay
self.cleanup_timeout = cleanup_timeout
self.terminate_timeout = terminate_timeout
self.logger = logger
self.current_ffmpeg_pid = None
self._current_callback = None
self.alsa_device = f"plughw:{card_index},{device_index}" if (
card_index >= 0 and device_index >= 0
) else "default"
def force_stop(self):
if self._current_callback:
self._current_callback._interrupted = True
if not self.current_ffmpeg_pid:
if self.logger:
self.logger.warn("[TTS] force_stop: current_ffmpeg_pid is None")
return
pid = self.current_ffmpeg_pid
try:
if self.logger:
self.logger.info(f"[TTS] force_stop: 正在kill进程 {pid}")
os.kill(pid, signal.SIGTERM)
time.sleep(self.force_stop_delay)
try:
os.kill(pid, 0)
os.kill(pid, signal.SIGKILL)
if self.logger:
self.logger.info(f"[TTS] force_stop: 已发送SIGKILL到进程 {pid}")
except ProcessLookupError:
if self.logger:
self.logger.info(f"[TTS] force_stop: 进程 {pid} 已退出")
except (ProcessLookupError, OSError) as e:
if self.logger:
self.logger.warn(f"[TTS] force_stop: kill进程失败 {pid}: {e}")
finally:
self.current_ffmpeg_pid = None
self._current_callback = None
def synthesize(self, text: str, voice: str = None,
on_chunk=None,
interrupt_check=None) -> bool:
callback = _TTSCallback(self, interrupt_check, on_chunk)
self._current_callback = callback
voice_to_use = voice if voice and voice.strip() else self.voice
if not voice_to_use or not voice_to_use.strip():
if self.logger:
self.logger.error(f"Voice参数无效: '{voice_to_use}'")
self._current_callback = None
return False
synthesizer = SpeechSynthesizer(
model=self.model,
voice=voice_to_use,
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
callback=callback,
)
try:
synthesizer.streaming_call(text)
synthesizer.streaming_complete()
finally:
callback.cleanup()
self._current_callback = None
return not callback._interrupted
class _TTSCallback(ResultCallback):
def __init__(self, tts_client: DashScopeTTSClient,
interrupt_check=None,
on_chunk=None):
self.tts_client = tts_client
self.interrupt_check = interrupt_check
self.on_chunk = on_chunk
self._proc = None
self._interrupted = False
self._cleaned_up = False
def on_open(self):
ffmpeg_cmd = [
'ffmpeg',
'-f', 's16le',
'-ar', str(self.tts_client.tts_source_sample_rate),
'-ac', str(self.tts_client.tts_source_channels),
'-i', 'pipe:0',
'-f', 'alsa',
'-ar', str(self.tts_client.output_sample_rate),
'-ac', str(self.tts_client.output_channels),
'-acodec', 'pcm_s16le',
'-fflags', 'nobuffer',
'-flags', 'low_delay',
'-avioflags', 'direct',
self.tts_client.alsa_device
]
insert_pos = ffmpeg_cmd.index('-i')
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
if self.tts_client.output_volume != 1.0:
acodec_idx = ffmpeg_cmd.index('-acodec')
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
ffmpeg_cmd.insert(acodec_idx, '-af')
self._proc = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
self.tts_client.current_ffmpeg_pid = self._proc.pid
def on_data(self, data: bytes) -> None:
if self._interrupted:
return
if self.interrupt_check and self.interrupt_check():
self._interrupted = True
if self._proc:
self._proc.terminate()
return
if self._proc and self._proc.stdin and not self._interrupted:
try:
self._proc.stdin.write(data)
self._proc.stdin.flush()
except BrokenPipeError:
self._interrupted = True
except OSError:
self._interrupted = True
if self.on_chunk and not self._interrupted:
self.on_chunk(data)
def cleanup(self):
if self._cleaned_up or not self._proc:
return
self._cleaned_up = True
if self._proc.stdin and not self._proc.stdin.closed:
self._proc.stdin.close()
if self._proc.poll() is None:
self._proc.wait(timeout=self.tts_client.cleanup_timeout)
if self._proc.poll() is None:
self._proc.terminate()
self._proc.wait(timeout=self.tts_client.terminate_timeout)
if self._proc.poll() is None:
self._proc.kill()
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
self.tts_client.current_ffmpeg_pid = None
class TTSAudioNode(Node):
def __init__(self):
super().__init__('tts_audio_node')
self._load_config()
self._init_tts_client()
self.callback_group = ReentrantCallbackGroup()
self.synthesize_service = self.create_service(
TTSSynthesize, '/tts/synthesize', self._synthesize_callback,
callback_group=self.callback_group
)
self.interrupt_event = threading.Event()
self.playing_lock = threading.Lock()
self.is_playing = False
self.get_logger().info("TTS Audio节点已启动")
def _load_config(self):
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
audio = config['audio']
soundcard = audio['soundcard']
tts_audio = audio['tts']
dashscope = config['dashscope']
self.output_card_index = soundcard['card_index']
self.output_device_index = soundcard['device_index']
self.output_sample_rate = soundcard['sample_rate']
self.output_channels = soundcard['channels']
self.output_volume = soundcard['volume']
self.tts_source_sample_rate = tts_audio['source_sample_rate']
self.tts_source_channels = tts_audio['source_channels']
self.tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
self.force_stop_delay = tts_audio['force_stop_delay']
self.cleanup_timeout = tts_audio['cleanup_timeout']
self.terminate_timeout = tts_audio['terminate_timeout']
self.interrupt_wait = tts_audio['interrupt_wait']
self.dashscope_api_key = dashscope['api_key']
self.tts_model = dashscope['tts']['model']
self.tts_voice = dashscope['tts']['voice']
def _init_tts_client(self):
self.tts_client = DashScopeTTSClient(
api_key=self.dashscope_api_key,
model=self.tts_model,
voice=self.tts_voice,
card_index=self.output_card_index,
device_index=self.output_device_index,
output_sample_rate=self.output_sample_rate,
output_channels=self.output_channels,
output_volume=self.output_volume,
tts_source_sample_rate=self.tts_source_sample_rate,
tts_source_channels=self.tts_source_channels,
tts_ffmpeg_thread_queue_size=self.tts_ffmpeg_thread_queue_size,
force_stop_delay=self.force_stop_delay,
cleanup_timeout=self.cleanup_timeout,
terminate_timeout=self.terminate_timeout,
logger=self.get_logger()
)
def _synthesize_callback(self, request, response):
command = request.command if request.command else "synthesize"
if command == "interrupt":
with self.playing_lock:
was_playing = self.is_playing
has_pid = self.tts_client.current_ffmpeg_pid is not None
if was_playing or has_pid:
self.interrupt_event.set()
self.tts_client.force_stop()
self.is_playing = False
response.success = True
response.message = "已中断播放"
response.status = "interrupted"
else:
response.success = False
response.message = "没有正在播放的内容"
response.status = "none"
return response
if not request.text or not request.text.strip():
response.success = False
response.message = "文本为空"
response.status = "error"
return response
with self.playing_lock:
if self.is_playing:
self.tts_client.force_stop()
time.sleep(self.interrupt_wait)
self.is_playing = True
self.interrupt_event.clear()
def synthesize_worker():
try:
success = self.tts_client.synthesize(
request.text.strip(),
voice=request.voice if request.voice else None,
interrupt_check=lambda: self.interrupt_event.is_set()
)
with self.playing_lock:
self.is_playing = False
if self.get_logger():
if success:
self.get_logger().info("[TTS] 合成并播放成功")
else:
self.get_logger().info("[TTS] 播放被中断")
except Exception as e:
with self.playing_lock:
self.is_playing = False
if self.get_logger():
self.get_logger().error(f"[TTS] 合成失败: {e}")
thread = threading.Thread(target=synthesize_worker, daemon=True)
thread.start()
response.success = True
response.message = "合成任务已启动"
response.status = "playing"
return response
def main(args=None):
rclpy.init(args=args)
node = TTSAudioNode()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@@ -1,9 +0,0 @@
"""理解层"""

View File

@@ -15,6 +15,7 @@ setup(
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
(os.path.join('share', package_name, 'srv'), glob('srv/*.srv')),
],
install_requires=[
'setuptools',
@@ -25,12 +26,13 @@ setup(
maintainer_email='mzebra@foxmail.com',
description='语音识别和合成ROS2包',
license='Apache-2.0',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
'skill_bridge_node = robot_speaker.bridge.skill_bridge_node:main',
'asr_audio_node = robot_speaker.services.asr_audio_node:main',
'tts_audio_node = robot_speaker.services.tts_audio_node:main',
],
},
)

18
srv/ASRRecognize.srv Normal file
View File

@@ -0,0 +1,18 @@
# 请求:启动识别
string command # "start" (默认), "stop", "reset"
---
# 响应:识别结果
bool success
string text # 识别文本(空字符串表示未识别到)
string message # 状态消息

20
srv/AudioData.srv Normal file
View File

@@ -0,0 +1,20 @@
# 请求:获取音频数据
string command # "start" (开始录音), "stop" (停止并返回), "get" (获取当前缓冲区)
int32 duration_ms # 录音时长毫秒仅用于start命令
---
# 响应:音频数据
bool success
uint8[] audio_data # PCM音频数据int16格式
int32 sample_rate
int32 channels
int32 samples # 样本数
string message

14
srv/TTSSynthesize.srv Normal file
View File

@@ -0,0 +1,14 @@
# 请求:合成文本或中断命令
string command # "synthesize" (默认), "interrupt"
string text
string voice # 可选,默认使用配置
---
# 响应:合成状态
bool success
string message
string status # "playing", "completed", "interrupted"

17
srv/VADEvent.srv Normal file
View File

@@ -0,0 +1,17 @@
# 请求等待VAD事件
string command # "wait" (等待下一个事件)
int32 timeout_ms # 超时时间毫秒0表示无限等待
---
# 响应VAD事件
bool success
string event # "speech_started", "speech_stopped", "none"
string message