代码重构,区分声纹注册和主节点

This commit is contained in:
lxy
2026-01-16 10:40:40 +08:00
parent eb91e2f139
commit 0c118412ec
33 changed files with 2417 additions and 1788 deletions

View File

@@ -19,6 +19,20 @@ pip3 install -r requirements.txt --break-system-packages
```
## 编译启动
1. 注册声纹
- 启动节点后可以说:二狗今天天气真好开始注册声纹
- 注意要包含唤醒词语句不要停顿尽量大于3秒
```bash
cd ~/ros_learn/hivecore_robot_voice
colcon build
source install/setup.bash
ros2 run robot_speaker register_speaker_node
```
2. 主节点
- 启动节点后每句交互包含唤醒词,唤醒词和语句之间不要有停顿
- 二狗拍照看看开启图文交互
- 支持已注册声纹用户打断
```bash
cd ~/ros_learn/hivecore_robot_voice
colcon build
@@ -78,3 +92,4 @@ rs-enumerate-devices -c
```bash
modelscope download --model iic/speech_campplus_sv_zh-cn_16k-common --local_dir [指定路径]
```

18
config/knowledge.json Normal file
View File

@@ -0,0 +1,18 @@
{
"entries": [
{
"id": "robot_identity",
"patterns": [
"ni shi shei"
],
"answer": "我叫二狗,是蜂核科技的机器人,很高兴为你服务"
},
{
"id": "wake_word",
"patterns": [
"ni de ming zi"
],
"answer": "我的名字是二狗"
}
]
}

View File

@@ -197,5 +197,403 @@
"env": "near",
"threshold": 0.4,
"registered_at": 1768311644.5742264
},
"user_1768529827": {
"embedding": [
0.0077949948608875275,
-0.012852567248046398,
0.0014490776229649782,
0.088177390396595,
-0.052150458097457886,
-0.1070166826248169,
-0.051932964473962784,
0.040730226784944534,
0.09491471946239471,
-0.10504328459501266,
-0.17986123263835907,
0.06056514009833336,
0.0002809118013828993,
-0.05353177338838577,
-0.08724740147590637,
-0.01057526096701622,
-0.10766296088695526,
0.024376090615987778,
-0.11535818874835968,
0.12653452157974243,
-0.0063497889786958694,
-0.02372283861041069,
-0.049704890698194504,
0.01079346239566803,
-0.10683158040046692,
0.00932641327381134,
0.043871842324733734,
0.04073511064052582,
0.005968529265373945,
0.05397576093673706,
0.07122175395488739,
0.06804963946342468,
-0.058389563113451004,
-0.03463176265358925,
-0.06834574788808823,
-0.09127284586429596,
-0.09805246442556381,
-0.015370666980743408,
-0.07054834067821503,
-0.07520422339439392,
-0.0502505861222744,
0.01580144092440605,
0.04316972196102142,
-0.010298517532646656,
-0.09042523056268692,
-0.03399325907230377,
0.03738871216773987,
0.09461583197116852,
0.07643604278564453,
-0.04089711233973503,
0.14397914707660675,
-0.03218085318803787,
-0.03981873393058777,
-0.05353623256087303,
-0.06475386023521423,
0.047925639897584915,
0.008481102995574474,
0.09522885829210281,
0.05679373815655708,
0.021448519080877304,
0.04586802423000336,
0.007880095392465591,
-0.08111433684825897,
-0.030093876644968987,
0.18197935819625854,
0.049670975655317307,
-0.029350068420171738,
0.1003178134560585,
0.05890532210469246,
-0.0418926365673542,
-0.015124992467463017,
-0.0016869385726749897,
0.029022999107837677,
0.10370466858148575,
-0.07392475008964539,
-0.041242245584726334,
0.0948185846209526,
0.0766805037856102,
0.12104924768209457,
0.07941737771034241,
-0.024586958810687065,
-0.005290709435939789,
0.08198735862970352,
-0.15709130465984344,
0.11847008019685745,
0.01280289888381958,
0.09401026368141174,
0.10199982672929764,
0.00811630580574274,
0.09336159378290176,
-0.1219155564904213,
0.00885648000985384,
0.08536995947360992,
-0.031735390424728394,
-0.02445235848426819,
0.17981232702732086,
0.05046188458800316,
-0.012413986958563328,
-0.16514025628566742,
-0.09369593858718872,
0.03961285203695297,
-0.024150250479578972,
0.024869512766599655,
0.009099201299250126,
0.0023227918427437544,
0.005291149020195007,
-0.08285452425479889,
0.02174258604645729,
-0.00018321558309253305,
-0.01761690340936184,
-0.13327360153198242,
0.07804469764232635,
-0.03172646835446358,
0.05993621423840523,
-0.0034280805848538876,
0.09203101694583893,
0.04720155894756317,
-0.12012632191181183,
-0.028879230841994286,
-0.04471825063228607,
-0.08928379416465759,
-0.055793069303035736,
-0.0230169165879488,
0.04459748789668083,
-0.08481008559465408,
0.09873232245445251,
-0.057500336319208145,
-0.05438977852463722,
0.06309207528829575,
-0.045493170619010925,
-0.0636027380824089,
-0.03580763190984726,
-0.043026816099882126,
0.04125182330608368,
-0.06327074766159058,
0.02830875851213932,
-0.0697140172123909,
-0.11324217170476913,
-0.02744743973016739,
-0.09659717977046967,
-0.036915868520736694,
0.06836548447608948,
-0.19481360912322998,
-0.08151774108409882,
0.013570327311754227,
-0.013908851891756058,
-0.02302597463130951,
-0.14017312228679657,
-0.0654999315738678,
0.0582318976521492,
-0.023702487349510193,
-0.046911414712667465,
-0.02062028832733631,
0.09885907918214798,
-0.010111358016729355,
-0.009303858503699303,
-0.07802718877792358,
0.09181840717792511,
-0.00822418462485075,
-0.024477459490299225,
0.04909557104110718,
0.024657243862748146,
0.08074013143777847,
0.10684694349765778,
-0.009657780639827251,
0.04053448513150215,
-0.054968591779470444,
0.09773849695920944,
-0.019937219098210335,
-0.11860335618257523,
-0.12553851306438446,
0.0016870739636942744,
0.07446407526731491,
-0.12183381617069244,
-0.07524612545967102,
0.06794209778308868,
-0.04324038699269295,
-0.018201345577836037,
-0.08356837183237076,
0.08218713104724884,
-0.1253940612077713,
-0.05880133807659149,
0.11516888439655304,
-0.007864559069275856,
0.06438153237104416,
-0.06551646441221237,
0.11812424659729004,
-0.07544125616550446,
0.033888354897499084,
0.02552076056599617,
0.019394448027014732,
-0.009682931937277317
],
"env": "near",
"threshold": 0.55,
"registered_at": 1768529827.4784193
},
"user_1768530001": {
"embedding": [
-0.02827363647520542,
0.04181317239999771,
-0.07721243053674698,
0.031220311298966408,
-0.006549456622451544,
-0.045262161642313004,
-0.06796529144048691,
0.10546170920133591,
-0.054266564548015594,
-0.04982651397585869,
0.008982052095234394,
0.0887555256485939,
-0.03736695274710655,
-0.027568811550736427,
-0.01881324127316475,
-0.030173255130648613,
-0.03817622363567352,
-0.027703644707798958,
-0.020354237407445908,
0.08958664536476135,
0.027346525341272354,
-0.007979321293532848,
-0.01638970896601677,
0.14815205335617065,
-0.029478076845407486,
0.0968138799071312,
0.011266525834798813,
0.10481037944555283,
0.006314543075859547,
-0.07480890303850174,
-0.126618891954422,
0.054260920733213425,
-0.054261378943920135,
0.02066616155207157,
0.056972429156303406,
-0.02620418183505535,
-0.08435375243425369,
-0.06768523901700974,
-0.001804384752176702,
-0.03350691497325897,
-0.06783927977085114,
0.09583555907011032,
0.042077258229255676,
-0.03811662644147873,
-0.09298640489578247,
0.11314687132835388,
0.06972789764404297,
-0.10421980172395706,
0.02739877998828888,
-0.06242597475647926,
0.06683704257011414,
0.030034003779292107,
-0.04094783961772919,
0.08657337725162506,
0.02882716991007328,
0.07672230899333954,
-0.0162385031580925,
0.12335177510976791,
-0.07505486160516739,
0.05924128741025925,
0.02278822474181652,
0.051575034856796265,
-0.07616295665502548,
-0.049982234835624695,
-0.021159915253520012,
0.023469945415854454,
-0.008445728570222855,
0.18868982791900635,
0.10217619687318802,
0.0029947187285870314,
0.003596147522330284,
-0.010885344818234444,
0.002336243400350213,
-0.06228164955973625,
-0.09452632069587708,
0.06288570165634155,
0.09799493104219437,
0.05772380530834198,
-0.012649190612137318,
0.037833958864212036,
-0.07815677672624588,
0.11595622450113297,
-0.006132716778665781,
-0.047689273953437805,
0.10451581329107285,
0.12618094682693481,
-0.012135603465139866,
-0.14452683925628662,
-0.011882219463586807,
0.05687599256634712,
-0.10221579670906067,
0.09555421024560928,
0.050166770815849304,
0.026791365817189217,
0.0343380831182003,
0.0643647089600563,
-0.09814899414777756,
-0.01735001988708973,
0.0002968672488350421,
-0.16691210865974426,
-0.044747937470674515,
0.10229559987783432,
0.01551489345729351,
0.0614253506064415,
-0.012457458302378654,
-0.059297215193510056,
-0.0662546306848526,
0.06900843977928162,
-0.15012530982494354,
0.14357514679431915,
-0.08563537150621414,
0.1512402445077896,
-0.05548126623034477,
-0.13191379606723785,
0.02588576264679432,
-0.007292638067156076,
-0.033004030585289,
-0.08764250576496124,
-0.04006534814834595,
0.001069005811586976,
0.0708790197968483,
-0.11471016705036163,
-0.08249906450510025,
-0.07923658937215805,
-0.029890256002545357,
0.027568599209189415,
-0.00042784016113728285,
0.01911524124443531,
0.002947323489934206,
-0.058468904346227646,
0.0006662740488536656,
-0.09472604095935822,
-0.07827164232730865,
0.05823435261845589,
-0.022661248221993446,
0.007729553151875734,
0.044511985033750534,
-0.17424426972866058,
-0.054321326315402985,
-0.010871038772165775,
-0.04280569776892662,
0.01373684499412775,
-0.03464324399828911,
0.0012510031228885055,
-0.13786448538303375,
0.13943427801132202,
0.07161138951778412,
-0.0017689999658614397,
-0.0330035537481308,
0.01767006888985634,
-0.06832484155893326,
-0.16906532645225525,
-0.08673631399869919,
0.016205811873078346,
-0.040736377239227295,
-0.053034041076898575,
-0.057571377605199814,
-0.018383856862783432,
0.029812879860401154,
-0.005708644632250071,
0.07977750152349472,
0.03715944290161133,
0.029830463230609894,
-0.15909501910209656,
0.10081987082958221,
0.07019384205341339,
0.05683498457074165,
0.008955223485827446,
-0.06697771698236465,
0.044268134981393814,
0.08812808990478516,
-0.17523430287837982,
0.05148027464747429,
-0.11579684168100357,
-0.06281758099794388,
-0.08106749504804611,
-0.07915353775024414,
0.03760797902941704,
-0.059639666229486465,
0.012170189991593361,
-0.028386766090989113,
-0.043592486530542374,
0.029122747480869293,
0.052276406437158585,
0.06929390132427216,
-0.10774848610162735,
0.06797030568122864,
-0.017512541264295578,
0.07446594536304474,
-0.07573172450065613,
-0.15186654031276703,
-0.03710319101810455
],
"env": "near",
"threshold": 0.55,
"registered_at": 1768530001.2158406
}
}

View File

@@ -11,6 +11,7 @@ dashscope:
temperature: 0.7
max_tokens: 4096
max_history: 10
summary_trigger: 3
tts:
model: "cosyvoice-v3-flash"
voice: "longanyang"
@@ -25,6 +26,8 @@ audio:
soundcard:
card_index: 1 # USB Audio Device (card 1)
device_index: 0 # USB Audio [USB Audio] (device 0)
# card_index: -1 # 使用默认声卡
# device_index: -1 # 使用默认输出设备
sample_rate: 44100 # 输出采样率44.1kHz支持48000或44100
channels: 2 # 输出声道数立体声2声道FL+FR
volume: 1.0 # 音量比例0.0-1.00.2表示20%音量)
@@ -64,5 +67,3 @@ camera:
image:
jpeg_quality: 85 # JPEG压缩质量0-10085是质量和大小平衡点
max_size: "1280x720" # 最大尺寸
commands:
capture_keywords: "pai zhao,pai ge zhao,pai zhang zhao pian,pai zhang,da kai xiang ji,kan zhe li,zhao xiang" # 拍照相关指令(拼音,逗号分隔)

View File

@@ -1,16 +1,16 @@
dashscope>=1.20.0
openai>=1.0.0
pyaudio>=0.2.11
webrtcvad>=2.0.10 # WebRTC VAD语音活动检测不包含回声消除
webrtcvad>=2.0.10
pypinyin>=0.49.0
rclpy>=3.0.0
pyrealsense2>=2.54.0
Pillow>=10.0.0
numpy>=1.24.0
# 回声消除库(可选):
# aec-audio-processing - 专门用于回声消除的WebRTC库API简单推荐
# pip install aec-audio-processing
# 如果未安装,将使用内置的简单自适应算法
PyYAML>=6.0
aec-audio-processing
modelscope>=1.33.0
datasets>=3.6.0

View File

@@ -0,0 +1,4 @@
"""核心模块"""

View File

@@ -0,0 +1,10 @@
from enum import Enum
class ConversationState(Enum):
"""会话状态机"""
IDLE = "idle" # 等待用户唤醒或声音
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
AUTHORIZED = "authorized" # 已注册用户

View File

@@ -0,0 +1,127 @@
from dataclasses import dataclass
from typing import Optional
from pypinyin import pinyin, Style
@dataclass
class IntentResult:
intent: str # "skill_sequence" | "kb_qa" | "chat_text" | "chat_camera"
text: str
need_camera: bool
camera_mode: Optional[str] # "head" | "left_hand" | "right_hand" | None
system_prompt: Optional[str]
class IntentRouter:
def __init__(self):
self.camera_capture_keywords = [
"pai zhao", "pai ge zhao", "pai zhang zhao"
]
self.skill_keywords = [
"ban xiang zi"
]
self.kb_keywords = [
"ni shi shei", "ni de ming zi"
]
def to_pinyin(self, text: str) -> str:
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
if not chars:
return ""
py_list = pinyin(''.join(chars), style=Style.NORMAL)
return ' '.join([item[0] for item in py_list]).lower().strip()
def is_skill_sequence_intent(self, text: str) -> bool:
text_pinyin = self.to_pinyin(text)
return any(k in text_pinyin for k in self.skill_keywords)
def check_camera_command(self, text: str) -> tuple[bool, Optional[str]]:
if not text:
return False, None
text_pinyin = self.to_pinyin(text)
for keyword in self.camera_capture_keywords:
if keyword in text_pinyin:
return True, self.detect_camera_mode(text)
return False, None
def detect_camera_mode(self, text: str) -> str:
text_pinyin = self.to_pinyin(text)
left_keys = ["zuo shou", "zuo bi", "zuo bian"]
right_keys = ["you shou", "you bi", "you bian"]
head_keys = ["tou", "nao dai"]
for kw in left_keys:
if kw in text_pinyin:
return "left_hand"
for kw in right_keys:
if kw in text_pinyin:
return "right_hand"
for kw in head_keys:
if kw in text_pinyin:
return "head"
return "head"
def build_skill_prompt(self) -> str:
return (
"你是机器人任务规划器。\n"
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照(默认头部相机),并结合当前环境信息生成简洁、可执行的技能序列。"
)
def build_chat_prompt(self, need_camera: bool) -> str:
if need_camera:
return (
"你是一个智能语音助手。\n"
"请结合图片内容简短回答。"
)
return (
"你是一个智能语音助手。\n"
"请自然、简短地与用户对话。"
)
def build_kb_prompt(self) -> str:
return (
"你是蜂核科技的员工。\n"
"请基于知识库信息回答用户问题,回答要准确简洁。"
)
def build_default_system_prompt(self) -> str:
return (
"你是一个智能语音助手。\n"
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
"请根据对话模式调整你的回答风格。"
)
def route(self, text: str) -> IntentResult:
need_camera, camera_mode = self.check_camera_command(text)
text_pinyin = self.to_pinyin(text)
if self.is_skill_sequence_intent(text):
if camera_mode is None:
camera_mode = "head"
return IntentResult(
intent="skill_sequence",
text=text,
need_camera=True,
camera_mode=camera_mode,
system_prompt=self.build_skill_prompt()
)
if any(k in text_pinyin for k in self.kb_keywords):
return IntentResult(
intent="kb_qa",
text=text,
need_camera=False,
camera_mode=None,
system_prompt=self.build_kb_prompt()
)
return IntentResult(
intent="chat_camera" if need_camera else "chat_text",
text=text,
need_camera=need_camera,
camera_mode=camera_mode,
system_prompt=self.build_chat_prompt(need_camera)
)

View File

@@ -0,0 +1,246 @@
import threading
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeCallbacks:
# ==================== 初始化与内部工具 ====================
def __init__(self, node):
self.node = node
def _mark_utterance_processed(self) -> bool:
node = self.node
with node.utterance_lock:
if node.current_utterance_id == node.last_processed_utterance_id:
return False
node.last_processed_utterance_id = node.current_utterance_id
return True
def _trigger_sv_for_check_voice(self, source: str):
node = self.node
if not (node.sv_enabled and node.sv_client):
return
if not self._mark_utterance_processed():
return
if node._handle_empty_speaker_db():
node.get_logger().info(f"[声纹] CHECK_VOICE状态数据库为空跳过声纹验证来源: {source}")
return
if not node.sv_speech_end_event.is_set():
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if buffer_size > 0:
node.sv_speech_end_event.set()
else:
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source}")
# ==================== 业务逻辑代理 ====================
def handle_interrupt_command(self, msg):
return self.node._handle_interrupt_command(msg)
def check_interrupt_and_cancel_turn(self) -> bool:
return self.node._check_interrupt_and_cancel_turn()
def handle_wake_word(self, text: str) -> str:
return self.node._handle_wake_word(text)
def check_shutup_command(self, text: str) -> bool:
return self.node._check_shutup_command(text)
def check_camera_command(self, text: str):
return self.node.intent_router.check_camera_command(text)
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
return self.node._llm_process_stream_with_camera(user_text, need_camera)
def put_tts_text(self, text: str):
return self.node._put_tts_text(text)
def force_stop_tts(self):
return self.node._force_stop_tts()
def drain_queue(self, q):
return self.node._drain_queue(q)
# ==================== 录音/VAD回调 ====================
def get_silence_threshold(self) -> int:
"""获取动态静音阈值(毫秒)"""
node = self.node
return node.silence_duration_ms
def should_put_audio_to_queue(self) -> bool:
"""
检查是否应该将音频放入队列用于ASR,根据状态机决定是否允许ASR
"""
node = self.node
state = node._get_state()
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
ConversationState.AUTHORIZED]:
return True
return False
def on_speech_start(self):
"""录音线程检测到人声开始"""
node = self.node
node.get_logger().info("[录音线程] 检测到人声,开始录音")
with node.utterance_lock:
node.current_utterance_id += 1
state = node._get_state()
if state == ConversationState.IDLE:
# Idle -> CheckVoice
if node.sv_enabled and node.sv_client:
# 开始录音用于声纹验证
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
elif state == ConversationState.CHECK_VOICE:
# CheckVoice状态继续录音用于声纹验证
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
elif state == ConversationState.AUTHORIZED:
# Authorized状态开始录音用于声纹验证验证当前用户
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = True
node.sv_audio_buffer.clear()
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
node = self.node
state = node._get_state()
# 声纹验证录音CHECK_VOICE, AUTHORIZED状态
if node.sv_enabled and node.sv_recording:
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with node.sv_lock:
node.sv_audio_buffer.extend(audio_array)
except Exception as e:
node.get_logger().debug(f"[声纹] 录音失败: {e}")
def on_speech_end(self):
"""录音线程检测到说话结束(静音一段时间)"""
node = self.node
node.get_logger().info("[录音线程] 检测到说话结束")
state = node._get_state()
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
if state == ConversationState.CHECK_VOICE:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
self._trigger_sv_for_check_voice("VAD")
return
elif state == ConversationState.AUTHORIZED:
if node.asr_client and node.asr_client.running:
node.asr_client.stop_current_recognition()
if node.sv_enabled:
with node.sv_lock:
node.sv_recording = False
buffer_size = len(node.sv_audio_buffer)
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
node.sv_speech_end_event.set()
# 如果TTS正在播放异步等待声纹验证结果如果通过才中断TTS
# 使用独立线程避免阻塞录音线程影响TTS播放
if node.tts_playing_event.is_set():
node.get_logger().info("[打断] TTS播放中用户说话结束异步等待声纹验证结果...")
def _check_sv_and_interrupt():
# 等待声纹验证结果最多等待2秒
with node.sv_result_cv:
current_seq = node.sv_result_seq
if node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=2.0
):
# 声纹验证完成,检查结果
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id})中断TTS播放")
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
else:
node.get_logger().debug(f"[打断] 声纹验证未通过不中断TTS状态: {speaker_state.value}")
else:
node.get_logger().warning("[打断] 声纹验证超时不中断TTS")
# 在独立线程中等待,避免阻塞录音线程
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
return
def on_new_segment(self):
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
node = self.node
state = node._get_state()
if state == ConversationState.AUTHORIZED:
# TTS播放期间检测到人声时不立即中断而是开始录音用于声纹验证
# 等待用户说话结束speech_end如果声纹验证通过才中断TTS
# 这样可以避免TTS回声误触发但支持真正的用户打断
if node.tts_playing_event.is_set():
node.get_logger().debug("[打断] TTS播放中检测到人声开始录音用于声纹验证等待说话结束后验证")
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
else:
# TTS未播放时检查声纹验证结果并立即中断
if node.sv_enabled and node.sv_client:
with node.sv_lock:
current_speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
node._interrupt_tts("检测到人声(已授权用户)")
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声但声纹未验证或未匹配不中断TTS当前状态: {speaker_state.value}")
else:
# 未启用声纹,直接中断(保持原有行为)
node._interrupt_tts("检测到人声(未启用声纹)")
node.get_logger().info("[打断] 检测到人声中断TTS播放")
else:
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value}非已授权用户不允许打断TTS")
def on_heartbeat(self):
"""录音线程静音心跳回调"""
self.node.get_logger().info("[录音线程] 静音中")
# ==================== ASR回调 ====================
def on_asr_sentence_end(self, text: str):
"""ASR sentence_end回调 - 将文本放入队列"""
node = self.node
if not text or not text.strip():
return
text_clean = text.strip()
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
state = node._get_state()
# 规则2CHECK_VOICE状态下如果ASR识别完成但VAD还没有触发speech_end主动触发声纹验证
if state == ConversationState.CHECK_VOICE:
if node.sv_enabled and node.sv_client:
node.get_logger().info("[ASR] CHECK_VOICE状态ASR识别完成主动触发声纹验证")
self._trigger_sv_for_check_voice("ASR")
# 其他状态,将文本放入队列
node.text_queue.put(text_clean, timeout=1.0)
def on_asr_text_update(self, text: str):
"""ASR 实时文本更新回调 - 用于多轮提示"""
if not text or not text.strip():
return
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")

View File

@@ -0,0 +1,179 @@
import queue
import time
import numpy as np
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.perception.speaker_verifier import SpeakerState
class NodeWorkers:
def __init__(self, node):
self.node = node
def recording_worker(self):
"""线程1: 录音线程 - 唯一实时线程"""
node = self.node
node.get_logger().info("[录音线程] 启动")
node.audio_recorder.record_with_vad()
def asr_worker(self):
"""线程2: ASR推理线程 - 只做 audio → text"""
node = self.node
node.get_logger().info("[ASR推理线程] 启动")
while not node.stop_event.is_set():
try:
audio_chunk = node.audio_queue.get(timeout=0.1)
except queue.Empty:
continue
if node.interrupt_event.is_set():
continue
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
node.asr_client.send_audio(audio_chunk)
def process_worker(self):
"""线程3: 主线程 - 处理业务逻辑"""
node = self.node
node.get_logger().info("[主线程] 启动")
while not node.stop_event.is_set():
try:
text = node.text_queue.get(timeout=0.1)
except queue.Empty:
continue
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
current_state = node._get_state()
if current_state == ConversationState.CHECK_VOICE:
if node.use_wake_word:
node.get_logger().info(f"[主线程] CHECK_VOICE状态检查唤醒词文本: {text}")
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'回到Idle状态")
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
text = processed_text
if node.sv_enabled and node.sv_client:
node.get_logger().info("[主线程] CHECK_VOICE状态等待声纹验证结果...")
with node.sv_result_cv:
current_seq = node.sv_result_seq
if not node.sv_result_cv.wait_for(
lambda: node.sv_result_seq > current_seq,
timeout=2.0
):
node.get_logger().warning("[主线程] CHECK_VOICE状态声纹结果未ready超时2秒拒绝本轮")
with node.sv_lock:
node.sv_audio_buffer.clear()
node._change_state(ConversationState.IDLE, "声纹结果未ready")
continue
node.get_logger().info("[主线程] CHECK_VOICE状态声纹结果ready继续处理")
with node.sv_lock:
speaker_id = node.current_speaker_id
speaker_state = node.current_speaker_state
score = node.current_speaker_score
if speaker_id and speaker_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
else:
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
node.callbacks.put_tts_text("声纹验证失败")
node._change_state(ConversationState.IDLE, "声纹验证失败")
continue
else:
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
elif current_state == ConversationState.AUTHORIZED:
if node.tts_playing_event.is_set():
node.get_logger().debug("[主线程] AUTHORIZED状态TTS播放中忽略ASR识别结果只有VAD检测到已授权用户人声才能中断")
continue
elif current_state == ConversationState.IDLE:
node.get_logger().warning("[主线程] Idle状态收到文本忽略")
continue
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
processed_text = node.callbacks.handle_wake_word(text)
if not processed_text:
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
continue
text = processed_text
if node.callbacks.check_shutup_command(text):
node.get_logger().info("[主线程] 检测到闭嘴指令")
node.interrupt_event.set()
node.callbacks.force_stop_tts()
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
continue
intent_payload = node.intent_router.route(text)
node._handle_intent(intent_payload)
if current_state == ConversationState.AUTHORIZED:
node.session_start_time = time.time()
def sv_worker(self):
"""线程5: 声纹识别线程 - 非实时、低频CAM++"""
node = self.node
node.get_logger().info("[声纹识别线程] 启动")
min_audio_samples = 8000
while not node.stop_event.is_set():
try:
if node.sv_speech_end_event.wait(timeout=0.1):
node.sv_speech_end_event.clear()
with node.sv_lock:
audio_list = list(node.sv_audio_buffer)
buffer_size = len(audio_list)
node.sv_audio_buffer.clear()
node.get_logger().info(f"[声纹识别] 收到speech_end事件录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
if node._handle_empty_speaker_db():
node.get_logger().info("[声纹识别] 数据库为空跳过验证直接设置UNKNOWN状态")
continue
if buffer_size >= min_audio_samples:
audio_array = np.array(audio_list, dtype=np.int16)
embedding, success = node.sv_client.extract_embedding(
audio_array,
sample_rate=node.sample_rate
)
if not success or embedding is None:
node.get_logger().debug("[声纹识别] 提取embedding失败")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.ERROR
node.current_speaker_score = 0.0
else:
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
with node.sv_lock:
node.current_speaker_id = speaker_id
node.current_speaker_state = match_state
node.current_speaker_score = score
if match_state == SpeakerState.VERIFIED:
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
elif match_state == SpeakerState.REJECTED:
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
else:
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
else:
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
with node.sv_lock:
node.current_speaker_id = None
node.current_speaker_state = SpeakerState.UNKNOWN
node.current_speaker_score = 0.0
with node.sv_result_cv:
node.sv_result_seq += 1
node.sv_result_cv.notify_all()
except Exception as e:
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
time.sleep(0.1)

View File

@@ -0,0 +1,399 @@
"""
声纹注册独立节点:运行完成后退出
"""
import collections
import os
import queue
import threading
import time
import yaml
import numpy as np
import rclpy
from rclpy.node import Node
from ament_index_python.packages import get_package_share_directory
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
from robot_speaker.models.asr.dashscope import DashScopeASR
from pypinyin import pinyin, Style
class RegisterSpeakerNode(Node):
def __init__(self):
super().__init__('register_speaker_node')
self._load_config()
self.stop_event = threading.Event()
self.processing = False
self.buffer_lock = threading.Lock()
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
# 状态:等待唤醒词 -> 等待声纹语音
self.waiting_for_wake_word = True
self.waiting_for_voiceprint = False
# 音频队列和文本队列用于ASR
self.audio_queue = queue.Queue()
self.text_queue = queue.Queue()
self.vad_detector = VADDetector(
mode=self.vad_mode,
sample_rate=self.sample_rate
)
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
vad_detector=self.vad_detector,
audio_queue=self.audio_queue, # 送ASR用于唤醒词检测
silence_duration_ms=self.silence_duration_ms,
min_energy_threshold=self.min_energy_threshold,
heartbeat_interval=self.audio_microphone_heartbeat_interval,
on_heartbeat=self._on_heartbeat,
is_playing=lambda: False,
on_new_segment=None,
on_speech_start=self._on_speech_start,
on_speech_end=self._on_speech_end,
stop_flag=self.stop_event.is_set,
on_audio_chunk=self._on_audio_chunk,
should_put_to_queue=self._should_put_to_queue,
get_silence_threshold=lambda: self.silence_duration_ms,
enable_echo_cancellation=False,
reference_signal_buffer=None,
logger=self.get_logger()
)
# ASR客户端 - 用于唤醒词检测
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self._on_asr_sentence_end
self.asr_client.start()
# ASR处理线程
self.asr_thread = threading.Thread(
target=self._asr_worker,
name="RegisterASRThread",
daemon=True
)
self.asr_thread.start()
# 文本处理线程
self.text_thread = threading.Thread(
target=self._text_worker,
name="RegisterTextThread",
daemon=True
)
self.text_thread.start()
self.sv_client = SpeakerVerificationClient(
model_path=self.sv_model_path,
threshold=self.sv_threshold,
speaker_db_path=self.sv_speaker_db_path,
logger=self.get_logger()
)
self.get_logger().info("声纹注册节点启动,请说'er gou......'唤醒注册")
self.recording_thread = threading.Thread(
target=self.audio_recorder.record_with_vad,
name="RegisterRecordingThread",
daemon=True
)
self.recording_thread.start()
self.timer = self.create_timer(0.2, self._check_done)
def _load_config(self):
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
dashscope = config['dashscope']
audio = config['audio']
mic = audio['microphone']
vad = config['vad']
system = config['system']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
self.input_device_index = mic['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
self.vad_mode = vad['vad_mode']
self.silence_duration_ms = vad['silence_duration_ms']
self.min_energy_threshold = vad['min_energy_threshold']
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
self.sv_threshold = system['sv_threshold']
self.sv_speaker_db_path = system['sv_speaker_db_path']
self.sv_buffer_size = system['sv_buffer_size']
self.wake_word = system['wake_word']
def _should_put_to_queue(self) -> bool:
"""判断是否应该将音频放入ASR队列仅在等待唤醒词时"""
return self.waiting_for_wake_word
def _on_heartbeat(self):
if self.waiting_for_wake_word:
self.get_logger().info("[注册录音] 等待唤醒词'er gou'...")
elif self.waiting_for_voiceprint:
self.get_logger().info("[注册录音] 等待声纹语音...")
def _on_speech_start(self):
if self.waiting_for_wake_word:
# 等待唤醒词时,开始录音(可能包含唤醒词)
self.get_logger().info("[注册录音] 检测到人声,开始录音")
elif self.waiting_for_voiceprint:
self.get_logger().info("[注册录音] 检测到人声,继续录音(用于声纹注册)")
# 注意:不清空缓冲区,保留包含唤醒词的音频
def _on_audio_chunk(self, audio_chunk: bytes):
# 记录所有音频(包括唤醒词),用于声纹注册
try:
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
with self.buffer_lock:
self.audio_buffer.extend(audio_array)
except Exception as e:
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
def _on_speech_end(self):
# 如果还在等待唤醒词,不处理
if self.waiting_for_wake_word:
return
# 如果已经在处理,不重复处理
if self.processing:
return
# 等待声纹语音时用户说话结束使用当前音频即使不足3秒
if self.waiting_for_voiceprint:
self._process_voiceprint_audio(use_current_audio_if_short=True)
def _process_voiceprint_audio(self, use_current_audio_if_short: bool = False):
"""处理声纹音频:使用用户完整的第一段语音进行注册
Args:
use_current_audio_if_short: 如果音频不足3秒是否使用当前音频用于用户已说完的情况
"""
if self.processing:
return
self.processing = True
with self.buffer_lock:
audio_list = list(self.audio_buffer)
buffer_size = len(audio_list)
buffer_sec = buffer_size / self.sample_rate
self.get_logger().info(f"[注册录音] 当前音频长度: {buffer_sec:.2f}")
# 需要3秒音频
required_samples = int(self.sample_rate * 3.0)
# 如果音频不足3秒
if buffer_size < required_samples:
if use_current_audio_if_short:
# 用户已经说完了使用当前音频即使不足3秒
self.get_logger().info(f"[注册录音] 音频不足3秒当前{buffer_sec:.2f}秒),但用户已说完,使用当前音频进行注册")
audio_to_use = audio_list
else:
# 等待继续录音
self.get_logger().info(f"[注册录音] 音频不足3秒当前{buffer_sec:.2f}秒),等待继续录音...")
self.processing = False
return
else:
# 音频达到3秒截取最后3秒
audio_to_use = audio_list[-required_samples:]
self.get_logger().info(f"[注册录音] 使用完整的第一段语音截取最后3秒用于注册")
# 清空缓冲区
with self.buffer_lock:
self.audio_buffer.clear()
try:
audio_array = np.array(audio_to_use, dtype=np.int16)
embedding, success = self.sv_client.extract_embedding(
audio_array,
sample_rate=self.sample_rate
)
if not success or embedding is None:
self.get_logger().error("[注册录音] 提取embedding失败")
self.processing = False
return
speaker_id = f"user_{int(time.time())}"
if self.sv_client.register_speaker(speaker_id, embedding):
self.get_logger().info(f"[注册录音] 注册成功用户ID: {speaker_id},准备退出")
self.stop_event.set()
else:
self.get_logger().error("[注册录音] 注册失败")
self.processing = False
except Exception as e:
self.get_logger().error(f"[注册录音] 注册异常: {e}")
self.processing = False
def _extract_speech_segments(self, audio_array: np.ndarray, frame_size: int = 1024) -> list:
"""使用能量检测提取人声片段(过滤静音)"""
speech_segments = []
frame_samples = frame_size
total_frames = 0
speech_frames = 0
for i in range(0, len(audio_array), frame_samples):
frame = audio_array[i:i + frame_samples]
if len(frame) < frame_samples:
break
total_frames += 1
# 计算帧的能量RMS对于int16音频
frame_float = frame.astype(np.float32)
energy = np.sqrt(np.mean(frame_float ** 2))
# 使用更低的阈值来检测人声(降低阈值,避免误判静音)
# 阈值可以动态调整,或者使用自适应阈值
threshold = self.min_energy_threshold * 0.5 # 降低阈值到原来的50%
# 如果能量超过阈值,认为是人声
if energy >= threshold:
speech_segments.append((i, i + frame_samples))
speech_frames += 1
# 调试信息
if total_frames > 0:
speech_ratio = speech_frames / total_frames
self.get_logger().debug(f"[注册录音] 能量检测: 总帧数={total_frames}, 人声帧数={speech_frames}, 人声比例={speech_ratio:.2%}, 阈值={self.min_energy_threshold}")
return speech_segments
def _merge_speech_segments(self, audio_array: np.ndarray, segments: list, min_samples: int) -> np.ndarray:
"""合并人声片段,返回连续的人声音频"""
if not segments:
return np.array([], dtype=np.int16)
# 合并相邻的片段
merged_segments = []
current_start, current_end = segments[0]
for start, end in segments[1:]:
if start <= current_end + 1024: # 允许小间隙1帧
current_end = end
else:
merged_segments.append((current_start, current_end))
current_start, current_end = start, end
merged_segments.append((current_start, current_end))
# 从后往前选择片段直到达到3秒
selected_audio = []
total_samples = 0
for start, end in reversed(merged_segments):
segment_audio = audio_array[start:end]
selected_audio.insert(0, segment_audio)
total_samples += len(segment_audio)
if total_samples >= min_samples:
break
if not selected_audio:
return np.array([], dtype=np.int16)
return np.concatenate(selected_audio)
def _asr_worker(self):
"""ASR处理线程"""
while not self.stop_event.is_set():
try:
audio_chunk = self.audio_queue.get(timeout=0.1)
if self.asr_client and self.asr_client.running:
self.asr_client.send_audio(audio_chunk)
except queue.Empty:
continue
except Exception as e:
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
def _on_asr_sentence_end(self, text: str):
"""ASR识别完成回调"""
if text and text.strip():
self.text_queue.put(text.strip())
def _text_worker(self):
"""文本处理线程:检测唤醒词"""
while not self.stop_event.is_set():
try:
text = self.text_queue.get(timeout=0.1)
if self.waiting_for_wake_word:
self._check_wake_word(text)
except queue.Empty:
continue
except Exception as e:
self.get_logger().error(f"[注册文本] 处理异常: {e}")
def _to_pinyin(self, text: str) -> str:
"""将中文文本转换为拼音"""
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
if not chars:
return ""
py_list = pinyin(chars, style=Style.NORMAL)
return ' '.join([item[0] for item in py_list]).lower().strip()
def _check_wake_word(self, text: str):
"""检查是否包含唤醒词"""
text_pinyin = self._to_pinyin(text)
wake_word_pinyin = self.wake_word.lower().strip()
self.get_logger().info(f"[注册唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
if not wake_word_pinyin:
return
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
wake_word_parts = wake_word_pinyin.split()
# 检查是否包含唤醒词
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'")
self.get_logger().info("=" * 50)
self.get_logger().info("[声纹注册] 开始注册声纹将截取3秒音频用于注册")
self.get_logger().info("=" * 50)
self.waiting_for_wake_word = False
self.waiting_for_voiceprint = True
# 停止ASR不再需要识别
if self.asr_client:
self.asr_client.stop_current_recognition()
# 立即处理当前音频缓冲区中的完整音频
# 用户可能已经说完了(包含唤醒词的整段语音)
self._process_voiceprint_audio()
return
def _check_done(self):
if self.stop_event.is_set():
self.get_logger().info("注册完成,节点退出")
# 清理资源
if self.asr_client:
self.asr_client.stop()
self.destroy_node()
rclpy.shutdown()
def main(args=None):
rclpy.init(args=args)
node = RegisterSpeakerNode()
rclpy.spin(node)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,803 @@
"""
语音交互节点
"""
import rclpy
from rclpy.node import Node
from std_msgs.msg import String
import threading
import queue
import time
import re
import base64
import io
import numpy as np
from PIL import Image
import subprocess
import collections
import os
import yaml
import json
from ament_index_python.packages import get_package_share_directory
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
from robot_speaker.models.asr.dashscope import DashScopeASR
from robot_speaker.models.tts.dashscope import DashScopeTTSClient
from robot_speaker.models.llm.dashscope import DashScopeLLM
from robot_speaker.understanding.context_manager import ConversationHistory
from robot_speaker.core.types import LLMMessage, TTSRequest
from robot_speaker.perception.camera_client import CameraClient
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient, SpeakerState
from robot_speaker.perception.echo_cancellation import ReferenceSignalBuffer
from robot_speaker.core.conversation_state import ConversationState
from robot_speaker.core.node_workers import NodeWorkers
from robot_speaker.core.node_callbacks import NodeCallbacks
from robot_speaker.core.intent_router import IntentRouter, IntentResult
class RobotSpeakerNode(Node):
# ==================== 初始化 ====================
def __init__(self):
super().__init__('robot_speaker_node')
# 直接从配置文件加载参数
self._load_config()
# 初始化队列(线程间通信)
self.audio_queue = queue.Queue() # 录音线程 → ASR线程
self.text_queue = queue.Queue() # ASR线程 → 主线程
self.tts_queue = queue.Queue() # 主线程 → TTS线程
# 初始化线程同步事件
self.interrupt_event = threading.Event() # 中断标志
self.stop_event = threading.Event() # 停止标志
self.tts_playing_event = threading.Event() # TTS播放状态
# 初始化会话管理
self.session_active = False
self.session_start_time = 0.0
self.session_lock = threading.Lock()
# 状态机状态
self.conversation_state = ConversationState.IDLE # 当前会话状态
self.state_lock = threading.Lock() # 保护状态机状态
# 声纹识别共享状态
self.current_speaker_id = None # 当前说话人ID共享状态只读
self.current_speaker_state = SpeakerState.UNKNOWN # 当前说话人状态
self.current_speaker_score = 0.0 # 当前说话人相似度得分
self.sv_lock = threading.Lock() # 保护声纹识别共享状态
self.sv_speech_end_event = threading.Event() # 通知声纹线程处理speech_end触发
self.sv_result_ready_event = threading.Event() # 保留兼容(已不用于同步)
self.sv_result_lock = threading.Lock() # 声纹结果序号锁
self.sv_result_cv = threading.Condition(self.sv_result_lock)
self.sv_result_seq = 0
# 声纹缓冲区大小将在_init_components中初始化需要先读取参数
self.sv_audio_buffer = None # 声纹验证录音缓冲区将在_init_components中初始化
self.sv_recording = False # 是否正在为声纹验证录音
# 声纹注册状态
self.utterance_lock = threading.Lock()
self.current_utterance_id = 0
self.last_processed_utterance_id = 0
self.intent_router = IntentRouter()
self.callbacks = NodeCallbacks(self)
# 初始化组件VAD、录音器、ASR、LLM、TTS
self._init_components()
self.workers = NodeWorkers(self)
# 状态机初始状态
if self.sv_enabled and self.sv_client:
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
self.get_logger().info("声纹数据库为空,请注册声纹")
# ROS订阅
self.interrupt_sub = self.create_subscription(
String, 'interrupt_command', self.callbacks.handle_interrupt_command, self.system_interrupt_command_queue_depth
)
# 启动线程
self._start_threads()
self.get_logger().info("语音节点已启动")
# ==================== 配置加载 ====================
def _load_config(self):
"""直接从 voice.yaml 配置文件加载参数"""
config_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'voice.yaml'
)
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
# 音频参数
audio = config['audio']
mic = audio['microphone']
soundcard = audio['soundcard']
echo = audio['echo_cancellation']
tts_audio = audio['tts']
self.input_device_index = mic['device_index']
self.output_card_index = soundcard['card_index']
self.output_device_index = soundcard['device_index']
self.sample_rate = mic['sample_rate']
self.channels = mic['channels']
self.chunk = mic['chunk']
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
self.output_sample_rate = soundcard['sample_rate']
self.output_channels = soundcard['channels']
self.output_volume = soundcard['volume']
self.audio_echo_cancellation_max_duration_ms = echo['max_duration_ms']
self.audio_tts_source_sample_rate = tts_audio['source_sample_rate']
self.audio_tts_source_channels = tts_audio['source_channels']
self.audio_tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
# VAD参数
vad = config['vad']
self.vad_mode = vad['vad_mode']
self.silence_duration_ms = vad['silence_duration_ms']
self.min_energy_threshold = vad['min_energy_threshold']
# DashScope参数
dashscope = config['dashscope']
self.dashscope_api_key = dashscope['api_key']
self.asr_model = dashscope['asr']['model']
self.asr_url = dashscope['asr']['url']
self.llm_model = dashscope['llm']['model']
self.llm_base_url = dashscope['llm']['base_url']
self.llm_temperature = dashscope['llm']['temperature']
self.llm_max_tokens = dashscope['llm']['max_tokens']
self.llm_max_history = dashscope['llm']['max_history']
self.llm_summary_trigger = dashscope['llm']['summary_trigger']
self.tts_model = dashscope['tts']['model']
self.tts_voice = dashscope['tts']['voice']
# 系统参数
system = config['system']
self.use_llm = system['use_llm']
self.use_wake_word = system['use_wake_word']
self.wake_word = system['wake_word']
self.session_timeout = system['session_timeout']
self.system_shutup_keywords = system['shutup_keywords']
self.system_interrupt_command_queue_depth = system['interrupt_command_queue_depth']
self.sv_enabled = system['sv_enabled']
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
self.sv_threshold = system['sv_threshold']
self.sv_speaker_db_path = system['sv_speaker_db_path']
self.sv_buffer_size = system['sv_buffer_size']
# 相机参数
camera = config['camera']
self.camera_serial_number = camera['serial_number']
self.camera_rgb_width = camera['rgb']['width']
self.camera_rgb_height = camera['rgb']['height']
self.camera_rgb_fps = camera['rgb']['fps']
self.camera_rgb_format = camera['rgb']['format']
self.camera_image_jpeg_quality = camera['image']['jpeg_quality']
self.camera_image_max_size = camera['image']['max_size']
self.knowledge_file = os.path.join(
get_package_share_directory('robot_speaker'),
'config',
'knowledge.json'
)
# ==================== 组件初始化 ====================
def _init_components(self):
"""初始化所有组件"""
self.shutup_keywords = [k.strip() for k in self.system_shutup_keywords.split(',') if k.strip()]
self.kb_answers_map = {}
if self.knowledge_file and os.path.exists(self.knowledge_file):
try:
with open(self.knowledge_file, 'r') as f:
kb_data = json.load(f)
entries = kb_data["entries"]
for entry in entries:
patterns = entry["patterns"]
answer = entry["answer"]
if not answer.strip():
continue
for pattern in patterns:
key = pattern.strip().lower()
if key:
self.kb_answers_map[key] = answer.strip()
self.get_logger().info(f"知识库已加载: {len(self.kb_answers_map)}")
except Exception as e:
self.get_logger().warning(f"知识库加载失败: {e}")
self.sv_audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
self.vad_detector = VADDetector(
mode=self.vad_mode,
sample_rate=self.sample_rate
)
# 创建参考信号缓冲区(用于回声消除),虽然播放是44100Hz但麦克风输入是16kHz
self.reference_signal_buffer = ReferenceSignalBuffer(
max_duration_ms=self.audio_echo_cancellation_max_duration_ms,
sample_rate=self.sample_rate,
channels=self.output_channels
)
# 录音器 - 直接发送音频chunk到队列
self.audio_recorder = AudioRecorder(
device_index=self.input_device_index,
sample_rate=self.sample_rate,
channels=self.channels,
chunk=self.chunk,
vad_detector=self.vad_detector,
audio_queue=self.audio_queue,
silence_duration_ms=self.silence_duration_ms,
min_energy_threshold=self.min_energy_threshold,
heartbeat_interval=self.audio_microphone_heartbeat_interval,
on_heartbeat=self.callbacks.on_heartbeat,
is_playing=self.tts_playing_event.is_set,
on_new_segment=self.callbacks.on_new_segment,
on_speech_start=self.callbacks.on_speech_start,
on_speech_end=self.callbacks.on_speech_end,
stop_flag=self.stop_event.is_set,
on_audio_chunk=self.callbacks.on_audio_chunk_for_sv if self.sv_enabled else None, # 声纹录音回调
should_put_to_queue=self.callbacks.should_put_audio_to_queue, # 检查是否应该将音频放入队列
get_silence_threshold=self.callbacks.get_silence_threshold, # 动态静音阈值回调
enable_echo_cancellation=True, # 启用回声消除
reference_signal_buffer=self.reference_signal_buffer, # 传递参考信号缓冲区
logger=self.get_logger()
)
# ASR客户端 - 流式识别
self.asr_client = DashScopeASR(
api_key=self.dashscope_api_key,
sample_rate=self.sample_rate,
model=self.asr_model,
url=self.asr_url,
logger=self.get_logger()
)
self.asr_client.on_sentence_end = self.callbacks.on_asr_sentence_end
self.asr_client.on_text_update = self.callbacks.on_asr_text_update
self.asr_client.start()
# LLM客户端
if self.use_llm:
self.llm_client = DashScopeLLM(
api_key=self.dashscope_api_key,
model=self.llm_model,
base_url=self.llm_base_url,
temperature=self.llm_temperature,
max_tokens=self.llm_max_tokens,
name="LLM-chat",
logger=self.get_logger()
)
self.history = ConversationHistory(
max_history=self.llm_max_history,
summary_trigger=self.llm_summary_trigger
)
else:
self.llm_client = None
self.history = None
# TTS客户端
self.get_logger().info(f"TTS配置: model={self.tts_model}, voice={self.tts_voice}")
self.get_logger().info(f"音频输出配置: sample_rate={self.output_sample_rate}, channels={self.output_channels}")
self.tts_client = DashScopeTTSClient(
api_key=self.dashscope_api_key,
model=self.tts_model,
voice=self.tts_voice,
card_index=self.output_card_index,
device_index=self.output_device_index,
output_sample_rate=self.output_sample_rate,
output_channels=self.output_channels,
output_volume=self.output_volume,
tts_source_sample_rate=self.audio_tts_source_sample_rate,
tts_source_channels=self.audio_tts_source_channels,
tts_ffmpeg_thread_queue_size=self.audio_tts_ffmpeg_thread_queue_size,
reference_signal_buffer=self.reference_signal_buffer, # 传递参考信号缓冲区
logger=self.get_logger()
)
# 相机客户端(默认一直运行)
try:
self.camera_client = CameraClient(
serial_number=self.camera_serial_number,
width=self.camera_rgb_width,
height=self.camera_rgb_height,
fps=self.camera_rgb_fps,
format=self.camera_rgb_format,
logger=self.get_logger()
)
self.camera_client.initialize()
except Exception as e:
self.get_logger().warning(f"相机初始化失败: {e},相机功能将不可用")
self.camera_client = None
# 声纹识别客户端
if self.sv_enabled and self.sv_model_path:
try:
self.sv_client = SpeakerVerificationClient(
model_path=self.sv_model_path,
threshold=self.sv_threshold,
speaker_db_path=self.sv_speaker_db_path,
logger=self.get_logger()
)
except Exception as e:
self.get_logger().warning(f"声纹识别初始化失败: {e},声纹功能将不可用")
self.sv_client = None
self.sv_enabled = False
else:
self.sv_client = None
# ==================== 线程启动 ====================
def _start_threads(self):
"""启动线程"""
# 线程1: 录音线程
self.recording_thread = threading.Thread(
target=self.workers.recording_worker,
name="RecordingThread",
daemon=True
)
self.recording_thread.start()
# 线程2: ASR推理线程
self.asr_thread = threading.Thread(
target=self.workers.asr_worker,
name="ASRThread",
daemon=True
)
self.asr_thread.start()
# 线程3: 主线程 - 处理业务逻辑
self.process_thread = threading.Thread(
target=self.workers.process_worker,
name="ProcessThread",
daemon=True
)
self.process_thread.start()
# 线程4: TTS播放线程
self.tts_thread = threading.Thread(
target=self._tts_worker,
name="TTSThread",
daemon=True
)
self.tts_thread.start()
# 线程5: 声纹识别线程(如果启用)
if self.sv_enabled and self.sv_client:
self.sv_thread = threading.Thread(
target=self.workers.sv_worker,
name="SVThread",
daemon=True
)
self.sv_thread.start()
else:
self.sv_thread = None
# ==================== TTS播放线程 ====================
def _tts_worker(self):
"""
线程4: TTS播放线程 - 只播放
"""
self.get_logger().info("[TTS播放线程] 启动")
while not self.stop_event.is_set():
try:
text = self.tts_queue.get(timeout=1.0)
except queue.Empty:
if self.interrupt_event.is_set():
self.get_logger().debug("[TTS播放线程] 检测到中断事件")
continue
if self.interrupt_event.is_set():
self.get_logger().info("[TTS播放线程] 中断播放,跳过文本")
continue
if not text or not str(text).strip():
continue
text_str = str(text).strip()
text_len = len(text_str)
self.get_logger().info(f"[TTS播放线程] 开始播放: {text_str[:100]}... (总长度: {text_len}字符)")
self.tts_playing_event.set()
request = TTSRequest(text=text_str, voice=None)
success = self.tts_client.synthesize(
request,
interrupt_check=lambda: self.interrupt_event.is_set()
)
if success:
self.get_logger().info("[TTS播放线程] 播放完成")
else:
self.get_logger().info("[TTS播放线程] 播放被中断")
self.tts_playing_event.clear()
if self.interrupt_event.is_set():
self.get_logger().info("[TTS播放线程] 播放完成后检测到中断,清空队列")
self._drain_queue(self.tts_queue)
self.interrupt_event.clear()
# ==================== 状态机方法 ====================
def _change_state(self, new_state: ConversationState, reason: str | None = None):
"""改变状态机状态"""
with self.state_lock:
old_state = self.conversation_state
self.conversation_state = new_state
if reason:
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}: {reason}")
else:
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}")
def _get_state(self) -> ConversationState:
"""获取当前状态"""
with self.state_lock:
return self.conversation_state
# ==================== LLM处理含拍照 ====================
def _encode_image_to_base64(self, image_data: np.ndarray, quality: int = 85) -> str:
"""将numpy图像数组编码为base64字符串"""
try:
if image_data.shape[2] == 3:
pil_image = Image.fromarray(image_data, 'RGB')
else:
pil_image = Image.fromarray(image_data)
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=quality)
image_bytes = buffer.getvalue()
base64_str = base64.b64encode(image_bytes).decode('utf-8')
return base64_str
except Exception as e:
self.get_logger().error(f"图像编码失败: {e}")
return ""
def _llm_process_stream_with_camera(self, user_text: str, need_camera: bool, system_prompt: str | None = None) -> str:
"""LLM流式处理 - 支持多模态(文本+图像)"""
if not self.llm_client or not self.history:
return ""
messages = list(self.history.get_messages())
has_system_msg = any(msg.role == "system" for msg in messages)
if not has_system_msg:
if not system_prompt:
system_prompt = self.intent_router.build_default_system_prompt()
messages.insert(0, LLMMessage(role="system", content=system_prompt))
full_reply = ""
tts_text_buffer = ""
image_base64_list = []
def on_token(token: str):
nonlocal full_reply, tts_text_buffer
if self.interrupt_event.is_set():
self.get_logger().info("[LLM流式处理] on_token回调中检测到中断停止处理")
return
full_reply += token
tts_text_buffer += token
if need_camera and self.camera_client:
with self.camera_client.capture_context() as image_data:
if image_data is not None:
image_base64 = self._encode_image_to_base64(
image_data,
quality=self.camera_image_jpeg_quality
)
if image_base64:
image_base64_list.append(image_base64)
self.get_logger().info("[相机] 已拍照")
if image_base64_list:
self.get_logger().info(
f"[多模态] 准备发送给LLM: {len(image_base64_list)}张图片,用户文本: {user_text[:50]}"
)
for idx, img_b64 in enumerate(image_base64_list):
self.get_logger().debug(f"[多模态] 图片#{idx+1} base64长度: {len(img_b64)}")
reply = self.llm_client.chat_stream(
messages,
on_token=on_token,
images=image_base64_list if image_base64_list else None,
interrupt_check=lambda: self.interrupt_event.is_set()
)
if self.interrupt_event.is_set() or (reply is None):
if self.interrupt_event.is_set():
self.get_logger().info("[LLM流式处理] 处理被中断")
return ""
if image_base64_list:
for img_b64 in image_base64_list:
del img_b64
image_base64_list.clear()
self.get_logger().info("[相机] 已删除照片")
if reply and reply.strip():
tts_text_to_send = reply.strip()
tts_buffer_len = len(tts_text_buffer.strip()) if tts_text_buffer else 0
reply_len = len(tts_text_to_send)
if tts_buffer_len != reply_len:
self.get_logger().info(
f"[流式TTS] tts_text_buffer({tts_buffer_len}字符)和reply({reply_len}字符)长度不一致使用reply作为TTS文本"
)
elif tts_text_buffer and tts_text_buffer.strip():
tts_text_to_send = tts_text_buffer.strip()
self.get_logger().warning(
f"[流式TTS] reply为空使用tts_text_buffer({len(tts_text_to_send)}字符)作为TTS文本"
)
else:
tts_text_to_send = ""
self.get_logger().warning("[流式TTS] reply和tts_text_buffer都为空无法发送TTS文本")
if not self.interrupt_event.is_set() and tts_text_to_send:
text_len = len(tts_text_to_send)
self.get_logger().info(
f"[流式TTS] 发送完整文本到TTS队列: {tts_text_to_send[:100]}... (总长度: {text_len}字符)"
)
if text_len > 100:
self.get_logger().debug(f"[流式TTS] 完整文本内容: {tts_text_to_send}")
self._put_tts_text(tts_text_to_send)
return reply.strip() if reply else ""
# ==================== 中断与TTS工具 ====================
def _force_stop_tts(self):
"""强制停止TTS播放 - 直接杀死记录的ffmpeg进程PID"""
self._drain_queue(self.tts_queue)
self.interrupt_event.set()
if self.tts_client and self.tts_client.current_ffmpeg_pid:
try:
pid = self.tts_client.current_ffmpeg_pid
os.kill(pid, 9) # SIGKILL
self.get_logger().info(f"[强制停止TTS] 已终止ffmpeg进程PID={pid}")
self.tts_client.current_ffmpeg_pid = None
except ProcessLookupError:
self.get_logger().debug(f"[强制停止TTS] ffmpeg进程已不存在PID={pid}")
self.tts_client.current_ffmpeg_pid = None
except Exception as e:
self.get_logger().warning(f"[强制停止TTS] 终止ffmpeg进程失败: {e}")
def _check_interrupt(self, auto_clear: bool = False) -> bool:
"""
检查中断标志
"""
if self.interrupt_event.is_set():
if auto_clear:
self.interrupt_event.clear()
return True
return False
def _check_interrupt_and_cancel_turn(self) -> bool:
"""检查中断并取消轮次(统一处理中断后的清理)"""
if self._check_interrupt(auto_clear=True):
if self.use_llm and self.history:
self.history.cancel_turn()
return True
return False
# ==================== 注册/会话/唤醒词 ====================
def _handle_empty_speaker_db(self) -> bool:
"""处理数据库为空的情况(统一处理)"""
if not (self.sv_enabled and self.sv_client):
return False
speaker_count = self.sv_client.get_speaker_count()
if speaker_count == 0:
with self.sv_lock:
self.current_speaker_id = None
self.current_speaker_state = SpeakerState.UNKNOWN
self.current_speaker_score = 0.0
self.sv_result_ready_event.set()
return True
return False
def _put_tts_text(self, text: str):
"""统一处理TTS队列put带异常处理"""
try:
self.tts_queue.put(text, timeout=0.2)
self.get_logger().debug(f"[TTS队列] 文本已成功放入队列: {text[:50]}... (队列大小: {self.tts_queue.qsize()})")
except Exception as e:
self.get_logger().error(f"[TTS队列] 放入队列失败: {e}, 文本: {text[:50]}")
def _interrupt_tts(self, reason: str):
"""
中断TTS播放,只设置中断事件不清空队列让TTS线程自己检查并停止播放
"""
self.get_logger().info(f"[中断] {reason}")
self.interrupt_event.set()
@staticmethod
def _drain_queue(q: queue.Queue):
"""清空队列"""
while True:
try:
q.get_nowait()
except queue.Empty:
break
def _start_session(self):
"""开始会话"""
with self.session_lock:
self.session_active = True
self.session_start_time = time.time()
def _reset_session(self):
"""重置会话"""
with self.session_lock:
self.session_start_time = time.time()
def _is_session_active(self) -> bool:
"""检查会话是否活跃"""
with self.session_lock:
if not self.session_active:
return False
if time.time() - self.session_start_time >= self.session_timeout:
self.session_active = False
return False
return True
# ==================== 意图处理 ====================
def _handle_wake_word(self, text: str) -> str:
"""处理唤醒词ASR文本转拼音检查是否包含唤醒词拼音"""
if not self.use_wake_word:
return text.strip()
if self._is_session_active():
self._reset_session()
return text.strip()
text_pinyin = self.intent_router.to_pinyin(text)
wake_word_pinyin = self.wake_word.lower().strip()
self.get_logger().info(f"[唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
if not wake_word_pinyin:
self.get_logger().info("[唤醒词] 唤醒词为空,过滤文本")
return ""
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
wake_word_parts = wake_word_pinyin.split()
start_idx = -1
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
start_idx = i
break
if start_idx == -1:
self.get_logger().info(f"[唤醒词] 未检测到唤醒词 '{self.wake_word}',过滤文本")
return ""
removed = 0
new_text = ""
for c in text:
if '\u4e00' <= c <= '\u9fa5':
if removed < start_idx or removed >= start_idx + len(wake_word_parts):
new_text += c
removed += 1
else:
new_text += c
self._start_session()
return new_text.strip()
def _check_shutup_command(self, text: str) -> bool:
"""检查闭嘴指令"""
if not text:
return False
text_lower = text.lower()
text_pinyin = self.intent_router.to_pinyin(text)
for keyword in self.shutup_keywords:
kw = keyword.lower().strip()
if not kw:
continue
if kw in text_lower or (text_pinyin and kw in text_pinyin):
return True
return False
def _handle_intent(self, intent_payload: IntentResult):
"""按意图路由到不同处理逻辑"""
intent = intent_payload.intent
text = intent_payload.text
need_camera = intent_payload.need_camera
system_prompt = intent_payload.system_prompt
if intent == "kb_qa":
answer = None
text_pinyin = self.intent_router.to_pinyin(text)
if text_pinyin:
answer = self.kb_answers_map.get(text_pinyin)
if answer:
if "{wake_word}" in answer:
answer = answer.replace("{wake_word}", self.wake_word or "")
self._put_tts_text(answer)
else:
pass
return
if self.use_llm and self.llm_client:
if self.history:
self.history.start_turn(text)
reply = self._llm_process_stream_with_camera(
text,
need_camera=need_camera,
system_prompt=system_prompt
)
if reply:
if self.history:
self.history.commit_turn(reply)
else:
if self.history:
self.history.cancel_turn()
else:
self.get_logger().warning("[主线程] 未启用LLM无法处理文本")
# ==================== 资源清理 ====================
def destroy_node(self):
"""销毁节点"""
self.get_logger().info("语音节点正在关闭...")
self.stop_event.set()
self.interrupt_event.set()
self.get_logger().info("强制停止TTS播放...")
self._force_stop_tts()
self._drain_queue(self.tts_queue)
threads_to_join = [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]
if self.sv_thread:
threads_to_join.append(self.sv_thread)
for thread in threads_to_join:
if thread and thread.is_alive():
thread.join(timeout=1.0)
self._force_stop_tts()
if hasattr(self, 'asr_client') and self.asr_client:
self.asr_client.stop()
if hasattr(self, 'audio_recorder') and self.audio_recorder:
self.audio_recorder.cleanup()
if hasattr(self, 'camera_client') and self.camera_client:
self.camera_client.cleanup()
if hasattr(self, 'sv_client') and self.sv_client:
try:
self.sv_client.save_speakers()
self.sv_client.cleanup()
except Exception as e:
self.get_logger().warning(f"清理声纹识别资源时出错: {e}")
super().destroy_node()
def _init_ros(args):
rclpy.init(args=args)
def _create_node():
return RobotSpeakerNode()
def _run_node(node):
rclpy.spin(node)
def _cleanup_node(node):
if node:
node.destroy_node()
def _shutdown_ros():
if rclpy.ok():
rclpy.shutdown()
# ==================== 入口 ====================
def main(args=None):
node = None
_init_ros(args)
node = _create_node()
_run_node(node)
_cleanup_node(node)
_shutdown_ros()
if __name__ == '__main__':
main()

View File

@@ -1,157 +0,0 @@
"""
回声消除模块
mic = 人声 + 扬声器回声 + 环境噪声
ref = 声卡原始播放音频
AEC(mic, ref) → 去掉 ref 在 mic 中的那一部分
"""
import numpy as np
import struct
import threading
from collections import deque
import aec_audio_processing
class EchoCanceller:
"""回声消除器"""
def __init__(self, sample_rate: int, frame_size: int, channels: int,
ref_channels: int, logger=None):
self.sample_rate = sample_rate
self.frame_size = frame_size
self.channels = channels
self.ref_channels = ref_channels
self.logger = logger
self.aec = None
self.aec_frame_size = None # AudioProcessor期望的帧大小固定10ms=160样本
# 初始化aec-audio-processing的AudioProcessor
try:
self.aec = aec_audio_processing.AudioProcessor(
enable_aec=True, # 回声消除
enable_ns=False, # 降噪
enable_agc=False # 自动增益
)
# 设置流格式麦克风输入1声道16kHz
self.aec.set_stream_format(
sample_rate_in=sample_rate,
channel_count_in=channels,
sample_rate_out=sample_rate,
channel_count_out=channels
)
# 设置反向流格式参考信号播放是2声道重采样到16kHz
# 参考信号是声卡播放的音频2声道重采样到16kHz用于回声消除
self.aec.set_reverse_stream_format(sample_rate, ref_channels)
# 获取AudioProcessor期望的帧大小固定10ms
self.aec_frame_size = self.aec.get_frame_size()
if logger:
logger.info(f"AudioProcessor期望的帧大小: {self.aec_frame_size} 样本 ({self.aec_frame_size / sample_rate * 1000}ms)")
except Exception as e:
if logger:
logger.warning(f"aec_audio_processing 初始化失败: {e},将禁用回声消除")
self.aec = None
def process(self, mic_signal: bytes, ref_signal: bytes = None) -> bytes:
"""处理音频数据,消除回声(在录音线程中同步处理)"""
if self.aec is None or ref_signal is None or self.aec_frame_size is None:
return mic_signal
# 保存原始长度
original_mic_len = len(mic_signal)
# AudioProcessor期望固定10ms的帧需要将大的chunk分成多个小块处理
# 麦克风1声道160样本 * 1声道 * 2字节 = 320字节
mic_frame_bytes = self.aec_frame_size * self.channels * 2
# 参考信号2声道160样本 * 2声道 * 2字节 = 640字节
ref_frame_bytes = self.aec_frame_size * self.ref_channels * 2
# 确保输入数据长度是帧大小的整数倍
if len(mic_signal) % mic_frame_bytes != 0:
padding = mic_frame_bytes - (len(mic_signal) % mic_frame_bytes)
mic_signal = mic_signal + b'\x00' * padding
if len(ref_signal) % ref_frame_bytes != 0:
padding = ref_frame_bytes - (len(ref_signal) % ref_frame_bytes)
ref_signal = ref_signal + b'\x00' * padding
# 分块处理将大的chunk1024样本分成多个10ms块160样本处理
try:
num_frames = len(mic_signal) // mic_frame_bytes
output_chunks = []
for i in range(num_frames):
mic_chunk = mic_signal[i * mic_frame_bytes:(i + 1) * mic_frame_bytes]
ref_chunk = ref_signal[i * ref_frame_bytes:(i + 1) * ref_frame_bytes]
self.aec.process_reverse_stream(ref_chunk)
output_chunk = self.aec.process_stream(mic_chunk)
# AudioProcessor.process_stream返回bytes
output_chunks.append(output_chunk)
result = b''.join(output_chunks)
return result[:original_mic_len]
except Exception as e:
if self.logger:
self.logger.warning(f"回声消除处理失败: {e}")
return mic_signal[:original_mic_len]
class ReferenceSignalBuffer:
"""缓存声卡播放的参考音频(供 AEC 使用)"""
def __init__(self, max_duration_ms: int, sample_rate: int, channels: int):
max_samples = int(sample_rate * max_duration_ms / 1000)
self.sample_rate = sample_rate
self.channels = channels # 参考信号声道数播放声道数2声道
self.buffer = deque(maxlen=max_samples * channels)
self.lock = threading.Lock()
def add_reference(self, audio_data: bytes, source_sample_rate: int = None, source_channels: int = 1):
"""
添加参考信号
"""
if not audio_data:
return
with self.lock:
# 重采样TTS源采样率 -> 麦克风采样率(匹配麦克风采样率)
if source_sample_rate and source_sample_rate != self.sample_rate:
audio_data = self._resample(audio_data, source_sample_rate, self.sample_rate)
# 转换声道数1声道 -> 2声道匹配播放声道数
samples = struct.unpack(f'<{len(audio_data) // 2}h', audio_data)
if source_channels == 1 and self.channels == 2:
# 单声道转2声道复制到左右声道
stereo_samples = [s for sample in samples for s in [sample, sample]]
samples = stereo_samples
self.buffer.extend(samples)
def get_reference(self, num_samples: int) -> bytes:
"""获取参考信号(指定样本数,考虑声道数)"""
with self.lock:
if not self.buffer:
return b'\x00' * (num_samples * self.channels * 2)
# 需要的总样本数(考虑声道数)
total_samples_needed = num_samples * self.channels
samples = list(self.buffer)[-total_samples_needed:] if len(self.buffer) >= total_samples_needed else list(self.buffer)
if len(samples) < total_samples_needed:
samples = [0] * (total_samples_needed - len(samples)) + samples
return struct.pack(f'<{len(samples)}h', *samples)
def clear(self):
"""清空缓冲区"""
with self.lock:
self.buffer.clear()
def _resample(self, audio_data: bytes, source_rate: int, target_rate: int) -> bytes:
"""简单线性重采样"""
if source_rate == target_rate:
return audio_data
samples = np.frombuffer(audio_data, dtype=np.int16)
ratio = target_rate / source_rate
indices = np.linspace(0, len(samples) - 1, int(len(samples) * ratio))
resampled = np.interp(indices, np.arange(len(samples)), samples.astype(np.float32))
return resampled.astype(np.int16).tobytes()

View File

@@ -0,0 +1,4 @@
"""模型层"""

View File

@@ -0,0 +1,4 @@
"""ASR模型"""

View File

@@ -0,0 +1,12 @@
class ASRClient:
def start(self) -> bool:
raise NotImplementedError
def stop(self) -> bool:
raise NotImplementedError
def send_audio(self, audio_data: bytes) -> bool:
raise NotImplementedError

View File

@@ -7,9 +7,10 @@ import threading
import dashscope
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
from robot_speaker.models.asr.base import ASRClient
class DashScopeASR:
class DashScopeASR(ASRClient):
"""DashScope实时ASR识别器封装"""
def __init__(self, api_key: str,
@@ -65,15 +66,11 @@ class DashScopeASR:
callback.conversation = self.conversation
self.conversation.connect()
# 自定义文本语料增强识别,听不清的时候高概率说这个词
# custom_text = "二狗"
transcription_params = TranscriptionParams(
language='zh',
sample_rate=self.sample_rate,
input_audio_format="pcm",
# corpus_text=custom_text,
)
# 本地 VAD → 只控制 TTS 打断

View File

@@ -0,0 +1,4 @@
"""LLM模型"""

View File

@@ -0,0 +1,14 @@
from robot_speaker.core.types import LLMMessage
class LLMClient:
def chat(self, messages: list[LLMMessage]) -> str | None:
raise NotImplementedError
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
interrupt_check=None) -> str | None:
raise NotImplementedError

View File

@@ -3,18 +3,9 @@ LLM大语言模型模块
支持多模态文本+图像
"""
from openai import OpenAI
from .types import LLMMessage
from typing import Optional, List
class LLMClient:
def chat(self, messages: list[LLMMessage]) -> str | None:
raise NotImplementedError
def chat_stream(self, messages: list[LLMMessage],
on_token=None,
interrupt_check=None) -> str | None:
raise NotImplementedError
from robot_speaker.core.types import LLMMessage
from robot_speaker.models.llm.base import LLMClient
class DashScopeLLM(LLMClient):

View File

@@ -0,0 +1,4 @@
"""TTS模型"""

View File

@@ -0,0 +1,13 @@
from robot_speaker.core.types import TTSRequest
class TTSClient:
"""TTS客户端抽象基类"""
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
raise NotImplementedError

View File

@@ -4,16 +4,8 @@ TTS语音合成模块
import subprocess
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
from .types import TTSRequest
class TTSClient:
"""TTS客户端抽象基类"""
def synthesize(self, request: TTSRequest,
on_chunk=None,
interrupt_check=None) -> bool:
raise NotImplementedError
from robot_speaker.core.types import TTSRequest
from robot_speaker.models.tts.base import TTSClient
class DashScopeTTSClient(TTSClient):
@@ -205,8 +197,8 @@ class _TTSCallback(ResultCallback):
# 参考信号处理失败不应影响播放
self.tts_client._log("warning", f"参考信号处理失败: {e}")
if self.on_chunk:
self.on_chunk(data)
if self.on_chunk:
self.on_chunk(data)
def cleanup(self):
"""清理资源"""
@@ -221,21 +213,26 @@ class _TTSCallback(ResultCallback):
except:
pass
# 等待进程自然结束(最多2秒)
# 等待进程自然结束(根据文本长度估算最少10秒最多30秒)
# 假设平均语速3-4字/秒,加上缓冲时间
if self._proc.poll() is None:
try:
self._proc.wait(timeout=2.0)
# 增加等待时间确保ffmpeg播放完成
# 对于长文本,可能需要更长时间
self._proc.wait(timeout=30.0)
except:
# 超时后强制终止
try:
self._proc.terminate()
self._proc.wait(timeout=0.5)
except:
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
if self._proc.poll() is None:
self.tts_client._log("warning", "ffmpeg播放超时强制终止")
try:
self._proc.kill()
self._proc.wait(timeout=0.1)
self._proc.terminate()
self._proc.wait(timeout=1.0)
except:
pass
try:
self._proc.kill()
self._proc.wait(timeout=0.1)
except:
pass
# 清空PID记录
if self.tts_client.current_ffmpeg_pid == self._proc.pid:

View File

@@ -0,0 +1,4 @@
"""感知层"""

View File

@@ -18,13 +18,7 @@ class VADDetector:
class AudioRecorder:
"""
音频录音器 - 录音线程
功能
1. VAD + 能量检测
2. 检测到人声 立即中断TTS
3. 音频chunk 队列直接发送不保存文件
"""
"""音频录音器 - 录音线程"""
def __init__(self, device_index: int, sample_rate: int, channels: int,
chunk: int, vad_detector: VADDetector,
@@ -102,11 +96,7 @@ class AudioRecorder:
self.echo_canceller = None
def record_with_vad(self):
"""
录音线程VAD + 能量检测
- 检测到人声 立即中断TTS
- 音频chunk 队列直接发送不保存文件
"""
"""录音线程VAD + 能量检测"""
if self.on_heartbeat:
self.on_heartbeat()

View File

@@ -1,6 +1,5 @@
"""
相机模块 - RealSense相机封装
相机默认一直运行只在用户说拍照时捕获图像
"""
import numpy as np
import contextlib
@@ -35,7 +34,6 @@ class CameraClient:
def initialize(self) -> bool:
"""
初始化并启动相机管道
相机启动后会一直运行直到调用 cleanup()
"""
if self._is_initialized:
return True
@@ -58,7 +56,6 @@ class CameraClient:
self.fps
)
# 启动管道,相机开始运行
self.pipeline.start(self.config)
self._is_initialized = True
self._log("info", f"相机已启动并保持运行: {self.width}x{self.height}@{self.fps}fps")
@@ -80,7 +77,6 @@ class CameraClient:
def capture_rgb(self) -> np.ndarray | None:
"""
从运行中的相机管道捕获一帧RGB图像
相机管道必须已经通过 initialize() 启动
"""
if not self._is_initialized:
self._log("error", "相机未初始化,无法捕获图像")

View File

@@ -0,0 +1,98 @@
import collections
import numpy as np
class ReferenceSignalBuffer:
"""参考信号缓冲区"""
def __init__(self, sample_rate: int, channels: int, max_duration_ms: int | None = None,
buffer_seconds: float = 5.0):
self.sample_rate = int(sample_rate)
self.channels = int(channels)
if max_duration_ms is not None:
buffer_seconds = max(float(max_duration_ms) / 1000.0, 0.1)
self.max_samples = int(self.sample_rate * buffer_seconds)
self._buffer = collections.deque(maxlen=self.max_samples * self.channels)
def add_reference(self, data: bytes, source_sample_rate: int, source_channels: int):
if source_sample_rate != self.sample_rate or source_channels != self.channels:
return
samples = np.frombuffer(data, dtype=np.int16)
self._buffer.extend(samples.tolist())
def get_reference(self, num_samples: int) -> bytes:
needed = int(num_samples) * self.channels
if needed <= 0:
return b""
if len(self._buffer) < needed:
data = list(self._buffer) + [0] * (needed - len(self._buffer))
else:
data = list(self._buffer)[-needed:]
return np.array(data, dtype=np.int16).tobytes()
class EchoCanceller:
"""回声消除器(基于 aec-audio-processing"""
def __init__(self, sample_rate: int, frame_size: int, channels: int, ref_channels: int, logger=None):
self.sample_rate = int(sample_rate)
self.frame_size = int(frame_size)
self.channels = int(channels)
self.ref_channels = int(ref_channels)
self.logger = logger
self.aec = None
self._process_reverse = None
self._frame_bytes = int(self.sample_rate / 100) * self.channels * 2 # 10ms, int16
self._ref_frame_bytes = int(self.sample_rate / 100) * self.ref_channels * 2
try:
from aec_audio_processing import AudioProcessor
self.aec = AudioProcessor(enable_aec=True, enable_ns=False, enable_agc=False)
self.aec.set_stream_format(self.sample_rate, self.channels)
if hasattr(self.aec, "set_reverse_stream_format"):
self.aec.set_reverse_stream_format(self.sample_rate, self.ref_channels)
if hasattr(self.aec, "set_stream_delay"):
self.aec.set_stream_delay(0)
if hasattr(self.aec, "process_reverse_stream"):
self._process_reverse = self.aec.process_reverse_stream
elif hasattr(self.aec, "process_reverse"):
self._process_reverse = self.aec.process_reverse
except Exception:
self.aec = None
def process(self, mic_data: bytes, ref_data: bytes) -> bytes:
if not self.aec:
return mic_data
if not mic_data:
return mic_data
try:
out_chunks = []
total_len = len(mic_data)
frame_bytes = self._frame_bytes
ref_frame_bytes = self._ref_frame_bytes
frame_count = (total_len + frame_bytes - 1) // frame_bytes
for i in range(frame_count):
m_start = i * frame_bytes
m_end = m_start + frame_bytes
mic_frame = mic_data[m_start:m_end]
if len(mic_frame) < frame_bytes:
mic_frame = mic_frame + b"\x00" * (frame_bytes - len(mic_frame))
if ref_data:
r_start = i * ref_frame_bytes
r_end = r_start + ref_frame_bytes
ref_frame = ref_data[r_start:r_end]
if len(ref_frame) < ref_frame_bytes:
ref_frame = ref_frame + b"\x00" * (ref_frame_bytes - len(ref_frame))
if self._process_reverse:
self._process_reverse(ref_frame)
processed = self.aec.process_stream(mic_frame)
out_chunks.append(processed if processed is not None else mic_frame)
return b"".join(out_chunks)[:total_len]
except Exception as e:
if self.logger:
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
return mic_data

View File

@@ -29,12 +29,8 @@ class SpeakerVerificationClient:
self.logger = logger
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
self._lock = threading.Lock()
# 固定embedding维度避免第一次异常embedding导致后续全部被拒绝
# 常见维度192 (CAM++), 256 等
self._expected_embedding_dim = None # 只存储维度大小不存储shape元组
from funasr import AutoModel
# 确保模型路径是绝对路径(展开 ~
model_path = os.path.expanduser(self.model_path)
self.model = AutoModel(model=model_path, device="cpu")
if self.logger:
@@ -47,7 +43,6 @@ class SpeakerVerificationClient:
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
if self.logger:
try:
# 使用映射字典避免动态获取方法导致ROS2 logger错误
log_methods = {
"debug": self.logger.debug,
"info": self.logger.info,
@@ -58,8 +53,6 @@ class SpeakerVerificationClient:
log_method = log_methods.get(level.lower(), self.logger.info)
log_method(msg)
except ValueError as e:
# ROS2 logger在多线程环境中可能出现"Logger severity cannot be changed between calls"错误
# 如果出现此错误降级使用info级别
if "severity cannot be changed" in str(e):
try:
self.logger.info(f"[声纹-{level.upper()}] {msg}")
@@ -88,9 +81,6 @@ class SpeakerVerificationClient:
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
"""
提取说话人embedding低频调用一句话只调用一次
返回: (embedding: np.ndarray | None, success: bool)
- 成功返回 (embedding, True)
- 失败返回 (None, False)
"""
if len(audio_data) < int(sample_rate * 0.5):
return None, False
@@ -100,25 +90,13 @@ class SpeakerVerificationClient:
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
result = self.model.generate(input=temp_wav_path)
# funasr返回格式list[dict]dict键为'spk_embedding'值为torch.Tensor(shape=[1, 192])
import torch
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
# 校验embedding维度
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, False
# 如果已有注册的声纹,校验维度是否一致
if self._expected_embedding_dim is not None:
if embedding_dim != self._expected_embedding_dim:
self._log("error", f"embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
return None, False
else:
# 第一次成功提取固定维度只在第一次成功时设置避免异常embedding污染
self._expected_embedding_dim = embedding_dim
self._log("info", f"固定embedding维度: {embedding_dim}")
return embedding, True
except Exception as e:
self._log("error", f"提取embedding失败: {e}")
@@ -134,22 +112,10 @@ class SpeakerVerificationClient:
env: str = "near", threshold: float = None) -> bool:
"""
注册说话人
关键修复注册时对embedding进行归一化一劳永逸
"""
embedding_dim = len(embedding)
if embedding_dim == 0:
return False
# 校验维度一致性
if self._expected_embedding_dim is not None:
if embedding_dim != self._expected_embedding_dim:
self._log("error", f"注册失败embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
return False
else:
# 如果还没有固定维度使用当前embedding的维度
self._expected_embedding_dim = embedding_dim
# 关键修复注册时归一化embedding
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
self._log("error", f"注册失败embedding范数为0")
@@ -159,26 +125,21 @@ class SpeakerVerificationClient:
speaker_threshold = threshold if threshold is not None else self.threshold
with self._lock:
# 使用dict结构存储便于序列化
self.speaker_db[speaker_id] = {
"embedding": embedding_normalized, # 已归一化
"env": env,
"embedding": embedding_normalized,
"env": env, # 添加 env 字段
"threshold": speaker_threshold,
"registered_at": time.time()
}
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
# 在锁外调用save_speakers避免死锁save_speakers内部会获取锁
save_result = self.save_speakers()
if not save_result:
# 使用info级别而不是warning避免ROS2 logger在多线程环境中的问题
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
return True
def match_speaker(self, embedding: np.ndarray):
"""
匹配说话人一句话只调用一次
返回: (speaker_id: str | None, state: SpeakerState, score: float, threshold: float)
"""
if not self.speaker_db:
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
@@ -186,12 +147,7 @@ class SpeakerVerificationClient:
embedding_dim = len(embedding)
if embedding_dim == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
# 校验维度一致性
if self._expected_embedding_dim is not None and embedding_dim != self._expected_embedding_dim:
return None, SpeakerState.ERROR, 0.0, self.threshold
# 归一化当前embedding注册时已归一化这里只需要归一化当前embedding
embedding_norm = np.linalg.norm(embedding)
if embedding_norm == 0:
return None, SpeakerState.ERROR, 0.0, self.threshold
@@ -203,13 +159,7 @@ class SpeakerVerificationClient:
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
# 从dict结构读取
ref_embedding = speaker_data["embedding"] # 注册时已归一化
if len(ref_embedding) != embedding_dim:
continue
# 直接计算cosine similarity两个向量都已归一化
ref_embedding = speaker_data["embedding"]
score = np.dot(embedding_normalized, ref_embedding)
if score > best_score:
@@ -242,8 +192,6 @@ class SpeakerVerificationClient:
def load_speakers(self) -> bool:
"""
从文件加载已注册的声纹
使用JSON格式存储更安全可迁移
格式: {speaker_id: {"embedding": [list], "env": str, "threshold": float, "registered_at": float}}
"""
if not self.speaker_db_path:
return False
@@ -257,25 +205,14 @@ class SpeakerVerificationClient:
data = json.load(f)
with self._lock:
# 转换list回numpy array并校验维度
for speaker_id, speaker_data in data.items():
embedding_list = speaker_data["embedding"]
embedding_array = np.array(embedding_list, dtype=np.float32)
# 校验维度
embedding_dim = len(embedding_array)
if embedding_dim == 0:
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
continue
# 设置期望维度(如果还没有)
if self._expected_embedding_dim is None:
self._expected_embedding_dim = embedding_dim
elif embedding_dim != self._expected_embedding_dim:
self._log("warning", f"跳过维度不匹配的声纹: {speaker_id} (期望{self._expected_embedding_dim}, 实际{embedding_dim})")
continue
# 确保embedding已归一化
embedding_norm = np.linalg.norm(embedding_array)
if embedding_norm > 0:
embedding_array = embedding_array / embedding_norm
@@ -288,10 +225,7 @@ class SpeakerVerificationClient:
}
count = len(self.speaker_db)
if self._expected_embedding_dim:
self._log("info", f"已加载 {count} 个已注册说话人embedding维度: {self._expected_embedding_dim}")
else:
self._log("info", f"已加载 {count} 个已注册说话人")
self._log("info", f"已加载 {count} 个已注册说话人")
return True
except Exception as e:
self._log("error", f"加载声纹数据库失败: {e}")
@@ -300,7 +234,6 @@ class SpeakerVerificationClient:
def save_speakers(self) -> bool:
"""
保存已注册的声纹到文件
使用JSON格式更安全可迁移
"""
if not self.speaker_db_path:
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
@@ -310,24 +243,20 @@ class SpeakerVerificationClient:
db_dir = os.path.dirname(self.speaker_db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
# 转换为JSON可序列化的格式
json_data = {}
with self._lock:
for speaker_id, speaker_data in self.speaker_db.items():
json_data[speaker_id] = {
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
"env": speaker_data["env"],
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
"threshold": speaker_data["threshold"],
"registered_at": speaker_data["registered_at"]
}
# 使用临时文件 + 原子替换,避免写入过程中崩溃导致数据丢失
temp_path = self.speaker_db_path + ".tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
# 原子替换
os.replace(temp_path, self.speaker_db_path)
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
@@ -337,7 +266,6 @@ class SpeakerVerificationClient:
self._log("error", f"保存声纹数据库失败: {e}")
self._log("error", f"保存路径: {self.speaker_db_path}")
self._log("error", f"错误详情: {traceback.format_exc()}")
# 清理临时文件
temp_path = self.speaker_db_path + ".tmp"
if os.path.exists(temp_path):
try:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
"""理解层"""

View File

@@ -1,22 +1,14 @@
"""
对话历史管理模块
"""
from .types import LLMMessage
from robot_speaker.core.types import LLMMessage
import threading
class ConversationHistory:
"""
对话历史管理器 - 实时语音友好版本
"""对话历史管理器 - 实时语音"""
使用待确认机制确保历史完整性
1. start_turn() - 开始新轮次暂存用户消息
2. get_messages() - 获取历史包含待确认的用户消息用于LLM上下文
3. commit_turn() - 确认轮次完成写入历史
4. cancel_turn() - 取消当前轮次丢弃待确认消息
"""
def __init__(self, max_history: int = 3, summary_trigger: int = 3):
def __init__(self, max_history: int, summary_trigger: int):
self.max_history = max_history
self.summary_trigger = summary_trigger
self.conversation_history: list[LLMMessage] = []
@@ -27,53 +19,38 @@ class ConversationHistory:
self._lock = threading.Lock() # 线程安全锁
def start_turn(self, user_content: str):
"""
开始一个新的对话轮次,暂存用户消息等待LLM完成后确认写入历史
"""
"""开始一个新的对话轮次,暂存用户消息等待LLM完成后确认写入历史"""
with self._lock:
# 如果有未确认的轮次,新消息会覆盖它(不写入历史,防止半句污染)
# 这是正常的场景,比如用户快速连续说话,不需要特殊处理
self._pending_user_message = LLMMessage(role="user", content=user_content)
def commit_turn(self, assistant_content: str) -> bool:
"""
确认当前轮次完成将用户和助手消息写入历史
"""
"""确认当前轮次完成将usr和assistant消息写入历史"""
with self._lock:
if self._pending_user_message is None:
return False
# 只有助手回复非空时才写入历史
if not assistant_content or not assistant_content.strip():
self._pending_user_message = None
return False
# 写入用户消息和助手回复
self.conversation_history.append(self._pending_user_message)
self.conversation_history.append(
LLMMessage(role="assistant", content=assistant_content.strip())
)
# 清空待确认消息
self._pending_user_message = None
# 检查是否需要压缩
self._maybe_compress()
return True
def cancel_turn(self):
"""
取消当前待确认的轮次丢弃待确认的用户消息,用于处理中断情况防止不完整内容污染历史
"""
"""取消当前待确认的轮次,丢弃待确认的用户消息,用于处理中断情况,防止不完整内容污染历史"""
with self._lock:
if self._pending_user_message is not None:
self._pending_user_message = None
def add_message(self, role: str, content: str):
"""
直接添加消息向后兼容但推荐使用 start_turn/commit_turn
注意此方法会立即写入历史不会经过待确认机制
"""
"""直接添加消息"""
with self._lock:
# 如果有待确认的轮次,先取消它
self.cancel_turn()
@@ -81,21 +58,16 @@ class ConversationHistory:
self._maybe_compress()
def get_messages(self) -> list[LLMMessage]:
"""
获取消息列表包含摘要和待确认的用户消息
"""
"""获取消息列表"""
with self._lock:
messages = []
# 添加摘要
if self.summary:
messages.append(LLMMessage(role="system", content=self.summary))
# 添加历史消息
if self.max_history > 0:
messages.extend(self.conversation_history[-self.max_history * 2:])
# 添加待确认的用户消息用于LLM上下文但不写入历史
if self._pending_user_message is not None:
messages.append(self._pending_user_message)

View File

@@ -1,4 +1,4 @@
from setuptools import setup
from setuptools import setup, find_packages
import os
from glob import glob
@@ -7,14 +7,14 @@ package_name = 'robot_speaker'
setup(
name=package_name,
version='0.0.1',
packages=[package_name],
packages=find_packages(where='.'),
package_dir={'': '.'},
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
],
install_requires=[
'setuptools',
@@ -28,7 +28,8 @@ setup(
tests_require=['pytest'],
entry_points={
'console_scripts': [
'robot_speaker_node = robot_speaker.robot_speaker_node:main',
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
],
},
)