merge remote
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,3 +4,6 @@ log/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
dist/
|
||||
lib/
|
||||
installed_files.txt
|
||||
|
||||
142
CMakeLists.txt
Normal file
142
CMakeLists.txt
Normal file
@@ -0,0 +1,142 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(robot_speaker)
|
||||
|
||||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
find_package(ament_cmake REQUIRED)
|
||||
find_package(ament_cmake_python REQUIRED)
|
||||
find_package(rosidl_default_generators REQUIRED)
|
||||
|
||||
# 确保使用系统 Python(而不是 conda/miniconda 的 Python)
|
||||
find_program(PYTHON3_CMD python3 PATHS /usr/bin /usr/local/bin NO_DEFAULT_PATH)
|
||||
if(NOT PYTHON3_CMD)
|
||||
find_program(PYTHON3_CMD python3)
|
||||
endif()
|
||||
if(PYTHON3_CMD)
|
||||
set(Python3_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python 3 executable" FORCE)
|
||||
set(PYTHON_EXECUTABLE ${PYTHON3_CMD} CACHE FILEPATH "Python executable" FORCE)
|
||||
endif()
|
||||
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"srv/ASRRecognize.srv"
|
||||
"srv/TTSSynthesize.srv"
|
||||
"srv/VADEvent.srv"
|
||||
"srv/AudioData.srv"
|
||||
)
|
||||
|
||||
install(CODE "
|
||||
execute_process(
|
||||
COMMAND ${PYTHON3_CMD} -m pip install --prefix=${CMAKE_INSTALL_PREFIX} --no-deps ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE install_result
|
||||
OUTPUT_VARIABLE install_output
|
||||
ERROR_VARIABLE install_error
|
||||
)
|
||||
if(NOT install_result EQUAL 0)
|
||||
message(FATAL_ERROR \"Failed to install Python package. Output: ${install_output} Error: ${install_error}\")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${PYTHON3_CMD} -c \"
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
import sysconfig
|
||||
|
||||
install_prefix = '${CMAKE_INSTALL_PREFIX}'
|
||||
build_dir = '${CMAKE_CURRENT_BINARY_DIR}'
|
||||
python_version = f'{sysconfig.get_python_version()}'
|
||||
|
||||
# ROS2 期望的 Python 包位置
|
||||
ros2_site_packages = os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages')
|
||||
os.makedirs(ros2_site_packages, exist_ok=True)
|
||||
|
||||
# pip install --prefix 可能将包安装到不同位置(系统环境通常是 local/lib/pythonX/dist-packages)
|
||||
pip_locations = [
|
||||
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'dist-packages'),
|
||||
os.path.join(install_prefix, 'lib', f'python{python_version}', 'site-packages'),
|
||||
os.path.join(install_prefix, 'local', 'lib', f'python{python_version}', 'site-packages'),
|
||||
]
|
||||
|
||||
# 查找并复制 robot_speaker 包到 ROS2 期望的位置
|
||||
robot_speaker_src = None
|
||||
for location in pip_locations:
|
||||
candidate = os.path.join(location, 'robot_speaker')
|
||||
if os.path.exists(candidate) and os.path.isdir(candidate):
|
||||
robot_speaker_src = candidate
|
||||
break
|
||||
|
||||
if robot_speaker_src:
|
||||
robot_speaker_dest = os.path.join(ros2_site_packages, 'robot_speaker')
|
||||
if os.path.exists(robot_speaker_dest):
|
||||
shutil.rmtree(robot_speaker_dest)
|
||||
if robot_speaker_src != robot_speaker_dest:
|
||||
shutil.copytree(robot_speaker_src, robot_speaker_dest)
|
||||
print(f'Copied robot_speaker from {robot_speaker_src} to {ros2_site_packages}')
|
||||
else:
|
||||
print(f'robot_speaker already in correct location')
|
||||
|
||||
# 复制 ROS2 生成的 srv 模块(rosidl_generate_interfaces 生成的)
|
||||
rosidl_py_src = os.path.join(build_dir, 'rosidl_generator_py', 'robot_speaker')
|
||||
if os.path.exists(rosidl_py_src):
|
||||
# 复制 srv 目录
|
||||
srv_src = os.path.join(rosidl_py_src, 'srv')
|
||||
srv_dest = os.path.join(robot_speaker_dest, 'srv')
|
||||
if os.path.exists(srv_src):
|
||||
if os.path.exists(srv_dest):
|
||||
shutil.rmtree(srv_dest)
|
||||
shutil.copytree(srv_src, srv_dest)
|
||||
print(f'Copied srv module to {srv_dest}')
|
||||
|
||||
# 复制生成的接口文件(.so 和 .c 文件)
|
||||
for pattern in ['robot_speaker_s__rosidl_typesupport*.so', '_robot_speaker_s*.c']:
|
||||
for file in glob.glob(os.path.join(rosidl_py_src, pattern)):
|
||||
dest_file = os.path.join(robot_speaker_dest, os.path.basename(file))
|
||||
shutil.copy2(file, dest_file)
|
||||
print(f'Copied {os.path.basename(file)} to {robot_speaker_dest}')
|
||||
|
||||
# 处理 entry_points 脚本
|
||||
lib_dir = os.path.join(install_prefix, 'lib', 'robot_speaker')
|
||||
os.makedirs(lib_dir, exist_ok=True)
|
||||
|
||||
# 脚本可能在 local/bin 或 bin 中
|
||||
for bin_dir in [os.path.join(install_prefix, 'local', 'bin'), os.path.join(install_prefix, 'bin')]:
|
||||
if os.path.exists(bin_dir):
|
||||
scripts = glob.glob(os.path.join(bin_dir, '*_node'))
|
||||
for script in scripts:
|
||||
script_name = os.path.basename(script)
|
||||
dest = os.path.join(lib_dir, script_name)
|
||||
if script != dest:
|
||||
shutil.copy2(script, dest)
|
||||
os.chmod(dest, 0o755)
|
||||
print(f'Copied {script_name} to {lib_dir}')
|
||||
\"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE python_result
|
||||
OUTPUT_VARIABLE python_output
|
||||
)
|
||||
if(python_result EQUAL 0)
|
||||
message(STATUS \"${python_output}\")
|
||||
else()
|
||||
message(WARNING \"Failed to setup Python package: ${python_output}\")
|
||||
endif()
|
||||
")
|
||||
|
||||
install(DIRECTORY launch/
|
||||
DESTINATION share/${PROJECT_NAME}/launch
|
||||
FILES_MATCHING PATTERN "*.launch.py"
|
||||
)
|
||||
|
||||
install(DIRECTORY config/
|
||||
DESTINATION share/${PROJECT_NAME}/config
|
||||
FILES_MATCHING PATTERN "*.yaml" PATTERN "*.json"
|
||||
)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
ament_package()
|
||||
45
README.md
45
README.md
@@ -45,37 +45,26 @@ source install/setup.bash
|
||||
ros2 launch robot_speaker voice.launch.py
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
[录音线程] - 唯一实时线程
|
||||
├─ 麦克风采集 PCM
|
||||
├─ VAD + 能量检测
|
||||
├─ 检测到人声 → 立即中断TTS
|
||||
├─ 语音 PCM → ASR 音频队列
|
||||
└─ 语音 PCM → 声纹音频队列(旁路,不阻塞)
|
||||
3. ASR节点
|
||||
```bash
|
||||
ros2 run robot_speaker asr_audio_node
|
||||
```
|
||||
|
||||
[ASR推理线程] - 只做 audio → text
|
||||
└─ 从 ASR 音频队列取音频→ 实时 / 流式 ASR → text → 文本队列
|
||||
4. TTS节点
|
||||
```bash
|
||||
# 终端1: 启动TTS节点
|
||||
ros2 run robot_speaker tts_audio_node
|
||||
|
||||
[声纹识别线程] - 非实时、低频(CAM++)
|
||||
├─ 通过回调函数接收音频chunk,写入缓冲区,等待 speech_end 事件触发处理
|
||||
├─ 累积 1~2 秒有效人声(VAD 后)
|
||||
├─ CAM++ 提取 speaker embedding
|
||||
├─ 声纹匹配 / 注册
|
||||
└─ 更新 current_speaker_id(共享状态,只写不控)
|
||||
声纹线程要求:不影响录音,不影响ASR,不控制TTS,只更新当前说话人是谁
|
||||
# 终端2: 启动播放
|
||||
source install/setup.bash
|
||||
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
|
||||
"{command: 'synthesize', text: '这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。让我继续说下去,这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。让我继续说下去,这是一段很长的测试文本,用于测试TTS中断功能。我需要说很多很多内容,这样你才有足够的时间来测试中断命令。', voice: ''}"
|
||||
|
||||
[主线程/处理线程] - 处理业务逻辑
|
||||
├─ 从 文本队列 取 ASR 文本
|
||||
├─ 读取 current_speaker_id(只读)
|
||||
├─ 唤醒词处理(结合 speaker_id)
|
||||
├─ 权限 / 身份判断(是否允许继续)
|
||||
├─ VLM处理(文本 / 多模态)
|
||||
└─ TTS播放(启动TTS线程,不等待)
|
||||
|
||||
[TTS播放线程] - 只播放(可被中断)
|
||||
├─ 接收 TTS 音频流
|
||||
├─ 播放到输出设备
|
||||
└─ 响应中断标志(由录音线程触发)
|
||||
# 终端3: 立即执行中断
|
||||
source install/setup.bash
|
||||
ros2 service call /tts/synthesize robot_speaker/srv/TTSSynthesize \
|
||||
"{command: 'interrupt', text: '', voice: ''}"
|
||||
```
|
||||
|
||||
|
||||
## 用到的命令
|
||||
|
||||
@@ -394,5 +394,205 @@
|
||||
],
|
||||
"env": "near",
|
||||
"registered_at": 1768964438.9963026
|
||||
},
|
||||
"user_1769515089": {
|
||||
"embedding": [
|
||||
[
|
||||
-1.5295110940933228,
|
||||
0.5238341093063354,
|
||||
0.08633111417293549,
|
||||
0.11756575852632523,
|
||||
1.44246244430542,
|
||||
-1.6976442337036133,
|
||||
0.2645050585269928,
|
||||
1.5642119646072388,
|
||||
1.4558132886886597,
|
||||
-2.018132448196411,
|
||||
0.30136486887931824,
|
||||
1.5590322017669678,
|
||||
0.3676050007343292,
|
||||
2.096036434173584,
|
||||
-1.203681468963623,
|
||||
0.2745387852191925,
|
||||
1.128976821899414,
|
||||
-0.8042266368865967,
|
||||
-0.04837780073285103,
|
||||
-0.8245053291320801,
|
||||
-0.6101562976837158,
|
||||
0.08143205940723419,
|
||||
-1.1198647022247314,
|
||||
1.7753965854644775,
|
||||
-0.5257269144058228,
|
||||
-0.6572340726852417,
|
||||
-0.08467039465904236,
|
||||
0.08285830914974213,
|
||||
0.49599483609199524,
|
||||
-2.871098756790161,
|
||||
-1.1618938446044922,
|
||||
0.7318744659423828,
|
||||
2.08620548248291,
|
||||
0.18100303411483765,
|
||||
-0.5528441071510315,
|
||||
0.13717415928840637,
|
||||
0.22606758773326874,
|
||||
0.23349706828594208,
|
||||
0.40789690613746643,
|
||||
-0.23644576966762543,
|
||||
-0.12830045819282532,
|
||||
1.0583454370498657,
|
||||
0.3954410254955292,
|
||||
-1.0476133823394775,
|
||||
0.6569878458976746,
|
||||
0.43412935733795166,
|
||||
0.7459996938705444,
|
||||
0.25105446577072144,
|
||||
0.40695688128471375,
|
||||
0.41371095180511475,
|
||||
-0.5081073045730591,
|
||||
-0.15921951830387115,
|
||||
0.6312111020088196,
|
||||
2.678532123565674,
|
||||
1.5355063676834106,
|
||||
1.898784875869751,
|
||||
1.257870078086853,
|
||||
2.026048421859741,
|
||||
1.1490176916122437,
|
||||
0.742881178855896,
|
||||
-1.206595540046692,
|
||||
0.5405871272087097,
|
||||
0.01001159567385912,
|
||||
-0.7743952870368958,
|
||||
-0.1243305653333664,
|
||||
0.4287954568862915,
|
||||
-1.1704397201538086,
|
||||
2.057995557785034,
|
||||
0.30912983417510986,
|
||||
1.0761916637420654,
|
||||
1.3979746103286743,
|
||||
-1.070613145828247,
|
||||
2.0996458530426025,
|
||||
-0.16294217109680176,
|
||||
-0.15417678654193878,
|
||||
-0.6481220722198486,
|
||||
0.9156526923179626,
|
||||
0.7209145426750183,
|
||||
-1.3280514478683472,
|
||||
0.08632978051900864,
|
||||
-0.09424483776092529,
|
||||
1.8493571281433105,
|
||||
0.917565107345581,
|
||||
-0.0257036704570055,
|
||||
-1.0192301273345947,
|
||||
-0.8172388672828674,
|
||||
0.37842708826065063,
|
||||
0.20112906396389008,
|
||||
-0.18812096118927002,
|
||||
0.12312255054712296,
|
||||
0.3173609673976898,
|
||||
0.029730113223195076,
|
||||
-0.662641704082489,
|
||||
0.6436728239059448,
|
||||
0.3574063181877136,
|
||||
0.27612701058387756,
|
||||
-0.6808024644851685,
|
||||
-1.1454781293869019,
|
||||
0.7457495927810669,
|
||||
-1.8407135009765625,
|
||||
-0.6051219701766968,
|
||||
2.167180299758911,
|
||||
0.181788831949234,
|
||||
1.2942312955856323,
|
||||
-2.2572178840637207,
|
||||
-0.6572328209877014,
|
||||
-0.44301870465278625,
|
||||
0.5519763827323914,
|
||||
-0.02834797278046608,
|
||||
-1.118048906326294,
|
||||
-0.44019994139671326,
|
||||
1.2326226234436035,
|
||||
-0.2865355312824249,
|
||||
-1.9306018352508545,
|
||||
0.4287217855453491,
|
||||
-0.5471329092979431,
|
||||
-1.8593220710754395,
|
||||
-0.2029312551021576,
|
||||
0.6949507594108582,
|
||||
-0.2491024136543274,
|
||||
-0.6223251819610596,
|
||||
-0.5916008949279785,
|
||||
1.3497960567474365,
|
||||
-0.47974079847335815,
|
||||
1.6955225467681885,
|
||||
0.17834797501564026,
|
||||
0.13161484897136688,
|
||||
0.20850282907485962,
|
||||
-0.04633784666657448,
|
||||
-0.9113361835479736,
|
||||
-1.1419169902801514,
|
||||
-1.0826172828674316,
|
||||
-0.2316463589668274,
|
||||
0.45178237557411194,
|
||||
0.18495112657546997,
|
||||
0.535635232925415,
|
||||
1.923178791999817,
|
||||
-0.7357022762298584,
|
||||
-0.5064287185668945,
|
||||
0.5609160661697388,
|
||||
1.1650713682174683,
|
||||
-0.5384876728057861,
|
||||
1.2522424459457397,
|
||||
-1.309113621711731,
|
||||
0.22394417226314545,
|
||||
-0.14331775903701782,
|
||||
0.7612791061401367,
|
||||
-1.8949273824691772,
|
||||
-0.8273413181304932,
|
||||
0.15730154514312744,
|
||||
0.5960761904716492,
|
||||
-1.5179729461669922,
|
||||
-1.3346058130264282,
|
||||
-1.0774084329605103,
|
||||
-0.960814356803894,
|
||||
-0.14860300719738007,
|
||||
-0.9822415113449097,
|
||||
1.821016788482666,
|
||||
-0.4035312235355377,
|
||||
0.6270486116409302,
|
||||
0.6994175910949707,
|
||||
-0.8607892394065857,
|
||||
0.7216717004776001,
|
||||
-1.2650134563446045,
|
||||
0.05397822707891464,
|
||||
0.2296375185251236,
|
||||
-0.40239569544792175,
|
||||
-0.44462206959724426,
|
||||
0.12279012054204941,
|
||||
-0.3110475540161133,
|
||||
1.0768173933029175,
|
||||
-0.21416479349136353,
|
||||
-0.44052380323410034,
|
||||
0.743086040019989,
|
||||
-1.3203964233398438,
|
||||
0.47284168004989624,
|
||||
0.16021426022052765,
|
||||
1.2153557538986206,
|
||||
0.7987464666366577,
|
||||
-0.27521243691444397,
|
||||
0.25042879581451416,
|
||||
-0.36083176732063293,
|
||||
1.5787007808685303,
|
||||
1.2494744062423706,
|
||||
0.16907380521297455,
|
||||
0.01833455078303814,
|
||||
-0.16504760086536407,
|
||||
1.3832142353057861,
|
||||
-0.331011027097702,
|
||||
-0.28575095534324646,
|
||||
-0.3638729751110077,
|
||||
0.37575358152389526
|
||||
]
|
||||
],
|
||||
"env": "",
|
||||
"registered_at": 1769515089.7623787
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
# ROS 语音包配置文件
|
||||
|
||||
asr:
|
||||
mode: 'cloud' # 'cloud' | 'local' - ASR模式选择
|
||||
local:
|
||||
server_url: "ws://127.0.0.1:10095" # 本地FunASR服务地址
|
||||
# 云端模式配置在dashscope中
|
||||
|
||||
dashscope:
|
||||
api_key: "sk-7215a5ab7a00469db4072e1672a0661e"
|
||||
asr:
|
||||
@@ -33,6 +39,10 @@ audio:
|
||||
source_sample_rate: 22050 # TTS服务固定输出采样率(DashScope服务固定值,不可修改)
|
||||
source_channels: 1 # TTS服务固定输出声道数(DashScope服务固定值,不可修改)
|
||||
ffmpeg_thread_queue_size: 4096 # ffmpeg输入线程队列大小(增大以减少卡顿)
|
||||
force_stop_delay: 0.1 # 强制停止时的延迟(秒)
|
||||
cleanup_timeout: 30.0 # 清理超时(秒)
|
||||
terminate_timeout: 1.0 # 终止超时(秒)
|
||||
interrupt_wait: 0.1 # 中断等待时间(秒)
|
||||
|
||||
vad:
|
||||
vad_mode: 3 # VAD模式:0-3,3最严格
|
||||
@@ -40,7 +50,6 @@ vad:
|
||||
min_energy_threshold: 300 # 最小能量阈值
|
||||
|
||||
system:
|
||||
use_llm: true # 是否使用LLM
|
||||
use_wake_word: true # 是否启用唤醒词检测
|
||||
wake_word: "er gou" # 唤醒词(拼音)
|
||||
session_timeout: 3.0 # 会话超时时间(秒)
|
||||
@@ -53,7 +62,9 @@ system:
|
||||
sv_speaker_db_path: "~/hivecore_robot_os1/config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
# sv_speaker_db_path: "~/ros_learn/hivecore_robot_voice/config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
sv_buffer_size: 240000 # 声纹验证录音缓冲区大小(样本数,48kHz下5秒=240000)
|
||||
continue_without_image: false # 多模态意图(skill_sequence/chat_camera)未获取到图片时是否继续推理
|
||||
continue_without_image: true # 多模态意图(skill_sequence/chat_camera)未获取到图片时是否继续推理
|
||||
skill_auto_retry: true
|
||||
skill_max_retries: 5
|
||||
|
||||
camera:
|
||||
image:
|
||||
|
||||
54
launch/register_speaker.launch.py
Normal file
54
launch/register_speaker.launch.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import SetEnvironmentVariable, RegisterEventHandler
|
||||
from launch.event_handlers import OnProcessExit
|
||||
from launch.actions import EmitEvent
|
||||
from launch.events import Shutdown
|
||||
import os
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""启动声纹注册节点(需要ASR服务)"""
|
||||
# 获取interfaces包的install路径
|
||||
interfaces_install_path = os.path.expanduser('~/ros_learn/hivecore_robot_interfaces/install')
|
||||
|
||||
# 设置AMENT_PREFIX_PATH,确保能找到interfaces包的消息类型
|
||||
ament_prefix_path = os.environ.get('AMENT_PREFIX_PATH', '')
|
||||
if interfaces_install_path not in ament_prefix_path:
|
||||
if ament_prefix_path:
|
||||
ament_prefix_path = f'{ament_prefix_path}:{interfaces_install_path}'
|
||||
else:
|
||||
ament_prefix_path = interfaces_install_path
|
||||
|
||||
# ASR + 音频输入设备节点(提供ASR和AudioData服务)
|
||||
asr_audio_node = Node(
|
||||
package='robot_speaker',
|
||||
executable='asr_audio_node',
|
||||
name='asr_audio_node',
|
||||
output='screen'
|
||||
)
|
||||
|
||||
# 声纹注册节点
|
||||
register_speaker_node = Node(
|
||||
package='robot_speaker',
|
||||
executable='register_speaker_node',
|
||||
name='register_speaker_node',
|
||||
output='screen'
|
||||
)
|
||||
|
||||
# 当注册节点退出时,关闭整个 launch
|
||||
register_exit_handler = RegisterEventHandler(
|
||||
OnProcessExit(
|
||||
target_action=register_speaker_node,
|
||||
on_exit=[
|
||||
EmitEvent(event=Shutdown(reason='注册完成,关闭所有节点'))
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return LaunchDescription([
|
||||
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
|
||||
asr_audio_node,
|
||||
register_speaker_node,
|
||||
register_exit_handler,
|
||||
])
|
||||
@@ -19,6 +19,21 @@ def generate_launch_description():
|
||||
|
||||
return LaunchDescription([
|
||||
SetEnvironmentVariable('AMENT_PREFIX_PATH', ament_prefix_path),
|
||||
# ASR + 音频输入设备节点(同时提供VAD事件Service,利用云端ASR的VAD)
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='asr_audio_node',
|
||||
name='asr_audio_node',
|
||||
output='screen'
|
||||
),
|
||||
# TTS + 音频输出设备节点
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='tts_audio_node',
|
||||
name='tts_audio_node',
|
||||
output='screen'
|
||||
),
|
||||
# 主业务逻辑节点
|
||||
Node(
|
||||
package='robot_speaker',
|
||||
executable='robot_speaker_node',
|
||||
|
||||
@@ -13,6 +13,11 @@
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>ament_index_python</depend>
|
||||
<depend>interfaces</depend>
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<buildtool_depend>ament_cmake_python</buildtool_depend>
|
||||
<buildtool_depend>rosidl_default_generators</buildtool_depend>
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<exec_depend>python3-pyaudio</exec_depend>
|
||||
<exec_depend>python3-requests</exec_depend>
|
||||
@@ -27,6 +32,6 @@
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
</package>
|
||||
|
||||
@@ -12,6 +12,7 @@ aec-audio-processing
|
||||
modelscope>=1.33.0
|
||||
funasr>=1.0.0
|
||||
datasets==3.6.0
|
||||
websocket-client>=1.6.0
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,3 +4,14 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -7,3 +7,15 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""
|
||||
对话历史管理模块
|
||||
"""
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM消息"""
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
|
||||
|
||||
class ConversationHistory:
|
||||
"""对话历史管理器 - 实时语音"""
|
||||
|
||||
@@ -109,3 +116,8 @@ class ConversationHistory:
|
||||
self.summary = None
|
||||
self._pending_user_message = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""会话状态机"""
|
||||
IDLE = "idle" # 等待用户唤醒或声音
|
||||
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
|
||||
AUTHORIZED = "authorized" # 已注册用户
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ class IntentRouter:
|
||||
"ju qi", "sheng qi", # 举起、升起
|
||||
"jia zhua", "jia qi", "jia", # 夹爪、夹起、夹
|
||||
"shen you bi", "shen zuo bi", "shen chu", "shen shou", # 伸右臂、伸左臂、伸出、伸手
|
||||
"zhuan quan", "zhuan yi quan", "zhuan", # 转个圈、转一圈、转
|
||||
]
|
||||
self.kb_keywords = [
|
||||
"ni shi shui", "ni de ming zi", "tiao ge wu", "ni jiao sha", "ni hui gan", "ni neng gan"
|
||||
@@ -153,6 +154,7 @@ class IntentRouter:
|
||||
return (
|
||||
"你是机器人任务规划器。\n"
|
||||
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照,并结合当前环境信息生成简洁、可执行的技能序列。\n"
|
||||
"如果用户明确要求或者任务明显需要双手/双臂协作(如扶稳+操作、抓取大体积的物体),必须规划双手技能。\n"
|
||||
+ execution_hint
|
||||
+ "\n"
|
||||
"【规划要求】\n"
|
||||
@@ -161,18 +163,13 @@ class IntentRouter:
|
||||
" - parallel(并行):技能可以同时执行\n"
|
||||
"2. parameters规划:根据目标物距离和任务需求,规划具体参数值\n"
|
||||
" - parameters字典必须包含该技能接口文件目标字段的所有字段\n"
|
||||
" - 对于包含body_id字段的技能(如Arm),body_id值根据目标物在图片中的方位选择:\n"
|
||||
" * 目标物在图片左侧或机器人左侧,使用body_id=0(左臂)\n"
|
||||
" * 目标物在图片右侧或机器人右侧,使用body_id=1(右臂)\n"
|
||||
" * 目标物在图片中央或需要头部操作,使用body_id=2(头部)\n"
|
||||
"\n"
|
||||
"【输出格式要求】\n"
|
||||
"必须输出JSON格式,包含sequence数组。每个技能对象包含3个一级字段:\n"
|
||||
"1. skill: 技能名称(字符串)\n"
|
||||
"2. execution: 执行方式,serial(串行)或 parallel(并行)\n"
|
||||
"3. parameters: 参数字典,包含该技能接口文件目标字段的所有字段,并填入合理的预测值。如果技能无参数,使用null。\n"
|
||||
"\n"
|
||||
"注意:一级字段(skill, execution, parameters)是固定结构,直接使用即可,不需要预测。\n"
|
||||
"注意:一级字段(skill, execution, parameters)是固定结构。\n"
|
||||
"\n"
|
||||
"【技能参数说明】\n"
|
||||
+ skill_params_doc +
|
||||
@@ -191,12 +188,13 @@ class IntentRouter:
|
||||
def build_chat_prompt(self, need_camera: bool) -> str:
|
||||
if need_camera:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请结合图片内容简短回答。不要超过100个token。"
|
||||
"你是一个机器人视觉助理,擅长分析图片中物体的相对位置和空间关系。\n"
|
||||
"请结合图片内容,重点描述物体之间的相对位置(如左右、前后、上下、远近),仅基于可观察信息回答。\n"
|
||||
"回答应简短、客观,不要超过100个token。"
|
||||
)
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请自然、简短地与用户对话。不要超过100个token。"
|
||||
"你是一个表达清晰、语气自然的真人助理。\n"
|
||||
"请简短地与用户对话,不要超过100个token。"
|
||||
)
|
||||
|
||||
def _load_kb_data(self) -> list[dict]:
|
||||
@@ -233,7 +231,7 @@ class IntentRouter:
|
||||
|
||||
def build_default_system_prompt(self) -> str:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"你是一个工厂专业的助手。\n"
|
||||
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
|
||||
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
|
||||
"请根据对话模式调整你的回答风格。"
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeCallbacks:
|
||||
# ==================== 初始化与内部工具 ====================
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def _mark_utterance_processed(self) -> bool:
|
||||
node = self.node
|
||||
with node.utterance_lock:
|
||||
if node.current_utterance_id == node.last_processed_utterance_id:
|
||||
return False
|
||||
node.last_processed_utterance_id = node.current_utterance_id
|
||||
return True
|
||||
|
||||
def _trigger_sv_for_check_voice(self, source: str):
|
||||
node = self.node
|
||||
if not (node.sv_enabled and node.sv_client):
|
||||
return
|
||||
if not self._mark_utterance_processed():
|
||||
return
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info(f"[声纹] CHECK_VOICE状态,数据库为空,跳过声纹验证(来源: {source})")
|
||||
return
|
||||
if not node.sv_speech_end_event.is_set():
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
if buffer_size > 0:
|
||||
node.sv_speech_end_event.set()
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source})")
|
||||
|
||||
# ==================== 业务逻辑代理 ====================
|
||||
def handle_interrupt_command(self, msg):
|
||||
return self.node._handle_interrupt_command(msg)
|
||||
|
||||
def check_interrupt_and_cancel_turn(self) -> bool:
|
||||
return self.node._check_interrupt_and_cancel_turn()
|
||||
|
||||
def handle_wake_word(self, text: str) -> str:
|
||||
return self.node._handle_wake_word(text)
|
||||
|
||||
def check_shutup_command(self, text: str) -> bool:
|
||||
return self.node._check_shutup_command(text)
|
||||
|
||||
def check_camera_command(self, text: str):
|
||||
return self.node.intent_router.check_camera_command(text)
|
||||
|
||||
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
|
||||
return self.node._llm_process_stream_with_camera(user_text, need_camera)
|
||||
|
||||
def put_tts_text(self, text: str):
|
||||
return self.node._put_tts_text(text)
|
||||
|
||||
def force_stop_tts(self):
|
||||
return self.node._force_stop_tts()
|
||||
|
||||
def drain_queue(self, q):
|
||||
return self.node._drain_queue(q)
|
||||
|
||||
# ==================== 录音/VAD回调 ====================
|
||||
def get_silence_threshold(self) -> int:
|
||||
"""获取动态静音阈值(毫秒)"""
|
||||
node = self.node
|
||||
return node.silence_duration_ms
|
||||
|
||||
def should_put_audio_to_queue(self) -> bool:
|
||||
"""
|
||||
检查是否应该将音频放入队列(用于ASR),根据状态机决定是否允许ASR
|
||||
"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
|
||||
ConversationState.AUTHORIZED]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_speech_start(self):
|
||||
"""录音线程检测到人声开始"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到人声,开始录音")
|
||||
|
||||
with node.utterance_lock:
|
||||
node.current_utterance_id += 1
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
if state == ConversationState.IDLE:
|
||||
# Idle -> CheckVoice
|
||||
if node.sv_enabled and node.sv_client:
|
||||
# 开始录音用于声纹验证
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
|
||||
|
||||
elif state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态,继续录音用于声纹验证
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态,开始录音用于声纹验证(验证当前用户)
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
|
||||
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
|
||||
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
|
||||
# 声纹验证录音(CHECK_VOICE, AUTHORIZED状态)
|
||||
if node.sv_enabled and node.sv_recording:
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
node.get_logger().debug(f"[声纹] 录音失败: {e}")
|
||||
|
||||
def on_speech_end(self):
|
||||
"""录音线程检测到说话结束(静音一段时间)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到说话结束")
|
||||
|
||||
state = node._get_state()
|
||||
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
|
||||
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
self._trigger_sv_for_check_voice("VAD")
|
||||
return
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
|
||||
node.sv_speech_end_event.set()
|
||||
|
||||
# 如果TTS正在播放,异步等待声纹验证结果,如果通过才中断TTS
|
||||
# 使用独立线程避免阻塞录音线程,影响TTS播放
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().info("[打断] TTS播放中,用户说话结束,异步等待声纹验证结果...")
|
||||
def _check_sv_and_interrupt():
|
||||
# 等待声纹验证结果(最多等待2秒)
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=2.0
|
||||
):
|
||||
# 声纹验证完成,检查结果
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id}),中断TTS播放")
|
||||
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 声纹验证未通过,不中断TTS(状态: {speaker_state.value})")
|
||||
else:
|
||||
node.get_logger().warning("[打断] 声纹验证超时,不中断TTS")
|
||||
# 在独立线程中等待,避免阻塞录音线程
|
||||
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
|
||||
return
|
||||
|
||||
def on_new_segment(self):
|
||||
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state == ConversationState.AUTHORIZED:
|
||||
# TTS播放期间,检测到人声时不立即中断,而是开始录音用于声纹验证
|
||||
# 等待用户说话结束(speech_end)后,如果声纹验证通过,才中断TTS
|
||||
# 这样可以避免TTS回声误触发,但支持真正的用户打断
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[打断] TTS播放中,检测到人声,开始录音用于声纹验证(等待说话结束后验证)")
|
||||
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
|
||||
else:
|
||||
# TTS未播放时,检查声纹验证结果并立即中断
|
||||
if node.sv_enabled and node.sv_client:
|
||||
with node.sv_lock:
|
||||
current_speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
|
||||
node._interrupt_tts("检测到人声(已授权用户)")
|
||||
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但声纹未验证或未匹配,不中断TTS(当前状态: {speaker_state.value})")
|
||||
else:
|
||||
# 未启用声纹,直接中断(保持原有行为)
|
||||
node._interrupt_tts("检测到人声(未启用声纹)")
|
||||
node.get_logger().info("[打断] 检测到人声,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value},非已授权用户,不允许打断TTS")
|
||||
|
||||
def on_heartbeat(self):
|
||||
"""录音线程静音心跳回调"""
|
||||
self.node.get_logger().info("[录音线程] 静音中")
|
||||
|
||||
# ==================== ASR回调 ====================
|
||||
def on_asr_sentence_end(self, text: str):
|
||||
"""ASR sentence_end回调 - 将文本放入队列"""
|
||||
node = self.node
|
||||
if not text or not text.strip():
|
||||
return
|
||||
text_clean = text.strip()
|
||||
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
# 规则2:CHECK_VOICE状态下,如果ASR识别完成但VAD还没有触发speech_end,主动触发声纹验证
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[ASR] CHECK_VOICE状态,ASR识别完成,主动触发声纹验证")
|
||||
self._trigger_sv_for_check_voice("ASR")
|
||||
|
||||
# 其他状态,将文本放入队列
|
||||
node.text_queue.put(text_clean, timeout=1.0)
|
||||
|
||||
def on_asr_text_update(self, text: str):
|
||||
"""ASR 实时文本更新回调 - 用于多轮提示"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")
|
||||
@@ -1,188 +0,0 @@
|
||||
import queue
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeWorkers:
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def recording_worker(self):
|
||||
"""线程1: 录音线程 - 唯一实时线程"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 启动")
|
||||
node.audio_recorder.record_with_vad()
|
||||
|
||||
def asr_worker(self):
|
||||
"""线程2: ASR推理线程 - 只做 audio → text"""
|
||||
node = self.node
|
||||
node.get_logger().info("[ASR推理线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = node.audio_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if node.interrupt_event.is_set():
|
||||
continue
|
||||
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
|
||||
node.asr_client.send_audio(audio_chunk)
|
||||
|
||||
def process_worker(self):
|
||||
"""线程3: 主线程 - 处理业务逻辑"""
|
||||
node = self.node
|
||||
node.get_logger().info("[主线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
text = node.text_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
|
||||
current_state = node._get_state()
|
||||
|
||||
if current_state == ConversationState.CHECK_VOICE:
|
||||
if node.use_wake_word:
|
||||
node.get_logger().info(f"[主线程] CHECK_VOICE状态,检查唤醒词,文本: {text}")
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'),回到Idle状态")
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
|
||||
text = processed_text
|
||||
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:等待声纹验证结果...")
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if not node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=15.0
|
||||
):
|
||||
node.get_logger().warning("[主线程] CHECK_VOICE状态:声纹结果未ready(超时15秒),拒绝本轮")
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.clear()
|
||||
node._change_state(ConversationState.IDLE, "声纹结果未ready")
|
||||
continue
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:声纹结果ready,继续处理")
|
||||
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
score = node.current_speaker_score
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
|
||||
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
|
||||
else:
|
||||
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
|
||||
node.callbacks.put_tts_text("声纹验证失败")
|
||||
node._change_state(ConversationState.IDLE, "声纹验证失败")
|
||||
continue
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
|
||||
|
||||
elif current_state == ConversationState.AUTHORIZED:
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[主线程] AUTHORIZED状态,TTS播放中,忽略ASR识别结果(只有VAD检测到已授权用户人声才能中断)")
|
||||
continue
|
||||
|
||||
elif current_state == ConversationState.IDLE:
|
||||
node.get_logger().warning("[主线程] Idle状态收到文本,忽略")
|
||||
continue
|
||||
|
||||
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
text = processed_text
|
||||
|
||||
if node.callbacks.check_shutup_command(text):
|
||||
node.get_logger().info("[主线程] 检测到闭嘴指令")
|
||||
node.interrupt_event.set()
|
||||
node.callbacks.force_stop_tts()
|
||||
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
|
||||
continue
|
||||
|
||||
intent_payload = node.intent_router.route(text)
|
||||
node._handle_intent(intent_payload)
|
||||
|
||||
if current_state == ConversationState.AUTHORIZED:
|
||||
node.session_start_time = time.time()
|
||||
|
||||
def sv_worker(self):
|
||||
"""线程5: 声纹识别线程 - 非实时、低频(CAM++)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[声纹识别线程] 启动")
|
||||
|
||||
# 动态计算最小音频样本数,确保降采样到16kHz后≥0.5秒
|
||||
target_sr = 16000 # CAM++模型目标采样率
|
||||
min_duration_seconds = 0.5
|
||||
min_samples_at_target_sr = int(target_sr * min_duration_seconds) # 8000样本@16kHz
|
||||
|
||||
if node.sample_rate >= target_sr:
|
||||
downsample_step = int(node.sample_rate / target_sr)
|
||||
min_audio_samples = min_samples_at_target_sr * downsample_step
|
||||
else:
|
||||
min_audio_samples = int(node.sample_rate * min_duration_seconds)
|
||||
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
if node.sv_speech_end_event.wait(timeout=0.1):
|
||||
node.sv_speech_end_event.clear()
|
||||
with node.sv_lock:
|
||||
audio_list = list(node.sv_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
node.sv_audio_buffer.clear()
|
||||
|
||||
node.get_logger().info(f"[声纹识别] 收到speech_end事件,录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info("[声纹识别] 数据库为空,跳过验证,直接设置UNKNOWN状态")
|
||||
continue
|
||||
|
||||
if buffer_size >= min_audio_samples:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
embedding, success = node.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=node.sample_rate
|
||||
)
|
||||
|
||||
if not success or embedding is None:
|
||||
node.get_logger().debug("[声纹识别] 提取embedding失败")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.ERROR
|
||||
node.current_speaker_score = 0.0
|
||||
else:
|
||||
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = speaker_id
|
||||
node.current_speaker_state = match_state
|
||||
node.current_speaker_score = score
|
||||
|
||||
if match_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
|
||||
elif match_state == SpeakerState.REJECTED:
|
||||
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.UNKNOWN
|
||||
node.current_speaker_score = 0.0
|
||||
|
||||
with node.sv_result_cv:
|
||||
node.sv_result_seq += 1
|
||||
node.sv_result_cv.notify_all()
|
||||
|
||||
except Exception as e:
|
||||
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
"""
|
||||
声纹注册独立节点:运行完成后退出
|
||||
"""
|
||||
import collections
|
||||
"""声纹注册独立节点:运行完成后退出"""
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import threading
|
||||
import queue
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
|
||||
from robot_speaker.models.asr.dashscope import DashScopeASR
|
||||
from robot_speaker.srv import ASRRecognize, AudioData, VADEvent
|
||||
from robot_speaker.core.speaker_verifier import SpeakerVerificationClient
|
||||
from pypinyin import pinyin, Style
|
||||
|
||||
|
||||
@@ -24,66 +19,15 @@ class RegisterSpeakerNode(Node):
|
||||
super().__init__('register_speaker_node')
|
||||
self._load_config()
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
self.buffer_lock = threading.Lock()
|
||||
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
|
||||
self.asr_client = self.create_client(ASRRecognize, '/asr/recognize')
|
||||
self.audio_data_client = self.create_client(AudioData, '/asr/audio_data')
|
||||
self.vad_client = self.create_client(VADEvent, '/vad/event')
|
||||
|
||||
self.speech_start_idx = None
|
||||
self.speech_end_idx = None
|
||||
self.speech_start_time = None
|
||||
|
||||
self.audio_queue = queue.Queue()
|
||||
self.text_queue = queue.Queue()
|
||||
|
||||
self.vad_detector = VADDetector(
|
||||
mode=self.vad_mode,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
vad_detector=self.vad_detector,
|
||||
audio_queue=self.audio_queue,
|
||||
silence_duration_ms=self.silence_duration_ms,
|
||||
min_energy_threshold=self.min_energy_threshold,
|
||||
heartbeat_interval=self.audio_microphone_heartbeat_interval,
|
||||
on_heartbeat=None,
|
||||
is_playing=lambda: False,
|
||||
on_new_segment=None,
|
||||
on_speech_start=self._on_speech_start,
|
||||
on_speech_end=self._on_speech_end,
|
||||
stop_flag=self.stop_event.is_set,
|
||||
on_audio_chunk=self._on_audio_chunk,
|
||||
get_silence_threshold=lambda: self.silence_duration_ms,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.asr_client.on_sentence_end = self._on_asr_sentence_end
|
||||
self.asr_client.start()
|
||||
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self._asr_worker,
|
||||
name="RegisterASRThread",
|
||||
daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
self.text_thread = threading.Thread(
|
||||
target=self._text_worker,
|
||||
name="RegisterTextThread",
|
||||
daemon=True
|
||||
)
|
||||
self.text_thread.start()
|
||||
self.get_logger().info('等待服务启动...')
|
||||
self.asr_client.wait_for_service(timeout_sec=10.0)
|
||||
self.audio_data_client.wait_for_service(timeout_sec=10.0)
|
||||
self.vad_client.wait_for_service(timeout_sec=10.0)
|
||||
self.get_logger().info('所有服务已就绪')
|
||||
|
||||
self.sv_client = SpeakerVerificationClient(
|
||||
model_path=self.sv_model_path,
|
||||
@@ -92,15 +36,20 @@ class RegisterSpeakerNode(Node):
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.get_logger().info("声纹注册节点启动,请说'er gou我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去。'")
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.audio_recorder.record_with_vad,
|
||||
name="RegisterRecordingThread",
|
||||
daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
self.registered = False
|
||||
self.shutting_down = False
|
||||
self.get_logger().info("声纹注册节点启动,请说唤醒词开始注册(例如:'二狗我现在正在注册声纹,这是一段很长的测试语音,请把我的声音录进去')")
|
||||
|
||||
self.timer = self.create_timer(0.2, self._check_done)
|
||||
# 使用队列在线程间传递 VAD 事件,避免在子线程中调用 spin_until_future_complete
|
||||
self.vad_event_queue = queue.Queue()
|
||||
self.recording = False # 录音状态标志
|
||||
self.pending_asr_future = None # 待处理的 ASR future
|
||||
self.pending_audio_future = None # 待处理的 AudioData future
|
||||
self.state = "waiting_speech" # 状态机:waiting_speech, waiting_asr, waiting_audio
|
||||
|
||||
self.vad_thread = threading.Thread(target=self._vad_event_worker, daemon=True)
|
||||
self.vad_thread.start()
|
||||
self.timer = self.create_timer(0.1, self._main_loop)
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
@@ -111,129 +60,59 @@ class RegisterSpeakerNode(Node):
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
dashscope = config['dashscope']
|
||||
audio = config['audio']
|
||||
mic = audio['microphone']
|
||||
soundcard = audio['soundcard']
|
||||
vad = config['vad']
|
||||
system = config['system']
|
||||
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
|
||||
self.input_device_index = mic['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
|
||||
|
||||
self.vad_mode = vad['vad_mode']
|
||||
self.silence_duration_ms = vad['silence_duration_ms']
|
||||
self.min_energy_threshold = vad['min_energy_threshold']
|
||||
|
||||
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
|
||||
self.sv_threshold = system['sv_threshold']
|
||||
self.sv_speaker_db_path = os.path.expanduser(system['sv_speaker_db_path'])
|
||||
self.sv_buffer_size = system['sv_buffer_size']
|
||||
self.wake_word = system['wake_word']
|
||||
|
||||
def _on_speech_start(self):
|
||||
with self.buffer_lock:
|
||||
if self.speech_start_idx is None:
|
||||
self.speech_start_idx = len(self.audio_buffer)
|
||||
self.speech_start_time = time.time()
|
||||
|
||||
def _on_audio_chunk(self, audio_chunk: bytes):
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.buffer_lock:
|
||||
self.audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
|
||||
|
||||
def _on_speech_end(self):
|
||||
with self.buffer_lock:
|
||||
self.speech_end_idx = len(self.audio_buffer)
|
||||
|
||||
def _process_voiceprint_audio(self):
|
||||
"""处理声纹音频:使用用户完整的一句话进行注册"""
|
||||
with self.buffer_lock:
|
||||
audio_list = list(self.audio_buffer)
|
||||
start_idx = self.speech_start_idx if self.speech_start_idx is not None else 0
|
||||
end_idx = self.speech_end_idx if self.speech_end_idx is not None else len(audio_list)
|
||||
self.audio_buffer.clear()
|
||||
self.speech_start_idx = None
|
||||
self.speech_end_idx = None
|
||||
self.speech_start_time = None
|
||||
|
||||
audio_list = audio_list[start_idx:end_idx]
|
||||
buffer_sec = len(audio_list) / self.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}秒")
|
||||
|
||||
try:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
if not success:
|
||||
self.get_logger().error("[注册录音] 提取embedding失败")
|
||||
return
|
||||
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
if self.sv_client.register_speaker(speaker_id, embedding):
|
||||
self.get_logger().info(f"[注册录音] 注册成功,用户ID: {speaker_id},准备退出")
|
||||
self.stop_event.set()
|
||||
else:
|
||||
self.get_logger().error("[注册录音] 注册失败")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册录音] 注册异常: {e}")
|
||||
|
||||
def _asr_worker(self):
|
||||
"""ASR处理线程"""
|
||||
while not self.stop_event.is_set():
|
||||
def _vad_event_worker(self):
|
||||
"""VAD 事件监听线程,只负责接收事件并放入队列,不调用 spin_until_future_complete"""
|
||||
while not self.registered and not self.shutting_down:
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
except queue.Empty:
|
||||
continue
|
||||
request = VADEvent.Request()
|
||||
request.command = "wait"
|
||||
request.timeout_ms = 1000
|
||||
future = self.vad_client.call_async(request)
|
||||
|
||||
# 简单等待 future 完成,不使用 spin_until_future_complete
|
||||
start_time = time.time()
|
||||
while not future.done() and (time.time() - start_time) < 1.5:
|
||||
time.sleep(0.01)
|
||||
|
||||
if not future.done() or self.registered or self.shutting_down:
|
||||
continue
|
||||
|
||||
response = future.result()
|
||||
if response.success and response.event in ["speech_started", "speech_stopped"]:
|
||||
# 将事件放入队列,由主线程处理
|
||||
try:
|
||||
self.vad_event_queue.put(response.event, timeout=0.1)
|
||||
except queue.Full:
|
||||
self.get_logger().warn(f"[VAD] 事件队列已满,丢弃事件: {response.event}")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
|
||||
if not self.shutting_down:
|
||||
self.get_logger().error(f"[VAD] 线程异常: {e}")
|
||||
break
|
||||
|
||||
def _on_asr_sentence_end(self, text: str):
|
||||
"""ASR识别完成回调"""
|
||||
if text and text.strip():
|
||||
self.text_queue.put(text.strip())
|
||||
|
||||
def _text_worker(self):
|
||||
"""文本处理线程:检测唤醒词"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
text = self.text_queue.get(timeout=0.1)
|
||||
self._check_wake_word(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册文本] 处理异常: {e}")
|
||||
def _start_recording(self):
|
||||
"""启动录音,返回 future 供主线程处理"""
|
||||
request = AudioData.Request()
|
||||
request.command = "start"
|
||||
return self.audio_data_client.call_async(request)
|
||||
|
||||
def _to_pinyin(self, text: str) -> str:
|
||||
"""将中文文本转换为拼音"""
|
||||
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
|
||||
if not chars:
|
||||
return ""
|
||||
py_list = pinyin(chars, style=Style.NORMAL)
|
||||
# 多音字处理:取第一个读音
|
||||
return ' '.join([item[0] for item in py_list]).lower().strip()
|
||||
|
||||
def _check_wake_word(self, text: str):
|
||||
"""检查是否包含唤醒词,如果有则注册,没有则继续等待"""
|
||||
text_pinyin = self._to_pinyin(text)
|
||||
wake_word_pinyin = self.wake_word.lower().strip()
|
||||
|
||||
if not wake_word_pinyin:
|
||||
self.get_logger().info(f"[注册唤醒词] 唤醒词配置为空,继续等待")
|
||||
return
|
||||
|
||||
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
|
||||
@@ -246,27 +125,112 @@ class RegisterSpeakerNode(Node):
|
||||
break
|
||||
|
||||
if has_wake_word:
|
||||
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}',使用完整音频注册")
|
||||
self._process_voiceprint_audio()
|
||||
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}',停止录音并获取音频")
|
||||
request = AudioData.Request()
|
||||
request.command = "stop"
|
||||
future = self.audio_data_client.call_async(request)
|
||||
future._future_type = "stop"
|
||||
self.pending_audio_future = future
|
||||
|
||||
def _process_voiceprint_audio(self, response):
|
||||
"""处理声纹音频数据 - 直接使用 AudioData 返回的音频,不再过滤"""
|
||||
if not response or not response.success or response.samples == 0:
|
||||
self.get_logger().error(f"[注册录音] 获取音频数据失败: {response.message if response else '无响应'}")
|
||||
return
|
||||
|
||||
audio_array = np.frombuffer(response.audio_data, dtype=np.int16)
|
||||
buffer_sec = response.samples / response.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 音频长度: {buffer_sec:.2f}秒")
|
||||
|
||||
# 直接使用音频,不再进行 VAD 过滤
|
||||
# 因为 AudioData 服务基于 DashScope VAD,已经是语音活动片段
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=response.sample_rate
|
||||
)
|
||||
if not success or embedding is None:
|
||||
self.get_logger().error("[注册录音] 提取embedding失败")
|
||||
return
|
||||
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
if self.sv_client.register_speaker(speaker_id, embedding):
|
||||
# 注册成功后立即保存到文件
|
||||
self.sv_client.save_speakers()
|
||||
self.get_logger().info(f"[注册录音] 注册成功,用户ID: {speaker_id},已保存到文件,准备退出")
|
||||
self.registered = True
|
||||
else:
|
||||
self.get_logger().info(f"[注册唤醒词] 未检测到唤醒词,继续等待用户说话")
|
||||
|
||||
def _check_done(self):
|
||||
if self.stop_event.is_set():
|
||||
self.get_logger().error("[注册录音] 注册失败")
|
||||
|
||||
def _main_loop(self):
|
||||
"""主循环,在主线程中处理所有异步操作"""
|
||||
# 检查是否完成注册
|
||||
if self.registered:
|
||||
self.get_logger().info("注册完成,节点退出")
|
||||
if self.asr_client:
|
||||
self.asr_client.stop()
|
||||
self.destroy_node()
|
||||
self.shutting_down = True
|
||||
self.timer.cancel()
|
||||
rclpy.shutdown()
|
||||
return
|
||||
|
||||
# 处理待处理的 ASR future
|
||||
if self.pending_asr_future and self.pending_asr_future.done():
|
||||
response = self.pending_asr_future.result()
|
||||
self.pending_asr_future = None
|
||||
|
||||
if response.success and response.text:
|
||||
text = response.text.strip()
|
||||
if text:
|
||||
self._check_wake_word(text)
|
||||
|
||||
self.state = "waiting_speech"
|
||||
|
||||
# 处理待处理的 AudioData future
|
||||
if self.pending_audio_future and self.pending_audio_future.done():
|
||||
response = self.pending_audio_future.result()
|
||||
future_type = getattr(self.pending_audio_future, '_future_type', None)
|
||||
self.pending_audio_future = None
|
||||
|
||||
if future_type == "start":
|
||||
if response.success:
|
||||
self.get_logger().info("[注册录音] 已开始录音")
|
||||
self.recording = True
|
||||
else:
|
||||
self.get_logger().warn(f"[注册录音] 启动录音失败: {response.message}")
|
||||
self.state = "waiting_speech"
|
||||
elif future_type == "stop":
|
||||
self.recording = False
|
||||
self._process_voiceprint_audio(response)
|
||||
|
||||
# 处理 VAD 事件队列
|
||||
try:
|
||||
event = self.vad_event_queue.get_nowait()
|
||||
|
||||
if event == "speech_started" and self.state == "waiting_speech" and not self.recording:
|
||||
self.get_logger().info("[VAD] 检测到语音开始,启动录音")
|
||||
future = self._start_recording()
|
||||
future._future_type = "start"
|
||||
self.pending_audio_future = future
|
||||
|
||||
elif event == "speech_stopped" and self.recording and self.state == "waiting_speech":
|
||||
self.get_logger().info("[VAD] 检测到语音结束,请求 ASR 识别")
|
||||
self.state = "waiting_asr"
|
||||
request = ASRRecognize.Request()
|
||||
request.command = "start"
|
||||
self.pending_asr_future = self.asr_client.call_async(request)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = RegisterSpeakerNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
try:
|
||||
rclpy.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
198
robot_speaker/core/speaker_verifier.py
Normal file
198
robot_speaker/core/speaker_verifier.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
声纹识别模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpeakerState(Enum):
|
||||
"""说话人识别状态"""
|
||||
UNKNOWN = "unknown"
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SpeakerVerificationClient:
|
||||
"""声纹识别客户端 - 非实时、低频处理"""
|
||||
|
||||
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
|
||||
self.model_path = model_path
|
||||
self.threshold = threshold
|
||||
self.speaker_db_path = speaker_db_path
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 优化CPU性能:限制Torch使用的线程数,防止多线程竞争导致性能骤降
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
from funasr import AutoModel
|
||||
model_path = os.path.expanduser(self.model_path)
|
||||
# 禁用自动更新检查,防止每次初始化都联网检查
|
||||
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
|
||||
if self.logger:
|
||||
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
|
||||
|
||||
if self.speaker_db_path:
|
||||
self.load_speakers()
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
if level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warning(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
elif level == "debug":
|
||||
self.logger.debug(msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_speakers(self):
|
||||
if not self.speaker_db_path:
|
||||
return
|
||||
|
||||
db_path = os.path.expanduser(self.speaker_db_path)
|
||||
if not os.path.exists(db_path):
|
||||
self._log("info", f"声纹数据库文件不存在: {db_path},将创建新文件")
|
||||
return
|
||||
try:
|
||||
with open(db_path, 'rb') as f:
|
||||
data = json.load(f)
|
||||
with self._lock:
|
||||
self.speaker_db = {}
|
||||
for speaker_id, info in data.items():
|
||||
embedding_array = np.array(info["embedding"], dtype=np.float32)
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_array,
|
||||
"env": info.get("env", ""),
|
||||
"registered_at": info.get("registered_at", 0.0)
|
||||
}
|
||||
self._log("info", f"已加载 {len(self.speaker_db)} 个已注册说话人")
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
|
||||
def save_speakers(self):
|
||||
if not self.speaker_db_path:
|
||||
return
|
||||
db_path = os.path.expanduser(self.speaker_db_path)
|
||||
try:
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
with self._lock:
|
||||
data = {}
|
||||
for speaker_id, info in self.speaker_db.items():
|
||||
data[speaker_id] = {
|
||||
"embedding": info["embedding"].tolist(),
|
||||
"env": info.get("env", ""),
|
||||
"registered_at": info.get("registered_at", 0.0)
|
||||
}
|
||||
with open(db_path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self._log("info", f"已保存 {len(data)} 个已注册说话人到: {db_path}")
|
||||
except Exception as e:
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
|
||||
def extract_embedding(self, audio_array: np.ndarray, sample_rate: int = 16000) -> tuple[np.ndarray | None, bool]:
|
||||
try:
|
||||
if len(audio_array) == 0:
|
||||
return None, False
|
||||
# 确保是int16格式
|
||||
if audio_array.dtype != np.int16:
|
||||
audio_array = audio_array.astype(np.int16)
|
||||
# 转换为float32并归一化到[-1, 1]
|
||||
audio_float = audio_array.astype(np.float32) / 32768.0
|
||||
# 调用模型提取embedding
|
||||
result = self.model.generate(input=audio_float, cache={})
|
||||
if result and len(result) > 0 and "spk_embedding" in result[0]:
|
||||
embedding = result[0]["spk_embedding"]
|
||||
if embedding is not None and len(embedding) > 0:
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
return embedding_array, True
|
||||
return None, False
|
||||
except Exception as e:
|
||||
self._log("error", f"提取声纹特征失败: {e}")
|
||||
return None, False
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray) -> tuple[str | None, SpeakerState, float, float]:
|
||||
if embedding is None or len(embedding) == 0:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
|
||||
|
||||
with self._lock:
|
||||
if len(self.speaker_db) == 0:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, float(self.threshold)
|
||||
try:
|
||||
best_speaker_id = None
|
||||
best_score = 0.0
|
||||
with self._lock:
|
||||
for speaker_id, info in self.speaker_db.items():
|
||||
stored_embedding = info["embedding"]
|
||||
# 计算余弦相似度
|
||||
dot_product = np.dot(embedding, stored_embedding)
|
||||
norm_embedding = np.linalg.norm(embedding)
|
||||
norm_stored = np.linalg.norm(stored_embedding)
|
||||
|
||||
if norm_embedding > 0 and norm_stored > 0:
|
||||
score = dot_product / (norm_embedding * norm_stored)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_speaker_id = speaker_id
|
||||
|
||||
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
|
||||
return best_speaker_id, state, float(best_score), float(self.threshold)
|
||||
except Exception as e:
|
||||
self._log("error", f"匹配说话人失败: {e}")
|
||||
return None, SpeakerState.ERROR, 0.0, float(self.threshold)
|
||||
|
||||
def register_speaker(self, speaker_id: str, embedding: np.ndarray, env: str = "") -> bool:
|
||||
if embedding is None or len(embedding) == 0:
|
||||
return False
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": np.array(embedding, dtype=np.float32),
|
||||
"env": env,
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"注册说话人失败: {e}")
|
||||
return False
|
||||
|
||||
def get_speaker_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self.speaker_db)
|
||||
|
||||
def get_speaker_list(self) -> list[str]:
|
||||
with self._lock:
|
||||
return list(self.speaker_db.keys())
|
||||
|
||||
def remove_speaker(self, speaker_id: str) -> bool:
|
||||
with self._lock:
|
||||
if speaker_id in self.speaker_db:
|
||||
del self.speaker_db[speaker_id]
|
||||
self._log("info", f"已删除说话人: {speaker_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
try:
|
||||
self.save_speakers()
|
||||
if hasattr(self, 'model') and self.model:
|
||||
del self.model
|
||||
except Exception as e:
|
||||
self._log("error", f"清理资源失败: {e}")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
统一数据结构定义
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASRResult:
|
||||
"""ASR识别结果"""
|
||||
text: str
|
||||
confidence: float | None = None
|
||||
language: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""LLM消息"""
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSRequest:
|
||||
"""TTS请求"""
|
||||
text: str
|
||||
voice: str | None = None # 如果为None,使用控制台配置的默认音色
|
||||
speed: float | None = None
|
||||
pitch: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageMessage:
|
||||
"""图像消息 - 用于多模态LLM"""
|
||||
image_data: bytes # base64编码的图像数据
|
||||
image_format: str = "jpeg"
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""模型层"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""ASR模型"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
class ASRClient:
|
||||
def start(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def send_audio(self, audio_data: bytes) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
ASR语音识别模块
|
||||
"""
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
|
||||
from robot_speaker.models.asr.base import ASRClient
|
||||
|
||||
|
||||
class DashScopeASR(ASRClient):
|
||||
"""DashScope实时ASR识别器封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
sample_rate: int,
|
||||
model: str,
|
||||
url: str,
|
||||
logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.sample_rate = sample_rate
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.logger = logger
|
||||
|
||||
self.conversation = None
|
||||
self.running = False
|
||||
self.on_sentence_end = None
|
||||
self.on_text_update = None # 实时文本更新回调
|
||||
|
||||
# 线程同步机制
|
||||
self._stop_lock = threading.Lock() # 防止并发调用 stop_current_recognition
|
||||
self._final_result_event = threading.Event() # 等待 final 回调完成
|
||||
self._pending_commit = False # 标记是否有待处理的 commit
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
else:
|
||||
print(f"[ASR] {msg}")
|
||||
|
||||
def start(self):
|
||||
"""启动ASR识别器"""
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
callback = _ASRCallback(self)
|
||||
self.conversation = OmniRealtimeConversation(
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
callback=callback
|
||||
)
|
||||
callback.conversation = self.conversation
|
||||
|
||||
self.conversation.connect()
|
||||
|
||||
transcription_params = TranscriptionParams(
|
||||
language='zh',
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
|
||||
# 本地 VAD → 只控制 TTS 打断
|
||||
# 服务端 turn detection → 只控制 ASR 输出、LLM 生成轮次
|
||||
|
||||
self.conversation.update_session(
|
||||
output_modalities=[MultiModality.TEXT],
|
||||
enable_input_audio_transcription=True,
|
||||
transcription_params=transcription_params,
|
||||
enable_turn_detection=True,
|
||||
# 保留服务端 turn detection
|
||||
turn_detection_type='server_vad', # 服务端VAD
|
||||
turn_detection_threshold=0.2, # 可调
|
||||
turn_detection_silence_duration_ms=800
|
||||
)
|
||||
|
||||
self.running = True
|
||||
self._log("info", "ASR已启动")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.running = False
|
||||
self._log("error", f"ASR启动失败: {e}")
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except:
|
||||
pass
|
||||
self.conversation = None
|
||||
return False
|
||||
|
||||
def send_audio(self, audio_chunk: bytes):
|
||||
"""发送音频chunk到ASR"""
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
try:
|
||||
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
|
||||
self.conversation.append_audio(audio_b64)
|
||||
return True
|
||||
except Exception as e:
|
||||
# 连接已关闭或其他错误,静默处理(避免日志过多)
|
||||
# running状态会在stop_current_recognition中正确设置
|
||||
return False
|
||||
|
||||
def stop_current_recognition(self):
|
||||
"""
|
||||
停止当前识别,触发final结果,然后重新启动
|
||||
优化:
|
||||
1. 使用事件代替 sleep,等待 final 回调完成
|
||||
2. 使用锁防止并发调用
|
||||
3. 处理 start() 失败的情况,确保 running 状态正确
|
||||
4. 添加超时机制,避免无限等待
|
||||
"""
|
||||
# 使用锁防止并发调用
|
||||
if not self._stop_lock.acquire(blocking=False):
|
||||
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
|
||||
# 重置事件,准备等待 final 回调
|
||||
self._final_result_event.clear()
|
||||
self._pending_commit = True
|
||||
|
||||
# 触发 commit,等待 final 结果
|
||||
self.conversation.commit()
|
||||
|
||||
# 等待 final 回调完成(最多等待3秒)
|
||||
if self._final_result_event.wait(timeout=3.0):
|
||||
self._log("debug", "已收到 final 回调,准备关闭连接")
|
||||
else:
|
||||
self._log("warning", "等待 final 回调超时,继续执行")
|
||||
|
||||
# 先设置running=False,防止ASR线程继续发送音频
|
||||
self.running = False
|
||||
|
||||
# 关闭当前连接
|
||||
old_conversation = self.conversation
|
||||
self.conversation = None # 立即清空,防止send_audio继续使用
|
||||
try:
|
||||
old_conversation.close()
|
||||
except Exception as e:
|
||||
self._log("warning", f"关闭连接时出错: {e}")
|
||||
|
||||
# 短暂等待,确保连接完全关闭
|
||||
time.sleep(0.1)
|
||||
|
||||
# 重新启动,如果失败则保持 running=False
|
||||
if not self.start():
|
||||
self._log("error", "ASR重启失败,running状态已重置")
|
||||
return False
|
||||
|
||||
# 启动成功,running已在start()中设置为True
|
||||
return True
|
||||
finally:
|
||||
self._pending_commit = False
|
||||
self._stop_lock.release()
|
||||
|
||||
def stop(self):
|
||||
"""停止ASR识别器"""
|
||||
# 等待正在执行的 stop_current_recognition 完成
|
||||
with self._stop_lock:
|
||||
self.running = False
|
||||
self._final_result_event.set() # 唤醒可能正在等待的线程
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception as e:
|
||||
self._log("warning", f"停止时关闭连接出错: {e}")
|
||||
self.conversation = None
|
||||
self._log("info", "ASR已停止")
|
||||
|
||||
|
||||
class _ASRCallback(OmniRealtimeCallback):
|
||||
"""ASR回调处理"""
|
||||
|
||||
def __init__(self, asr_client: DashScopeASR):
|
||||
self.asr_client = asr_client
|
||||
self.conversation = None
|
||||
|
||||
def on_open(self):
|
||||
self.asr_client._log("info", "ASR WebSocket已连接")
|
||||
|
||||
def on_close(self, code, msg):
|
||||
self.asr_client._log("info", f"ASR WebSocket已关闭: code={code}, msg={msg}")
|
||||
|
||||
def on_event(self, response):
|
||||
event_type = response.get('type', '')
|
||||
|
||||
if event_type == 'session.created':
|
||||
session_id = response.get('session', {}).get('id', '')
|
||||
self.asr_client._log("info", f"ASR会话已创建: {session_id}")
|
||||
|
||||
elif event_type == 'conversation.item.input_audio_transcription.completed':
|
||||
# 最终识别结果
|
||||
transcript = response.get('transcript', '')
|
||||
if transcript and transcript.strip() and self.asr_client.on_sentence_end:
|
||||
self.asr_client.on_sentence_end(transcript.strip())
|
||||
|
||||
# 如果有待处理的 commit,通知等待的线程
|
||||
if self.asr_client._pending_commit:
|
||||
self.asr_client._final_result_event.set()
|
||||
|
||||
elif event_type == 'conversation.item.input_audio_transcription.text':
|
||||
# 实时识别文本更新(多轮提示)
|
||||
transcript = response.get('transcript', '') or response.get('text', '')
|
||||
if transcript and transcript.strip() and self.asr_client.on_text_update:
|
||||
self.asr_client.on_text_update(transcript.strip())
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_started':
|
||||
self.asr_client._log("info", "ASR检测到说话开始")
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_stopped':
|
||||
self.asr_client._log("info", "ASR检测到说话结束")
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""LLM模型"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
interrupt_check=None) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
LLM大语言模型模块
|
||||
支持多模态(文本+图像)
|
||||
"""
|
||||
from openai import OpenAI
|
||||
from typing import Optional, List
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
from robot_speaker.models.llm.base import LLMClient
|
||||
|
||||
|
||||
class DashScopeLLM(LLMClient):
|
||||
"""DashScope LLM客户端封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
base_url: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
name: str = "LLM",
|
||||
logger=None):
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
msg = f"[{self.name}] {msg}"
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
"""非流式聊天:任务规划"""
|
||||
payload_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=payload_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
stream=False
|
||||
)
|
||||
reply = response.choices[0].message.content.strip()
|
||||
return reply if reply else None
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
images: Optional[List[str]] = None,
|
||||
interrupt_check=None) -> str | None:
|
||||
"""
|
||||
流式聊天:语音系统
|
||||
支持多模态(文本+图像)
|
||||
支持中断检查(interrupt_check: 返回True表示需要中断)
|
||||
"""
|
||||
# 转换消息格式,支持多模态
|
||||
# 图像只添加到最后一个user消息中
|
||||
payload_messages = []
|
||||
last_user_idx = -1
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == "user":
|
||||
last_user_idx = i
|
||||
|
||||
has_images_in_message = False
|
||||
for i, msg in enumerate(messages):
|
||||
msg_dict = {"role": msg.role}
|
||||
|
||||
# 如果当前消息是最后一个user消息且有图像,构建多模态content
|
||||
if i == last_user_idx and msg.role == "user" and images and len(images) > 0:
|
||||
content_list = [{"type": "text", "text": msg.content}]
|
||||
# 添加所有图像
|
||||
for img_idx, img_base64 in enumerate(images):
|
||||
image_url = f"data:image/jpeg;base64,{img_base64[:50]}..." if len(img_base64) > 50 else f"data:image/jpeg;base64,{img_base64}"
|
||||
content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||
}
|
||||
})
|
||||
self._log("info", f"[多模态] 添加图像 #{img_idx+1} 到user消息,base64长度: {len(img_base64)}")
|
||||
msg_dict["content"] = content_list
|
||||
has_images_in_message = True
|
||||
else:
|
||||
msg_dict["content"] = msg.content
|
||||
|
||||
payload_messages.append(msg_dict)
|
||||
|
||||
# 记录多模态信息
|
||||
if images and len(images) > 0:
|
||||
if has_images_in_message:
|
||||
# 找到最后一个user消息,记录其content结构
|
||||
last_user_msg = payload_messages[last_user_idx] if last_user_idx >= 0 else None
|
||||
if last_user_msg and isinstance(last_user_msg.get("content"), list):
|
||||
content_items = last_user_msg["content"]
|
||||
text_items = [item for item in content_items if item.get("type") == "text"]
|
||||
image_items = [item for item in content_items if item.get("type") == "image_url"]
|
||||
self._log("info", f"[多模态] 已发送多模态请求: {len(text_items)}个文本 + {len(image_items)}张图片")
|
||||
self._log("debug", f"[多模态] 用户文本: {text_items[0].get('text', '')[:50] if text_items else 'N/A'}")
|
||||
else:
|
||||
self._log("warning", "[多模态] 消息格式异常,无法确认图片是否添加")
|
||||
else:
|
||||
self._log("warning", f"[多模态] 有{len(images)}张图片,但未找到user消息,图片未被添加")
|
||||
else:
|
||||
self._log("debug", "[多模态] 纯文本请求(无图片)")
|
||||
|
||||
full_reply = ""
|
||||
interrupted = False
|
||||
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=payload_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
# 检查中断标志
|
||||
if interrupt_check and interrupt_check():
|
||||
self._log("info", "LLM流式处理被中断")
|
||||
interrupted = True
|
||||
break
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_reply += content
|
||||
if on_token:
|
||||
on_token(content)
|
||||
# 在on_token回调后再次检查中断(on_token可能设置中断标志)
|
||||
if interrupt_check and interrupt_check():
|
||||
self._log("info", "LLM流式处理在on_token回调后被中断")
|
||||
interrupted = True
|
||||
break
|
||||
|
||||
if interrupted:
|
||||
return None # 被中断时返回None,表示未完成
|
||||
|
||||
return full_reply.strip() if full_reply else None
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""TTS模型"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
|
||||
|
||||
class TTSClient:
|
||||
"""TTS客户端抽象基类"""
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
"""
|
||||
TTS语音合成模块
|
||||
"""
|
||||
import subprocess
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
from robot_speaker.models.tts.base import TTSClient
|
||||
|
||||
|
||||
class DashScopeTTSClient(TTSClient):
|
||||
"""DashScope流式TTS客户端封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
voice: str,
|
||||
card_index: int,
|
||||
device_index: int,
|
||||
output_sample_rate: int = 44100,
|
||||
output_channels: int = 2,
|
||||
output_volume: float = 1.0,
|
||||
tts_source_sample_rate: int = 22050, # TTS服务固定输出采样率
|
||||
tts_source_channels: int = 1, # TTS服务固定输出声道数
|
||||
tts_ffmpeg_thread_queue_size: int = 1024, # ffmpeg输入线程队列大小
|
||||
reference_signal_buffer=None, # 参考信号缓冲区(用于回声消除)
|
||||
logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.card_index = card_index
|
||||
self.device_index = device_index
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_channels = output_channels
|
||||
self.output_volume = output_volume
|
||||
self.tts_source_sample_rate = tts_source_sample_rate
|
||||
self.tts_source_channels = tts_source_channels
|
||||
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
|
||||
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
|
||||
self.logger = logger
|
||||
self.current_ffmpeg_pid = None # 当前ffmpeg进程的PID
|
||||
|
||||
# 构建ALSA设备, 允许 ffmpeg 自动重采样 / 重声道
|
||||
self.alsa_device = f"plughw:{card_index},{device_index}" if (
|
||||
card_index >= 0 and device_index >= 0
|
||||
) else "default"
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志,根据级别调用对应的ROS2日志方法"""
|
||||
if self.logger:
|
||||
# ROS2 logger不能动态改变severity级别,需要显式调用对应方法
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
else:
|
||||
self.logger.info(msg) # 默认使用info级别
|
||||
else:
|
||||
print(f"[TTS] {msg}")
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
"""主流程:流式合成并播放"""
|
||||
callback = _TTSCallback(self, interrupt_check, on_chunk, self.reference_signal_buffer)
|
||||
# 使用配置的voice,request.voice为None或空时使用self.voice
|
||||
voice_to_use = request.voice if request.voice and request.voice.strip() else self.voice
|
||||
|
||||
if not voice_to_use or not voice_to_use.strip():
|
||||
self._log("error", f"Voice参数无效: '{voice_to_use}'")
|
||||
return False
|
||||
|
||||
self._log("info", f"TTS开始: 文本='{request.text[:50]}...', voice='{voice_to_use}'")
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=self.model,
|
||||
voice=voice_to_use,
|
||||
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
try:
|
||||
synthesizer.streaming_call(request.text)
|
||||
synthesizer.streaming_complete()
|
||||
finally:
|
||||
callback.cleanup()
|
||||
|
||||
return not callback._interrupted
|
||||
|
||||
|
||||
class _TTSCallback(ResultCallback):
|
||||
"""TTS回调处理 - 使用ffmpeg播放,自动处理采样率转换"""
|
||||
|
||||
def __init__(self, tts_client: DashScopeTTSClient,
|
||||
interrupt_check=None,
|
||||
on_chunk=None,
|
||||
reference_signal_buffer=None):
|
||||
self.tts_client = tts_client
|
||||
self.interrupt_check = interrupt_check
|
||||
self.on_chunk = on_chunk
|
||||
self.reference_signal_buffer = reference_signal_buffer # 参考信号缓冲区
|
||||
self._proc = None
|
||||
self._interrupted = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def on_open(self):
|
||||
# 使用ffmpeg播放,自动处理采样率转换(TTS源采样率 -> 设备采样率)
|
||||
# TTS服务输出固定采样率和声道数,ffmpeg会自动转换为设备采样率和声道数
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # 原始 PCM
|
||||
'-ar', str(self.tts_client.tts_source_sample_rate), # TTS输出采样率(从配置文件读取)
|
||||
'-ac', str(self.tts_client.tts_source_channels), # TTS输出声道数(从配置文件读取)
|
||||
'-i', 'pipe:0', # stdin
|
||||
'-f', 'alsa', # 输出到 ALSA
|
||||
'-ar', str(self.tts_client.output_sample_rate), # 输出设备采样率(从配置文件读取)
|
||||
'-ac', str(self.tts_client.output_channels), # 输出设备声道数(从配置文件读取)
|
||||
'-acodec', 'pcm_s16le', # 输出编码
|
||||
'-fflags', 'nobuffer', # 减少缓冲
|
||||
'-flags', 'low_delay', # 低延迟
|
||||
'-avioflags', 'direct', # 尝试直通写入 ALSA,减少延迟
|
||||
self.tts_client.alsa_device
|
||||
]
|
||||
|
||||
# 将 -thread_queue_size 放到输入文件之前
|
||||
insert_pos = ffmpeg_cmd.index('-i')
|
||||
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
|
||||
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
|
||||
|
||||
# 添加音量调节filter(如果音量不是1.0)
|
||||
if self.tts_client.output_volume != 1.0:
|
||||
# 在输出编码前插入音量filter
|
||||
# volume filter放在输入之后、输出编码之前
|
||||
acodec_idx = ffmpeg_cmd.index('-acodec')
|
||||
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
|
||||
ffmpeg_cmd.insert(acodec_idx, '-af')
|
||||
|
||||
self.tts_client._log("info", f"启动ffmpeg播放: ALSA设备={self.tts_client.alsa_device}, "
|
||||
f"输出采样率={self.tts_client.output_sample_rate}Hz, "
|
||||
f"输出声道数={self.tts_client.output_channels}, "
|
||||
f"音量={self.tts_client.output_volume * 100:.0f}%")
|
||||
self._proc = subprocess.Popen(
|
||||
ffmpeg_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE # 改为PIPE以便捕获错误
|
||||
)
|
||||
# 记录ffmpeg进程PID
|
||||
self.tts_client.current_ffmpeg_pid = self._proc.pid
|
||||
self.tts_client._log("debug", f"ffmpeg进程已启动,PID={self._proc.pid}")
|
||||
|
||||
def on_complete(self):
|
||||
pass
|
||||
|
||||
def on_error(self, message: str):
|
||||
self.tts_client._log("error", f"TTS错误: {message}")
|
||||
|
||||
def on_close(self):
|
||||
self.cleanup()
|
||||
|
||||
def on_event(self, message):
|
||||
pass
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
"""接收音频数据并播放"""
|
||||
if self._interrupted:
|
||||
return
|
||||
|
||||
if self.interrupt_check and self.interrupt_check():
|
||||
# 停止播放,不停止 TTS
|
||||
self._interrupted = True
|
||||
if self._proc:
|
||||
self._proc.terminate()
|
||||
return
|
||||
|
||||
# 优先写入ffmpeg,避免阻塞播放
|
||||
# 优先写入ffmpeg,避免阻塞播放
|
||||
if self._proc and self._proc.stdin and not self._interrupted:
|
||||
try:
|
||||
self._proc.stdin.write(data)
|
||||
self._proc.stdin.flush()
|
||||
except BrokenPipeError:
|
||||
# ffmpeg进程可能已退出,检查错误
|
||||
if self._proc.stderr:
|
||||
error_msg = self._proc.stderr.read().decode('utf-8', errors='ignore')
|
||||
self.tts_client._log("error", f"ffmpeg错误: {error_msg}")
|
||||
self._interrupted = True
|
||||
|
||||
# 将音频数据添加到参考信号缓冲区(用于回声消除)
|
||||
# 在写入ffmpeg之后处理,避免阻塞播放
|
||||
if self.reference_signal_buffer and data:
|
||||
try:
|
||||
self.reference_signal_buffer.add_reference(
|
||||
data,
|
||||
source_sample_rate=self.tts_client.tts_source_sample_rate,
|
||||
source_channels=self.tts_client.tts_source_channels
|
||||
)
|
||||
except Exception as e:
|
||||
# 参考信号处理失败不应影响播放
|
||||
self.tts_client._log("warning", f"参考信号处理失败: {e}")
|
||||
|
||||
if self.on_chunk:
|
||||
self.on_chunk(data)
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self._cleaned_up or not self._proc:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
|
||||
# 关闭stdin,让ffmpeg处理完剩余数据
|
||||
if self._proc.stdin and not self._proc.stdin.closed:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 等待进程自然结束(根据文本长度估算,最少10秒,最多30秒)
|
||||
# 假设平均语速:3-4字/秒,加上缓冲时间
|
||||
if self._proc.poll() is None:
|
||||
try:
|
||||
# 增加等待时间,确保ffmpeg播放完成
|
||||
# 对于长文本,可能需要更长时间
|
||||
self._proc.wait(timeout=30.0)
|
||||
except:
|
||||
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
|
||||
if self._proc.poll() is None:
|
||||
self.tts_client._log("warning", "ffmpeg播放超时,强制终止")
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=1.0)
|
||||
except:
|
||||
try:
|
||||
self._proc.kill()
|
||||
self._proc.wait(timeout=0.1)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 清空PID记录
|
||||
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""感知层"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
"""
|
||||
声纹识别模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
import tempfile
|
||||
import os
|
||||
import wave
|
||||
import time
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpeakerState(Enum):
|
||||
"""说话人识别状态"""
|
||||
UNKNOWN = "unknown"
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SpeakerVerificationClient:
|
||||
"""声纹识别客户端 - 非实时、低频处理"""
|
||||
|
||||
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
|
||||
self.model_path = model_path
|
||||
self.threshold = threshold
|
||||
self.speaker_db_path = speaker_db_path
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 优化CPU性能:限制Torch使用的线程数,防止多线程竞争导致性能骤降
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
from funasr import AutoModel
|
||||
model_path = os.path.expanduser(self.model_path)
|
||||
# 禁用自动更新检查,防止每次初始化都联网检查
|
||||
self.model = AutoModel(model=model_path, device="cpu", disable_update=True)
|
||||
if self.logger:
|
||||
self.logger.info(f"声纹模型已加载: {model_path}, 阈值: {self.threshold}")
|
||||
|
||||
if self.speaker_db_path:
|
||||
self.load_speakers()
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
log_methods = {
|
||||
"debug": self.logger.debug,
|
||||
"info": self.logger.info,
|
||||
"warning": self.logger.warning,
|
||||
"error": self.logger.error,
|
||||
"fatal": self.logger.fatal
|
||||
}
|
||||
log_method = log_methods.get(level.lower(), self.logger.info)
|
||||
log_method(msg)
|
||||
except ValueError as e:
|
||||
if "severity cannot be changed" in str(e):
|
||||
try:
|
||||
self.logger.info(f"[声纹-{level.upper()}] {msg}")
|
||||
except:
|
||||
print(f"[声纹-{level.upper()}] {msg}")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
print(f"[声纹] {msg}")
|
||||
|
||||
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""将numpy音频数组写入临时wav文件"""
|
||||
audio_int16 = audio_data.astype(np.int16)
|
||||
|
||||
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
|
||||
os.close(fd)
|
||||
|
||||
with wave.open(temp_path, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
return temp_path
|
||||
|
||||
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""
|
||||
提取说话人embedding(低频调用,一句话只调用一次)
|
||||
"""
|
||||
# 降采样到 16000Hz (如果需要)
|
||||
# Cam++ 等模型通常只支持 16k,如果传入 48k 会导致内部重采样极慢或计算量剧增
|
||||
target_sr = 16000
|
||||
if sample_rate > target_sr:
|
||||
if sample_rate % target_sr == 0:
|
||||
step = sample_rate // target_sr
|
||||
audio_data = audio_data[::step]
|
||||
sample_rate = target_sr
|
||||
else:
|
||||
# 简单的非整数倍降采样可能导致问题,但对于语音验证通常 48k->16k 是整数倍
|
||||
# 如果不是,此处暂不处理,依赖 funasr 内部处理,或者简单的步长取整
|
||||
step = int(sample_rate / target_sr)
|
||||
audio_data = audio_data[::step]
|
||||
sample_rate = target_sr
|
||||
|
||||
if len(audio_data) < int(sample_rate * 0.5):
|
||||
return None, False
|
||||
|
||||
temp_wav_path = None
|
||||
try:
|
||||
# 限制Torch在推理时使用单线程,避免在多任务环境下(尤其是一边录音一边识别)
|
||||
# 出现的极端CPU竞争和上下文切换开销
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
# 临时设置,虽然全局已经设置了,但在调用前再次确保
|
||||
# 注意:set_num_threads 是全局的,这里再次确认
|
||||
if torch.get_num_threads() != 1:
|
||||
torch.set_num_threads(1)
|
||||
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, False
|
||||
|
||||
return embedding, True
|
||||
except Exception as e:
|
||||
self._log("error", f"提取embedding失败: {e}")
|
||||
return None, False
|
||||
finally:
|
||||
if temp_wav_path and os.path.exists(temp_wav_path):
|
||||
try:
|
||||
os.unlink(temp_wav_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
|
||||
env: str = "near") -> bool:
|
||||
"""
|
||||
注册说话人
|
||||
"""
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return False
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
self._log("error", f"注册失败:embedding范数为0")
|
||||
return False
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
with self._lock:
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_normalized,
|
||||
"env": env, # 添加 env 字段
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}, 维度: {embedding_dim}")
|
||||
save_result = self.save_speakers()
|
||||
if not save_result:
|
||||
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
|
||||
return True
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray):
|
||||
"""
|
||||
匹配说话人(一句话只调用一次)
|
||||
"""
|
||||
if not self.speaker_db:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
best_match = None
|
||||
best_score = -float('inf')
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
ref_embedding = speaker_data["embedding"]
|
||||
score = np.dot(embedding_normalized, ref_embedding)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = speaker_id
|
||||
|
||||
state = SpeakerState.VERIFIED if best_score >= self.threshold else SpeakerState.REJECTED
|
||||
return (best_match, state, best_score, self.threshold)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
pass
|
||||
|
||||
def get_speaker_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self.speaker_db)
|
||||
|
||||
def remove_speaker(self, speaker_id: str) -> bool:
|
||||
with self._lock:
|
||||
if speaker_id not in self.speaker_db:
|
||||
return False
|
||||
del self.speaker_db[speaker_id]
|
||||
self.save_speakers()
|
||||
return True
|
||||
|
||||
def load_speakers(self) -> bool:
|
||||
"""
|
||||
从文件加载已注册的声纹
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
return False
|
||||
|
||||
if not os.path.exists(self.speaker_db_path):
|
||||
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in data.items():
|
||||
embedding_list = speaker_data["embedding"]
|
||||
embedding_array = np.array(embedding_list, dtype=np.float32)
|
||||
|
||||
embedding_dim = len(embedding_array)
|
||||
if embedding_dim == 0:
|
||||
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
|
||||
continue
|
||||
embedding_norm = np.linalg.norm(embedding_array)
|
||||
if embedding_norm > 0:
|
||||
embedding_array = embedding_array / embedding_norm
|
||||
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_array,
|
||||
"env": speaker_data["env"],
|
||||
"registered_at": speaker_data["registered_at"]
|
||||
}
|
||||
|
||||
count = len(self.speaker_db)
|
||||
self._log("info", f"已加载 {count} 个已注册说话人")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
return False
|
||||
|
||||
def save_speakers(self) -> bool:
|
||||
"""
|
||||
保存已注册的声纹到文件
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
|
||||
return False
|
||||
|
||||
try:
|
||||
db_dir = os.path.dirname(self.speaker_db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
json_data = {}
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
json_data[speaker_id] = {
|
||||
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
|
||||
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
|
||||
"registered_at": speaker_data["registered_at"]
|
||||
}
|
||||
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
os.replace(temp_path, self.speaker_db_path)
|
||||
|
||||
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
self._log("error", f"保存路径: {self.speaker_db_path}")
|
||||
self._log("error", f"错误详情: {traceback.format_exc()}")
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
15
robot_speaker/services/__init__.py
Normal file
15
robot_speaker/services/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Service节点模块
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
566
robot_speaker/services/asr_audio_node.py
Normal file
566
robot_speaker/services/asr_audio_node.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from robot_speaker.srv import ASRRecognize, AudioData, VADEvent
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import pyaudio
|
||||
import yaml
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
import base64
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
class AudioRecorder:
|
||||
def __init__(self, device_index: int, sample_rate: int, channels: int,
|
||||
chunk: int, audio_queue: queue.Queue, stop_event, logger=None):
|
||||
self.device_index = device_index
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.chunk = chunk
|
||||
self.audio_queue = audio_queue
|
||||
self.stop_event = stop_event
|
||||
self.logger = logger
|
||||
self.audio = pyaudio.PyAudio()
|
||||
|
||||
original_index = self.device_index
|
||||
try:
|
||||
for i in range(self.audio.get_device_count()):
|
||||
device_info = self.audio.get_device_info_by_index(i)
|
||||
if 'iFLYTEK' in device_info['name'] and device_info['maxInputChannels'] > 0:
|
||||
self.device_index = i
|
||||
if self.logger:
|
||||
self.logger.info(f"已自动定位到麦克风设备: {device_info['name']} (Index: {i})")
|
||||
break
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"设备自动检测过程出错: {e}")
|
||||
|
||||
if self.device_index == original_index and original_index == -1:
|
||||
self.device_index = 0
|
||||
if self.logger:
|
||||
self.logger.info("未找到 iFLYTEK 设备,使用系统默认输入设备")
|
||||
self.format = pyaudio.paInt16
|
||||
|
||||
def record(self):
|
||||
if self.logger:
|
||||
self.logger.info(f"录音线程启动,设备索引: {self.device_index}")
|
||||
stream = None
|
||||
try:
|
||||
stream = self.audio.open(
|
||||
format=self.format,
|
||||
channels=self.channels,
|
||||
rate=self.sample_rate,
|
||||
input=True,
|
||||
input_device_index=self.device_index if self.device_index >= 0 else None,
|
||||
frames_per_buffer=self.chunk
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.info("音频输入设备已打开")
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"无法打开音频输入设备: {e}")
|
||||
return
|
||||
try:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
data = stream.read(self.chunk, exception_on_overflow=False)
|
||||
|
||||
if self.audio_queue.full():
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.put_nowait(data)
|
||||
except OSError as e:
|
||||
if self.logger:
|
||||
self.logger.debug(f"录音设备错误: {e}")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
if self.logger:
|
||||
self.logger.info("录音线程收到中断信号")
|
||||
finally:
|
||||
if stream is not None:
|
||||
try:
|
||||
if stream.is_active():
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.logger:
|
||||
self.logger.info("录音线程已退出")
|
||||
|
||||
|
||||
class DashScopeASR:
|
||||
def __init__(self, api_key: str, sample_rate: int, model: str, url: str, logger=None):
|
||||
dashscope.api_key = api_key
|
||||
self.sample_rate = sample_rate
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.logger = logger
|
||||
|
||||
self.conversation = None
|
||||
self.running = False
|
||||
self.on_sentence_end = None
|
||||
self.on_speech_started = None
|
||||
self.on_speech_stopped = None
|
||||
|
||||
self._stop_lock = threading.Lock()
|
||||
self._final_result_event = threading.Event()
|
||||
self._pending_commit = False
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
if not self.logger:
|
||||
return
|
||||
try:
|
||||
if level == "debug":
|
||||
self.logger.debug(msg)
|
||||
elif level == "warning":
|
||||
self.logger.warn(msg)
|
||||
elif level == "error":
|
||||
self.logger.error(msg)
|
||||
elif level == "info":
|
||||
self.logger.info(msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
callback = _ASRCallback(self)
|
||||
self.conversation = OmniRealtimeConversation(
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
callback=callback
|
||||
)
|
||||
callback.conversation = self.conversation
|
||||
|
||||
self.conversation.connect()
|
||||
|
||||
transcription_params = TranscriptionParams(
|
||||
language='zh',
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
)
|
||||
|
||||
self.conversation.update_session(
|
||||
output_modalities=[MultiModality.TEXT],
|
||||
enable_input_audio_transcription=True,
|
||||
transcription_params=transcription_params,
|
||||
enable_turn_detection=True,
|
||||
turn_detection_type='server_vad',
|
||||
turn_detection_threshold=0.2,
|
||||
turn_detection_silence_duration_ms=800
|
||||
)
|
||||
|
||||
self.running = True
|
||||
self._log("info", "ASR已启动")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.running = False
|
||||
self._log("error", f"ASR启动失败: {e}")
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.conversation = None
|
||||
return False
|
||||
|
||||
def send_audio(self, audio_chunk: bytes):
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
try:
|
||||
audio_b64 = base64.b64encode(audio_chunk).decode('ascii')
|
||||
self.conversation.append_audio(audio_b64)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def stop_current_recognition(self):
|
||||
if not self._stop_lock.acquire(blocking=False):
|
||||
self._log("warning", "stop_current_recognition 正在执行,跳过本次调用")
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self.running or not self.conversation:
|
||||
return False
|
||||
|
||||
self._final_result_event.clear()
|
||||
self._pending_commit = True
|
||||
|
||||
self.conversation.commit()
|
||||
self._final_result_event.wait(timeout=3.0)
|
||||
self.running = False
|
||||
|
||||
old_conversation = self.conversation
|
||||
self.conversation = None
|
||||
try:
|
||||
old_conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
if not self.start():
|
||||
self._log("error", "ASR重启失败")
|
||||
return False
|
||||
|
||||
return True
|
||||
finally:
|
||||
self._pending_commit = False
|
||||
self._stop_lock.release()
|
||||
|
||||
def stop(self):
|
||||
with self._stop_lock:
|
||||
self.running = False
|
||||
self._final_result_event.set()
|
||||
if self.conversation:
|
||||
try:
|
||||
self.conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.conversation = None
|
||||
|
||||
|
||||
class _ASRCallback(OmniRealtimeCallback):
|
||||
def __init__(self, asr_client: DashScopeASR):
|
||||
self.asr_client = asr_client
|
||||
self.conversation = None
|
||||
|
||||
|
||||
def on_event(self, response):
|
||||
try:
|
||||
event_type = response['type']
|
||||
if event_type == 'conversation.item.input_audio_transcription.completed':
|
||||
transcript = response['transcript']
|
||||
if transcript.strip() and self.asr_client.on_sentence_end:
|
||||
self.asr_client.on_sentence_end(transcript.strip())
|
||||
if self.asr_client._pending_commit:
|
||||
self.asr_client._final_result_event.set()
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_started':
|
||||
if self.asr_client.logger:
|
||||
self.asr_client.logger.info("[ASR] 检测到语音开始")
|
||||
if self.asr_client.on_speech_started:
|
||||
self.asr_client.on_speech_started()
|
||||
|
||||
elif event_type == 'input_audio_buffer.speech_stopped':
|
||||
if self.asr_client.logger:
|
||||
self.asr_client.logger.info("[ASR] 检测到语音结束")
|
||||
if self.asr_client.on_speech_stopped:
|
||||
self.asr_client.on_speech_stopped()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ASRAudioNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('asr_audio_node')
|
||||
self._load_config()
|
||||
|
||||
self.audio_queue = queue.Queue(maxsize=100)
|
||||
self.stop_event = threading.Event()
|
||||
self._shutdown_in_progress = False
|
||||
|
||||
self._init_components()
|
||||
|
||||
self.recognize_service = self.create_service(
|
||||
ASRRecognize, '/asr/recognize', self._recognize_callback
|
||||
)
|
||||
self.audio_data_service = self.create_service(
|
||||
AudioData, '/asr/audio_data', self._audio_data_callback
|
||||
)
|
||||
self.vad_event_service = self.create_service(
|
||||
VADEvent, '/vad/event', self._vad_event_callback
|
||||
)
|
||||
|
||||
self._last_result = None
|
||||
self._result_event = threading.Event()
|
||||
self._last_result_time = None
|
||||
self.vad_event_queue = queue.Queue()
|
||||
self.audio_buffer = collections.deque(maxlen=240000)
|
||||
self.audio_recording = False
|
||||
self.audio_lock = threading.Lock()
|
||||
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.audio_recorder.record, name="RecordingThread", daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self._asr_worker, name="ASRThread", daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
self.get_logger().info("ASR Audio节点已启动")
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
mic = config['audio']['microphone']
|
||||
self.input_device_index = mic['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
|
||||
dashscope = config['dashscope']
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
|
||||
def _init_components(self):
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
audio_queue=self.audio_queue,
|
||||
stop_event=self.stop_event,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.asr_client.on_sentence_end = self._on_asr_result
|
||||
self.asr_client.on_speech_started = lambda: self._put_vad_event("speech_started")
|
||||
self.asr_client.on_speech_stopped = lambda: self._put_vad_event("speech_stopped")
|
||||
self.asr_client.start()
|
||||
|
||||
def _on_asr_result(self, text: str):
|
||||
if not text or not text.strip():
|
||||
return
|
||||
|
||||
self._last_result = text.strip()
|
||||
self._last_result_time = time.time()
|
||||
self._result_event.set()
|
||||
try:
|
||||
self.get_logger().info(f"[ASR] 识别结果: {self._last_result}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _put_vad_event(self, event_type):
|
||||
try:
|
||||
self.vad_event_queue.put(event_type, timeout=0.1)
|
||||
except queue.Full:
|
||||
try:
|
||||
self.get_logger().warn(f"[ASR] VAD事件队列已满,丢弃{event_type}事件")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _audio_data_callback(self, request, response):
|
||||
response.sample_rate = self.sample_rate
|
||||
response.channels = self.channels
|
||||
|
||||
if request.command == "start":
|
||||
with self.audio_lock:
|
||||
self.audio_buffer.clear()
|
||||
self.audio_recording = True
|
||||
response.success = True
|
||||
response.message = "开始录音"
|
||||
response.samples = 0
|
||||
return response
|
||||
|
||||
if request.command == "stop":
|
||||
with self.audio_lock:
|
||||
self.audio_recording = False
|
||||
audio_list = list(self.audio_buffer)
|
||||
self.audio_buffer.clear()
|
||||
if len(audio_list) > 0:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
response.success = True
|
||||
response.audio_data = audio_array.tobytes()
|
||||
response.samples = len(audio_list)
|
||||
response.message = f"录音完成{len(audio_list)}样本"
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "缓冲区为空"
|
||||
response.samples = 0
|
||||
return response
|
||||
|
||||
if request.command == "get":
|
||||
with self.audio_lock:
|
||||
audio_list = list(self.audio_buffer)
|
||||
if len(audio_list) > 0:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
response.success = True
|
||||
response.audio_data = audio_array.tobytes()
|
||||
response.samples = len(audio_list)
|
||||
response.message = f"获取到{len(audio_list)}样本"
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "缓冲区为空"
|
||||
response.samples = 0
|
||||
return response
|
||||
|
||||
def _vad_event_callback(self, request, response):
|
||||
timeout = request.timeout_ms / 1000.0 if request.timeout_ms > 0 else None
|
||||
try:
|
||||
event = self.vad_event_queue.get(timeout=timeout)
|
||||
response.success = True
|
||||
response.event = event
|
||||
response.message = "收到VAD事件"
|
||||
except queue.Empty:
|
||||
response.success = False
|
||||
response.event = "none"
|
||||
response.message = "等待超时"
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
self.get_logger().info("[VAD] 收到中断信号,正在关闭")
|
||||
except Exception:
|
||||
pass
|
||||
response.success = False
|
||||
response.event = "none"
|
||||
response.message = "节点正在关闭"
|
||||
self.stop_event.set()
|
||||
return response
|
||||
|
||||
def _clear_result(self):
|
||||
self._last_result = None
|
||||
self._last_result_time = None
|
||||
self._result_event.clear()
|
||||
|
||||
def _return_result(self, response, text, message):
|
||||
response.success = True
|
||||
response.text = text
|
||||
response.message = message
|
||||
self._clear_result()
|
||||
return response
|
||||
|
||||
def _recognize_callback(self, request, response):
|
||||
if request.command == "stop":
|
||||
if self.asr_client.running:
|
||||
self.asr_client.stop_current_recognition()
|
||||
response.success = True
|
||||
response.text = ""
|
||||
response.message = "识别已停止"
|
||||
return response
|
||||
|
||||
if request.command == "reset":
|
||||
self.asr_client.stop_current_recognition()
|
||||
time.sleep(0.1)
|
||||
self.asr_client.start()
|
||||
response.success = True
|
||||
response.text = ""
|
||||
response.message = "识别器已重置"
|
||||
return response
|
||||
|
||||
if self.asr_client.running:
|
||||
current_time = time.time()
|
||||
if (self._last_result and self._last_result_time and
|
||||
(current_time - self._last_result_time) < 5.0) or (self._result_event.is_set() and self._last_result):
|
||||
return self._return_result(response, self._last_result, "返回最近识别结果")
|
||||
if self._result_event.wait(timeout=2.0) and self._last_result:
|
||||
return self._return_result(response, self._last_result, "识别成功(等待中)")
|
||||
self.asr_client.stop_current_recognition()
|
||||
time.sleep(0.2)
|
||||
|
||||
self._clear_result()
|
||||
|
||||
if not self.asr_client.running and not self.asr_client.start():
|
||||
response.success = False
|
||||
response.text = ""
|
||||
response.message = "ASR启动失败"
|
||||
return response
|
||||
if self._result_event.wait(timeout=5.0) and self._last_result:
|
||||
response.success = True
|
||||
response.text = self._last_result
|
||||
response.message = "识别成功"
|
||||
else:
|
||||
response.success = False
|
||||
response.text = ""
|
||||
response.message = "识别超时" if not self._result_event.is_set() else "识别结果为空"
|
||||
self._clear_result()
|
||||
return response
|
||||
|
||||
def _asr_worker(self):
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
self.get_logger().info("[ASR Worker] 收到中断信号")
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if self.audio_recording:
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.audio_lock:
|
||||
self.audio_buffer.extend(audio_array)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
|
||||
def destroy_node(self):
|
||||
if self._shutdown_in_progress:
|
||||
return
|
||||
self._shutdown_in_progress = True
|
||||
try:
|
||||
self.get_logger().info("ASR Audio节点正在关闭...")
|
||||
except Exception:
|
||||
pass
|
||||
self.stop_event.set()
|
||||
if hasattr(self, 'recording_thread') and self.recording_thread.is_alive():
|
||||
self.recording_thread.join(timeout=1.0)
|
||||
if hasattr(self, 'asr_thread') and self.asr_thread.is_alive():
|
||||
self.asr_thread.join(timeout=1.0)
|
||||
try:
|
||||
if hasattr(self, 'audio_recorder'):
|
||||
self.audio_recorder.audio.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(self, 'asr_client'):
|
||||
self.asr_client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
super().destroy_node()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = ASRAudioNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
node.get_logger().info("收到中断信号,正在关闭节点")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
node.destroy_node()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
rclpy.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
341
robot_speaker/services/tts_audio_node.py
Normal file
341
robot_speaker/services/tts_audio_node.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||
from robot_speaker.srv import TTSSynthesize
|
||||
import threading
|
||||
import yaml
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
class DashScopeTTSClient:
|
||||
def __init__(self, api_key: str,
|
||||
model: str,
|
||||
voice: str,
|
||||
card_index: int,
|
||||
device_index: int,
|
||||
output_sample_rate: int,
|
||||
output_channels: int,
|
||||
output_volume: float,
|
||||
tts_source_sample_rate: int,
|
||||
tts_source_channels: int,
|
||||
tts_ffmpeg_thread_queue_size: int,
|
||||
force_stop_delay: float,
|
||||
cleanup_timeout: float,
|
||||
terminate_timeout: float,
|
||||
logger):
|
||||
dashscope.api_key = api_key
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.card_index = card_index
|
||||
self.device_index = device_index
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.output_channels = output_channels
|
||||
self.output_volume = output_volume
|
||||
self.tts_source_sample_rate = tts_source_sample_rate
|
||||
self.tts_source_channels = tts_source_channels
|
||||
self.tts_ffmpeg_thread_queue_size = tts_ffmpeg_thread_queue_size
|
||||
self.force_stop_delay = force_stop_delay
|
||||
self.cleanup_timeout = cleanup_timeout
|
||||
self.terminate_timeout = terminate_timeout
|
||||
self.logger = logger
|
||||
self.current_ffmpeg_pid = None
|
||||
self._current_callback = None
|
||||
|
||||
self.alsa_device = f"plughw:{card_index},{device_index}" if (
|
||||
card_index >= 0 and device_index >= 0
|
||||
) else "default"
|
||||
|
||||
|
||||
def force_stop(self):
|
||||
if self._current_callback:
|
||||
self._current_callback._interrupted = True
|
||||
if not self.current_ffmpeg_pid:
|
||||
if self.logger:
|
||||
self.logger.warn("[TTS] force_stop: current_ffmpeg_pid is None")
|
||||
return
|
||||
pid = self.current_ffmpeg_pid
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 正在kill进程 {pid}")
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
time.sleep(self.force_stop_delay)
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 已发送SIGKILL到进程 {pid}")
|
||||
except ProcessLookupError:
|
||||
if self.logger:
|
||||
self.logger.info(f"[TTS] force_stop: 进程 {pid} 已退出")
|
||||
except (ProcessLookupError, OSError) as e:
|
||||
if self.logger:
|
||||
self.logger.warn(f"[TTS] force_stop: kill进程失败 {pid}: {e}")
|
||||
finally:
|
||||
self.current_ffmpeg_pid = None
|
||||
self._current_callback = None
|
||||
|
||||
def synthesize(self, text: str, voice: str = None,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
callback = _TTSCallback(self, interrupt_check, on_chunk)
|
||||
self._current_callback = callback
|
||||
voice_to_use = voice if voice and voice.strip() else self.voice
|
||||
|
||||
if not voice_to_use or not voice_to_use.strip():
|
||||
if self.logger:
|
||||
self.logger.error(f"Voice参数无效: '{voice_to_use}'")
|
||||
self._current_callback = None
|
||||
return False
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=self.model,
|
||||
voice=voice_to_use,
|
||||
format=AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
try:
|
||||
synthesizer.streaming_call(text)
|
||||
synthesizer.streaming_complete()
|
||||
finally:
|
||||
callback.cleanup()
|
||||
self._current_callback = None
|
||||
|
||||
return not callback._interrupted
|
||||
|
||||
|
||||
class _TTSCallback(ResultCallback):
|
||||
def __init__(self, tts_client: DashScopeTTSClient,
|
||||
interrupt_check=None,
|
||||
on_chunk=None):
|
||||
self.tts_client = tts_client
|
||||
self.interrupt_check = interrupt_check
|
||||
self.on_chunk = on_chunk
|
||||
self._proc = None
|
||||
self._interrupted = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def on_open(self):
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg',
|
||||
'-f', 's16le',
|
||||
'-ar', str(self.tts_client.tts_source_sample_rate),
|
||||
'-ac', str(self.tts_client.tts_source_channels),
|
||||
'-i', 'pipe:0',
|
||||
'-f', 'alsa',
|
||||
'-ar', str(self.tts_client.output_sample_rate),
|
||||
'-ac', str(self.tts_client.output_channels),
|
||||
'-acodec', 'pcm_s16le',
|
||||
'-fflags', 'nobuffer',
|
||||
'-flags', 'low_delay',
|
||||
'-avioflags', 'direct',
|
||||
self.tts_client.alsa_device
|
||||
]
|
||||
|
||||
insert_pos = ffmpeg_cmd.index('-i')
|
||||
ffmpeg_cmd.insert(insert_pos, str(self.tts_client.tts_ffmpeg_thread_queue_size))
|
||||
ffmpeg_cmd.insert(insert_pos, '-thread_queue_size')
|
||||
|
||||
if self.tts_client.output_volume != 1.0:
|
||||
acodec_idx = ffmpeg_cmd.index('-acodec')
|
||||
ffmpeg_cmd.insert(acodec_idx, f'volume={self.tts_client.output_volume}')
|
||||
ffmpeg_cmd.insert(acodec_idx, '-af')
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
ffmpeg_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
self.tts_client.current_ffmpeg_pid = self._proc.pid
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
if self._interrupted:
|
||||
return
|
||||
|
||||
if self.interrupt_check and self.interrupt_check():
|
||||
self._interrupted = True
|
||||
if self._proc:
|
||||
self._proc.terminate()
|
||||
return
|
||||
|
||||
if self._proc and self._proc.stdin and not self._interrupted:
|
||||
try:
|
||||
self._proc.stdin.write(data)
|
||||
self._proc.stdin.flush()
|
||||
except BrokenPipeError:
|
||||
self._interrupted = True
|
||||
except OSError:
|
||||
self._interrupted = True
|
||||
|
||||
if self.on_chunk and not self._interrupted:
|
||||
self.on_chunk(data)
|
||||
|
||||
def cleanup(self):
|
||||
if self._cleaned_up or not self._proc:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
|
||||
if self._proc.stdin and not self._proc.stdin.closed:
|
||||
self._proc.stdin.close()
|
||||
|
||||
if self._proc.poll() is None:
|
||||
self._proc.wait(timeout=self.tts_client.cleanup_timeout)
|
||||
if self._proc.poll() is None:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=self.tts_client.terminate_timeout)
|
||||
if self._proc.poll() is None:
|
||||
self._proc.kill()
|
||||
|
||||
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
|
||||
|
||||
class TTSAudioNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('tts_audio_node')
|
||||
self._load_config()
|
||||
self._init_tts_client()
|
||||
|
||||
self.callback_group = ReentrantCallbackGroup()
|
||||
self.synthesize_service = self.create_service(
|
||||
TTSSynthesize, '/tts/synthesize', self._synthesize_callback,
|
||||
callback_group=self.callback_group
|
||||
)
|
||||
|
||||
self.interrupt_event = threading.Event()
|
||||
self.playing_lock = threading.Lock()
|
||||
self.is_playing = False
|
||||
|
||||
self.get_logger().info("TTS Audio节点已启动")
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
audio = config['audio']
|
||||
soundcard = audio['soundcard']
|
||||
tts_audio = audio['tts']
|
||||
dashscope = config['dashscope']
|
||||
|
||||
self.output_card_index = soundcard['card_index']
|
||||
self.output_device_index = soundcard['device_index']
|
||||
self.output_sample_rate = soundcard['sample_rate']
|
||||
self.output_channels = soundcard['channels']
|
||||
self.output_volume = soundcard['volume']
|
||||
|
||||
self.tts_source_sample_rate = tts_audio['source_sample_rate']
|
||||
self.tts_source_channels = tts_audio['source_channels']
|
||||
self.tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
|
||||
self.force_stop_delay = tts_audio['force_stop_delay']
|
||||
self.cleanup_timeout = tts_audio['cleanup_timeout']
|
||||
self.terminate_timeout = tts_audio['terminate_timeout']
|
||||
self.interrupt_wait = tts_audio['interrupt_wait']
|
||||
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.tts_model = dashscope['tts']['model']
|
||||
self.tts_voice = dashscope['tts']['voice']
|
||||
|
||||
def _init_tts_client(self):
|
||||
self.tts_client = DashScopeTTSClient(
|
||||
api_key=self.dashscope_api_key,
|
||||
model=self.tts_model,
|
||||
voice=self.tts_voice,
|
||||
card_index=self.output_card_index,
|
||||
device_index=self.output_device_index,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_channels=self.output_channels,
|
||||
output_volume=self.output_volume,
|
||||
tts_source_sample_rate=self.tts_source_sample_rate,
|
||||
tts_source_channels=self.tts_source_channels,
|
||||
tts_ffmpeg_thread_queue_size=self.tts_ffmpeg_thread_queue_size,
|
||||
force_stop_delay=self.force_stop_delay,
|
||||
cleanup_timeout=self.cleanup_timeout,
|
||||
terminate_timeout=self.terminate_timeout,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
def _synthesize_callback(self, request, response):
|
||||
command = request.command if request.command else "synthesize"
|
||||
|
||||
if command == "interrupt":
|
||||
with self.playing_lock:
|
||||
was_playing = self.is_playing
|
||||
has_pid = self.tts_client.current_ffmpeg_pid is not None
|
||||
if was_playing or has_pid:
|
||||
self.interrupt_event.set()
|
||||
self.tts_client.force_stop()
|
||||
self.is_playing = False
|
||||
response.success = True
|
||||
response.message = "已中断播放"
|
||||
response.status = "interrupted"
|
||||
else:
|
||||
response.success = False
|
||||
response.message = "没有正在播放的内容"
|
||||
response.status = "none"
|
||||
return response
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
response.success = False
|
||||
response.message = "文本为空"
|
||||
response.status = "error"
|
||||
return response
|
||||
|
||||
with self.playing_lock:
|
||||
if self.is_playing:
|
||||
self.tts_client.force_stop()
|
||||
time.sleep(self.interrupt_wait)
|
||||
self.is_playing = True
|
||||
|
||||
self.interrupt_event.clear()
|
||||
|
||||
def synthesize_worker():
|
||||
try:
|
||||
success = self.tts_client.synthesize(
|
||||
request.text.strip(),
|
||||
voice=request.voice if request.voice else None,
|
||||
interrupt_check=lambda: self.interrupt_event.is_set()
|
||||
)
|
||||
with self.playing_lock:
|
||||
self.is_playing = False
|
||||
if self.get_logger():
|
||||
if success:
|
||||
self.get_logger().info("[TTS] 合成并播放成功")
|
||||
else:
|
||||
self.get_logger().info("[TTS] 播放被中断")
|
||||
except Exception as e:
|
||||
with self.playing_lock:
|
||||
self.is_playing = False
|
||||
if self.get_logger():
|
||||
self.get_logger().error(f"[TTS] 合成失败: {e}")
|
||||
|
||||
thread = threading.Thread(target=synthesize_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
response.success = True
|
||||
response.message = "合成任务已启动"
|
||||
response.status = "playing"
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = TTSAudioNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""理解层"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
4
setup.py
4
setup.py
@@ -15,6 +15,7 @@ setup(
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
|
||||
(os.path.join('share', package_name, 'srv'), glob('srv/*.srv')),
|
||||
],
|
||||
install_requires=[
|
||||
'setuptools',
|
||||
@@ -25,12 +26,13 @@ setup(
|
||||
maintainer_email='mzebra@foxmail.com',
|
||||
description='语音识别和合成ROS2包',
|
||||
license='Apache-2.0',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
|
||||
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
|
||||
'skill_bridge_node = robot_speaker.bridge.skill_bridge_node:main',
|
||||
'asr_audio_node = robot_speaker.services.asr_audio_node:main',
|
||||
'tts_audio_node = robot_speaker.services.tts_audio_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
18
srv/ASRRecognize.srv
Normal file
18
srv/ASRRecognize.srv
Normal file
@@ -0,0 +1,18 @@
|
||||
# 请求:启动识别
|
||||
string command # "start" (默认), "stop", "reset"
|
||||
---
|
||||
# 响应:识别结果
|
||||
bool success
|
||||
string text # 识别文本(空字符串表示未识别到)
|
||||
string message # 状态消息
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
20
srv/AudioData.srv
Normal file
20
srv/AudioData.srv
Normal file
@@ -0,0 +1,20 @@
|
||||
# 请求:获取音频数据
|
||||
string command # "start" (开始录音), "stop" (停止并返回), "get" (获取当前缓冲区)
|
||||
int32 duration_ms # 录音时长(毫秒),仅用于start命令
|
||||
---
|
||||
# 响应:音频数据
|
||||
bool success
|
||||
uint8[] audio_data # PCM音频数据(int16格式)
|
||||
int32 sample_rate
|
||||
int32 channels
|
||||
int32 samples # 样本数
|
||||
string message
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
14
srv/TTSSynthesize.srv
Normal file
14
srv/TTSSynthesize.srv
Normal file
@@ -0,0 +1,14 @@
|
||||
# 请求:合成文本或中断命令
|
||||
string command # "synthesize" (默认), "interrupt"
|
||||
string text
|
||||
string voice # 可选,默认使用配置
|
||||
---
|
||||
# 响应:合成状态
|
||||
bool success
|
||||
string message
|
||||
string status # "playing", "completed", "interrupted"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
17
srv/VADEvent.srv
Normal file
17
srv/VADEvent.srv
Normal file
@@ -0,0 +1,17 @@
|
||||
# 请求:等待VAD事件
|
||||
string command # "wait" (等待下一个事件)
|
||||
int32 timeout_ms # 超时时间(毫秒),0表示无限等待
|
||||
---
|
||||
# 响应:VAD事件
|
||||
bool success
|
||||
string event # "speech_started", "speech_stopped", "none"
|
||||
string message
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user