增加声纹验证

This commit is contained in:
lxy
2026-01-12 20:39:47 +08:00
parent 9c775cff5c
commit 838a4a357c
5 changed files with 1118 additions and 13 deletions

View File

@@ -38,3 +38,7 @@ system:
use_wake_word: true # 是否启用唤醒词检测
wake_word: "er gou" # 唤醒词(拼音)
session_timeout: 10.0 # 会话超时时间(秒)
sv_enabled: true # 是否启用声纹识别
sv_model_path: "~/ros_learn/speech_campplus_sv_zh-cn_16k-common" # 声纹模型路径
sv_threshold: 0.35 # 声纹识别阈值0.0-1.0,值越小越宽松,值越大越严格)
sv_speaker_db_path: "config/speakers.json" # 声纹数据库保存路径JSON格式相对于ROS2包share目录

View File

@@ -183,6 +183,31 @@ def generate_launch_description():
description='会话超时时间(秒)'
),
# 声纹识别参数
DeclareLaunchArgument(
'sv_enabled',
default_value=str(system_config.get('sv_enabled', True)).lower(),
description='是否启用声纹识别'
),
DeclareLaunchArgument(
'sv_model_path',
default_value=os.path.expanduser(system_config.get('sv_model_path', '')),
description='声纹模型路径'
),
DeclareLaunchArgument(
'sv_threshold',
default_value=str(system_config.get('sv_threshold', 0.45)),
description='声纹识别阈值'
),
DeclareLaunchArgument(
'sv_speaker_db_path',
default_value=os.path.join(
get_package_share_directory('robot_speaker'),
system_config.get('sv_speaker_db_path', 'config/speakers.json')
) if system_config.get('sv_speaker_db_path') else '',
description='声纹数据库路径'
),
# 相机参数
DeclareLaunchArgument(
'camera_serial_number',
@@ -260,6 +285,12 @@ def generate_launch_description():
'wake_word': LaunchConfiguration('wake_word'),
'session_timeout': LaunchConfiguration('session_timeout'),
# 声纹识别参数
'sv_enabled': LaunchConfiguration('sv_enabled'),
'sv_model_path': LaunchConfiguration('sv_model_path'),
'sv_threshold': LaunchConfiguration('sv_threshold'),
'sv_speaker_db_path': LaunchConfiguration('sv_speaker_db_path'),
# 相机参数
'camera_serial_number': LaunchConfiguration('camera_serial_number'),
'camera_width': LaunchConfiguration('camera_width'),

View File

@@ -37,6 +37,9 @@ class AudioRecorder:
on_speech_start=None, # 检测到人声开始
on_speech_end=None, # 检测到静音结束(说话结束)
stop_flag=None,
on_audio_chunk=None, # 音频chunk回调用于声纹录音等可选
should_put_to_queue=None, # 检查是否应该将音频放入队列用于阻止ASR可选
get_silence_threshold=None, # 获取动态静音阈值(毫秒,可选)
logger=None):
self.device_index = device_index
self.sample_rate = sample_rate
@@ -54,6 +57,9 @@ class AudioRecorder:
self.on_speech_start = on_speech_start
self.on_speech_end = on_speech_end
self.stop_flag = stop_flag or (lambda: False)
self.on_audio_chunk = on_audio_chunk # 音频chunk回调用于声纹录音等
self.should_put_to_queue = should_put_to_queue or (lambda: True) # 默认允许放入队列
self.get_silence_threshold = get_silence_threshold # 动态静音阈值回调
self.logger = logger
self.audio = pyaudio.PyAudio()
self.format = pyaudio.paInt16
@@ -97,10 +103,16 @@ class AudioRecorder:
# exception_on_overflow=False, 宁可丢帧,也不阻塞
data = stream.read(self.chunk, exception_on_overflow=False)
# 队列满时丢弃最旧的数据ASR 跟不上时系统仍然听得见
if self.audio_queue.full():
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(data)
# 检查是否应该将音频放入队列用于阻止ASR例如无声纹文件时需要注册
if self.should_put_to_queue():
# 队列满时丢弃最旧的数据ASR 跟不上时系统仍然听得见
if self.audio_queue.full():
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(data)
# 音频chunk回调用于声纹录音等仅在需要时调用
if self.on_audio_chunk:
self.on_audio_chunk(data)
audio_buffer.append(data) # 只用于 VAD不用于 ASR
@@ -133,8 +145,21 @@ class AudioRecorder:
else:
if was_speaking:
silence_duration = now - last_active_time
if silence_duration >= no_speech_threshold:
# 动态获取静音阈值(如果提供回调函数)
if self.get_silence_threshold:
current_silence_ms = self.get_silence_threshold()
current_no_speech_threshold = max(current_silence_ms / 1000.0, 0.1)
else:
current_no_speech_threshold = no_speech_threshold
# 添加调试日志
if self.logger and silence_duration < current_no_speech_threshold:
self.logger.debug(f"[VAD] 静音中: {silence_duration:.3f}秒 < {current_no_speech_threshold:.3f}秒阈值")
if silence_duration >= current_no_speech_threshold:
if self.on_speech_end:
if self.logger:
self.logger.debug(f"[VAD] 触发speech_end: 静音持续时间 {silence_duration:.3f}秒 >= 阈值 {current_no_speech_threshold:.3f}")
self.on_speech_end() # 通知系统用户停止说话
if self.on_heartbeat and now - last_heartbeat_time >= self.heartbeat_interval:

View File

@@ -14,6 +14,8 @@ import numpy as np
from PIL import Image
from pypinyin import pinyin, Style
import subprocess
from enum import Enum
import collections
from .audio import VADDetector, AudioRecorder
from .asr import DashScopeASR
from .tts import DashScopeTTSClient, TTSRequest
@@ -21,6 +23,15 @@ from .llm import DashScopeLLM
from .history import ConversationHistory
from .types import LLMMessage
from .camera import CameraClient
from .speaker_verification import SpeakerVerificationClient, SpeakerState
class ConversationState(Enum):
"""会话状态机"""
IDLE = "idle" # 等待用户唤醒或声音
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
REGISTERING = "registering" # 声纹注册中
AUTHORIZED = "authorized" # 已注册用户
class RobotSpeakerNode(Node):
@@ -46,9 +57,33 @@ class RobotSpeakerNode(Node):
self.session_active = False
self.session_start_time = 0.0
self.session_lock = threading.Lock()
# 状态机状态
self.conversation_state = ConversationState.IDLE # 当前会话状态
self.state_lock = threading.Lock() # 保护状态机状态
# 声纹识别共享状态
self.current_speaker_id = None # 当前说话人ID共享状态只读
self.current_speaker_state = SpeakerState.UNKNOWN # 当前说话人状态
self.sv_lock = threading.Lock() # 保护current_speaker_id和current_speaker_state
self.sv_speech_end_event = threading.Event() # 通知声纹线程处理speech_end触发
self.sv_result_ready_event = threading.Event() # 声纹结果ready事件用于CHECK_VOICE状态同步
self.sv_audio_buffer = collections.deque(maxlen=64000) # 声纹验证录音缓冲区(只在需要时录音)
self.sv_recording = False # 是否正在为声纹验证录音
# 声纹注册状态
self.sv_registration_audio_buffer = collections.deque(maxlen=64000) # 注册时录音缓冲区
self.sv_registration_processing = False # 是否正在处理注册(处理中时停止录音)
self.sv_registration_start_time = None # 注册开始时间(用于超时检测)
# 初始化组件VAD、录音器、ASR、LLM、TTS
self._init_components()
# 状态机初始状态
if self.sv_enabled and self.sv_client:
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
self.get_logger().info("声纹数据库为空,将在检测到语音时提示注册")
# ROS订阅
self.interrupt_sub = self.create_subscription(
@@ -112,6 +147,12 @@ class RobotSpeakerNode(Node):
# 闭嘴指令参数
self.declare_parameter('shutup_keywords', 'bi zui,ting zhi,bu yao shuo le,bu yao shuo hua')
# 声纹识别参数
self.declare_parameter('sv_model_path', '')
self.declare_parameter('sv_threshold', 0.45) # 默认值实际从voice.yaml读取
self.declare_parameter('sv_speaker_db_path', '') # 声纹数据库路径
self.declare_parameter('sv_enabled', True) # 默认启用声纹识别可通过voice.yaml配置
def _load_parameters(self):
"""加载ROS参数"""
# 音频参数
@@ -167,6 +208,13 @@ class RobotSpeakerNode(Node):
shutup_keywords_str = self.get_parameter('shutup_keywords').get_parameter_value().string_value
self.shutup_keywords = [k.strip() for k in shutup_keywords_str.split(',') if k.strip()]
# 声纹识别参数
self.sv_model_path = self.get_parameter('sv_model_path').get_parameter_value().string_value
self.sv_threshold = self.get_parameter('sv_threshold').get_parameter_value().double_value
sv_db_path = self.get_parameter('sv_speaker_db_path').get_parameter_value().string_value
self.sv_speaker_db_path = sv_db_path if sv_db_path and sv_db_path.strip() else None
self.sv_enabled = self.get_parameter('sv_enabled').get_parameter_value().bool_value
def _init_components(self):
"""初始化所有组件"""
# VAD检测器
@@ -192,6 +240,9 @@ class RobotSpeakerNode(Node):
on_speech_start=self._on_speech_start,
on_speech_end=self._on_speech_end,
stop_flag=self.stop_event.is_set,
on_audio_chunk=self._on_audio_chunk_for_sv if self.sv_enabled else None, # 声纹录音回调
should_put_to_queue=self._should_put_audio_to_queue, # 检查是否应该将音频放入队列
get_silence_threshold=self._get_silence_threshold, # 动态静音阈值回调
logger=self.get_logger()
)
@@ -252,6 +303,26 @@ class RobotSpeakerNode(Node):
except Exception as e:
self.get_logger().warning(f"相机初始化失败: {e},相机功能将不可用")
self.camera_client = None
# 声纹识别客户端
if self.sv_enabled and self.sv_model_path:
try:
self.sv_client = SpeakerVerificationClient(
model_path=self.sv_model_path,
threshold=self.sv_threshold,
speaker_db_path=self.sv_speaker_db_path,
logger=self.get_logger()
)
if not self.sv_client.is_available():
self.get_logger().warning("声纹识别模型不可用,将禁用声纹功能")
self.sv_client = None
self.sv_enabled = False
except Exception as e:
self.get_logger().warning(f"声纹识别初始化失败: {e},声纹功能将不可用")
self.sv_client = None
self.sv_enabled = False
else:
self.sv_client = None
def _start_threads(self):
"""启动4个线程"""
@@ -286,6 +357,17 @@ class RobotSpeakerNode(Node):
daemon=True
)
self.tts_thread.start()
# 线程5: 声纹识别线程(如果启用)
if self.sv_enabled and self.sv_client:
self.sv_thread = threading.Thread(
target=self._sv_worker,
name="SVThread",
daemon=True
)
self.sv_thread.start()
else:
self.sv_thread = None
def _handle_interrupt_command(self, msg: String):
"""处理ROS中断命令"""
@@ -301,17 +383,250 @@ class RobotSpeakerNode(Node):
self._drain_queue(self.text_queue)
# TTS 队列不清空,由 TTS 线程根据 interrupt_event 自行中断播放
# ==================== 状态机方法 ====================
def _change_state(self, new_state: ConversationState, reason: str = ""):
"""改变状态机状态"""
with self.state_lock:
old_state = self.conversation_state
self.conversation_state = new_state
if reason:
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}: {reason}")
else:
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}")
def _get_state(self) -> ConversationState:
"""获取当前状态"""
with self.state_lock:
return self.conversation_state
def _get_silence_threshold(self) -> int:
"""获取动态静音阈值毫秒REGISTERING状态下使用更短的阈值"""
state = self._get_state()
if state == ConversationState.REGISTERING:
# REGISTERING状态下使用更短的静音阈值500ms加快响应
return 500
else:
# 其他状态使用默认阈值
return self.silence_duration_ms
def _should_put_audio_to_queue(self) -> bool:
"""
检查是否应该将音频放入队列用于ASR
根据状态机决定是否允许ASR
"""
state = self._get_state()
# 以下状态允许ASR
# - IDLE: 等待用户说话允许ASR检测
# - CHECK_VOICE: 检查声纹需要ASR识别
# - AUTHORIZED: 已授权用户,正常处理
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
ConversationState.AUTHORIZED]:
return True
# 以下状态不允许ASR
# - REGISTERING: 正在录音注册不需要ASR
return False
# ==================== 录音线程回调 ====================
def _on_speech_start(self):
"""录音线程检测到人声开始"""
self.get_logger().info("[录音线程] 检测到人声,开始录音")
state = self._get_state()
if state == ConversationState.IDLE:
# Idle -> CheckVoice
if self.sv_enabled and self.sv_client:
# 开始录音用于声纹验证
with self.sv_lock:
self.sv_recording = True
self.sv_audio_buffer.clear()
self.get_logger().debug("[声纹] 开始录音用于声纹验证")
self._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
else:
# 未启用声纹直接进入Authorized状态
self._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
elif state == ConversationState.REGISTERING:
# Registering状态记录开始时间用于超时检测
self.sv_registration_start_time = time.time()
self.get_logger().info("[声纹注册] 检测到人声开始,记录开始时间")
elif state == ConversationState.CHECK_VOICE:
# CheckVoice状态继续录音用于声纹验证
if self.sv_enabled:
with self.sv_lock:
self.sv_recording = True
self.sv_audio_buffer.clear()
self.get_logger().debug("[声纹] 继续录音用于声纹验证")
elif state == ConversationState.AUTHORIZED:
# Authorized状态开始录音用于声纹验证验证当前用户
if self.sv_enabled:
with self.sv_lock:
self.sv_recording = True
self.sv_audio_buffer.clear()
self.get_logger().debug("[声纹] 开始录音用于声纹验证")
def _on_audio_chunk_for_sv(self, audio_chunk: bytes):
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
state = self._get_state()
# 声纹验证录音CHECK_VOICE, AUTHORIZED状态
if self.sv_enabled and self.sv_recording:
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.sv_lock:
self.sv_audio_buffer.extend(audio_array)
except Exception as e:
self.get_logger().debug(f"[声纹] 录音失败: {e}")
# 注册音频录制REGISTERING状态且未开始处理且TTS不在播放
# 重要TTS播放完成后才开始录音确保第一段语音就是声纹信息
if state == ConversationState.REGISTERING and not self.sv_registration_processing:
# 如果TTS正在播放不录音避免录到TTS声音或噪音
if self.tts_playing_event.is_set():
return # TTS播放中不录音
# 超时检测如果录音超过10秒强制处理注册
if self.sv_registration_start_time is not None:
elapsed_time = time.time() - self.sv_registration_start_time
if elapsed_time >= 10.0: # 10秒超时
buffer_size = len(self.sv_registration_audio_buffer)
buffer_sec = buffer_size / self.sample_rate
self.get_logger().warning(f"[声纹注册] 录音超时({elapsed_time:.2f}秒),强制处理注册,音频长度: {buffer_sec:.2f}")
# 强制触发注册处理
self._force_process_registration()
return
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
self.sv_registration_audio_buffer.extend(audio_array)
# 添加调试日志每10000个样本记录一次约0.625秒)
buffer_size = len(self.sv_registration_audio_buffer)
if buffer_size % 10000 == 0:
buffer_sec = buffer_size / self.sample_rate
self.get_logger().info(f"[声纹注册] 录音中,当前音频长度: {buffer_sec:.2f}秒 ({buffer_size} 样本)")
except Exception as e:
self.get_logger().debug(f"[声纹注册] 录音失败: {e}")
def _force_process_registration(self):
"""强制处理注册(用于超时情况)"""
if self.sv_registration_processing:
return # 已经在处理中
self.sv_registration_processing = True
audio_list = list(self.sv_registration_audio_buffer)
buffer_size = len(audio_list)
buffer_sec = buffer_size / self.sample_rate
self.get_logger().info(f"[声纹注册] 强制处理注册,录音缓冲区大小: {buffer_size} 样本({buffer_sec:.2f}秒)")
self.sv_registration_audio_buffer.clear()
self.sv_registration_start_time = None # 清除开始时间
# 需要大于3秒音频
min_samples = int(self.sample_rate * 3.0) # 3秒
if buffer_size >= min_samples:
self.get_logger().info(f"[声纹注册] ✓ 音频长度足够,开始注册,音频长度: {buffer_size} 样本({buffer_sec:.2f}秒)")
# 在新线程中处理注册,避免阻塞
threading.Thread(
target=self._process_sv_registration,
args=(audio_list,),
daemon=True,
name="SVRegistrationThread"
).start()
else:
self.get_logger().warning(f"[声纹注册] ✗ 音频太短: {buffer_size} < {min_samples}{buffer_sec:.2f}秒 < 3秒注册失败")
prompt_text = "音频太短注册失败请说话大于3秒"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册失败回到Idle状态
self.sv_registration_processing = False # 重置处理标志
self._change_state(ConversationState.IDLE, "注册失败")
def _on_speech_end(self):
"""录音线程检测到说话结束(静音一段时间)"""
self.get_logger().info("[录音线程] 检测到说话结束触发ASR识别")
if self.asr_client and self.asr_client.running:
self.asr_client.stop_current_recognition()
self.get_logger().info("[录音线程] 检测到说话结束")
state = self._get_state()
self.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
if state == ConversationState.REGISTERING:
# Registering状态说话结束直接处理注册不依赖ASR
if self.sv_registration_processing:
return # 已经在处理中
self.sv_registration_processing = True
self.get_logger().info("[声纹注册] 说话结束,开始处理注册,停止录音")
audio_list = list(self.sv_registration_audio_buffer)
buffer_size = len(audio_list)
buffer_sec = buffer_size / self.sample_rate
self.get_logger().info(f"[声纹注册] 录音缓冲区大小: {buffer_size} 样本({buffer_sec:.2f}秒)")
self.sv_registration_audio_buffer.clear()
self.sv_registration_start_time = None # 清除开始时间
# 需要大于3秒音频
min_samples = int(self.sample_rate * 3.0) # 3秒
if buffer_size >= min_samples:
self.get_logger().info(f"[声纹注册] ✓ 音频长度足够,开始注册,音频长度: {buffer_size} 样本({buffer_sec:.2f}秒)")
# 在新线程中处理注册,避免阻塞
threading.Thread(
target=self._process_sv_registration,
args=(audio_list,),
daemon=True,
name="SVRegistrationThread"
).start()
else:
self.get_logger().warning(f"[声纹注册] ✗ 音频太短: {buffer_size} < {min_samples}{buffer_sec:.2f}秒 < 3秒注册失败")
prompt_text = "音频太短注册失败请说话大于3秒"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册失败回到Idle状态
self.sv_registration_processing = False # 重置处理标志
self._change_state(ConversationState.IDLE, "注册失败")
return
elif state == ConversationState.CHECK_VOICE:
# CheckVoice状态语音结束触发ASR识别然后检查声纹
if self.asr_client and self.asr_client.running:
self.asr_client.stop_current_recognition()
# 通知声纹线程处理避免重复触发如果ASR已经触发了就不再触发
if self.sv_enabled and self.sv_client:
# 检查数据库是否为空,如果为空则跳过验证
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
self.get_logger().info("[声纹] CHECK_VOICE状态数据库为空跳过声纹验证")
# 直接设置结果,避免等待验证
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.UNKNOWN
self.sv_result_ready_event.set()
# 检查是否已经触发了声纹验证(避免重复触发)
elif not self.sv_speech_end_event.is_set():
with self.sv_lock:
self.sv_recording = False
buffer_size = len(self.sv_audio_buffer)
self.get_logger().info(f"[声纹] VAD触发验证缓冲区大小: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
self.sv_speech_end_event.set()
else:
self.get_logger().debug("[声纹] 声纹验证已由ASR触发跳过VAD触发")
# 状态转换在声纹验证完成后进行在_process_worker中
return
elif state == ConversationState.AUTHORIZED:
# Authorized状态正常触发ASR识别
if self.asr_client and self.asr_client.running:
self.asr_client.stop_current_recognition()
# 通知声纹线程处理(可选,用于验证当前用户)
if self.sv_enabled:
with self.sv_lock:
self.sv_recording = False
buffer_size = len(self.sv_audio_buffer)
self.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
self.sv_speech_end_event.set()
return
# IDLE状态不需要特殊处理
def _on_new_segment(self):
"""录音线程检测到新的人声段立即中断TTS"""
@@ -329,6 +644,38 @@ class RobotSpeakerNode(Node):
return
text_clean = text.strip()
self.get_logger().info(f"[ASR] 识别完成: {text_clean}")
state = self._get_state()
# 规则1REGISTERING状态不允许ASR由VAD的speech_end处理这里直接忽略
if state == ConversationState.REGISTERING:
self.get_logger().warning("[ASR] REGISTERING状态下收到ASR文本忽略应由VAD的speech_end处理")
return
# 规则2CHECK_VOICE状态下如果ASR识别完成但VAD还没有触发speech_end主动触发声纹验证
if state == ConversationState.CHECK_VOICE:
if self.sv_enabled and self.sv_client:
# 检查数据库是否为空,如果为空则跳过验证
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
self.get_logger().info("[ASR] CHECK_VOICE状态数据库为空跳过声纹验证")
# 直接设置结果,避免等待验证
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.UNKNOWN
self.sv_result_ready_event.set()
# 检查是否已经触发了声纹验证通过检查sv_speech_end_event是否已设置
elif not self.sv_speech_end_event.is_set():
# 如果还没有触发,主动触发声纹验证
self.get_logger().info("[ASR] CHECK_VOICE状态ASR识别完成主动触发声纹验证")
with self.sv_lock:
self.sv_recording = False
buffer_size = len(self.sv_audio_buffer)
if buffer_size > 0:
self.get_logger().info(f"[声纹] ASR触发验证缓冲区大小: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
self.sv_speech_end_event.set()
# 其他状态,将文本放入队列
self.text_queue.put(text_clean, timeout=1.0)
def _on_asr_text_update(self, text: str):
@@ -360,7 +707,9 @@ class RobotSpeakerNode(Node):
# ASR线程只检查中断不清除让主线程清除
if self.interrupt_event.is_set():
continue
if self.asr_client and self.asr_client.running:
# 检查是否应该发送音频到ASR由_should_put_audio_to_queue控制
if self._should_put_audio_to_queue() and self.asr_client and self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
def _process_worker(self):
@@ -388,8 +737,109 @@ class RobotSpeakerNode(Node):
if self.use_llm and self.history:
self.history.cancel_turn()
continue
# 检查状态机状态
current_state = self._get_state()
# 根据状态处理文本
if current_state == ConversationState.CHECK_VOICE:
# CheckVoice状态先检查唤醒词如果开启再检查注册指令最后检查声纹验证结果
# 步骤1: 如果开启唤醒词,先检查唤醒词
if self.use_wake_word:
self.get_logger().info(f"[主线程] CHECK_VOICE状态检查唤醒词文本: {text}")
processed_text = self._handle_wake_word(text)
if not processed_text:
# 未检测到唤醒词直接回到Idle状态
self.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{self.wake_word}'回到Idle状态")
self._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue # 不处理当前文本
# 检测到唤醒词,继续检查
self.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
text = processed_text # 使用处理后的文本(移除唤醒词)
# 步骤1.5: 检查是否有注册指令(优先级最高,在声纹验证之前)
if self.sv_enabled and self._check_register_intent(text):
if self.sv_client:
self.get_logger().info("[声纹注册] CHECK_VOICE状态检测到注册指令直接进入注册状态")
self.sv_registration_audio_buffer.clear()
self.sv_registration_processing = False
self.sv_registration_start_time = None # 重置开始时间
self._change_state(ConversationState.REGISTERING, "用户请求声纹注册")
prompt_text = "请说话至少3秒用于声纹注册"
self.tts_queue.put(prompt_text, timeout=0.2)
else:
self.get_logger().warning("[声纹注册] 声纹功能未启用,无法注册")
prompt_text = "声纹功能未启用"
self.tts_queue.put(prompt_text, timeout=0.2)
self._change_state(ConversationState.IDLE, "声纹功能未启用")
continue
# 步骤2: 检查声纹验证结果
if self.sv_enabled and self.sv_client:
# 等待声纹验证结果
# 规则2必须等待声纹结果ready
# 先清除事件,确保等待的是新结果
self.sv_result_ready_event.clear()
self.get_logger().info("[主线程] CHECK_VOICE状态等待声纹验证结果...")
if not self.sv_result_ready_event.wait(timeout=2.0):
# 声纹未完成拒绝本轮回到Idle状态
self.get_logger().warning("[主线程] CHECK_VOICE状态声纹结果未ready超时2秒拒绝本轮")
# 清理声纹缓冲区,避免残留数据影响下一轮
with self.sv_lock:
self.sv_audio_buffer.clear()
self._change_state(ConversationState.IDLE, "声纹结果未ready")
continue
self.get_logger().info("[主线程] CHECK_VOICE状态声纹结果ready继续处理")
self.sv_result_ready_event.clear() # 清除事件,准备下次
with self.sv_lock:
speaker_id = self.current_speaker_id
speaker_state = self.current_speaker_state
if speaker_id and speaker_state == SpeakerState.VERIFIED:
# 声纹匹配成功进入Authorized状态
self.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}")
self._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
# 继续处理文本(正常交互)
else:
# 声纹匹配失败,不是已注册用户
self.get_logger().info("[主线程] 声纹验证失败,不是已注册用户")
# 如果数据库为空,提示注册
if self.sv_client.get_speaker_count() == 0:
prompt_text = "声纹数据库为空,请说'注册声纹'进行注册"
self.tts_queue.put(prompt_text, timeout=0.2)
self._change_state(ConversationState.IDLE, "声纹数据库为空")
else:
prompt_text = "声纹验证失败,不是已注册用户"
self.tts_queue.put(prompt_text, timeout=0.2)
self._change_state(ConversationState.IDLE, "声纹验证失败")
continue # 不处理当前文本
else:
# 未启用声纹直接进入Authorized状态
self._change_state(ConversationState.AUTHORIZED, "未启用声纹")
elif current_state == ConversationState.REGISTERING:
# Registering状态不应该收到文本由VAD的speech_end处理
self.get_logger().warning("[主线程] Registering状态收到文本忽略应由VAD的speech_end处理")
continue
elif current_state == ConversationState.AUTHORIZED:
# Authorized状态正常处理用户请求
pass # 继续处理
elif current_state == ConversationState.IDLE:
# Idle状态不应该收到文本
self.get_logger().warning("[主线程] Idle状态收到文本忽略")
continue
# 步骤2: 唤醒词处理
# 主线程:检查中断,如果中断则清除事件并跳过
if self._check_interrupt(auto_clear=True):
if self.use_llm and self.history:
self.history.cancel_turn()
continue
# 步骤2: 唤醒词处理AUTHORIZED状态下
processed_text = self._handle_wake_word(text)
if not processed_text:
self.get_logger().info("[主线程] 唤醒词过滤后为空,跳过处理")
@@ -412,7 +862,23 @@ class RobotSpeakerNode(Node):
self.history.cancel_turn()
continue
# 步骤2.6: 检测相机指令
# 步骤2.6: 检测声纹注册指令(优先级最高)
if self.sv_enabled and self._check_register_intent(processed_text):
if self.sv_client:
self.get_logger().info("[声纹注册] 检测到注册指令,进入注册状态")
self.sv_registration_audio_buffer.clear()
self.sv_registration_processing = False
self.sv_registration_start_time = None # 重置开始时间
self._change_state(ConversationState.REGISTERING, "用户请求声纹注册")
prompt_text = "请说话至少3秒用于声纹注册"
self.tts_queue.put(prompt_text, timeout=0.2)
else:
self.get_logger().warning("[声纹注册] 声纹功能未启用,无法注册")
prompt_text = "声纹功能未启用"
self.tts_queue.put(prompt_text, timeout=0.2)
continue
# 步骤2.7: 检测相机指令
need_camera, user_text = self._check_camera_command(processed_text)
if need_camera:
self.get_logger().info(f"[相机指令] 检测到拍照指令,将进行多模态推理")
@@ -784,6 +1250,15 @@ class RobotSpeakerNode(Node):
self.get_logger().info("[TTS播放线程] 播放完成")
else:
self.get_logger().info("[TTS播放线程] 播放被中断")
# 在清除播放标志之前,检查是否是注册提示播放完成
state = self._get_state()
if state == ConversationState.REGISTERING and success:
# TTS播放完成清空注册缓冲区准备开始录音
# 这样确保TTS播放完成后检测到的第一段语音就是声纹信息
self.sv_registration_audio_buffer.clear()
self.get_logger().info("[声纹注册] TTS提示播放完成已清空缓冲区准备开始录音")
self.tts_playing_event.clear()
# 播放完成后检查中断,如果被中断则清空队列
@@ -793,6 +1268,168 @@ class RobotSpeakerNode(Node):
# 清除中断事件,准备下一次处理
self.interrupt_event.clear()
def _sv_worker(self):
"""
线程5: 声纹识别线程 - 非实时、低频CAM++
要求不影响录音不影响ASR不控制TTS只更新current_speaker_id
实现方式:
- 通过回调函数 _on_audio_chunk_for_sv 接收音频chunk写入缓冲区 sv_audio_buffer
- 等待 speech_end 事件sv_speech_end_event触发处理
- 从缓冲区取累积的有效人声VAD 后)
- CAM++ 提取 speaker embedding
- 声纹匹配 / 注册
- 更新 current_speaker_id共享状态只写不控
"""
self.get_logger().info("[声纹识别线程] 启动")
min_audio_samples = 8000 # 至少0.5秒音频
while not self.stop_event.is_set():
try:
# 等待speech_end事件一句话只处理一次
if self.sv_speech_end_event.wait(timeout=0.1):
self.sv_speech_end_event.clear() # 清除事件标志
# 从缓冲区获取录音的音频
with self.sv_lock:
audio_list = list(self.sv_audio_buffer)
buffer_size = len(audio_list)
self.sv_audio_buffer.clear() # 清空缓冲区
self.get_logger().info(f"[声纹识别] 收到speech_end事件录音长度: {buffer_size} 样本({buffer_size/self.sample_rate:.2f}秒)")
# 规则2清除ready事件表示正在计算
self.sv_result_ready_event.clear()
# 检查数据库是否为空,如果为空则直接跳过验证
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
self.get_logger().info("[声纹识别] 数据库为空跳过验证直接设置UNKNOWN状态")
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.UNKNOWN
# 设置ready事件通知处理线程
self.sv_result_ready_event.set()
continue # 跳过后续处理
# 检查是否有足够的音频
if buffer_size >= min_audio_samples:
# 转换为numpy数组
audio_array = np.array(audio_list, dtype=np.int16)
# 提取embedding低频调用一句话只调用一次
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=self.sample_rate
)
if not success or embedding is None:
self.get_logger().debug("[声纹识别] 提取embedding失败")
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.ERROR
else:
# 匹配说话人(一句话只调用一次)
speaker_id, match_state = self.sv_client.match_speaker(embedding)
# 更新current_speaker_id和state只写不控
with self.sv_lock:
self.current_speaker_id = speaker_id
self.current_speaker_state = match_state
if match_state == SpeakerState.VERIFIED:
self.get_logger().info(f"[声纹识别] ✓ 识别到说话人: {speaker_id}")
elif match_state == SpeakerState.REJECTED:
self.get_logger().debug("[声纹识别] ✗ 未匹配到已知说话人(相似度不足)")
else:
self.get_logger().debug(f"[声纹识别] ? 状态: {match_state.value}")
else:
self.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
# 即使音频太短也设置ready事件避免处理线程无限等待
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.UNKNOWN
# 规则2设置ready事件通知处理线程声纹结果已准备好
self.sv_result_ready_event.set()
except Exception as e:
self.get_logger().error(f"[声纹识别线程] 错误: {e}")
time.sleep(0.1)
def _process_sv_registration(self, audio_list: list):
"""处理声纹注册(在独立线程中执行)"""
audio_length_sec = 0.0
try:
# 转换为numpy数组
audio_array = np.array(audio_list, dtype=np.int16)
audio_length_sec = len(audio_array) / self.sample_rate
self.get_logger().info(f"[声纹注册] 开始处理注册,音频长度: {len(audio_array)} 样本({audio_length_sec:.2f}秒)")
# 提取embedding
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=self.sample_rate
)
if not success or embedding is None:
self.get_logger().error("[声纹注册] ✗ 提取embedding失败")
prompt_text = "声纹注册失败,无法提取特征"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册失败转回Idle状态
self._change_state(ConversationState.IDLE, "注册失败")
self.sv_registration_processing = False
return
self.get_logger().info(f"[声纹注册] ✓ 成功提取embedding维度: {len(embedding) if embedding is not None else 0}")
# 自动生成speaker_id使用时间戳
speaker_id = f"user_{int(time.time())}"
# 获取注册前的说话人数量
speaker_count_before = self.sv_client.get_speaker_count()
# 注册说话人
success = self.sv_client.register_speaker(speaker_id, embedding)
if success:
# 获取注册后的说话人数量
speaker_count_after = self.sv_client.get_speaker_count()
# 输出详细的注册成功日志
self.get_logger().info("=" * 60)
self.get_logger().info("[声纹注册] ✓✓✓ 注册成功 ✓✓✓")
self.get_logger().info(f"[声纹注册] 说话人ID: {speaker_id}")
self.get_logger().info(f"[声纹注册] 音频长度: {audio_length_sec:.2f}秒 ({len(audio_array)} 样本)")
self.get_logger().info(f"[声纹注册] 注册前数据库说话人数量: {speaker_count_before}")
self.get_logger().info(f"[声纹注册] 注册后数据库说话人数量: {speaker_count_after}")
self.get_logger().info(f"[声纹注册] 注册时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
self.get_logger().info("=" * 60)
prompt_text = "声纹注册完成"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册成功转回Idle状态等待下次交互
self._change_state(ConversationState.IDLE, "注册成功")
else:
self.get_logger().error("[声纹注册] ✗ 注册失败register_speaker返回False")
prompt_text = "声纹注册失败"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册失败转回Idle状态
self._change_state(ConversationState.IDLE, "注册失败")
except Exception as e:
self.get_logger().error(f"[声纹注册] ✗ 处理失败,异常: {e}")
self.get_logger().error(f"[声纹注册] 异常类型: {type(e).__name__}")
import traceback
self.get_logger().error(f"[声纹注册] 异常堆栈:\n{traceback.format_exc()}")
prompt_text = "声纹注册过程出错"
self.tts_queue.put(prompt_text, timeout=0.2)
# 注册出错转回Idle状态
self._change_state(ConversationState.IDLE, "注册出错")
finally:
self.sv_registration_processing = False
# ==================== 工具函数 ====================
def _check_interrupt(self, auto_clear: bool = False) -> bool:
@@ -832,6 +1469,17 @@ class RobotSpeakerNode(Node):
py_list = pinyin(''.join(chars), style=Style.NORMAL)
return ' '.join([item[0] for item in py_list])
def _check_register_intent(self, text: str) -> bool:
"""检查用户是否有注册意图(优先级最高)"""
text_lower = text.lower().strip()
register_keywords = ["注册声纹", "开启声纹注册", "录入声纹", "声纹注册", "注册",
"录入我的声纹", "录入声纹信息", "开始录入声纹", "开始声纹验证"]
# 检查是否包含注册关键词(避免误判,要求关键词完整)
for keyword in register_keywords:
if keyword in text_lower:
return True
return False
def _start_session(self):
"""开始会话"""
with self.session_lock:
@@ -909,8 +1557,11 @@ class RobotSpeakerNode(Node):
self.stop_event.set()
self.interrupt_event.set()
for thread in [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]:
if thread.is_alive():
threads_to_join = [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]
if self.sv_thread:
threads_to_join.append(self.sv_thread)
for thread in threads_to_join:
if thread and thread.is_alive():
thread.join(timeout=2.0)
if hasattr(self, 'asr_client') and self.asr_client:
@@ -923,6 +1574,14 @@ class RobotSpeakerNode(Node):
if hasattr(self, 'camera_client') and self.camera_client:
self.camera_client.cleanup()
# 清理声纹识别资源
if hasattr(self, 'sv_client') and self.sv_client:
try:
self.sv_client.save_speakers()
self.sv_client.cleanup()
except Exception as e:
self.get_logger().warning(f"清理声纹识别资源时出错: {e}")
super().destroy_node()

View File

@@ -0,0 +1,386 @@
"""
声纹识别模块
"""
import numpy as np
import threading
import tempfile
import os
import wave
import time
import json
from enum import Enum
class SpeakerState(Enum):
"""说话人识别状态"""
UNKNOWN = "unknown"
VERIFIED = "verified"
REJECTED = "rejected"
ERROR = "error"
class SpeakerVerificationClient:
"""声纹识别客户端 - 非实时、低频处理"""
def __init__(self, model_path: str, threshold: float, speaker_db_path: str = None, logger=None):
self.model_path = model_path
self.threshold = threshold
self.speaker_db_path = speaker_db_path
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
self._lock = threading.Lock()
# 固定embedding维度避免第一次异常embedding导致后续全部被拒绝
# 常见维度192 (CAM++), 256 等
self._expected_embedding_dim = None # 只存储维度大小不存储shape元组
from funasr import AutoModel
self.model = AutoModel(model=self.model_path, device="cpu")
if self.logger:
self.logger.info(f"声纹模型已加载: {self.model_path}, 阈值: {self.threshold}")
if self.speaker_db_path:
self.load_speakers()
def _log(self, level: str, msg: str):
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
# 使用映射字典避免动态获取方法导致ROS2 logger错误
log_methods = {
"debug": self.logger.debug,
"info": self.logger.info,
"warning": self.logger.warning,
"error": self.logger.error,
"fatal": self.logger.fatal
}
log_method = log_methods.get(level.lower(), self.logger.info)
log_method(msg)
except ValueError as e:
# ROS2 logger在多线程环境中可能出现"Logger severity cannot be changed between calls"错误
# 如果出现此错误降级使用info级别
if "severity cannot be changed" in str(e):
try:
self.logger.info(f"[声纹-{level.upper()}] {msg}")
except:
print(f"[声纹-{level.upper()}] {msg}")
else:
raise
else:
print(f"[声纹] {msg}")
def _write_temp_wav(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""将numpy音频数组写入临时wav文件"""
audio_int16 = audio_data.astype(np.int16)
fd, temp_path = tempfile.mkstemp(suffix='.wav', prefix='sv_')
os.close(fd)
with wave.open(temp_path, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
return temp_path
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""
提取说话人embedding低频调用一句话只调用一次
返回: (embedding: np.ndarray | None, success: bool)
- 成功返回 (embedding, True)
- 失败返回 (None, False)
"""
if len(audio_data) < int(sample_rate * 0.5):
return None, False
temp_wav_path = None
try:
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
# funasr返回格式list[dict]dict键为'spk_embedding'值为torch.Tensor(shape=[1, 192])
import torch
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
# 校验embedding维度
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, False
# 如果已有注册的声纹,校验维度是否一致
if self._expected_embedding_dim is not None:
if embedding_dim != self._expected_embedding_dim:
self._log("error", f"embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
return None, False
else:
# 第一次成功提取固定维度只在第一次成功时设置避免异常embedding污染
self._expected_embedding_dim = embedding_dim
self._log("info", f"固定embedding维度: {embedding_dim}")
return embedding, True
except Exception as e:
self._log("error", f"提取embedding失败: {e}")
return None, False
finally:
if temp_wav_path and os.path.exists(temp_wav_path):
try:
os.unlink(temp_wav_path)
except:
pass
def register_speaker(self, speaker_id: str, embedding: np.ndarray,
env: str = "near", threshold: float = None) -> bool:
"""
注册说话人
关键修复注册时对embedding进行归一化一劳永逸
"""
embedding_dim = len(embedding)
if embedding_dim == 0:
return False
# 校验维度一致性
if self._expected_embedding_dim is not None:
if embedding_dim != self._expected_embedding_dim:
self._log("error", f"注册失败embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
return False
else:
# 如果还没有固定维度使用当前embedding的维度
self._expected_embedding_dim = embedding_dim
# 关键修复注册时归一化embedding
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
self._log("error", f"注册失败embedding范数为0")
return False
embedding_normalized = embedding / embedding_norm
speaker_threshold = threshold if threshold is not None else self.threshold
with self._lock:
# 使用dict结构存储便于序列化
self.speaker_db[speaker_id] = {
"embedding": embedding_normalized, # 已归一化
"env": env,
"threshold": speaker_threshold,
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
# 在锁外调用save_speakers避免死锁save_speakers内部会获取锁
save_result = self.save_speakers()
if not save_result:
# 使用info级别而不是warning避免ROS2 logger在多线程环境中的问题
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
return True
def match_speaker(self, embedding: np.ndarray):
"""
匹配说话人(一句话只调用一次)
返回: (speaker_id: str | None, state: SpeakerState)
"""
if not self.speaker_db:
return None, SpeakerState.UNKNOWN
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, SpeakerState.ERROR
# 校验维度一致性
if self._expected_embedding_dim is not None and embedding_dim != self._expected_embedding_dim:
return None, SpeakerState.ERROR
# 归一化当前embedding注册时已归一化这里只需要归一化当前embedding
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
return None, SpeakerState.ERROR
embedding_normalized = embedding / embedding_norm
best_match = None
best_score = -1.0
best_threshold = self.threshold
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
# 从dict结构读取
ref_embedding = speaker_data["embedding"] # 注册时已归一化
if len(ref_embedding) != embedding_dim:
continue
# 直接计算cosine similarity两个向量都已归一化
score = np.dot(embedding_normalized, ref_embedding)
if score > best_score:
best_score = score
best_match = speaker_id
best_threshold = speaker_data.get("threshold", self.threshold)
return (best_match, SpeakerState.VERIFIED) if best_score >= best_threshold else (None, SpeakerState.REJECTED)
def is_available(self) -> bool:
return self.model is not None
def cleanup(self):
"""清理资源"""
pass
def get_speaker_count(self) -> int:
with self._lock:
return len(self.speaker_db)
def remove_speaker(self, speaker_id: str) -> bool:
with self._lock:
if speaker_id not in self.speaker_db:
return False
del self.speaker_db[speaker_id]
self.save_speakers()
return True
def load_speakers(self) -> bool:
"""
从文件加载已注册的声纹
使用JSON格式存储更安全、可迁移
格式: {speaker_id: {"embedding": [list], "env": str, "threshold": float, "registered_at": float}}
"""
if not self.speaker_db_path:
return False
if not os.path.exists(self.speaker_db_path):
self._log("info", f"声纹数据库文件不存在: {self.speaker_db_path},将创建新数据库")
return False
try:
with open(self.speaker_db_path, 'r', encoding='utf-8') as f:
data = json.load(f)
with self._lock:
# 转换list回numpy array并校验维度
for speaker_id, speaker_data in data.items():
embedding_list = speaker_data["embedding"]
embedding_array = np.array(embedding_list, dtype=np.float32)
# 校验维度
embedding_dim = len(embedding_array)
if embedding_dim == 0:
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
continue
# 设置期望维度(如果还没有)
if self._expected_embedding_dim is None:
self._expected_embedding_dim = embedding_dim
elif embedding_dim != self._expected_embedding_dim:
self._log("warning", f"跳过维度不匹配的声纹: {speaker_id} (期望{self._expected_embedding_dim}, 实际{embedding_dim})")
continue
# 确保embedding已归一化兼容旧数据
embedding_norm = np.linalg.norm(embedding_array)
if embedding_norm > 0:
embedding_array = embedding_array / embedding_norm
self.speaker_db[speaker_id] = {
"embedding": embedding_array,
"env": speaker_data.get("env", "near"),
"threshold": speaker_data.get("threshold", self.threshold),
"registered_at": speaker_data.get("registered_at", time.time())
}
count = len(self.speaker_db)
if self._expected_embedding_dim:
self._log("info", f"已加载 {count} 个已注册说话人embedding维度: {self._expected_embedding_dim}")
else:
self._log("info", f"已加载 {count} 个已注册说话人")
return True
except json.JSONDecodeError as e:
# 尝试兼容旧的pickle格式
try:
import pickle
with open(self.speaker_db_path, 'rb') as f:
old_data = pickle.load(f)
self._log("warning", "检测到旧的pickle格式正在迁移...")
# 迁移逻辑:转换为新格式
with self._lock:
for speaker_id, speaker_info in old_data.items():
if hasattr(speaker_info, 'embedding'):
# 旧格式SpeakerInfo对象
embedding = speaker_info.embedding
embedding_norm = np.linalg.norm(embedding)
if embedding_norm > 0:
embedding = embedding / embedding_norm
self.speaker_db[speaker_id] = {
"embedding": embedding,
"env": getattr(speaker_info, 'env', 'near'),
"threshold": getattr(speaker_info, 'threshold', self.threshold),
"registered_at": getattr(speaker_info, 'registered_at', time.time())
}
else:
# 可能是dict格式
embedding = speaker_info.get("embedding")
if embedding is not None:
embedding_norm = np.linalg.norm(embedding)
if embedding_norm > 0:
embedding = embedding / embedding_norm
self.speaker_db[speaker_id] = {
"embedding": embedding,
"env": speaker_info.get("env", "near"),
"threshold": speaker_info.get("threshold", self.threshold),
"registered_at": speaker_info.get("registered_at", time.time())
}
# 保存为新格式
self.save_speakers()
self._log("info", "已迁移到新格式")
except Exception as e2:
self._log("error", f"加载声纹数据库失败JSON和pickle都失败: {e}, {e2}")
return False
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
return False
def save_speakers(self) -> bool:
"""
保存已注册的声纹到文件
使用JSON格式更安全、可迁移
"""
if not self.speaker_db_path:
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
return False
try:
db_dir = os.path.dirname(self.speaker_db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
# 转换为JSON可序列化的格式
json_data = {}
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
json_data[speaker_id] = {
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
"env": speaker_data.get("env", "near"),
"threshold": speaker_data.get("threshold", self.threshold),
"registered_at": speaker_data.get("registered_at", time.time())
}
# 使用临时文件 + 原子替换,避免写入过程中崩溃导致数据丢失
temp_path = self.speaker_db_path + ".tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
# 原子替换
os.replace(temp_path, self.speaker_db_path)
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
return True
except Exception as e:
import traceback
self._log("error", f"保存声纹数据库失败: {e}")
self._log("error", f"保存路径: {self.speaker_db_path}")
self._log("error", f"错误详情: {traceback.format_exc()}")
# 清理临时文件
temp_path = self.speaker_db_path + ".tmp"
if os.path.exists(temp_path):
try:
os.unlink(temp_path)
except:
pass
return False