增加声纹验证
This commit is contained in:
@@ -38,3 +38,7 @@ system:
|
||||
use_wake_word: true # 是否启用唤醒词检测
|
||||
wake_word: "er gou" # 唤醒词(拼音)
|
||||
session_timeout: 10.0 # 会话超时时间(秒)
|
||||
sv_enabled: true # 是否启用声纹识别
|
||||
sv_model_path: "~/ros_learn/speech_campplus_sv_zh-cn_16k-common" # 声纹模型路径
|
||||
sv_threshold: 0.35 # 声纹识别阈值(0.0-1.0,值越小越宽松,值越大越严格)
|
||||
sv_speaker_db_path: "config/speakers.json" # 声纹数据库保存路径(JSON格式,相对于ROS2包share目录)
|
||||
|
||||
@@ -183,6 +183,31 @@ def generate_launch_description():
|
||||
description='会话超时时间(秒)'
|
||||
),
|
||||
|
||||
# 声纹识别参数
|
||||
DeclareLaunchArgument(
|
||||
'sv_enabled',
|
||||
default_value=str(system_config.get('sv_enabled', True)).lower(),
|
||||
description='是否启用声纹识别'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'sv_model_path',
|
||||
default_value=os.path.expanduser(system_config.get('sv_model_path', '')),
|
||||
description='声纹模型路径'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'sv_threshold',
|
||||
default_value=str(system_config.get('sv_threshold', 0.45)),
|
||||
description='声纹识别阈值'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'sv_speaker_db_path',
|
||||
default_value=os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
system_config.get('sv_speaker_db_path', 'config/speakers.json')
|
||||
) if system_config.get('sv_speaker_db_path') else '',
|
||||
description='声纹数据库路径'
|
||||
),
|
||||
|
||||
# 相机参数
|
||||
DeclareLaunchArgument(
|
||||
'camera_serial_number',
|
||||
@@ -260,6 +285,12 @@ def generate_launch_description():
|
||||
'wake_word': LaunchConfiguration('wake_word'),
|
||||
'session_timeout': LaunchConfiguration('session_timeout'),
|
||||
|
||||
# 声纹识别参数
|
||||
'sv_enabled': LaunchConfiguration('sv_enabled'),
|
||||
'sv_model_path': LaunchConfiguration('sv_model_path'),
|
||||
'sv_threshold': LaunchConfiguration('sv_threshold'),
|
||||
'sv_speaker_db_path': LaunchConfiguration('sv_speaker_db_path'),
|
||||
|
||||
# 相机参数
|
||||
'camera_serial_number': LaunchConfiguration('camera_serial_number'),
|
||||
'camera_width': LaunchConfiguration('camera_width'),
|
||||
|
||||
@@ -37,6 +37,9 @@ class AudioRecorder:
|
||||
on_speech_start=None, # 检测到人声开始
|
||||
on_speech_end=None, # 检测到静音结束(说话结束)
|
||||
stop_flag=None,
|
||||
on_audio_chunk=None, # 音频chunk回调(用于声纹录音等,可选)
|
||||
should_put_to_queue=None, # 检查是否应该将音频放入队列(用于阻止ASR,可选)
|
||||
get_silence_threshold=None, # 获取动态静音阈值(毫秒,可选)
|
||||
logger=None):
|
||||
self.device_index = device_index
|
||||
self.sample_rate = sample_rate
|
||||
@@ -54,6 +57,9 @@ class AudioRecorder:
|
||||
self.on_speech_start = on_speech_start
|
||||
self.on_speech_end = on_speech_end
|
||||
self.stop_flag = stop_flag or (lambda: False)
|
||||
self.on_audio_chunk = on_audio_chunk # 音频chunk回调(用于声纹录音等)
|
||||
self.should_put_to_queue = should_put_to_queue or (lambda: True) # 默认允许放入队列
|
||||
self.get_silence_threshold = get_silence_threshold # 动态静音阈值回调
|
||||
self.logger = logger
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.format = pyaudio.paInt16
|
||||
@@ -97,10 +103,16 @@ class AudioRecorder:
|
||||
# exception_on_overflow=False, 宁可丢帧,也不阻塞
|
||||
data = stream.read(self.chunk, exception_on_overflow=False)
|
||||
|
||||
# 队列满时丢弃最旧的数据,ASR 跟不上时系统仍然听得见
|
||||
if self.audio_queue.full():
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.put_nowait(data)
|
||||
# 检查是否应该将音频放入队列(用于阻止ASR,例如无声纹文件时需要注册)
|
||||
if self.should_put_to_queue():
|
||||
# 队列满时丢弃最旧的数据,ASR 跟不上时系统仍然听得见
|
||||
if self.audio_queue.full():
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.put_nowait(data)
|
||||
|
||||
# 音频chunk回调(用于声纹录音等,仅在需要时调用)
|
||||
if self.on_audio_chunk:
|
||||
self.on_audio_chunk(data)
|
||||
|
||||
audio_buffer.append(data) # 只用于 VAD,不用于 ASR
|
||||
|
||||
@@ -133,8 +145,21 @@ class AudioRecorder:
|
||||
else:
|
||||
if was_speaking:
|
||||
silence_duration = now - last_active_time
|
||||
if silence_duration >= no_speech_threshold:
|
||||
# 动态获取静音阈值(如果提供回调函数)
|
||||
if self.get_silence_threshold:
|
||||
current_silence_ms = self.get_silence_threshold()
|
||||
current_no_speech_threshold = max(current_silence_ms / 1000.0, 0.1)
|
||||
else:
|
||||
current_no_speech_threshold = no_speech_threshold
|
||||
|
||||
# 添加调试日志
|
||||
if self.logger and silence_duration < current_no_speech_threshold:
|
||||
self.logger.debug(f"[VAD] 静音中: {silence_duration:.3f}秒 < {current_no_speech_threshold:.3f}秒阈值")
|
||||
|
||||
if silence_duration >= current_no_speech_threshold:
|
||||
if self.on_speech_end:
|
||||
if self.logger:
|
||||
self.logger.debug(f"[VAD] 触发speech_end: 静音持续时间 {silence_duration:.3f}秒 >= 阈值 {current_no_speech_threshold:.3f}秒")
|
||||
self.on_speech_end() # 通知系统用户停止说话
|
||||
|
||||
if self.on_heartbeat and now - last_heartbeat_time >= self.heartbeat_interval:
|
||||
|
||||
@@ -14,6 +14,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from pypinyin import pinyin, Style
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
import collections
|
||||
from .audio import VADDetector, AudioRecorder
|
||||
from .asr import DashScopeASR
|
||||
from .tts import DashScopeTTSClient, TTSRequest
|
||||
@@ -21,6 +23,15 @@ from .llm import DashScopeLLM
|
||||
from .history import ConversationHistory
|
||||
from .types import LLMMessage
|
||||
from .camera import CameraClient
|
||||
from .speaker_verification import SpeakerVerificationClient, SpeakerState
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""会话状态机"""
|
||||
IDLE = "idle" # 等待用户唤醒或声音
|
||||
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
|
||||
REGISTERING = "registering" # 声纹注册中
|
||||
AUTHORIZED = "authorized" # 已注册用户
|
||||
|
||||
|
||||
class RobotSpeakerNode(Node):
|
||||
@@ -46,9 +57,33 @@ class RobotSpeakerNode(Node):
|
||||
self.session_active = False
|
||||
self.session_start_time = 0.0
|
||||
self.session_lock = threading.Lock()
|
||||
|
||||
# 状态机状态
|
||||
self.conversation_state = ConversationState.IDLE # 当前会话状态
|
||||
self.state_lock = threading.Lock() # 保护状态机状态
|
||||
|
||||
# 声纹识别共享状态
|
||||
self.current_speaker_id = None # 当前说话人ID(共享状态,只读)
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN # 当前说话人状态
|
||||
self.sv_lock = threading.Lock() # 保护current_speaker_id和current_speaker_state
|
||||
self.sv_speech_end_event = threading.Event() # 通知声纹线程处理(speech_end触发)
|
||||
self.sv_result_ready_event = threading.Event() # 声纹结果ready事件(用于CHECK_VOICE状态同步)
|
||||
self.sv_audio_buffer = collections.deque(maxlen=64000) # 声纹验证录音缓冲区(只在需要时录音)
|
||||
self.sv_recording = False # 是否正在为声纹验证录音
|
||||
|
||||
# 声纹注册状态
|
||||
self.sv_registration_audio_buffer = collections.deque(maxlen=64000) # 注册时录音缓冲区
|
||||
self.sv_registration_processing = False # 是否正在处理注册(处理中时停止录音)
|
||||
self.sv_registration_start_time = None # 注册开始时间(用于超时检测)
|
||||
|
||||
# 初始化组件(VAD、录音器、ASR、LLM、TTS)
|
||||
self._init_components()
|
||||
|
||||
# 状态机初始状态
|
||||
if self.sv_enabled and self.sv_client:
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
self.get_logger().info("声纹数据库为空,将在检测到语音时提示注册")
|
||||
|
||||
# ROS订阅
|
||||
self.interrupt_sub = self.create_subscription(
|
||||
@@ -112,6 +147,12 @@ class RobotSpeakerNode(Node):
|
||||
# 闭嘴指令参数
|
||||
self.declare_parameter('shutup_keywords', 'bi zui,ting zhi,bu yao shuo le,bu yao shuo hua')
|
||||
|
||||
# 声纹识别参数
|
||||
self.declare_parameter('sv_model_path', '')
|
||||
self.declare_parameter('sv_threshold', 0.45) # 默认值,实际从voice.yaml读取
|
||||
self.declare_parameter('sv_speaker_db_path', '') # 声纹数据库路径
|
||||
self.declare_parameter('sv_enabled', True) # 默认启用声纹识别,可通过voice.yaml配置
|
||||
|
||||
def _load_parameters(self):
|
||||
"""加载ROS参数"""
|
||||
# 音频参数
|
||||
@@ -167,6 +208,13 @@ class RobotSpeakerNode(Node):
|
||||
shutup_keywords_str = self.get_parameter('shutup_keywords').get_parameter_value().string_value
|
||||
self.shutup_keywords = [k.strip() for k in shutup_keywords_str.split(',') if k.strip()]
|
||||
|
||||
# 声纹识别参数
|
||||
self.sv_model_path = self.get_parameter('sv_model_path').get_parameter_value().string_value
|
||||
self.sv_threshold = self.get_parameter('sv_threshold').get_parameter_value().double_value
|
||||
sv_db_path = self.get_parameter('sv_speaker_db_path').get_parameter_value().string_value
|
||||
self.sv_speaker_db_path = sv_db_path if sv_db_path and sv_db_path.strip() else None
|
||||
self.sv_enabled = self.get_parameter('sv_enabled').get_parameter_value().bool_value
|
||||
|
||||
def _init_components(self):
|
||||
"""初始化所有组件"""
|
||||
# VAD检测器
|
||||
@@ -192,6 +240,9 @@ class RobotSpeakerNode(Node):
|
||||
on_speech_start=self._on_speech_start,
|
||||
on_speech_end=self._on_speech_end,
|
||||
stop_flag=self.stop_event.is_set,
|
||||
on_audio_chunk=self._on_audio_chunk_for_sv if self.sv_enabled else None, # 声纹录音回调
|
||||
should_put_to_queue=self._should_put_audio_to_queue, # 检查是否应该将音频放入队列
|
||||
get_silence_threshold=self._get_silence_threshold, # 动态静音阈值回调
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
@@ -252,6 +303,26 @@ class RobotSpeakerNode(Node):
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"相机初始化失败: {e},相机功能将不可用")
|
||||
self.camera_client = None
|
||||
|
||||
# 声纹识别客户端
|
||||
if self.sv_enabled and self.sv_model_path:
|
||||
try:
|
||||
self.sv_client = SpeakerVerificationClient(
|
||||
model_path=self.sv_model_path,
|
||||
threshold=self.sv_threshold,
|
||||
speaker_db_path=self.sv_speaker_db_path,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
if not self.sv_client.is_available():
|
||||
self.get_logger().warning("声纹识别模型不可用,将禁用声纹功能")
|
||||
self.sv_client = None
|
||||
self.sv_enabled = False
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"声纹识别初始化失败: {e},声纹功能将不可用")
|
||||
self.sv_client = None
|
||||
self.sv_enabled = False
|
||||
else:
|
||||
self.sv_client = None
|
||||
|
||||
def _start_threads(self):
|
||||
"""启动4个线程"""
|
||||
@@ -286,6 +357,17 @@ class RobotSpeakerNode(Node):
|
||||
daemon=True
|
||||
)
|
||||
self.tts_thread.start()
|
||||
|
||||
# 线程5: 声纹识别线程(如果启用)
|
||||
if self.sv_enabled and self.sv_client:
|
||||
self.sv_thread = threading.Thread(
|
||||
target=self._sv_worker,
|
||||
name="SVThread",
|
||||
daemon=True
|
||||
)
|
||||
self.sv_thread.start()
|
||||
else:
|
||||
self.sv_thread = None
|
||||
|
||||
def _handle_interrupt_command(self, msg: String):
|
||||
"""处理ROS中断命令"""
|
||||
@@ -301,17 +383,250 @@ class RobotSpeakerNode(Node):
|
||||
self._drain_queue(self.text_queue)
|
||||
# TTS 队列不清空,由 TTS 线程根据 interrupt_event 自行中断播放
|
||||
|
||||
# ==================== 状态机方法 ====================
|
||||
|
||||
def _change_state(self, new_state: ConversationState, reason: str = ""):
|
||||
"""改变状态机状态"""
|
||||
with self.state_lock:
|
||||
old_state = self.conversation_state
|
||||
self.conversation_state = new_state
|
||||
if reason:
|
||||
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}: {reason}")
|
||||
else:
|
||||
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}")
|
||||
|
||||
def _get_state(self) -> ConversationState:
|
||||
"""获取当前状态"""
|
||||
with self.state_lock:
|
||||
return self.conversation_state
|
||||
|
||||
def _get_silence_threshold(self) -> int:
|
||||
"""获取动态静音阈值(毫秒):REGISTERING状态下使用更短的阈值"""
|
||||
state = self._get_state()
|
||||
if state == ConversationState.REGISTERING:
|
||||
# REGISTERING状态下使用更短的静音阈值(500ms),加快响应
|
||||
return 500
|
||||
else:
|
||||
# 其他状态使用默认阈值
|
||||
return self.silence_duration_ms
|
||||
|
||||
def _should_put_audio_to_queue(self) -> bool:
|
||||
"""
|
||||
检查是否应该将音频放入队列(用于ASR)
|
||||
根据状态机决定是否允许ASR
|
||||
"""
|
||||
state = self._get_state()
|
||||
|
||||
# 以下状态允许ASR:
|
||||
# - IDLE: 等待用户说话,允许ASR检测
|
||||
# - CHECK_VOICE: 检查声纹,需要ASR识别
|
||||
# - AUTHORIZED: 已授权用户,正常处理
|
||||
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
|
||||
ConversationState.AUTHORIZED]:
|
||||
return True
|
||||
|
||||
# 以下状态不允许ASR:
|
||||
# - REGISTERING: 正在录音注册,不需要ASR
|
||||
return False
|
||||
|
||||
# ==================== 录音线程回调 ====================
|
||||
|
||||
def _on_speech_start(self):
|
||||
"""录音线程检测到人声开始"""
|
||||
self.get_logger().info("[录音线程] 检测到人声,开始录音")
|
||||
|
||||
state = self._get_state()
|
||||
|
||||
if state == ConversationState.IDLE:
|
||||
# Idle -> CheckVoice
|
||||
if self.sv_enabled and self.sv_client:
|
||||
# 开始录音用于声纹验证
|
||||
with self.sv_lock:
|
||||
self.sv_recording = True
|
||||
self.sv_audio_buffer.clear()
|
||||
self.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
self._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
|
||||
else:
|
||||
# 未启用声纹,直接进入Authorized状态
|
||||
self._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
|
||||
|
||||
elif state == ConversationState.REGISTERING:
|
||||
# Registering状态,记录开始时间(用于超时检测)
|
||||
self.sv_registration_start_time = time.time()
|
||||
self.get_logger().info("[声纹注册] 检测到人声开始,记录开始时间")
|
||||
|
||||
elif state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态,继续录音用于声纹验证
|
||||
if self.sv_enabled:
|
||||
with self.sv_lock:
|
||||
self.sv_recording = True
|
||||
self.sv_audio_buffer.clear()
|
||||
self.get_logger().debug("[声纹] 继续录音用于声纹验证")
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态,开始录音用于声纹验证(验证当前用户)
|
||||
if self.sv_enabled:
|
||||
with self.sv_lock:
|
||||
self.sv_recording = True
|
||||
self.sv_audio_buffer.clear()
|
||||
self.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
|
||||
def _on_audio_chunk_for_sv(self, audio_chunk: bytes):
|
||||
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
|
||||
state = self._get_state()
|
||||
|
||||
# 声纹验证录音(CHECK_VOICE, AUTHORIZED状态)
|
||||
if self.sv_enabled and self.sv_recording:
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.sv_lock:
|
||||
self.sv_audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"[声纹] 录音失败: {e}")
|
||||
|
||||
# 注册音频录制(REGISTERING状态,且未开始处理,且TTS不在播放)
|
||||
# 重要:TTS播放完成后才开始录音,确保第一段语音就是声纹信息
|
||||
if state == ConversationState.REGISTERING and not self.sv_registration_processing:
|
||||
# 如果TTS正在播放,不录音(避免录到TTS声音或噪音)
|
||||
if self.tts_playing_event.is_set():
|
||||
return # TTS播放中,不录音
|
||||
|
||||
# 超时检测:如果录音超过10秒,强制处理注册
|
||||
if self.sv_registration_start_time is not None:
|
||||
elapsed_time = time.time() - self.sv_registration_start_time
|
||||
if elapsed_time >= 10.0: # 10秒超时
|
||||
buffer_size = len(self.sv_registration_audio_buffer)
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().warning(f"[声纹注册] 录音超时({elapsed_time:.2f}秒),强制处理注册,音频长度: {buffer_sec:.2f}秒")
|
||||
# 强制触发注册处理
|
||||
self._force_process_registration()
|
||||
return
|
||||
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
self.sv_registration_audio_buffer.extend(audio_array)
|
||||
# 添加调试日志(每10000个样本记录一次,约0.625秒)
|
||||
buffer_size = len(self.sv_registration_audio_buffer)
|
||||
if buffer_size % 10000 == 0:
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().info(f"[声纹注册] 录音中,当前音频长度: {buffer_sec:.2f}秒 ({buffer_size} 样本)")
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"[声纹注册] 录音失败: {e}")
|
||||
|
||||
def _force_process_registration(self):
|
||||
"""强制处理注册(用于超时情况)"""
|
||||
if self.sv_registration_processing:
|
||||
return # 已经在处理中
|
||||
|
||||
self.sv_registration_processing = True
|
||||
audio_list = list(self.sv_registration_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().info(f"[声纹注册] 强制处理注册,录音缓冲区大小: {buffer_size} 样本({buffer_sec:.2f}秒)")
|
||||
self.sv_registration_audio_buffer.clear()
|
||||
self.sv_registration_start_time = None # 清除开始时间
|
||||
|
||||
# 需要大于3秒音频
|
||||
min_samples = int(self.sample_rate * 3.0) # 3秒
|
||||
if buffer_size >= min_samples:
|
||||
self.get_logger().info(f"[声纹注册] ✓ 音频长度足够,开始注册,音频长度: {buffer_size} 样本({buffer_sec:.2f}秒)")
|
||||
# 在新线程中处理注册,避免阻塞
|
||||
threading.Thread(
|
||||
target=self._process_sv_registration,
|
||||
args=(audio_list,),
|
||||
daemon=True,
|
||||
name="SVRegistrationThread"
|
||||
).start()
|
||||
else:
|
||||
self.get_logger().warning(f"[声纹注册] ✗ 音频太短: {buffer_size} < {min_samples}({buffer_sec:.2f}秒 < 3秒),注册失败")
|
||||
prompt_text = "音频太短,注册失败,请说话大于3秒"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册失败,回到Idle状态
|
||||
self.sv_registration_processing = False # 重置处理标志
|
||||
self._change_state(ConversationState.IDLE, "注册失败")
|
||||
|
||||
def _on_speech_end(self):
|
||||
"""录音线程检测到说话结束(静音一段时间)"""
|
||||
self.get_logger().info("[录音线程] 检测到说话结束,触发ASR识别")
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.stop_current_recognition()
|
||||
self.get_logger().info("[录音线程] 检测到说话结束")
|
||||
|
||||
state = self._get_state()
|
||||
self.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
|
||||
|
||||
if state == ConversationState.REGISTERING:
|
||||
# Registering状态,说话结束,直接处理注册(不依赖ASR)
|
||||
if self.sv_registration_processing:
|
||||
return # 已经在处理中
|
||||
|
||||
self.sv_registration_processing = True
|
||||
self.get_logger().info("[声纹注册] 说话结束,开始处理注册,停止录音")
|
||||
audio_list = list(self.sv_registration_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().info(f"[声纹注册] 录音缓冲区大小: {buffer_size} 样本({buffer_sec:.2f}秒)")
|
||||
self.sv_registration_audio_buffer.clear()
|
||||
self.sv_registration_start_time = None # 清除开始时间
|
||||
|
||||
# 需要大于3秒音频
|
||||
min_samples = int(self.sample_rate * 3.0) # 3秒
|
||||
if buffer_size >= min_samples:
|
||||
self.get_logger().info(f"[声纹注册] ✓ 音频长度足够,开始注册,音频长度: {buffer_size} 样本({buffer_sec:.2f}秒)")
|
||||
# 在新线程中处理注册,避免阻塞
|
||||
threading.Thread(
|
||||
target=self._process_sv_registration,
|
||||
args=(audio_list,),
|
||||
daemon=True,
|
||||
name="SVRegistrationThread"
|
||||
).start()
|
||||
else:
|
||||
self.get_logger().warning(f"[声纹注册] ✗ 音频太短: {buffer_size} < {min_samples}({buffer_sec:.2f}秒 < 3秒),注册失败")
|
||||
prompt_text = "音频太短,注册失败,请说话大于3秒"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册失败,回到Idle状态
|
||||
self.sv_registration_processing = False # 重置处理标志
|
||||
self._change_state(ConversationState.IDLE, "注册失败")
|
||||
return
|
||||
|
||||
elif state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态,语音结束,触发ASR识别,然后检查声纹
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.stop_current_recognition()
|
||||
# 通知声纹线程处理(避免重复触发:如果ASR已经触发了,就不再触发)
|
||||
if self.sv_enabled and self.sv_client:
|
||||
# 检查数据库是否为空,如果为空则跳过验证
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
self.get_logger().info("[声纹] CHECK_VOICE状态,数据库为空,跳过声纹验证")
|
||||
# 直接设置结果,避免等待验证
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN
|
||||
self.sv_result_ready_event.set()
|
||||
# 检查是否已经触发了声纹验证(避免重复触发)
|
||||
elif not self.sv_speech_end_event.is_set():
|
||||
with self.sv_lock:
|
||||
self.sv_recording = False
|
||||
buffer_size = len(self.sv_audio_buffer)
|
||||
self.get_logger().info(f"[声纹] VAD触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
|
||||
self.sv_speech_end_event.set()
|
||||
else:
|
||||
self.get_logger().debug("[声纹] 声纹验证已由ASR触发,跳过VAD触发")
|
||||
# 状态转换在声纹验证完成后进行(在_process_worker中)
|
||||
return
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态,正常触发ASR识别
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.stop_current_recognition()
|
||||
# 通知声纹线程处理(可选,用于验证当前用户)
|
||||
if self.sv_enabled:
|
||||
with self.sv_lock:
|
||||
self.sv_recording = False
|
||||
buffer_size = len(self.sv_audio_buffer)
|
||||
self.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
|
||||
self.sv_speech_end_event.set()
|
||||
return
|
||||
|
||||
# IDLE状态:不需要特殊处理
|
||||
|
||||
def _on_new_segment(self):
|
||||
"""录音线程检测到新的人声段,立即中断TTS"""
|
||||
@@ -329,6 +644,38 @@ class RobotSpeakerNode(Node):
|
||||
return
|
||||
text_clean = text.strip()
|
||||
self.get_logger().info(f"[ASR] 识别完成: {text_clean}")
|
||||
|
||||
state = self._get_state()
|
||||
|
||||
# 规则1:REGISTERING状态不允许ASR,由VAD的speech_end处理,这里直接忽略
|
||||
if state == ConversationState.REGISTERING:
|
||||
self.get_logger().warning("[ASR] REGISTERING状态下收到ASR文本,忽略(应由VAD的speech_end处理)")
|
||||
return
|
||||
|
||||
# 规则2:CHECK_VOICE状态下,如果ASR识别完成但VAD还没有触发speech_end,主动触发声纹验证
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if self.sv_enabled and self.sv_client:
|
||||
# 检查数据库是否为空,如果为空则跳过验证
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
self.get_logger().info("[ASR] CHECK_VOICE状态,数据库为空,跳过声纹验证")
|
||||
# 直接设置结果,避免等待验证
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN
|
||||
self.sv_result_ready_event.set()
|
||||
# 检查是否已经触发了声纹验证(通过检查sv_speech_end_event是否已设置)
|
||||
elif not self.sv_speech_end_event.is_set():
|
||||
# 如果还没有触发,主动触发声纹验证
|
||||
self.get_logger().info("[ASR] CHECK_VOICE状态,ASR识别完成,主动触发声纹验证")
|
||||
with self.sv_lock:
|
||||
self.sv_recording = False
|
||||
buffer_size = len(self.sv_audio_buffer)
|
||||
if buffer_size > 0:
|
||||
self.get_logger().info(f"[声纹] ASR触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
|
||||
self.sv_speech_end_event.set()
|
||||
|
||||
# 其他状态,将文本放入队列
|
||||
self.text_queue.put(text_clean, timeout=1.0)
|
||||
|
||||
def _on_asr_text_update(self, text: str):
|
||||
@@ -360,7 +707,9 @@ class RobotSpeakerNode(Node):
|
||||
# ASR线程:只检查中断,不清除(让主线程清除)
|
||||
if self.interrupt_event.is_set():
|
||||
continue
|
||||
if self.asr_client and self.asr_client.running:
|
||||
|
||||
# 检查是否应该发送音频到ASR(由_should_put_audio_to_queue控制)
|
||||
if self._should_put_audio_to_queue() and self.asr_client and self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
|
||||
def _process_worker(self):
|
||||
@@ -388,8 +737,109 @@ class RobotSpeakerNode(Node):
|
||||
if self.use_llm and self.history:
|
||||
self.history.cancel_turn()
|
||||
continue
|
||||
|
||||
# 检查状态机状态
|
||||
current_state = self._get_state()
|
||||
|
||||
# 根据状态处理文本
|
||||
if current_state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态:先检查唤醒词(如果开启),再检查注册指令,最后检查声纹验证结果
|
||||
|
||||
# 步骤1: 如果开启唤醒词,先检查唤醒词
|
||||
if self.use_wake_word:
|
||||
self.get_logger().info(f"[主线程] CHECK_VOICE状态,检查唤醒词,文本: {text}")
|
||||
processed_text = self._handle_wake_word(text)
|
||||
if not processed_text:
|
||||
# 未检测到唤醒词,直接回到Idle状态
|
||||
self.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{self.wake_word}'),回到Idle状态")
|
||||
self._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue # 不处理当前文本
|
||||
# 检测到唤醒词,继续检查
|
||||
self.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
|
||||
text = processed_text # 使用处理后的文本(移除唤醒词)
|
||||
|
||||
# 步骤1.5: 检查是否有注册指令(优先级最高,在声纹验证之前)
|
||||
if self.sv_enabled and self._check_register_intent(text):
|
||||
if self.sv_client:
|
||||
self.get_logger().info("[声纹注册] CHECK_VOICE状态检测到注册指令,直接进入注册状态")
|
||||
self.sv_registration_audio_buffer.clear()
|
||||
self.sv_registration_processing = False
|
||||
self.sv_registration_start_time = None # 重置开始时间
|
||||
self._change_state(ConversationState.REGISTERING, "用户请求声纹注册")
|
||||
prompt_text = "请说话,至少3秒,用于声纹注册"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
else:
|
||||
self.get_logger().warning("[声纹注册] 声纹功能未启用,无法注册")
|
||||
prompt_text = "声纹功能未启用"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
self._change_state(ConversationState.IDLE, "声纹功能未启用")
|
||||
continue
|
||||
|
||||
# 步骤2: 检查声纹验证结果
|
||||
if self.sv_enabled and self.sv_client:
|
||||
# 等待声纹验证结果
|
||||
# 规则2:必须等待声纹结果ready
|
||||
# 先清除事件,确保等待的是新结果
|
||||
self.sv_result_ready_event.clear()
|
||||
self.get_logger().info("[主线程] CHECK_VOICE状态:等待声纹验证结果...")
|
||||
if not self.sv_result_ready_event.wait(timeout=2.0):
|
||||
# 声纹未完成,拒绝本轮,回到Idle状态
|
||||
self.get_logger().warning("[主线程] CHECK_VOICE状态:声纹结果未ready(超时2秒),拒绝本轮")
|
||||
# 清理声纹缓冲区,避免残留数据影响下一轮
|
||||
with self.sv_lock:
|
||||
self.sv_audio_buffer.clear()
|
||||
self._change_state(ConversationState.IDLE, "声纹结果未ready")
|
||||
continue
|
||||
self.get_logger().info("[主线程] CHECK_VOICE状态:声纹结果ready,继续处理")
|
||||
self.sv_result_ready_event.clear() # 清除事件,准备下次
|
||||
|
||||
with self.sv_lock:
|
||||
speaker_id = self.current_speaker_id
|
||||
speaker_state = self.current_speaker_state
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
# 声纹匹配成功,进入Authorized状态
|
||||
self.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}")
|
||||
self._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
|
||||
# 继续处理文本(正常交互)
|
||||
else:
|
||||
# 声纹匹配失败,不是已注册用户
|
||||
self.get_logger().info("[主线程] 声纹验证失败,不是已注册用户")
|
||||
# 如果数据库为空,提示注册
|
||||
if self.sv_client.get_speaker_count() == 0:
|
||||
prompt_text = "声纹数据库为空,请说'注册声纹'进行注册"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
self._change_state(ConversationState.IDLE, "声纹数据库为空")
|
||||
else:
|
||||
prompt_text = "声纹验证失败,不是已注册用户"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
self._change_state(ConversationState.IDLE, "声纹验证失败")
|
||||
continue # 不处理当前文本
|
||||
else:
|
||||
# 未启用声纹,直接进入Authorized状态
|
||||
self._change_state(ConversationState.AUTHORIZED, "未启用声纹")
|
||||
|
||||
elif current_state == ConversationState.REGISTERING:
|
||||
# Registering状态:不应该收到文本(由VAD的speech_end处理)
|
||||
self.get_logger().warning("[主线程] Registering状态收到文本,忽略(应由VAD的speech_end处理)")
|
||||
continue
|
||||
|
||||
elif current_state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态:正常处理用户请求
|
||||
pass # 继续处理
|
||||
|
||||
elif current_state == ConversationState.IDLE:
|
||||
# Idle状态:不应该收到文本
|
||||
self.get_logger().warning("[主线程] Idle状态收到文本,忽略")
|
||||
continue
|
||||
|
||||
# 步骤2: 唤醒词处理
|
||||
# 主线程:检查中断,如果中断则清除事件并跳过
|
||||
if self._check_interrupt(auto_clear=True):
|
||||
if self.use_llm and self.history:
|
||||
self.history.cancel_turn()
|
||||
continue
|
||||
|
||||
# 步骤2: 唤醒词处理(AUTHORIZED状态下)
|
||||
processed_text = self._handle_wake_word(text)
|
||||
if not processed_text:
|
||||
self.get_logger().info("[主线程] 唤醒词过滤后为空,跳过处理")
|
||||
@@ -412,7 +862,23 @@ class RobotSpeakerNode(Node):
|
||||
self.history.cancel_turn()
|
||||
continue
|
||||
|
||||
# 步骤2.6: 检测相机指令
|
||||
# 步骤2.6: 检测声纹注册指令(优先级最高)
|
||||
if self.sv_enabled and self._check_register_intent(processed_text):
|
||||
if self.sv_client:
|
||||
self.get_logger().info("[声纹注册] 检测到注册指令,进入注册状态")
|
||||
self.sv_registration_audio_buffer.clear()
|
||||
self.sv_registration_processing = False
|
||||
self.sv_registration_start_time = None # 重置开始时间
|
||||
self._change_state(ConversationState.REGISTERING, "用户请求声纹注册")
|
||||
prompt_text = "请说话,至少3秒,用于声纹注册"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
else:
|
||||
self.get_logger().warning("[声纹注册] 声纹功能未启用,无法注册")
|
||||
prompt_text = "声纹功能未启用"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
continue
|
||||
|
||||
# 步骤2.7: 检测相机指令
|
||||
need_camera, user_text = self._check_camera_command(processed_text)
|
||||
if need_camera:
|
||||
self.get_logger().info(f"[相机指令] 检测到拍照指令,将进行多模态推理")
|
||||
@@ -784,6 +1250,15 @@ class RobotSpeakerNode(Node):
|
||||
self.get_logger().info("[TTS播放线程] 播放完成")
|
||||
else:
|
||||
self.get_logger().info("[TTS播放线程] 播放被中断")
|
||||
|
||||
# 在清除播放标志之前,检查是否是注册提示播放完成
|
||||
state = self._get_state()
|
||||
if state == ConversationState.REGISTERING and success:
|
||||
# TTS播放完成,清空注册缓冲区,准备开始录音
|
||||
# 这样确保TTS播放完成后检测到的第一段语音就是声纹信息
|
||||
self.sv_registration_audio_buffer.clear()
|
||||
self.get_logger().info("[声纹注册] TTS提示播放完成,已清空缓冲区,准备开始录音")
|
||||
|
||||
self.tts_playing_event.clear()
|
||||
|
||||
# 播放完成后检查中断,如果被中断则清空队列
|
||||
@@ -793,6 +1268,168 @@ class RobotSpeakerNode(Node):
|
||||
# 清除中断事件,准备下一次处理
|
||||
self.interrupt_event.clear()
|
||||
|
||||
def _sv_worker(self):
|
||||
"""
|
||||
线程5: 声纹识别线程 - 非实时、低频(CAM++)
|
||||
要求:不影响录音,不影响ASR,不控制TTS,只更新current_speaker_id
|
||||
|
||||
实现方式:
|
||||
- 通过回调函数 _on_audio_chunk_for_sv 接收音频chunk,写入缓冲区 sv_audio_buffer
|
||||
- 等待 speech_end 事件(sv_speech_end_event)触发处理
|
||||
- 从缓冲区取累积的有效人声(VAD 后)
|
||||
- CAM++ 提取 speaker embedding
|
||||
- 声纹匹配 / 注册
|
||||
- 更新 current_speaker_id(共享状态,只写不控)
|
||||
"""
|
||||
self.get_logger().info("[声纹识别线程] 启动")
|
||||
|
||||
min_audio_samples = 8000 # 至少0.5秒音频
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# 等待speech_end事件(一句话只处理一次)
|
||||
if self.sv_speech_end_event.wait(timeout=0.1):
|
||||
self.sv_speech_end_event.clear() # 清除事件标志
|
||||
|
||||
# 从缓冲区获取录音的音频
|
||||
with self.sv_lock:
|
||||
audio_list = list(self.sv_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
self.sv_audio_buffer.clear() # 清空缓冲区
|
||||
|
||||
self.get_logger().info(f"[声纹识别] 收到speech_end事件,录音长度: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
|
||||
|
||||
# 规则2:清除ready事件,表示正在计算
|
||||
self.sv_result_ready_event.clear()
|
||||
|
||||
# 检查数据库是否为空,如果为空则直接跳过验证
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
self.get_logger().info("[声纹识别] 数据库为空,跳过验证,直接设置UNKNOWN状态")
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN
|
||||
# 设置ready事件,通知处理线程
|
||||
self.sv_result_ready_event.set()
|
||||
continue # 跳过后续处理
|
||||
|
||||
# 检查是否有足够的音频
|
||||
if buffer_size >= min_audio_samples:
|
||||
# 转换为numpy数组
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
|
||||
# 提取embedding(低频调用,一句话只调用一次)
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if not success or embedding is None:
|
||||
self.get_logger().debug("[声纹识别] 提取embedding失败")
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.ERROR
|
||||
else:
|
||||
# 匹配说话人(一句话只调用一次)
|
||||
speaker_id, match_state = self.sv_client.match_speaker(embedding)
|
||||
|
||||
# 更新current_speaker_id和state(只写不控)
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = speaker_id
|
||||
self.current_speaker_state = match_state
|
||||
|
||||
if match_state == SpeakerState.VERIFIED:
|
||||
self.get_logger().info(f"[声纹识别] ✓ 识别到说话人: {speaker_id}")
|
||||
elif match_state == SpeakerState.REJECTED:
|
||||
self.get_logger().debug("[声纹识别] ✗ 未匹配到已知说话人(相似度不足)")
|
||||
else:
|
||||
self.get_logger().debug(f"[声纹识别] ? 状态: {match_state.value}")
|
||||
else:
|
||||
self.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
|
||||
# 即使音频太短,也设置ready事件,避免处理线程无限等待
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN
|
||||
|
||||
# 规则2:设置ready事件,通知处理线程声纹结果已准备好
|
||||
self.sv_result_ready_event.set()
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[声纹识别线程] 错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
def _process_sv_registration(self, audio_list: list):
|
||||
"""处理声纹注册(在独立线程中执行)"""
|
||||
audio_length_sec = 0.0
|
||||
try:
|
||||
# 转换为numpy数组
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
audio_length_sec = len(audio_array) / self.sample_rate
|
||||
|
||||
self.get_logger().info(f"[声纹注册] 开始处理注册,音频长度: {len(audio_array)} 样本({audio_length_sec:.2f}秒)")
|
||||
|
||||
# 提取embedding
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
if not success or embedding is None:
|
||||
self.get_logger().error("[声纹注册] ✗ 提取embedding失败")
|
||||
prompt_text = "声纹注册失败,无法提取特征"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册失败,转回Idle状态
|
||||
self._change_state(ConversationState.IDLE, "注册失败")
|
||||
self.sv_registration_processing = False
|
||||
return
|
||||
|
||||
self.get_logger().info(f"[声纹注册] ✓ 成功提取embedding,维度: {len(embedding) if embedding is not None else 0}")
|
||||
|
||||
# 自动生成speaker_id(使用时间戳)
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
|
||||
# 获取注册前的说话人数量
|
||||
speaker_count_before = self.sv_client.get_speaker_count()
|
||||
|
||||
# 注册说话人
|
||||
success = self.sv_client.register_speaker(speaker_id, embedding)
|
||||
|
||||
if success:
|
||||
# 获取注册后的说话人数量
|
||||
speaker_count_after = self.sv_client.get_speaker_count()
|
||||
|
||||
# 输出详细的注册成功日志
|
||||
self.get_logger().info("=" * 60)
|
||||
self.get_logger().info("[声纹注册] ✓✓✓ 注册成功 ✓✓✓")
|
||||
self.get_logger().info(f"[声纹注册] 说话人ID: {speaker_id}")
|
||||
self.get_logger().info(f"[声纹注册] 音频长度: {audio_length_sec:.2f}秒 ({len(audio_array)} 样本)")
|
||||
self.get_logger().info(f"[声纹注册] 注册前数据库说话人数量: {speaker_count_before}")
|
||||
self.get_logger().info(f"[声纹注册] 注册后数据库说话人数量: {speaker_count_after}")
|
||||
self.get_logger().info(f"[声纹注册] 注册时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
self.get_logger().info("=" * 60)
|
||||
|
||||
prompt_text = "声纹注册完成"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册成功,转回Idle状态(等待下次交互)
|
||||
self._change_state(ConversationState.IDLE, "注册成功")
|
||||
else:
|
||||
self.get_logger().error("[声纹注册] ✗ 注册失败(register_speaker返回False)")
|
||||
prompt_text = "声纹注册失败"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册失败,转回Idle状态
|
||||
self._change_state(ConversationState.IDLE, "注册失败")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[声纹注册] ✗ 处理失败,异常: {e}")
|
||||
self.get_logger().error(f"[声纹注册] 异常类型: {type(e).__name__}")
|
||||
import traceback
|
||||
self.get_logger().error(f"[声纹注册] 异常堆栈:\n{traceback.format_exc()}")
|
||||
prompt_text = "声纹注册过程出错"
|
||||
self.tts_queue.put(prompt_text, timeout=0.2)
|
||||
# 注册出错,转回Idle状态
|
||||
self._change_state(ConversationState.IDLE, "注册出错")
|
||||
finally:
|
||||
self.sv_registration_processing = False
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def _check_interrupt(self, auto_clear: bool = False) -> bool:
|
||||
@@ -832,6 +1469,17 @@ class RobotSpeakerNode(Node):
|
||||
py_list = pinyin(''.join(chars), style=Style.NORMAL)
|
||||
return ' '.join([item[0] for item in py_list])
|
||||
|
||||
def _check_register_intent(self, text: str) -> bool:
|
||||
"""检查用户是否有注册意图(优先级最高)"""
|
||||
text_lower = text.lower().strip()
|
||||
register_keywords = ["注册声纹", "开启声纹注册", "录入声纹", "声纹注册", "注册",
|
||||
"录入我的声纹", "录入声纹信息", "开始录入声纹", "开始声纹验证"]
|
||||
# 检查是否包含注册关键词(避免误判,要求关键词完整)
|
||||
for keyword in register_keywords:
|
||||
if keyword in text_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _start_session(self):
|
||||
"""开始会话"""
|
||||
with self.session_lock:
|
||||
@@ -909,8 +1557,11 @@ class RobotSpeakerNode(Node):
|
||||
self.stop_event.set()
|
||||
self.interrupt_event.set()
|
||||
|
||||
for thread in [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]:
|
||||
if thread.is_alive():
|
||||
threads_to_join = [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]
|
||||
if self.sv_thread:
|
||||
threads_to_join.append(self.sv_thread)
|
||||
for thread in threads_to_join:
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
if hasattr(self, 'asr_client') and self.asr_client:
|
||||
@@ -923,6 +1574,14 @@ class RobotSpeakerNode(Node):
|
||||
if hasattr(self, 'camera_client') and self.camera_client:
|
||||
self.camera_client.cleanup()
|
||||
|
||||
# 清理声纹识别资源
|
||||
if hasattr(self, 'sv_client') and self.sv_client:
|
||||
try:
|
||||
self.sv_client.save_speakers()
|
||||
self.sv_client.cleanup()
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"清理声纹识别资源时出错: {e}")
|
||||
|
||||
super().destroy_node()
|
||||
|
||||
|
||||
|
||||
386
robot_speaker/speaker_verification.py
Normal file
386
robot_speaker/speaker_verification.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
声纹识别模块
|
||||
"""
|
||||
import numpy as np
|
||||
import threading
|
||||
import tempfile
|
||||
import os
|
||||
import wave
|
||||
import time
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SpeakerState(Enum):
|
||||
"""说话人识别状态"""
|
||||
UNKNOWN = "unknown"
|
||||
VERIFIED = "verified"
|
||||
REJECTED = "rejected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SpeakerVerificationClient:
|
||||
"""声纹识别客户端 - 非实时、低频处理"""
|
||||
|
||||
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
|
||||
self.model_path = model_path
|
||||
self.threshold = threshold
|
||||
self.speaker_db_path = speaker_db_path
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
# 固定embedding维度,避免第一次异常embedding导致后续全部被拒绝
|
||||
# 常见维度:192 (CAM++), 256 等
|
||||
self._expected_embedding_dim = None # 只存储维度大小,不存储shape元组
|
||||
|
||||
from funasr import AutoModel
|
||||
self.model = AutoModel(model=self.model_path, device="cpu")
|
||||
if self.logger:
|
||||
self.logger.info(f"声纹模型已加载: {self.model_path}, 阈值: {self.threshold}")
|
||||
|
||||
if self.speaker_db_path:
|
||||
self.load_speakers()
|
||||
|
||||
def _log(self, level: str, msg: str):
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
# 使用映射字典,避免动态获取方法导致ROS2 logger错误
|
||||
log_methods = {
|
||||
"debug": self.logger.debug,
|
||||
"info": self.logger.info,
|
||||
"warning": self.logger.warning,
|
||||
"error": self.logger.error,
|
||||
"fatal": self.logger.fatal
|
||||
}
|
||||
log_method = log_methods.get(level.lower(), self.logger.info)
|
||||
log_method(msg)
|
||||
except ValueError as e:
|
||||
# ROS2 logger在多线程环境中可能出现"Logger severity cannot be changed between calls"错误
|
||||
# 如果出现此错误,降级使用info级别
|
||||
if "severity cannot be changed" in str(e):
|
||||
try:
|
||||
self.logger.info(f"[声纹-{level.upper()}] {msg}")
|
||||
except:
|
||||
print(f"[声纹-{level.upper()}] {msg}")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
print(f"[声纹] {msg}")
|
||||
|
||||
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""将numpy音频数组写入临时wav文件"""
|
||||
audio_int16 = audio_data.astype(np.int16)
|
||||
|
||||
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
|
||||
os.close(fd)
|
||||
|
||||
with wave.open(temp_path, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
return temp_path
|
||||
|
||||
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""
|
||||
提取说话人embedding(低频调用,一句话只调用一次)
|
||||
返回: (embedding: np.ndarray | None, success: bool)
|
||||
- 成功返回 (embedding, True)
|
||||
- 失败返回 (None, False)
|
||||
"""
|
||||
if len(audio_data) < int(sample_rate * 0.5):
|
||||
return None, False
|
||||
|
||||
temp_wav_path = None
|
||||
try:
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
# funasr返回格式:list[dict],dict键为'spk_embedding',值为torch.Tensor(shape=[1, 192])
|
||||
import torch
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
|
||||
# 校验embedding维度
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, False
|
||||
|
||||
# 如果已有注册的声纹,校验维度是否一致
|
||||
if self._expected_embedding_dim is not None:
|
||||
if embedding_dim != self._expected_embedding_dim:
|
||||
self._log("error", f"embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
|
||||
return None, False
|
||||
else:
|
||||
# 第一次成功提取,固定维度(只在第一次成功时设置,避免异常embedding污染)
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
self._log("info", f"固定embedding维度: {embedding_dim}")
|
||||
|
||||
return embedding, True
|
||||
except Exception as e:
|
||||
self._log("error", f"提取embedding失败: {e}")
|
||||
return None, False
|
||||
finally:
|
||||
if temp_wav_path and os.path.exists(temp_wav_path):
|
||||
try:
|
||||
os.unlink(temp_wav_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
|
||||
env: str = "near", threshold: float = None) -> bool:
|
||||
"""
|
||||
注册说话人
|
||||
关键修复:注册时对embedding进行归一化,一劳永逸
|
||||
"""
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return False
|
||||
|
||||
# 校验维度一致性
|
||||
if self._expected_embedding_dim is not None:
|
||||
if embedding_dim != self._expected_embedding_dim:
|
||||
self._log("error", f"注册失败:embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
|
||||
return False
|
||||
else:
|
||||
# 如果还没有固定维度,使用当前embedding的维度
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
|
||||
# 关键修复:注册时归一化embedding
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
self._log("error", f"注册失败:embedding范数为0")
|
||||
return False
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
speaker_threshold = threshold if threshold is not None else self.threshold
|
||||
|
||||
with self._lock:
|
||||
# 使用dict结构存储,便于序列化
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_normalized, # 已归一化
|
||||
"env": env,
|
||||
"threshold": speaker_threshold,
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
|
||||
|
||||
# 在锁外调用save_speakers,避免死锁(save_speakers内部会获取锁)
|
||||
save_result = self.save_speakers()
|
||||
if not save_result:
|
||||
# 使用info级别而不是warning,避免ROS2 logger在多线程环境中的问题
|
||||
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
|
||||
return True
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray):
|
||||
"""
|
||||
匹配说话人(一句话只调用一次)
|
||||
返回: (speaker_id: str | None, state: SpeakerState)
|
||||
"""
|
||||
if not self.speaker_db:
|
||||
return None, SpeakerState.UNKNOWN
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, SpeakerState.ERROR
|
||||
|
||||
# 校验维度一致性
|
||||
if self._expected_embedding_dim is not None and embedding_dim != self._expected_embedding_dim:
|
||||
return None, SpeakerState.ERROR
|
||||
|
||||
# 归一化当前embedding(注册时已归一化,这里只需要归一化当前embedding)
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
return None, SpeakerState.ERROR
|
||||
embedding_normalized = embedding / embedding_norm
|
||||
|
||||
best_match = None
|
||||
best_score = -1.0
|
||||
best_threshold = self.threshold
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
# 从dict结构读取
|
||||
ref_embedding = speaker_data["embedding"] # 注册时已归一化
|
||||
|
||||
if len(ref_embedding) != embedding_dim:
|
||||
continue
|
||||
|
||||
# 直接计算cosine similarity(两个向量都已归一化)
|
||||
score = np.dot(embedding_normalized, ref_embedding)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = speaker_id
|
||||
best_threshold = speaker_data.get("threshold", self.threshold)
|
||||
|
||||
return (best_match, SpeakerState.VERIFIED) if best_score >= best_threshold else (None, SpeakerState.REJECTED)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
pass
|
||||
|
||||
def get_speaker_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self.speaker_db)
|
||||
|
||||
def remove_speaker(self, speaker_id: str) -> bool:
|
||||
with self._lock:
|
||||
if speaker_id not in self.speaker_db:
|
||||
return False
|
||||
del self.speaker_db[speaker_id]
|
||||
self.save_speakers()
|
||||
return True
|
||||
|
||||
def load_speakers(self) -> bool:
|
||||
"""
|
||||
从文件加载已注册的声纹
|
||||
使用JSON格式存储,更安全、可迁移
|
||||
格式: {speaker_id: {"embedding": [list], "env": str, "threshold": float, "registered_at": float}}
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
return False
|
||||
|
||||
if not os.path.exists(self.speaker_db_path):
|
||||
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
with self._lock:
|
||||
# 转换list回numpy array,并校验维度
|
||||
for speaker_id, speaker_data in data.items():
|
||||
embedding_list = speaker_data["embedding"]
|
||||
embedding_array = np.array(embedding_list, dtype=np.float32)
|
||||
|
||||
# 校验维度
|
||||
embedding_dim = len(embedding_array)
|
||||
if embedding_dim == 0:
|
||||
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
|
||||
continue
|
||||
|
||||
# 设置期望维度(如果还没有)
|
||||
if self._expected_embedding_dim is None:
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
elif embedding_dim != self._expected_embedding_dim:
|
||||
self._log("warning", f"跳过维度不匹配的声纹: {speaker_id} (期望{self._expected_embedding_dim}, 实际{embedding_dim})")
|
||||
continue
|
||||
|
||||
# 确保embedding已归一化(兼容旧数据)
|
||||
embedding_norm = np.linalg.norm(embedding_array)
|
||||
if embedding_norm > 0:
|
||||
embedding_array = embedding_array / embedding_norm
|
||||
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_array,
|
||||
"env": speaker_data.get("env", "near"),
|
||||
"threshold": speaker_data.get("threshold", self.threshold),
|
||||
"registered_at": speaker_data.get("registered_at", time.time())
|
||||
}
|
||||
|
||||
count = len(self.speaker_db)
|
||||
if self._expected_embedding_dim:
|
||||
self._log("info", f"已加载 {count} 个已注册说话人,embedding维度: {self._expected_embedding_dim}")
|
||||
else:
|
||||
self._log("info", f"已加载 {count} 个已注册说话人")
|
||||
return True
|
||||
except json.JSONDecodeError as e:
|
||||
# 尝试兼容旧的pickle格式
|
||||
try:
|
||||
import pickle
|
||||
with open(self.speaker_db_path, 'rb') as f:
|
||||
old_data = pickle.load(f)
|
||||
self._log("warning", "检测到旧的pickle格式,正在迁移...")
|
||||
# 迁移逻辑:转换为新格式
|
||||
with self._lock:
|
||||
for speaker_id, speaker_info in old_data.items():
|
||||
if hasattr(speaker_info, 'embedding'):
|
||||
# 旧格式:SpeakerInfo对象
|
||||
embedding = speaker_info.embedding
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm > 0:
|
||||
embedding = embedding / embedding_norm
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding,
|
||||
"env": getattr(speaker_info, 'env', 'near'),
|
||||
"threshold": getattr(speaker_info, 'threshold', self.threshold),
|
||||
"registered_at": getattr(speaker_info, 'registered_at', time.time())
|
||||
}
|
||||
else:
|
||||
# 可能是dict格式
|
||||
embedding = speaker_info.get("embedding")
|
||||
if embedding is not None:
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm > 0:
|
||||
embedding = embedding / embedding_norm
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding,
|
||||
"env": speaker_info.get("env", "near"),
|
||||
"threshold": speaker_info.get("threshold", self.threshold),
|
||||
"registered_at": speaker_info.get("registered_at", time.time())
|
||||
}
|
||||
# 保存为新格式
|
||||
self.save_speakers()
|
||||
self._log("info", "已迁移到新格式")
|
||||
except Exception as e2:
|
||||
self._log("error", f"加载声纹数据库失败(JSON和pickle都失败): {e}, {e2}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
return False
|
||||
|
||||
def save_speakers(self) -> bool:
|
||||
"""
|
||||
保存已注册的声纹到文件
|
||||
使用JSON格式,更安全、可迁移
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
|
||||
return False
|
||||
|
||||
try:
|
||||
db_dir = os.path.dirname(self.speaker_db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# 转换为JSON可序列化的格式
|
||||
json_data = {}
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
json_data[speaker_id] = {
|
||||
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
|
||||
"env": speaker_data.get("env", "near"),
|
||||
"threshold": speaker_data.get("threshold", self.threshold),
|
||||
"registered_at": speaker_data.get("registered_at", time.time())
|
||||
}
|
||||
|
||||
# 使用临时文件 + 原子替换,避免写入过程中崩溃导致数据丢失
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 原子替换
|
||||
os.replace(temp_path, self.speaker_db_path)
|
||||
|
||||
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
self._log("error", f"保存路径: {self.speaker_db_path}")
|
||||
self._log("error", f"错误详情: {traceback.format_exc()}")
|
||||
# 清理临时文件
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user