代码重构,区分声纹注册和主节点
This commit is contained in:
15
README.md
15
README.md
@@ -19,6 +19,20 @@ pip3 install -r requirements.txt --break-system-packages
|
||||
```
|
||||
|
||||
## 编译启动
|
||||
1. 注册声纹
|
||||
- 启动节点后可以说:二狗今天天气真好开始注册声纹
|
||||
- 注意:要包含唤醒词,语句不要停顿,尽量大于3秒
|
||||
```bash
|
||||
cd ~/ros_learn/hivecore_robot_voice
|
||||
colcon build
|
||||
source install/setup.bash
|
||||
ros2 run robot_speaker register_speaker_node
|
||||
```
|
||||
|
||||
2. 主节点
|
||||
- 启动节点后每句交互包含唤醒词,唤醒词和语句之间不要有停顿
|
||||
- 二狗拍照看看开启图文交互
|
||||
- 支持已注册声纹用户打断
|
||||
```bash
|
||||
cd ~/ros_learn/hivecore_robot_voice
|
||||
colcon build
|
||||
@@ -78,3 +92,4 @@ rs-enumerate-devices -c
|
||||
```bash
|
||||
modelscope download --model iic/speech_campplus_sv_zh-cn_16k-common --local_dir [指定路径]
|
||||
```
|
||||
|
||||
|
||||
18
config/knowledge.json
Normal file
18
config/knowledge.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"entries": [
|
||||
{
|
||||
"id": "robot_identity",
|
||||
"patterns": [
|
||||
"ni shi shei"
|
||||
],
|
||||
"answer": "我叫二狗,是蜂核科技的机器人,很高兴为你服务"
|
||||
},
|
||||
{
|
||||
"id": "wake_word",
|
||||
"patterns": [
|
||||
"ni de ming zi"
|
||||
],
|
||||
"answer": "我的名字是二狗"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -197,5 +197,403 @@
|
||||
"env": "near",
|
||||
"threshold": 0.4,
|
||||
"registered_at": 1768311644.5742264
|
||||
},
|
||||
"user_1768529827": {
|
||||
"embedding": [
|
||||
0.0077949948608875275,
|
||||
-0.012852567248046398,
|
||||
0.0014490776229649782,
|
||||
0.088177390396595,
|
||||
-0.052150458097457886,
|
||||
-0.1070166826248169,
|
||||
-0.051932964473962784,
|
||||
0.040730226784944534,
|
||||
0.09491471946239471,
|
||||
-0.10504328459501266,
|
||||
-0.17986123263835907,
|
||||
0.06056514009833336,
|
||||
0.0002809118013828993,
|
||||
-0.05353177338838577,
|
||||
-0.08724740147590637,
|
||||
-0.01057526096701622,
|
||||
-0.10766296088695526,
|
||||
0.024376090615987778,
|
||||
-0.11535818874835968,
|
||||
0.12653452157974243,
|
||||
-0.0063497889786958694,
|
||||
-0.02372283861041069,
|
||||
-0.049704890698194504,
|
||||
0.01079346239566803,
|
||||
-0.10683158040046692,
|
||||
0.00932641327381134,
|
||||
0.043871842324733734,
|
||||
0.04073511064052582,
|
||||
0.005968529265373945,
|
||||
0.05397576093673706,
|
||||
0.07122175395488739,
|
||||
0.06804963946342468,
|
||||
-0.058389563113451004,
|
||||
-0.03463176265358925,
|
||||
-0.06834574788808823,
|
||||
-0.09127284586429596,
|
||||
-0.09805246442556381,
|
||||
-0.015370666980743408,
|
||||
-0.07054834067821503,
|
||||
-0.07520422339439392,
|
||||
-0.0502505861222744,
|
||||
0.01580144092440605,
|
||||
0.04316972196102142,
|
||||
-0.010298517532646656,
|
||||
-0.09042523056268692,
|
||||
-0.03399325907230377,
|
||||
0.03738871216773987,
|
||||
0.09461583197116852,
|
||||
0.07643604278564453,
|
||||
-0.04089711233973503,
|
||||
0.14397914707660675,
|
||||
-0.03218085318803787,
|
||||
-0.03981873393058777,
|
||||
-0.05353623256087303,
|
||||
-0.06475386023521423,
|
||||
0.047925639897584915,
|
||||
0.008481102995574474,
|
||||
0.09522885829210281,
|
||||
0.05679373815655708,
|
||||
0.021448519080877304,
|
||||
0.04586802423000336,
|
||||
0.007880095392465591,
|
||||
-0.08111433684825897,
|
||||
-0.030093876644968987,
|
||||
0.18197935819625854,
|
||||
0.049670975655317307,
|
||||
-0.029350068420171738,
|
||||
0.1003178134560585,
|
||||
0.05890532210469246,
|
||||
-0.0418926365673542,
|
||||
-0.015124992467463017,
|
||||
-0.0016869385726749897,
|
||||
0.029022999107837677,
|
||||
0.10370466858148575,
|
||||
-0.07392475008964539,
|
||||
-0.041242245584726334,
|
||||
0.0948185846209526,
|
||||
0.0766805037856102,
|
||||
0.12104924768209457,
|
||||
0.07941737771034241,
|
||||
-0.024586958810687065,
|
||||
-0.005290709435939789,
|
||||
0.08198735862970352,
|
||||
-0.15709130465984344,
|
||||
0.11847008019685745,
|
||||
0.01280289888381958,
|
||||
0.09401026368141174,
|
||||
0.10199982672929764,
|
||||
0.00811630580574274,
|
||||
0.09336159378290176,
|
||||
-0.1219155564904213,
|
||||
0.00885648000985384,
|
||||
0.08536995947360992,
|
||||
-0.031735390424728394,
|
||||
-0.02445235848426819,
|
||||
0.17981232702732086,
|
||||
0.05046188458800316,
|
||||
-0.012413986958563328,
|
||||
-0.16514025628566742,
|
||||
-0.09369593858718872,
|
||||
0.03961285203695297,
|
||||
-0.024150250479578972,
|
||||
0.024869512766599655,
|
||||
0.009099201299250126,
|
||||
0.0023227918427437544,
|
||||
0.005291149020195007,
|
||||
-0.08285452425479889,
|
||||
0.02174258604645729,
|
||||
-0.00018321558309253305,
|
||||
-0.01761690340936184,
|
||||
-0.13327360153198242,
|
||||
0.07804469764232635,
|
||||
-0.03172646835446358,
|
||||
0.05993621423840523,
|
||||
-0.0034280805848538876,
|
||||
0.09203101694583893,
|
||||
0.04720155894756317,
|
||||
-0.12012632191181183,
|
||||
-0.028879230841994286,
|
||||
-0.04471825063228607,
|
||||
-0.08928379416465759,
|
||||
-0.055793069303035736,
|
||||
-0.0230169165879488,
|
||||
0.04459748789668083,
|
||||
-0.08481008559465408,
|
||||
0.09873232245445251,
|
||||
-0.057500336319208145,
|
||||
-0.05438977852463722,
|
||||
0.06309207528829575,
|
||||
-0.045493170619010925,
|
||||
-0.0636027380824089,
|
||||
-0.03580763190984726,
|
||||
-0.043026816099882126,
|
||||
0.04125182330608368,
|
||||
-0.06327074766159058,
|
||||
0.02830875851213932,
|
||||
-0.0697140172123909,
|
||||
-0.11324217170476913,
|
||||
-0.02744743973016739,
|
||||
-0.09659717977046967,
|
||||
-0.036915868520736694,
|
||||
0.06836548447608948,
|
||||
-0.19481360912322998,
|
||||
-0.08151774108409882,
|
||||
0.013570327311754227,
|
||||
-0.013908851891756058,
|
||||
-0.02302597463130951,
|
||||
-0.14017312228679657,
|
||||
-0.0654999315738678,
|
||||
0.0582318976521492,
|
||||
-0.023702487349510193,
|
||||
-0.046911414712667465,
|
||||
-0.02062028832733631,
|
||||
0.09885907918214798,
|
||||
-0.010111358016729355,
|
||||
-0.009303858503699303,
|
||||
-0.07802718877792358,
|
||||
0.09181840717792511,
|
||||
-0.00822418462485075,
|
||||
-0.024477459490299225,
|
||||
0.04909557104110718,
|
||||
0.024657243862748146,
|
||||
0.08074013143777847,
|
||||
0.10684694349765778,
|
||||
-0.009657780639827251,
|
||||
0.04053448513150215,
|
||||
-0.054968591779470444,
|
||||
0.09773849695920944,
|
||||
-0.019937219098210335,
|
||||
-0.11860335618257523,
|
||||
-0.12553851306438446,
|
||||
0.0016870739636942744,
|
||||
0.07446407526731491,
|
||||
-0.12183381617069244,
|
||||
-0.07524612545967102,
|
||||
0.06794209778308868,
|
||||
-0.04324038699269295,
|
||||
-0.018201345577836037,
|
||||
-0.08356837183237076,
|
||||
0.08218713104724884,
|
||||
-0.1253940612077713,
|
||||
-0.05880133807659149,
|
||||
0.11516888439655304,
|
||||
-0.007864559069275856,
|
||||
0.06438153237104416,
|
||||
-0.06551646441221237,
|
||||
0.11812424659729004,
|
||||
-0.07544125616550446,
|
||||
0.033888354897499084,
|
||||
0.02552076056599617,
|
||||
0.019394448027014732,
|
||||
-0.009682931937277317
|
||||
],
|
||||
"env": "near",
|
||||
"threshold": 0.55,
|
||||
"registered_at": 1768529827.4784193
|
||||
},
|
||||
"user_1768530001": {
|
||||
"embedding": [
|
||||
-0.02827363647520542,
|
||||
0.04181317239999771,
|
||||
-0.07721243053674698,
|
||||
0.031220311298966408,
|
||||
-0.006549456622451544,
|
||||
-0.045262161642313004,
|
||||
-0.06796529144048691,
|
||||
0.10546170920133591,
|
||||
-0.054266564548015594,
|
||||
-0.04982651397585869,
|
||||
0.008982052095234394,
|
||||
0.0887555256485939,
|
||||
-0.03736695274710655,
|
||||
-0.027568811550736427,
|
||||
-0.01881324127316475,
|
||||
-0.030173255130648613,
|
||||
-0.03817622363567352,
|
||||
-0.027703644707798958,
|
||||
-0.020354237407445908,
|
||||
0.08958664536476135,
|
||||
0.027346525341272354,
|
||||
-0.007979321293532848,
|
||||
-0.01638970896601677,
|
||||
0.14815205335617065,
|
||||
-0.029478076845407486,
|
||||
0.0968138799071312,
|
||||
0.011266525834798813,
|
||||
0.10481037944555283,
|
||||
0.006314543075859547,
|
||||
-0.07480890303850174,
|
||||
-0.126618891954422,
|
||||
0.054260920733213425,
|
||||
-0.054261378943920135,
|
||||
0.02066616155207157,
|
||||
0.056972429156303406,
|
||||
-0.02620418183505535,
|
||||
-0.08435375243425369,
|
||||
-0.06768523901700974,
|
||||
-0.001804384752176702,
|
||||
-0.03350691497325897,
|
||||
-0.06783927977085114,
|
||||
0.09583555907011032,
|
||||
0.042077258229255676,
|
||||
-0.03811662644147873,
|
||||
-0.09298640489578247,
|
||||
0.11314687132835388,
|
||||
0.06972789764404297,
|
||||
-0.10421980172395706,
|
||||
0.02739877998828888,
|
||||
-0.06242597475647926,
|
||||
0.06683704257011414,
|
||||
0.030034003779292107,
|
||||
-0.04094783961772919,
|
||||
0.08657337725162506,
|
||||
0.02882716991007328,
|
||||
0.07672230899333954,
|
||||
-0.0162385031580925,
|
||||
0.12335177510976791,
|
||||
-0.07505486160516739,
|
||||
0.05924128741025925,
|
||||
0.02278822474181652,
|
||||
0.051575034856796265,
|
||||
-0.07616295665502548,
|
||||
-0.049982234835624695,
|
||||
-0.021159915253520012,
|
||||
0.023469945415854454,
|
||||
-0.008445728570222855,
|
||||
0.18868982791900635,
|
||||
0.10217619687318802,
|
||||
0.0029947187285870314,
|
||||
0.003596147522330284,
|
||||
-0.010885344818234444,
|
||||
0.002336243400350213,
|
||||
-0.06228164955973625,
|
||||
-0.09452632069587708,
|
||||
0.06288570165634155,
|
||||
0.09799493104219437,
|
||||
0.05772380530834198,
|
||||
-0.012649190612137318,
|
||||
0.037833958864212036,
|
||||
-0.07815677672624588,
|
||||
0.11595622450113297,
|
||||
-0.006132716778665781,
|
||||
-0.047689273953437805,
|
||||
0.10451581329107285,
|
||||
0.12618094682693481,
|
||||
-0.012135603465139866,
|
||||
-0.14452683925628662,
|
||||
-0.011882219463586807,
|
||||
0.05687599256634712,
|
||||
-0.10221579670906067,
|
||||
0.09555421024560928,
|
||||
0.050166770815849304,
|
||||
0.026791365817189217,
|
||||
0.0343380831182003,
|
||||
0.0643647089600563,
|
||||
-0.09814899414777756,
|
||||
-0.01735001988708973,
|
||||
0.0002968672488350421,
|
||||
-0.16691210865974426,
|
||||
-0.044747937470674515,
|
||||
0.10229559987783432,
|
||||
0.01551489345729351,
|
||||
0.0614253506064415,
|
||||
-0.012457458302378654,
|
||||
-0.059297215193510056,
|
||||
-0.0662546306848526,
|
||||
0.06900843977928162,
|
||||
-0.15012530982494354,
|
||||
0.14357514679431915,
|
||||
-0.08563537150621414,
|
||||
0.1512402445077896,
|
||||
-0.05548126623034477,
|
||||
-0.13191379606723785,
|
||||
0.02588576264679432,
|
||||
-0.007292638067156076,
|
||||
-0.033004030585289,
|
||||
-0.08764250576496124,
|
||||
-0.04006534814834595,
|
||||
0.001069005811586976,
|
||||
0.0708790197968483,
|
||||
-0.11471016705036163,
|
||||
-0.08249906450510025,
|
||||
-0.07923658937215805,
|
||||
-0.029890256002545357,
|
||||
0.027568599209189415,
|
||||
-0.00042784016113728285,
|
||||
0.01911524124443531,
|
||||
0.002947323489934206,
|
||||
-0.058468904346227646,
|
||||
0.0006662740488536656,
|
||||
-0.09472604095935822,
|
||||
-0.07827164232730865,
|
||||
0.05823435261845589,
|
||||
-0.022661248221993446,
|
||||
0.007729553151875734,
|
||||
0.044511985033750534,
|
||||
-0.17424426972866058,
|
||||
-0.054321326315402985,
|
||||
-0.010871038772165775,
|
||||
-0.04280569776892662,
|
||||
0.01373684499412775,
|
||||
-0.03464324399828911,
|
||||
0.0012510031228885055,
|
||||
-0.13786448538303375,
|
||||
0.13943427801132202,
|
||||
0.07161138951778412,
|
||||
-0.0017689999658614397,
|
||||
-0.0330035537481308,
|
||||
0.01767006888985634,
|
||||
-0.06832484155893326,
|
||||
-0.16906532645225525,
|
||||
-0.08673631399869919,
|
||||
0.016205811873078346,
|
||||
-0.040736377239227295,
|
||||
-0.053034041076898575,
|
||||
-0.057571377605199814,
|
||||
-0.018383856862783432,
|
||||
0.029812879860401154,
|
||||
-0.005708644632250071,
|
||||
0.07977750152349472,
|
||||
0.03715944290161133,
|
||||
0.029830463230609894,
|
||||
-0.15909501910209656,
|
||||
0.10081987082958221,
|
||||
0.07019384205341339,
|
||||
0.05683498457074165,
|
||||
0.008955223485827446,
|
||||
-0.06697771698236465,
|
||||
0.044268134981393814,
|
||||
0.08812808990478516,
|
||||
-0.17523430287837982,
|
||||
0.05148027464747429,
|
||||
-0.11579684168100357,
|
||||
-0.06281758099794388,
|
||||
-0.08106749504804611,
|
||||
-0.07915353775024414,
|
||||
0.03760797902941704,
|
||||
-0.059639666229486465,
|
||||
0.012170189991593361,
|
||||
-0.028386766090989113,
|
||||
-0.043592486530542374,
|
||||
0.029122747480869293,
|
||||
0.052276406437158585,
|
||||
0.06929390132427216,
|
||||
-0.10774848610162735,
|
||||
0.06797030568122864,
|
||||
-0.017512541264295578,
|
||||
0.07446594536304474,
|
||||
-0.07573172450065613,
|
||||
-0.15186654031276703,
|
||||
-0.03710319101810455
|
||||
],
|
||||
"env": "near",
|
||||
"threshold": 0.55,
|
||||
"registered_at": 1768530001.2158406
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ dashscope:
|
||||
temperature: 0.7
|
||||
max_tokens: 4096
|
||||
max_history: 10
|
||||
summary_trigger: 3
|
||||
tts:
|
||||
model: "cosyvoice-v3-flash"
|
||||
voice: "longanyang"
|
||||
@@ -25,6 +26,8 @@ audio:
|
||||
soundcard:
|
||||
card_index: 1 # USB Audio Device (card 1)
|
||||
device_index: 0 # USB Audio [USB Audio] (device 0)
|
||||
# card_index: -1 # 使用默认声卡
|
||||
# device_index: -1 # 使用默认输出设备
|
||||
sample_rate: 44100 # 输出采样率:44.1kHz(支持48000或44100)
|
||||
channels: 2 # 输出声道数:立体声(2声道,FL+FR)
|
||||
volume: 1.0 # 音量比例(0.0-1.0,0.2表示20%音量)
|
||||
@@ -64,5 +67,3 @@ camera:
|
||||
image:
|
||||
jpeg_quality: 85 # JPEG压缩质量(0-100,85是质量和大小平衡点)
|
||||
max_size: "1280x720" # 最大尺寸
|
||||
commands:
|
||||
capture_keywords: "pai zhao,pai ge zhao,pai zhang zhao pian,pai zhang,da kai xiang ji,kan zhe li,zhao xiang" # 拍照相关指令(拼音,逗号分隔)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
dashscope>=1.20.0
|
||||
openai>=1.0.0
|
||||
pyaudio>=0.2.11
|
||||
webrtcvad>=2.0.10 # WebRTC VAD(语音活动检测),不包含回声消除
|
||||
webrtcvad>=2.0.10
|
||||
pypinyin>=0.49.0
|
||||
rclpy>=3.0.0
|
||||
pyrealsense2>=2.54.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
# 回声消除库(可选):
|
||||
# aec-audio-processing - 专门用于回声消除的WebRTC库,API简单(推荐)
|
||||
# pip install aec-audio-processing
|
||||
# 如果未安装,将使用内置的简单自适应算法
|
||||
PyYAML>=6.0
|
||||
aec-audio-processing
|
||||
modelscope>=1.33.0
|
||||
datasets>=3.6.0
|
||||
|
||||
|
||||
|
||||
|
||||
4
robot_speaker/core/__init__.py
Normal file
4
robot_speaker/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""核心模块"""
|
||||
|
||||
|
||||
|
||||
10
robot_speaker/core/conversation_state.py
Normal file
10
robot_speaker/core/conversation_state.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""会话状态机"""
|
||||
IDLE = "idle" # 等待用户唤醒或声音
|
||||
CHECK_VOICE = "check_voice" # 用户说话 → 检查声纹
|
||||
AUTHORIZED = "authorized" # 已注册用户
|
||||
|
||||
|
||||
127
robot_speaker/core/intent_router.py
Normal file
127
robot_speaker/core/intent_router.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from pypinyin import pinyin, Style
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentResult:
|
||||
intent: str # "skill_sequence" | "kb_qa" | "chat_text" | "chat_camera"
|
||||
text: str
|
||||
need_camera: bool
|
||||
camera_mode: Optional[str] # "head" | "left_hand" | "right_hand" | None
|
||||
system_prompt: Optional[str]
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
def __init__(self):
|
||||
self.camera_capture_keywords = [
|
||||
"pai zhao", "pai ge zhao", "pai zhang zhao"
|
||||
]
|
||||
self.skill_keywords = [
|
||||
"ban xiang zi"
|
||||
]
|
||||
self.kb_keywords = [
|
||||
"ni shi shei", "ni de ming zi"
|
||||
]
|
||||
|
||||
def to_pinyin(self, text: str) -> str:
|
||||
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
|
||||
if not chars:
|
||||
return ""
|
||||
py_list = pinyin(''.join(chars), style=Style.NORMAL)
|
||||
return ' '.join([item[0] for item in py_list]).lower().strip()
|
||||
|
||||
def is_skill_sequence_intent(self, text: str) -> bool:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
return any(k in text_pinyin for k in self.skill_keywords)
|
||||
|
||||
|
||||
def check_camera_command(self, text: str) -> tuple[bool, Optional[str]]:
|
||||
if not text:
|
||||
return False, None
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
for keyword in self.camera_capture_keywords:
|
||||
if keyword in text_pinyin:
|
||||
return True, self.detect_camera_mode(text)
|
||||
return False, None
|
||||
|
||||
def detect_camera_mode(self, text: str) -> str:
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
left_keys = ["zuo shou", "zuo bi", "zuo bian"]
|
||||
right_keys = ["you shou", "you bi", "you bian"]
|
||||
head_keys = ["tou", "nao dai"]
|
||||
for kw in left_keys:
|
||||
if kw in text_pinyin:
|
||||
return "left_hand"
|
||||
for kw in right_keys:
|
||||
if kw in text_pinyin:
|
||||
return "right_hand"
|
||||
for kw in head_keys:
|
||||
if kw in text_pinyin:
|
||||
return "head"
|
||||
return "head"
|
||||
|
||||
def build_skill_prompt(self) -> str:
|
||||
return (
|
||||
"你是机器人任务规划器。\n"
|
||||
"本任务必须拍照。请根据用户请求选择使用哪个相机拍照(默认头部相机),并结合当前环境信息生成简洁、可执行的技能序列。"
|
||||
)
|
||||
|
||||
def build_chat_prompt(self, need_camera: bool) -> str:
|
||||
if need_camera:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请结合图片内容简短回答。"
|
||||
)
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"请自然、简短地与用户对话。"
|
||||
)
|
||||
|
||||
def build_kb_prompt(self) -> str:
|
||||
return (
|
||||
"你是蜂核科技的员工。\n"
|
||||
"请基于知识库信息回答用户问题,回答要准确简洁。"
|
||||
)
|
||||
|
||||
def build_default_system_prompt(self) -> str:
|
||||
return (
|
||||
"你是一个智能语音助手。\n"
|
||||
"- 当用户发送图片时,请仔细观察图片内容,结合用户的问题或描述,提供简短、专业的回答。\n"
|
||||
"- 当用户没有发送图片时,请自然、友好地与用户对话。\n"
|
||||
"请根据对话模式调整你的回答风格。"
|
||||
)
|
||||
|
||||
def route(self, text: str) -> IntentResult:
|
||||
need_camera, camera_mode = self.check_camera_command(text)
|
||||
text_pinyin = self.to_pinyin(text)
|
||||
|
||||
if self.is_skill_sequence_intent(text):
|
||||
if camera_mode is None:
|
||||
camera_mode = "head"
|
||||
return IntentResult(
|
||||
intent="skill_sequence",
|
||||
text=text,
|
||||
need_camera=True,
|
||||
camera_mode=camera_mode,
|
||||
system_prompt=self.build_skill_prompt()
|
||||
)
|
||||
|
||||
if any(k in text_pinyin for k in self.kb_keywords):
|
||||
return IntentResult(
|
||||
intent="kb_qa",
|
||||
text=text,
|
||||
need_camera=False,
|
||||
camera_mode=None,
|
||||
system_prompt=self.build_kb_prompt()
|
||||
)
|
||||
|
||||
return IntentResult(
|
||||
intent="chat_camera" if need_camera else "chat_text",
|
||||
text=text,
|
||||
need_camera=need_camera,
|
||||
camera_mode=camera_mode,
|
||||
system_prompt=self.build_chat_prompt(need_camera)
|
||||
)
|
||||
|
||||
246
robot_speaker/core/node_callbacks.py
Normal file
246
robot_speaker/core/node_callbacks.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeCallbacks:
|
||||
# ==================== 初始化与内部工具 ====================
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def _mark_utterance_processed(self) -> bool:
|
||||
node = self.node
|
||||
with node.utterance_lock:
|
||||
if node.current_utterance_id == node.last_processed_utterance_id:
|
||||
return False
|
||||
node.last_processed_utterance_id = node.current_utterance_id
|
||||
return True
|
||||
|
||||
def _trigger_sv_for_check_voice(self, source: str):
|
||||
node = self.node
|
||||
if not (node.sv_enabled and node.sv_client):
|
||||
return
|
||||
if not self._mark_utterance_processed():
|
||||
return
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info(f"[声纹] CHECK_VOICE状态,数据库为空,跳过声纹验证(来源: {source})")
|
||||
return
|
||||
if not node.sv_speech_end_event.is_set():
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().info(f"[声纹] {source}触发验证,缓冲区大小: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
if buffer_size > 0:
|
||||
node.sv_speech_end_event.set()
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹] 声纹验证已触发,跳过(来源: {source})")
|
||||
|
||||
# ==================== 业务逻辑代理 ====================
|
||||
def handle_interrupt_command(self, msg):
|
||||
return self.node._handle_interrupt_command(msg)
|
||||
|
||||
def check_interrupt_and_cancel_turn(self) -> bool:
|
||||
return self.node._check_interrupt_and_cancel_turn()
|
||||
|
||||
def handle_wake_word(self, text: str) -> str:
|
||||
return self.node._handle_wake_word(text)
|
||||
|
||||
def check_shutup_command(self, text: str) -> bool:
|
||||
return self.node._check_shutup_command(text)
|
||||
|
||||
def check_camera_command(self, text: str):
|
||||
return self.node.intent_router.check_camera_command(text)
|
||||
|
||||
def llm_process_stream_with_camera(self, user_text: str, need_camera: bool) -> str:
|
||||
return self.node._llm_process_stream_with_camera(user_text, need_camera)
|
||||
|
||||
def put_tts_text(self, text: str):
|
||||
return self.node._put_tts_text(text)
|
||||
|
||||
def force_stop_tts(self):
|
||||
return self.node._force_stop_tts()
|
||||
|
||||
def drain_queue(self, q):
|
||||
return self.node._drain_queue(q)
|
||||
|
||||
# ==================== 录音/VAD回调 ====================
|
||||
def get_silence_threshold(self) -> int:
|
||||
"""获取动态静音阈值(毫秒)"""
|
||||
node = self.node
|
||||
return node.silence_duration_ms
|
||||
|
||||
def should_put_audio_to_queue(self) -> bool:
|
||||
"""
|
||||
检查是否应该将音频放入队列(用于ASR),根据状态机决定是否允许ASR
|
||||
"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state in [ConversationState.IDLE, ConversationState.CHECK_VOICE,
|
||||
ConversationState.AUTHORIZED]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_speech_start(self):
|
||||
"""录音线程检测到人声开始"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到人声,开始录音")
|
||||
|
||||
with node.utterance_lock:
|
||||
node.current_utterance_id += 1
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
if state == ConversationState.IDLE:
|
||||
# Idle -> CheckVoice
|
||||
if node.sv_enabled and node.sv_client:
|
||||
# 开始录音用于声纹验证
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
node._change_state(ConversationState.CHECK_VOICE, "检测到语音,开始检查声纹")
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹,直接授权")
|
||||
|
||||
elif state == ConversationState.CHECK_VOICE:
|
||||
# CheckVoice状态,继续录音用于声纹验证
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 继续录音用于声纹验证")
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
# Authorized状态,开始录音用于声纹验证(验证当前用户)
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = True
|
||||
node.sv_audio_buffer.clear()
|
||||
node.get_logger().debug("[声纹] 开始录音用于声纹验证")
|
||||
|
||||
def on_audio_chunk_for_sv(self, audio_chunk: bytes):
|
||||
"""录音线程音频chunk回调 - 仅在需要时录音到声纹缓冲区"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
|
||||
# 声纹验证录音(CHECK_VOICE, AUTHORIZED状态)
|
||||
if node.sv_enabled and node.sv_recording:
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
node.get_logger().debug(f"[声纹] 录音失败: {e}")
|
||||
|
||||
def on_speech_end(self):
|
||||
"""录音线程检测到说话结束(静音一段时间)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 检测到说话结束")
|
||||
|
||||
state = node._get_state()
|
||||
node.get_logger().info(f"[录音线程] 说话结束时的状态: {state}")
|
||||
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
self._trigger_sv_for_check_voice("VAD")
|
||||
return
|
||||
|
||||
elif state == ConversationState.AUTHORIZED:
|
||||
if node.asr_client and node.asr_client.running:
|
||||
node.asr_client.stop_current_recognition()
|
||||
if node.sv_enabled:
|
||||
with node.sv_lock:
|
||||
node.sv_recording = False
|
||||
buffer_size = len(node.sv_audio_buffer)
|
||||
node.get_logger().debug(f"[声纹] 停止录音,缓冲区大小: {buffer_size}")
|
||||
node.sv_speech_end_event.set()
|
||||
|
||||
# 如果TTS正在播放,异步等待声纹验证结果,如果通过才中断TTS
|
||||
# 使用独立线程避免阻塞录音线程,影响TTS播放
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().info("[打断] TTS播放中,用户说话结束,异步等待声纹验证结果...")
|
||||
def _check_sv_and_interrupt():
|
||||
# 等待声纹验证结果(最多等待2秒)
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=2.0
|
||||
):
|
||||
# 声纹验证完成,检查结果
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[打断] 声纹验证通过({speaker_id}),中断TTS播放")
|
||||
node._interrupt_tts("检测到人声(已授权用户,说话结束)")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 声纹验证未通过,不中断TTS(状态: {speaker_state.value})")
|
||||
else:
|
||||
node.get_logger().warning("[打断] 声纹验证超时,不中断TTS")
|
||||
# 在独立线程中等待,避免阻塞录音线程
|
||||
threading.Thread(target=_check_sv_and_interrupt, daemon=True, name="SVInterruptCheck").start()
|
||||
return
|
||||
|
||||
def on_new_segment(self):
|
||||
"""录音线程检测到新的已授权用户声段,开始录音用于声纹验证(不立即中断)"""
|
||||
node = self.node
|
||||
state = node._get_state()
|
||||
if state == ConversationState.AUTHORIZED:
|
||||
# TTS播放期间,检测到人声时不立即中断,而是开始录音用于声纹验证
|
||||
# 等待用户说话结束(speech_end)后,如果声纹验证通过,才中断TTS
|
||||
# 这样可以避免TTS回声误触发,但支持真正的用户打断
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[打断] TTS播放中,检测到人声,开始录音用于声纹验证(等待说话结束后验证)")
|
||||
# 录音已经在 on_speech_start 中开始了,这里不需要额外操作
|
||||
else:
|
||||
# TTS未播放时,检查声纹验证结果并立即中断
|
||||
if node.sv_enabled and node.sv_client:
|
||||
with node.sv_lock:
|
||||
current_speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
if speaker_state == SpeakerState.VERIFIED and current_speaker_id:
|
||||
node._interrupt_tts("检测到人声(已授权用户)")
|
||||
node.get_logger().info(f"[打断] 已授权用户({current_speaker_id})发言,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但声纹未验证或未匹配,不中断TTS(当前状态: {speaker_state.value})")
|
||||
else:
|
||||
# 未启用声纹,直接中断(保持原有行为)
|
||||
node._interrupt_tts("检测到人声(未启用声纹)")
|
||||
node.get_logger().info("[打断] 检测到人声,中断TTS播放")
|
||||
else:
|
||||
node.get_logger().debug(f"[打断] 检测到人声,但当前状态为 {state.value},非已授权用户,不允许打断TTS")
|
||||
|
||||
def on_heartbeat(self):
|
||||
"""录音线程静音心跳回调"""
|
||||
self.node.get_logger().info("[录音线程] 静音中")
|
||||
|
||||
# ==================== ASR回调 ====================
|
||||
def on_asr_sentence_end(self, text: str):
|
||||
"""ASR sentence_end回调 - 将文本放入队列"""
|
||||
node = self.node
|
||||
if not text or not text.strip():
|
||||
return
|
||||
text_clean = text.strip()
|
||||
node.get_logger().info(f"[ASR] 识别完成: {text_clean}")
|
||||
|
||||
state = node._get_state()
|
||||
|
||||
# 规则2:CHECK_VOICE状态下,如果ASR识别完成但VAD还没有触发speech_end,主动触发声纹验证
|
||||
if state == ConversationState.CHECK_VOICE:
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[ASR] CHECK_VOICE状态,ASR识别完成,主动触发声纹验证")
|
||||
self._trigger_sv_for_check_voice("ASR")
|
||||
|
||||
# 其他状态,将文本放入队列
|
||||
node.text_queue.put(text_clean, timeout=1.0)
|
||||
|
||||
def on_asr_text_update(self, text: str):
|
||||
"""ASR 实时文本更新回调 - 用于多轮提示"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
self.node.get_logger().debug(f"[ASR] 识别中: {text.strip()}")
|
||||
179
robot_speaker/core/node_workers.py
Normal file
179
robot_speaker/core/node_workers.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import queue
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerState
|
||||
|
||||
|
||||
class NodeWorkers:
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
|
||||
def recording_worker(self):
|
||||
"""线程1: 录音线程 - 唯一实时线程"""
|
||||
node = self.node
|
||||
node.get_logger().info("[录音线程] 启动")
|
||||
node.audio_recorder.record_with_vad()
|
||||
|
||||
def asr_worker(self):
|
||||
"""线程2: ASR推理线程 - 只做 audio → text"""
|
||||
node = self.node
|
||||
node.get_logger().info("[ASR推理线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = node.audio_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if node.interrupt_event.is_set():
|
||||
continue
|
||||
if node.callbacks.should_put_audio_to_queue() and node.asr_client and node.asr_client.running:
|
||||
node.asr_client.send_audio(audio_chunk)
|
||||
|
||||
def process_worker(self):
|
||||
"""线程3: 主线程 - 处理业务逻辑"""
|
||||
node = self.node
|
||||
node.get_logger().info("[主线程] 启动")
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
text = node.text_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
node.get_logger().info(f"[主线程] 收到识别文本: {text}")
|
||||
current_state = node._get_state()
|
||||
|
||||
if current_state == ConversationState.CHECK_VOICE:
|
||||
if node.use_wake_word:
|
||||
node.get_logger().info(f"[主线程] CHECK_VOICE状态,检查唤醒词,文本: {text}")
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node.get_logger().info(f"[主线程] 未检测到唤醒词(唤醒词配置: '{node.wake_word}'),回到Idle状态")
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
node.get_logger().info(f"[主线程] 检测到唤醒词,处理后的文本: {processed_text}")
|
||||
text = processed_text
|
||||
|
||||
if node.sv_enabled and node.sv_client:
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:等待声纹验证结果...")
|
||||
with node.sv_result_cv:
|
||||
current_seq = node.sv_result_seq
|
||||
if not node.sv_result_cv.wait_for(
|
||||
lambda: node.sv_result_seq > current_seq,
|
||||
timeout=2.0
|
||||
):
|
||||
node.get_logger().warning("[主线程] CHECK_VOICE状态:声纹结果未ready(超时2秒),拒绝本轮")
|
||||
with node.sv_lock:
|
||||
node.sv_audio_buffer.clear()
|
||||
node._change_state(ConversationState.IDLE, "声纹结果未ready")
|
||||
continue
|
||||
node.get_logger().info("[主线程] CHECK_VOICE状态:声纹结果ready,继续处理")
|
||||
|
||||
with node.sv_lock:
|
||||
speaker_id = node.current_speaker_id
|
||||
speaker_state = node.current_speaker_state
|
||||
score = node.current_speaker_score
|
||||
|
||||
if speaker_id and speaker_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[主线程] 声纹验证成功: {speaker_id}, 得分: {score:.4f}")
|
||||
node._change_state(ConversationState.AUTHORIZED, "声纹验证成功")
|
||||
else:
|
||||
node.get_logger().info(f"[主线程] 声纹验证失败,得分: {score:.4f}")
|
||||
node.callbacks.put_tts_text("声纹验证失败")
|
||||
node._change_state(ConversationState.IDLE, "声纹验证失败")
|
||||
continue
|
||||
else:
|
||||
node._change_state(ConversationState.AUTHORIZED, "未启用声纹")
|
||||
|
||||
elif current_state == ConversationState.AUTHORIZED:
|
||||
if node.tts_playing_event.is_set():
|
||||
node.get_logger().debug("[主线程] AUTHORIZED状态,TTS播放中,忽略ASR识别结果(只有VAD检测到已授权用户人声才能中断)")
|
||||
continue
|
||||
|
||||
elif current_state == ConversationState.IDLE:
|
||||
node.get_logger().warning("[主线程] Idle状态收到文本,忽略")
|
||||
continue
|
||||
|
||||
if node.use_wake_word and current_state == ConversationState.AUTHORIZED:
|
||||
processed_text = node.callbacks.handle_wake_word(text)
|
||||
if not processed_text:
|
||||
node._change_state(ConversationState.IDLE, "未检测到唤醒词")
|
||||
continue
|
||||
text = processed_text
|
||||
|
||||
if node.callbacks.check_shutup_command(text):
|
||||
node.get_logger().info("[主线程] 检测到闭嘴指令")
|
||||
node.interrupt_event.set()
|
||||
node.callbacks.force_stop_tts()
|
||||
node._change_state(ConversationState.IDLE, "用户闭嘴指令")
|
||||
continue
|
||||
|
||||
intent_payload = node.intent_router.route(text)
|
||||
node._handle_intent(intent_payload)
|
||||
|
||||
if current_state == ConversationState.AUTHORIZED:
|
||||
node.session_start_time = time.time()
|
||||
|
||||
def sv_worker(self):
|
||||
"""线程5: 声纹识别线程 - 非实时、低频(CAM++)"""
|
||||
node = self.node
|
||||
node.get_logger().info("[声纹识别线程] 启动")
|
||||
|
||||
min_audio_samples = 8000
|
||||
|
||||
while not node.stop_event.is_set():
|
||||
try:
|
||||
if node.sv_speech_end_event.wait(timeout=0.1):
|
||||
node.sv_speech_end_event.clear()
|
||||
with node.sv_lock:
|
||||
audio_list = list(node.sv_audio_buffer)
|
||||
buffer_size = len(audio_list)
|
||||
node.sv_audio_buffer.clear()
|
||||
|
||||
node.get_logger().info(f"[声纹识别] 收到speech_end事件,录音长度: {buffer_size} 样本({buffer_size/node.sample_rate:.2f}秒)")
|
||||
|
||||
if node._handle_empty_speaker_db():
|
||||
node.get_logger().info("[声纹识别] 数据库为空,跳过验证,直接设置UNKNOWN状态")
|
||||
continue
|
||||
|
||||
if buffer_size >= min_audio_samples:
|
||||
audio_array = np.array(audio_list, dtype=np.int16)
|
||||
embedding, success = node.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=node.sample_rate
|
||||
)
|
||||
|
||||
if not success or embedding is None:
|
||||
node.get_logger().debug("[声纹识别] 提取embedding失败")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.ERROR
|
||||
node.current_speaker_score = 0.0
|
||||
else:
|
||||
speaker_id, match_state, score, _ = node.sv_client.match_speaker(embedding)
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = speaker_id
|
||||
node.current_speaker_state = match_state
|
||||
node.current_speaker_score = score
|
||||
|
||||
if match_state == SpeakerState.VERIFIED:
|
||||
node.get_logger().info(f"[声纹识别] 识别到说话人: {speaker_id}, 相似度: {score:.4f}")
|
||||
elif match_state == SpeakerState.REJECTED:
|
||||
node.get_logger().info(f"[声纹识别] 未匹配到已知说话人(相似度不足), 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().info(f"[声纹识别] 状态: {match_state.value}, 相似度: {score:.4f}")
|
||||
else:
|
||||
node.get_logger().debug(f"[声纹识别] 录音太短: {buffer_size} < {min_audio_samples},跳过处理")
|
||||
with node.sv_lock:
|
||||
node.current_speaker_id = None
|
||||
node.current_speaker_state = SpeakerState.UNKNOWN
|
||||
node.current_speaker_score = 0.0
|
||||
|
||||
with node.sv_result_cv:
|
||||
node.sv_result_seq += 1
|
||||
node.sv_result_cv.notify_all()
|
||||
|
||||
except Exception as e:
|
||||
node.get_logger().error(f"[声纹识别线程] 错误: {e}")
|
||||
time.sleep(0.1)
|
||||
|
||||
399
robot_speaker/core/register_speaker_node.py
Normal file
399
robot_speaker/core/register_speaker_node.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
声纹注册独立节点:运行完成后退出
|
||||
"""
|
||||
import collections
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient
|
||||
from robot_speaker.models.asr.dashscope import DashScopeASR
|
||||
from pypinyin import pinyin, Style
|
||||
|
||||
|
||||
class RegisterSpeakerNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('register_speaker_node')
|
||||
self._load_config()
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
self.processing = False
|
||||
self.buffer_lock = threading.Lock()
|
||||
self.audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
|
||||
|
||||
# 状态:等待唤醒词 -> 等待声纹语音
|
||||
self.waiting_for_wake_word = True
|
||||
self.waiting_for_voiceprint = False
|
||||
|
||||
# 音频队列和文本队列(用于ASR)
|
||||
self.audio_queue = queue.Queue()
|
||||
self.text_queue = queue.Queue()
|
||||
|
||||
self.vad_detector = VADDetector(
|
||||
mode=self.vad_mode,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
vad_detector=self.vad_detector,
|
||||
audio_queue=self.audio_queue, # 送ASR用于唤醒词检测
|
||||
silence_duration_ms=self.silence_duration_ms,
|
||||
min_energy_threshold=self.min_energy_threshold,
|
||||
heartbeat_interval=self.audio_microphone_heartbeat_interval,
|
||||
on_heartbeat=self._on_heartbeat,
|
||||
is_playing=lambda: False,
|
||||
on_new_segment=None,
|
||||
on_speech_start=self._on_speech_start,
|
||||
on_speech_end=self._on_speech_end,
|
||||
stop_flag=self.stop_event.is_set,
|
||||
on_audio_chunk=self._on_audio_chunk,
|
||||
should_put_to_queue=self._should_put_to_queue,
|
||||
get_silence_threshold=lambda: self.silence_duration_ms,
|
||||
enable_echo_cancellation=False,
|
||||
reference_signal_buffer=None,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
# ASR客户端 - 用于唤醒词检测
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.asr_client.on_sentence_end = self._on_asr_sentence_end
|
||||
self.asr_client.start()
|
||||
|
||||
# ASR处理线程
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self._asr_worker,
|
||||
name="RegisterASRThread",
|
||||
daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
# 文本处理线程
|
||||
self.text_thread = threading.Thread(
|
||||
target=self._text_worker,
|
||||
name="RegisterTextThread",
|
||||
daemon=True
|
||||
)
|
||||
self.text_thread.start()
|
||||
|
||||
self.sv_client = SpeakerVerificationClient(
|
||||
model_path=self.sv_model_path,
|
||||
threshold=self.sv_threshold,
|
||||
speaker_db_path=self.sv_speaker_db_path,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
self.get_logger().info("声纹注册节点启动,请说'er gou......'唤醒注册")
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.audio_recorder.record_with_vad,
|
||||
name="RegisterRecordingThread",
|
||||
daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
|
||||
self.timer = self.create_timer(0.2, self._check_done)
|
||||
|
||||
def _load_config(self):
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
dashscope = config['dashscope']
|
||||
audio = config['audio']
|
||||
mic = audio['microphone']
|
||||
vad = config['vad']
|
||||
system = config['system']
|
||||
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
|
||||
self.input_device_index = mic['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
|
||||
|
||||
self.vad_mode = vad['vad_mode']
|
||||
self.silence_duration_ms = vad['silence_duration_ms']
|
||||
self.min_energy_threshold = vad['min_energy_threshold']
|
||||
|
||||
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
|
||||
self.sv_threshold = system['sv_threshold']
|
||||
self.sv_speaker_db_path = system['sv_speaker_db_path']
|
||||
self.sv_buffer_size = system['sv_buffer_size']
|
||||
self.wake_word = system['wake_word']
|
||||
|
||||
def _should_put_to_queue(self) -> bool:
|
||||
"""判断是否应该将音频放入ASR队列(仅在等待唤醒词时)"""
|
||||
return self.waiting_for_wake_word
|
||||
|
||||
def _on_heartbeat(self):
|
||||
if self.waiting_for_wake_word:
|
||||
self.get_logger().info("[注册录音] 等待唤醒词'er gou'...")
|
||||
elif self.waiting_for_voiceprint:
|
||||
self.get_logger().info("[注册录音] 等待声纹语音...")
|
||||
|
||||
def _on_speech_start(self):
|
||||
if self.waiting_for_wake_word:
|
||||
# 等待唤醒词时,开始录音(可能包含唤醒词)
|
||||
self.get_logger().info("[注册录音] 检测到人声,开始录音")
|
||||
elif self.waiting_for_voiceprint:
|
||||
self.get_logger().info("[注册录音] 检测到人声,继续录音(用于声纹注册)")
|
||||
# 注意:不清空缓冲区,保留包含唤醒词的音频
|
||||
|
||||
def _on_audio_chunk(self, audio_chunk: bytes):
|
||||
# 记录所有音频(包括唤醒词),用于声纹注册
|
||||
try:
|
||||
audio_array = np.frombuffer(audio_chunk, dtype=np.int16)
|
||||
with self.buffer_lock:
|
||||
self.audio_buffer.extend(audio_array)
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"[注册录音] 录音失败: {e}")
|
||||
|
||||
def _on_speech_end(self):
|
||||
# 如果还在等待唤醒词,不处理
|
||||
if self.waiting_for_wake_word:
|
||||
return
|
||||
# 如果已经在处理,不重复处理
|
||||
if self.processing:
|
||||
return
|
||||
|
||||
# 等待声纹语音时,用户说话结束,使用当前音频(即使不足3秒)
|
||||
if self.waiting_for_voiceprint:
|
||||
self._process_voiceprint_audio(use_current_audio_if_short=True)
|
||||
|
||||
def _process_voiceprint_audio(self, use_current_audio_if_short: bool = False):
|
||||
"""处理声纹音频:使用用户完整的第一段语音进行注册
|
||||
|
||||
Args:
|
||||
use_current_audio_if_short: 如果音频不足3秒,是否使用当前音频(用于用户已说完的情况)
|
||||
"""
|
||||
if self.processing:
|
||||
return
|
||||
self.processing = True
|
||||
with self.buffer_lock:
|
||||
audio_list = list(self.audio_buffer)
|
||||
|
||||
buffer_size = len(audio_list)
|
||||
buffer_sec = buffer_size / self.sample_rate
|
||||
self.get_logger().info(f"[注册录音] 当前音频长度: {buffer_sec:.2f}秒")
|
||||
|
||||
# 需要3秒音频
|
||||
required_samples = int(self.sample_rate * 3.0)
|
||||
|
||||
# 如果音频不足3秒
|
||||
if buffer_size < required_samples:
|
||||
if use_current_audio_if_short:
|
||||
# 用户已经说完了,使用当前音频(即使不足3秒)
|
||||
self.get_logger().info(f"[注册录音] 音频不足3秒(当前{buffer_sec:.2f}秒),但用户已说完,使用当前音频进行注册")
|
||||
audio_to_use = audio_list
|
||||
else:
|
||||
# 等待继续录音
|
||||
self.get_logger().info(f"[注册录音] 音频不足3秒(当前{buffer_sec:.2f}秒),等待继续录音...")
|
||||
self.processing = False
|
||||
return
|
||||
else:
|
||||
# 音频达到3秒,截取最后3秒
|
||||
audio_to_use = audio_list[-required_samples:]
|
||||
self.get_logger().info(f"[注册录音] 使用完整的第一段语音,截取最后3秒用于注册")
|
||||
|
||||
# 清空缓冲区
|
||||
with self.buffer_lock:
|
||||
self.audio_buffer.clear()
|
||||
|
||||
try:
|
||||
audio_array = np.array(audio_to_use, dtype=np.int16)
|
||||
embedding, success = self.sv_client.extract_embedding(
|
||||
audio_array,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
if not success or embedding is None:
|
||||
self.get_logger().error("[注册录音] 提取embedding失败")
|
||||
self.processing = False
|
||||
return
|
||||
|
||||
speaker_id = f"user_{int(time.time())}"
|
||||
if self.sv_client.register_speaker(speaker_id, embedding):
|
||||
self.get_logger().info(f"[注册录音] 注册成功,用户ID: {speaker_id},准备退出")
|
||||
self.stop_event.set()
|
||||
else:
|
||||
self.get_logger().error("[注册录音] 注册失败")
|
||||
self.processing = False
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册录音] 注册异常: {e}")
|
||||
self.processing = False
|
||||
|
||||
def _extract_speech_segments(self, audio_array: np.ndarray, frame_size: int = 1024) -> list:
|
||||
"""使用能量检测提取人声片段(过滤静音)"""
|
||||
speech_segments = []
|
||||
frame_samples = frame_size
|
||||
total_frames = 0
|
||||
speech_frames = 0
|
||||
|
||||
for i in range(0, len(audio_array), frame_samples):
|
||||
frame = audio_array[i:i + frame_samples]
|
||||
if len(frame) < frame_samples:
|
||||
break
|
||||
|
||||
total_frames += 1
|
||||
# 计算帧的能量(RMS,对于int16音频)
|
||||
frame_float = frame.astype(np.float32)
|
||||
energy = np.sqrt(np.mean(frame_float ** 2))
|
||||
|
||||
# 使用更低的阈值来检测人声(降低阈值,避免误判静音)
|
||||
# 阈值可以动态调整,或者使用自适应阈值
|
||||
threshold = self.min_energy_threshold * 0.5 # 降低阈值到原来的50%
|
||||
|
||||
# 如果能量超过阈值,认为是人声
|
||||
if energy >= threshold:
|
||||
speech_segments.append((i, i + frame_samples))
|
||||
speech_frames += 1
|
||||
|
||||
# 调试信息
|
||||
if total_frames > 0:
|
||||
speech_ratio = speech_frames / total_frames
|
||||
self.get_logger().debug(f"[注册录音] 能量检测: 总帧数={total_frames}, 人声帧数={speech_frames}, 人声比例={speech_ratio:.2%}, 阈值={self.min_energy_threshold}")
|
||||
|
||||
return speech_segments
|
||||
|
||||
def _merge_speech_segments(self, audio_array: np.ndarray, segments: list, min_samples: int) -> np.ndarray:
|
||||
"""合并人声片段,返回连续的人声音频"""
|
||||
if not segments:
|
||||
return np.array([], dtype=np.int16)
|
||||
|
||||
# 合并相邻的片段
|
||||
merged_segments = []
|
||||
current_start, current_end = segments[0]
|
||||
|
||||
for start, end in segments[1:]:
|
||||
if start <= current_end + 1024: # 允许小间隙(1帧)
|
||||
current_end = end
|
||||
else:
|
||||
merged_segments.append((current_start, current_end))
|
||||
current_start, current_end = start, end
|
||||
merged_segments.append((current_start, current_end))
|
||||
|
||||
# 从后往前选择片段,直到达到3秒
|
||||
selected_audio = []
|
||||
total_samples = 0
|
||||
|
||||
for start, end in reversed(merged_segments):
|
||||
segment_audio = audio_array[start:end]
|
||||
selected_audio.insert(0, segment_audio)
|
||||
total_samples += len(segment_audio)
|
||||
if total_samples >= min_samples:
|
||||
break
|
||||
|
||||
if not selected_audio:
|
||||
return np.array([], dtype=np.int16)
|
||||
|
||||
return np.concatenate(selected_audio)
|
||||
|
||||
def _asr_worker(self):
|
||||
"""ASR处理线程"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
audio_chunk = self.audio_queue.get(timeout=0.1)
|
||||
if self.asr_client and self.asr_client.running:
|
||||
self.asr_client.send_audio(audio_chunk)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册ASR] 处理异常: {e}")
|
||||
|
||||
def _on_asr_sentence_end(self, text: str):
|
||||
"""ASR识别完成回调"""
|
||||
if text and text.strip():
|
||||
self.text_queue.put(text.strip())
|
||||
|
||||
def _text_worker(self):
|
||||
"""文本处理线程:检测唤醒词"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
text = self.text_queue.get(timeout=0.1)
|
||||
if self.waiting_for_wake_word:
|
||||
self._check_wake_word(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[注册文本] 处理异常: {e}")
|
||||
|
||||
def _to_pinyin(self, text: str) -> str:
|
||||
"""将中文文本转换为拼音"""
|
||||
chars = [c for c in text if '\u4e00' <= c <= '\u9fa5']
|
||||
if not chars:
|
||||
return ""
|
||||
py_list = pinyin(chars, style=Style.NORMAL)
|
||||
return ' '.join([item[0] for item in py_list]).lower().strip()
|
||||
|
||||
def _check_wake_word(self, text: str):
|
||||
"""检查是否包含唤醒词"""
|
||||
text_pinyin = self._to_pinyin(text)
|
||||
wake_word_pinyin = self.wake_word.lower().strip()
|
||||
self.get_logger().info(f"[注册唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
|
||||
|
||||
if not wake_word_pinyin:
|
||||
return
|
||||
|
||||
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
|
||||
wake_word_parts = wake_word_pinyin.split()
|
||||
|
||||
# 检查是否包含唤醒词
|
||||
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
|
||||
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
|
||||
self.get_logger().info(f"[注册唤醒词] 检测到唤醒词 '{self.wake_word}'")
|
||||
self.get_logger().info("=" * 50)
|
||||
self.get_logger().info("[声纹注册] 开始注册声纹,将截取3秒音频用于注册")
|
||||
self.get_logger().info("=" * 50)
|
||||
self.waiting_for_wake_word = False
|
||||
self.waiting_for_voiceprint = True
|
||||
# 停止ASR,不再需要识别
|
||||
if self.asr_client:
|
||||
self.asr_client.stop_current_recognition()
|
||||
# 立即处理当前音频缓冲区中的完整音频
|
||||
# 用户可能已经说完了(包含唤醒词的整段语音)
|
||||
self._process_voiceprint_audio()
|
||||
return
|
||||
|
||||
def _check_done(self):
|
||||
if self.stop_event.is_set():
|
||||
self.get_logger().info("注册完成,节点退出")
|
||||
# 清理资源
|
||||
if self.asr_client:
|
||||
self.asr_client.stop()
|
||||
self.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = RegisterSpeakerNode()
|
||||
rclpy.spin(node)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
803
robot_speaker/core/robot_speaker_node.py
Normal file
803
robot_speaker/core/robot_speaker_node.py
Normal file
@@ -0,0 +1,803 @@
|
||||
"""
|
||||
语音交互节点
|
||||
"""
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import String
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
import io
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import subprocess
|
||||
import collections
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from robot_speaker.perception.audio_pipeline import VADDetector, AudioRecorder
|
||||
from robot_speaker.models.asr.dashscope import DashScopeASR
|
||||
from robot_speaker.models.tts.dashscope import DashScopeTTSClient
|
||||
from robot_speaker.models.llm.dashscope import DashScopeLLM
|
||||
from robot_speaker.understanding.context_manager import ConversationHistory
|
||||
from robot_speaker.core.types import LLMMessage, TTSRequest
|
||||
from robot_speaker.perception.camera_client import CameraClient
|
||||
from robot_speaker.perception.speaker_verifier import SpeakerVerificationClient, SpeakerState
|
||||
from robot_speaker.perception.echo_cancellation import ReferenceSignalBuffer
|
||||
from robot_speaker.core.conversation_state import ConversationState
|
||||
from robot_speaker.core.node_workers import NodeWorkers
|
||||
from robot_speaker.core.node_callbacks import NodeCallbacks
|
||||
from robot_speaker.core.intent_router import IntentRouter, IntentResult
|
||||
|
||||
|
||||
class RobotSpeakerNode(Node):
|
||||
# ==================== 初始化 ====================
|
||||
def __init__(self):
|
||||
super().__init__('robot_speaker_node')
|
||||
|
||||
# 直接从配置文件加载参数
|
||||
self._load_config()
|
||||
|
||||
# 初始化队列(线程间通信)
|
||||
self.audio_queue = queue.Queue() # 录音线程 → ASR线程
|
||||
self.text_queue = queue.Queue() # ASR线程 → 主线程
|
||||
self.tts_queue = queue.Queue() # 主线程 → TTS线程
|
||||
|
||||
# 初始化线程同步事件
|
||||
self.interrupt_event = threading.Event() # 中断标志
|
||||
self.stop_event = threading.Event() # 停止标志
|
||||
self.tts_playing_event = threading.Event() # TTS播放状态
|
||||
|
||||
# 初始化会话管理
|
||||
self.session_active = False
|
||||
self.session_start_time = 0.0
|
||||
self.session_lock = threading.Lock()
|
||||
|
||||
# 状态机状态
|
||||
self.conversation_state = ConversationState.IDLE # 当前会话状态
|
||||
self.state_lock = threading.Lock() # 保护状态机状态
|
||||
|
||||
# 声纹识别共享状态
|
||||
self.current_speaker_id = None # 当前说话人ID(共享状态,只读)
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN # 当前说话人状态
|
||||
self.current_speaker_score = 0.0 # 当前说话人相似度得分
|
||||
self.sv_lock = threading.Lock() # 保护声纹识别共享状态
|
||||
self.sv_speech_end_event = threading.Event() # 通知声纹线程处理(speech_end触发)
|
||||
self.sv_result_ready_event = threading.Event() # 保留兼容(已不用于同步)
|
||||
self.sv_result_lock = threading.Lock() # 声纹结果序号锁
|
||||
self.sv_result_cv = threading.Condition(self.sv_result_lock)
|
||||
self.sv_result_seq = 0
|
||||
# 声纹缓冲区大小将在_init_components中初始化(需要先读取参数)
|
||||
self.sv_audio_buffer = None # 声纹验证录音缓冲区(将在_init_components中初始化)
|
||||
self.sv_recording = False # 是否正在为声纹验证录音
|
||||
|
||||
# 声纹注册状态
|
||||
self.utterance_lock = threading.Lock()
|
||||
self.current_utterance_id = 0
|
||||
self.last_processed_utterance_id = 0
|
||||
|
||||
self.intent_router = IntentRouter()
|
||||
self.callbacks = NodeCallbacks(self)
|
||||
|
||||
# 初始化组件(VAD、录音器、ASR、LLM、TTS)
|
||||
self._init_components()
|
||||
self.workers = NodeWorkers(self)
|
||||
|
||||
# 状态机初始状态
|
||||
if self.sv_enabled and self.sv_client:
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
self.get_logger().info("声纹数据库为空,请注册声纹")
|
||||
|
||||
# ROS订阅
|
||||
self.interrupt_sub = self.create_subscription(
|
||||
String, 'interrupt_command', self.callbacks.handle_interrupt_command, self.system_interrupt_command_queue_depth
|
||||
)
|
||||
|
||||
# 启动线程
|
||||
self._start_threads()
|
||||
self.get_logger().info("语音节点已启动")
|
||||
|
||||
# ==================== 配置加载 ====================
|
||||
def _load_config(self):
|
||||
"""直接从 voice.yaml 配置文件加载参数"""
|
||||
config_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'voice.yaml'
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 音频参数
|
||||
audio = config['audio']
|
||||
mic = audio['microphone']
|
||||
soundcard = audio['soundcard']
|
||||
echo = audio['echo_cancellation']
|
||||
tts_audio = audio['tts']
|
||||
|
||||
self.input_device_index = mic['device_index']
|
||||
self.output_card_index = soundcard['card_index']
|
||||
self.output_device_index = soundcard['device_index']
|
||||
self.sample_rate = mic['sample_rate']
|
||||
self.channels = mic['channels']
|
||||
self.chunk = mic['chunk']
|
||||
self.audio_microphone_heartbeat_interval = mic['heartbeat_interval']
|
||||
self.output_sample_rate = soundcard['sample_rate']
|
||||
self.output_channels = soundcard['channels']
|
||||
self.output_volume = soundcard['volume']
|
||||
self.audio_echo_cancellation_max_duration_ms = echo['max_duration_ms']
|
||||
self.audio_tts_source_sample_rate = tts_audio['source_sample_rate']
|
||||
self.audio_tts_source_channels = tts_audio['source_channels']
|
||||
self.audio_tts_ffmpeg_thread_queue_size = tts_audio['ffmpeg_thread_queue_size']
|
||||
|
||||
# VAD参数
|
||||
vad = config['vad']
|
||||
self.vad_mode = vad['vad_mode']
|
||||
self.silence_duration_ms = vad['silence_duration_ms']
|
||||
self.min_energy_threshold = vad['min_energy_threshold']
|
||||
|
||||
# DashScope参数
|
||||
dashscope = config['dashscope']
|
||||
self.dashscope_api_key = dashscope['api_key']
|
||||
self.asr_model = dashscope['asr']['model']
|
||||
self.asr_url = dashscope['asr']['url']
|
||||
self.llm_model = dashscope['llm']['model']
|
||||
self.llm_base_url = dashscope['llm']['base_url']
|
||||
self.llm_temperature = dashscope['llm']['temperature']
|
||||
self.llm_max_tokens = dashscope['llm']['max_tokens']
|
||||
self.llm_max_history = dashscope['llm']['max_history']
|
||||
self.llm_summary_trigger = dashscope['llm']['summary_trigger']
|
||||
self.tts_model = dashscope['tts']['model']
|
||||
self.tts_voice = dashscope['tts']['voice']
|
||||
|
||||
# 系统参数
|
||||
system = config['system']
|
||||
self.use_llm = system['use_llm']
|
||||
self.use_wake_word = system['use_wake_word']
|
||||
self.wake_word = system['wake_word']
|
||||
self.session_timeout = system['session_timeout']
|
||||
self.system_shutup_keywords = system['shutup_keywords']
|
||||
self.system_interrupt_command_queue_depth = system['interrupt_command_queue_depth']
|
||||
self.sv_enabled = system['sv_enabled']
|
||||
self.sv_model_path = os.path.expanduser(system['sv_model_path'])
|
||||
self.sv_threshold = system['sv_threshold']
|
||||
self.sv_speaker_db_path = system['sv_speaker_db_path']
|
||||
self.sv_buffer_size = system['sv_buffer_size']
|
||||
|
||||
# 相机参数
|
||||
camera = config['camera']
|
||||
self.camera_serial_number = camera['serial_number']
|
||||
self.camera_rgb_width = camera['rgb']['width']
|
||||
self.camera_rgb_height = camera['rgb']['height']
|
||||
self.camera_rgb_fps = camera['rgb']['fps']
|
||||
self.camera_rgb_format = camera['rgb']['format']
|
||||
self.camera_image_jpeg_quality = camera['image']['jpeg_quality']
|
||||
self.camera_image_max_size = camera['image']['max_size']
|
||||
|
||||
self.knowledge_file = os.path.join(
|
||||
get_package_share_directory('robot_speaker'),
|
||||
'config',
|
||||
'knowledge.json'
|
||||
)
|
||||
|
||||
# ==================== 组件初始化 ====================
|
||||
def _init_components(self):
|
||||
"""初始化所有组件"""
|
||||
self.shutup_keywords = [k.strip() for k in self.system_shutup_keywords.split(',') if k.strip()]
|
||||
|
||||
self.kb_answers_map = {}
|
||||
if self.knowledge_file and os.path.exists(self.knowledge_file):
|
||||
try:
|
||||
with open(self.knowledge_file, 'r') as f:
|
||||
kb_data = json.load(f)
|
||||
entries = kb_data["entries"]
|
||||
for entry in entries:
|
||||
patterns = entry["patterns"]
|
||||
answer = entry["answer"]
|
||||
if not answer.strip():
|
||||
continue
|
||||
for pattern in patterns:
|
||||
key = pattern.strip().lower()
|
||||
if key:
|
||||
self.kb_answers_map[key] = answer.strip()
|
||||
self.get_logger().info(f"知识库已加载: {len(self.kb_answers_map)} 条")
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"知识库加载失败: {e}")
|
||||
|
||||
self.sv_audio_buffer = collections.deque(maxlen=self.sv_buffer_size)
|
||||
|
||||
self.vad_detector = VADDetector(
|
||||
mode=self.vad_mode,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
# 创建参考信号缓冲区(用于回声消除),虽然播放是44100Hz,但麦克风输入是16kHz
|
||||
self.reference_signal_buffer = ReferenceSignalBuffer(
|
||||
max_duration_ms=self.audio_echo_cancellation_max_duration_ms,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.output_channels
|
||||
)
|
||||
|
||||
# 录音器 - 直接发送音频chunk到队列
|
||||
self.audio_recorder = AudioRecorder(
|
||||
device_index=self.input_device_index,
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels,
|
||||
chunk=self.chunk,
|
||||
vad_detector=self.vad_detector,
|
||||
audio_queue=self.audio_queue,
|
||||
silence_duration_ms=self.silence_duration_ms,
|
||||
min_energy_threshold=self.min_energy_threshold,
|
||||
heartbeat_interval=self.audio_microphone_heartbeat_interval,
|
||||
on_heartbeat=self.callbacks.on_heartbeat,
|
||||
is_playing=self.tts_playing_event.is_set,
|
||||
on_new_segment=self.callbacks.on_new_segment,
|
||||
on_speech_start=self.callbacks.on_speech_start,
|
||||
on_speech_end=self.callbacks.on_speech_end,
|
||||
stop_flag=self.stop_event.is_set,
|
||||
on_audio_chunk=self.callbacks.on_audio_chunk_for_sv if self.sv_enabled else None, # 声纹录音回调
|
||||
should_put_to_queue=self.callbacks.should_put_audio_to_queue, # 检查是否应该将音频放入队列
|
||||
get_silence_threshold=self.callbacks.get_silence_threshold, # 动态静音阈值回调
|
||||
enable_echo_cancellation=True, # 启用回声消除
|
||||
reference_signal_buffer=self.reference_signal_buffer, # 传递参考信号缓冲区
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
# ASR客户端 - 流式识别
|
||||
self.asr_client = DashScopeASR(
|
||||
api_key=self.dashscope_api_key,
|
||||
sample_rate=self.sample_rate,
|
||||
model=self.asr_model,
|
||||
url=self.asr_url,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.asr_client.on_sentence_end = self.callbacks.on_asr_sentence_end
|
||||
self.asr_client.on_text_update = self.callbacks.on_asr_text_update
|
||||
self.asr_client.start()
|
||||
|
||||
# LLM客户端
|
||||
if self.use_llm:
|
||||
self.llm_client = DashScopeLLM(
|
||||
api_key=self.dashscope_api_key,
|
||||
model=self.llm_model,
|
||||
base_url=self.llm_base_url,
|
||||
temperature=self.llm_temperature,
|
||||
max_tokens=self.llm_max_tokens,
|
||||
name="LLM-chat",
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.history = ConversationHistory(
|
||||
max_history=self.llm_max_history,
|
||||
summary_trigger=self.llm_summary_trigger
|
||||
)
|
||||
else:
|
||||
self.llm_client = None
|
||||
self.history = None
|
||||
|
||||
# TTS客户端
|
||||
self.get_logger().info(f"TTS配置: model={self.tts_model}, voice={self.tts_voice}")
|
||||
self.get_logger().info(f"音频输出配置: sample_rate={self.output_sample_rate}, channels={self.output_channels}")
|
||||
self.tts_client = DashScopeTTSClient(
|
||||
api_key=self.dashscope_api_key,
|
||||
model=self.tts_model,
|
||||
voice=self.tts_voice,
|
||||
card_index=self.output_card_index,
|
||||
device_index=self.output_device_index,
|
||||
output_sample_rate=self.output_sample_rate,
|
||||
output_channels=self.output_channels,
|
||||
output_volume=self.output_volume,
|
||||
tts_source_sample_rate=self.audio_tts_source_sample_rate,
|
||||
tts_source_channels=self.audio_tts_source_channels,
|
||||
tts_ffmpeg_thread_queue_size=self.audio_tts_ffmpeg_thread_queue_size,
|
||||
reference_signal_buffer=self.reference_signal_buffer, # 传递参考信号缓冲区
|
||||
logger=self.get_logger()
|
||||
)
|
||||
|
||||
# 相机客户端(默认一直运行)
|
||||
try:
|
||||
self.camera_client = CameraClient(
|
||||
serial_number=self.camera_serial_number,
|
||||
width=self.camera_rgb_width,
|
||||
height=self.camera_rgb_height,
|
||||
fps=self.camera_rgb_fps,
|
||||
format=self.camera_rgb_format,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
self.camera_client.initialize()
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"相机初始化失败: {e},相机功能将不可用")
|
||||
self.camera_client = None
|
||||
|
||||
# 声纹识别客户端
|
||||
if self.sv_enabled and self.sv_model_path:
|
||||
try:
|
||||
self.sv_client = SpeakerVerificationClient(
|
||||
model_path=self.sv_model_path,
|
||||
threshold=self.sv_threshold,
|
||||
speaker_db_path=self.sv_speaker_db_path,
|
||||
logger=self.get_logger()
|
||||
)
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"声纹识别初始化失败: {e},声纹功能将不可用")
|
||||
self.sv_client = None
|
||||
self.sv_enabled = False
|
||||
else:
|
||||
self.sv_client = None
|
||||
|
||||
# ==================== 线程启动 ====================
|
||||
def _start_threads(self):
|
||||
"""启动线程"""
|
||||
# 线程1: 录音线程
|
||||
self.recording_thread = threading.Thread(
|
||||
target=self.workers.recording_worker,
|
||||
name="RecordingThread",
|
||||
daemon=True
|
||||
)
|
||||
self.recording_thread.start()
|
||||
|
||||
# 线程2: ASR推理线程
|
||||
self.asr_thread = threading.Thread(
|
||||
target=self.workers.asr_worker,
|
||||
name="ASRThread",
|
||||
daemon=True
|
||||
)
|
||||
self.asr_thread.start()
|
||||
|
||||
# 线程3: 主线程 - 处理业务逻辑
|
||||
self.process_thread = threading.Thread(
|
||||
target=self.workers.process_worker,
|
||||
name="ProcessThread",
|
||||
daemon=True
|
||||
)
|
||||
self.process_thread.start()
|
||||
|
||||
# 线程4: TTS播放线程
|
||||
self.tts_thread = threading.Thread(
|
||||
target=self._tts_worker,
|
||||
name="TTSThread",
|
||||
daemon=True
|
||||
)
|
||||
self.tts_thread.start()
|
||||
|
||||
# 线程5: 声纹识别线程(如果启用)
|
||||
if self.sv_enabled and self.sv_client:
|
||||
self.sv_thread = threading.Thread(
|
||||
target=self.workers.sv_worker,
|
||||
name="SVThread",
|
||||
daemon=True
|
||||
)
|
||||
self.sv_thread.start()
|
||||
else:
|
||||
self.sv_thread = None
|
||||
|
||||
# ==================== TTS播放线程 ====================
|
||||
def _tts_worker(self):
|
||||
"""
|
||||
线程4: TTS播放线程 - 只播放
|
||||
"""
|
||||
self.get_logger().info("[TTS播放线程] 启动")
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
text = self.tts_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
if self.interrupt_event.is_set():
|
||||
self.get_logger().debug("[TTS播放线程] 检测到中断事件")
|
||||
continue
|
||||
|
||||
if self.interrupt_event.is_set():
|
||||
self.get_logger().info("[TTS播放线程] 中断播放,跳过文本")
|
||||
continue
|
||||
|
||||
if not text or not str(text).strip():
|
||||
continue
|
||||
|
||||
text_str = str(text).strip()
|
||||
text_len = len(text_str)
|
||||
self.get_logger().info(f"[TTS播放线程] 开始播放: {text_str[:100]}... (总长度: {text_len}字符)")
|
||||
self.tts_playing_event.set()
|
||||
|
||||
request = TTSRequest(text=text_str, voice=None)
|
||||
success = self.tts_client.synthesize(
|
||||
request,
|
||||
interrupt_check=lambda: self.interrupt_event.is_set()
|
||||
)
|
||||
if success:
|
||||
self.get_logger().info("[TTS播放线程] 播放完成")
|
||||
else:
|
||||
self.get_logger().info("[TTS播放线程] 播放被中断")
|
||||
|
||||
self.tts_playing_event.clear()
|
||||
|
||||
if self.interrupt_event.is_set():
|
||||
self.get_logger().info("[TTS播放线程] 播放完成后检测到中断,清空队列")
|
||||
self._drain_queue(self.tts_queue)
|
||||
self.interrupt_event.clear()
|
||||
|
||||
# ==================== 状态机方法 ====================
|
||||
def _change_state(self, new_state: ConversationState, reason: str | None = None):
|
||||
"""改变状态机状态"""
|
||||
with self.state_lock:
|
||||
old_state = self.conversation_state
|
||||
self.conversation_state = new_state
|
||||
if reason:
|
||||
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}: {reason}")
|
||||
else:
|
||||
self.get_logger().info(f"[状态机] {old_state.value} -> {new_state.value}")
|
||||
|
||||
def _get_state(self) -> ConversationState:
|
||||
"""获取当前状态"""
|
||||
with self.state_lock:
|
||||
return self.conversation_state
|
||||
|
||||
# ==================== LLM处理(含拍照) ====================
|
||||
def _encode_image_to_base64(self, image_data: np.ndarray, quality: int = 85) -> str:
|
||||
"""将numpy图像数组编码为base64字符串"""
|
||||
try:
|
||||
if image_data.shape[2] == 3:
|
||||
pil_image = Image.fromarray(image_data, 'RGB')
|
||||
else:
|
||||
pil_image = Image.fromarray(image_data)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=quality)
|
||||
image_bytes = buffer.getvalue()
|
||||
base64_str = base64.b64encode(image_bytes).decode('utf-8')
|
||||
return base64_str
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"图像编码失败: {e}")
|
||||
return ""
|
||||
|
||||
def _llm_process_stream_with_camera(self, user_text: str, need_camera: bool, system_prompt: str | None = None) -> str:
|
||||
"""LLM流式处理 - 支持多模态(文本+图像)"""
|
||||
if not self.llm_client or not self.history:
|
||||
return ""
|
||||
|
||||
messages = list(self.history.get_messages())
|
||||
|
||||
has_system_msg = any(msg.role == "system" for msg in messages)
|
||||
if not has_system_msg:
|
||||
if not system_prompt:
|
||||
system_prompt = self.intent_router.build_default_system_prompt()
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
|
||||
full_reply = ""
|
||||
tts_text_buffer = ""
|
||||
image_base64_list = []
|
||||
|
||||
def on_token(token: str):
|
||||
nonlocal full_reply, tts_text_buffer
|
||||
if self.interrupt_event.is_set():
|
||||
self.get_logger().info("[LLM流式处理] on_token回调中检测到中断,停止处理")
|
||||
return
|
||||
|
||||
full_reply += token
|
||||
tts_text_buffer += token
|
||||
|
||||
if need_camera and self.camera_client:
|
||||
with self.camera_client.capture_context() as image_data:
|
||||
if image_data is not None:
|
||||
image_base64 = self._encode_image_to_base64(
|
||||
image_data,
|
||||
quality=self.camera_image_jpeg_quality
|
||||
)
|
||||
if image_base64:
|
||||
image_base64_list.append(image_base64)
|
||||
self.get_logger().info("[相机] 已拍照")
|
||||
|
||||
if image_base64_list:
|
||||
self.get_logger().info(
|
||||
f"[多模态] 准备发送给LLM: {len(image_base64_list)}张图片,用户文本: {user_text[:50]}"
|
||||
)
|
||||
for idx, img_b64 in enumerate(image_base64_list):
|
||||
self.get_logger().debug(f"[多模态] 图片#{idx+1} base64长度: {len(img_b64)}")
|
||||
|
||||
reply = self.llm_client.chat_stream(
|
||||
messages,
|
||||
on_token=on_token,
|
||||
images=image_base64_list if image_base64_list else None,
|
||||
interrupt_check=lambda: self.interrupt_event.is_set()
|
||||
)
|
||||
|
||||
if self.interrupt_event.is_set() or (reply is None):
|
||||
if self.interrupt_event.is_set():
|
||||
self.get_logger().info("[LLM流式处理] 处理被中断")
|
||||
return ""
|
||||
|
||||
if image_base64_list:
|
||||
for img_b64 in image_base64_list:
|
||||
del img_b64
|
||||
image_base64_list.clear()
|
||||
self.get_logger().info("[相机] 已删除照片")
|
||||
|
||||
if reply and reply.strip():
|
||||
tts_text_to_send = reply.strip()
|
||||
tts_buffer_len = len(tts_text_buffer.strip()) if tts_text_buffer else 0
|
||||
reply_len = len(tts_text_to_send)
|
||||
if tts_buffer_len != reply_len:
|
||||
self.get_logger().info(
|
||||
f"[流式TTS] tts_text_buffer({tts_buffer_len}字符)和reply({reply_len}字符)长度不一致,使用reply作为TTS文本"
|
||||
)
|
||||
elif tts_text_buffer and tts_text_buffer.strip():
|
||||
tts_text_to_send = tts_text_buffer.strip()
|
||||
self.get_logger().warning(
|
||||
f"[流式TTS] reply为空,使用tts_text_buffer({len(tts_text_to_send)}字符)作为TTS文本"
|
||||
)
|
||||
else:
|
||||
tts_text_to_send = ""
|
||||
self.get_logger().warning("[流式TTS] reply和tts_text_buffer都为空,无法发送TTS文本")
|
||||
|
||||
if not self.interrupt_event.is_set() and tts_text_to_send:
|
||||
text_len = len(tts_text_to_send)
|
||||
self.get_logger().info(
|
||||
f"[流式TTS] 发送完整文本到TTS队列: {tts_text_to_send[:100]}... (总长度: {text_len}字符)"
|
||||
)
|
||||
if text_len > 100:
|
||||
self.get_logger().debug(f"[流式TTS] 完整文本内容: {tts_text_to_send}")
|
||||
self._put_tts_text(tts_text_to_send)
|
||||
|
||||
return reply.strip() if reply else ""
|
||||
|
||||
# ==================== 中断与TTS工具 ====================
|
||||
def _force_stop_tts(self):
|
||||
"""强制停止TTS播放 - 直接杀死记录的ffmpeg进程PID"""
|
||||
self._drain_queue(self.tts_queue)
|
||||
self.interrupt_event.set()
|
||||
|
||||
if self.tts_client and self.tts_client.current_ffmpeg_pid:
|
||||
try:
|
||||
pid = self.tts_client.current_ffmpeg_pid
|
||||
os.kill(pid, 9) # SIGKILL
|
||||
self.get_logger().info(f"[强制停止TTS] 已终止ffmpeg进程,PID={pid}")
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
except ProcessLookupError:
|
||||
self.get_logger().debug(f"[强制停止TTS] ffmpeg进程已不存在,PID={pid}")
|
||||
self.tts_client.current_ffmpeg_pid = None
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"[强制停止TTS] 终止ffmpeg进程失败: {e}")
|
||||
|
||||
def _check_interrupt(self, auto_clear: bool = False) -> bool:
|
||||
"""
|
||||
检查中断标志
|
||||
"""
|
||||
if self.interrupt_event.is_set():
|
||||
if auto_clear:
|
||||
self.interrupt_event.clear()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_interrupt_and_cancel_turn(self) -> bool:
|
||||
"""检查中断并取消轮次(统一处理中断后的清理)"""
|
||||
if self._check_interrupt(auto_clear=True):
|
||||
if self.use_llm and self.history:
|
||||
self.history.cancel_turn()
|
||||
return True
|
||||
return False
|
||||
|
||||
# ==================== 注册/会话/唤醒词 ====================
|
||||
def _handle_empty_speaker_db(self) -> bool:
|
||||
"""处理数据库为空的情况(统一处理)"""
|
||||
if not (self.sv_enabled and self.sv_client):
|
||||
return False
|
||||
|
||||
speaker_count = self.sv_client.get_speaker_count()
|
||||
if speaker_count == 0:
|
||||
with self.sv_lock:
|
||||
self.current_speaker_id = None
|
||||
self.current_speaker_state = SpeakerState.UNKNOWN
|
||||
self.current_speaker_score = 0.0
|
||||
self.sv_result_ready_event.set()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _put_tts_text(self, text: str):
|
||||
"""统一处理TTS队列put(带异常处理)"""
|
||||
try:
|
||||
self.tts_queue.put(text, timeout=0.2)
|
||||
self.get_logger().debug(f"[TTS队列] 文本已成功放入队列: {text[:50]}... (队列大小: {self.tts_queue.qsize()})")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"[TTS队列] 放入队列失败: {e}, 文本: {text[:50]}")
|
||||
|
||||
def _interrupt_tts(self, reason: str):
|
||||
"""
|
||||
中断TTS播放,只设置中断事件,不清空队列,让TTS线程自己检查并停止播放
|
||||
"""
|
||||
self.get_logger().info(f"[中断] {reason}")
|
||||
self.interrupt_event.set()
|
||||
|
||||
@staticmethod
|
||||
def _drain_queue(q: queue.Queue):
|
||||
"""清空队列"""
|
||||
while True:
|
||||
try:
|
||||
q.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _start_session(self):
|
||||
"""开始会话"""
|
||||
with self.session_lock:
|
||||
self.session_active = True
|
||||
self.session_start_time = time.time()
|
||||
|
||||
def _reset_session(self):
|
||||
"""重置会话"""
|
||||
with self.session_lock:
|
||||
self.session_start_time = time.time()
|
||||
|
||||
def _is_session_active(self) -> bool:
|
||||
"""检查会话是否活跃"""
|
||||
with self.session_lock:
|
||||
if not self.session_active:
|
||||
return False
|
||||
if time.time() - self.session_start_time >= self.session_timeout:
|
||||
self.session_active = False
|
||||
return False
|
||||
return True
|
||||
|
||||
# ==================== 意图处理 ====================
|
||||
def _handle_wake_word(self, text: str) -> str:
|
||||
"""处理唤醒词:ASR文本转拼音,检查是否包含唤醒词拼音"""
|
||||
if not self.use_wake_word:
|
||||
return text.strip()
|
||||
|
||||
if self._is_session_active():
|
||||
self._reset_session()
|
||||
return text.strip()
|
||||
|
||||
text_pinyin = self.intent_router.to_pinyin(text)
|
||||
wake_word_pinyin = self.wake_word.lower().strip()
|
||||
self.get_logger().info(f"[唤醒词] 原始文本: {text}, 文本拼音: {text_pinyin}, 唤醒词拼音: {wake_word_pinyin}")
|
||||
if not wake_word_pinyin:
|
||||
self.get_logger().info("[唤醒词] 唤醒词为空,过滤文本")
|
||||
return ""
|
||||
|
||||
text_pinyin_parts = text_pinyin.split() if text_pinyin else []
|
||||
wake_word_parts = wake_word_pinyin.split()
|
||||
|
||||
start_idx = -1
|
||||
for i in range(len(text_pinyin_parts) - len(wake_word_parts) + 1):
|
||||
if text_pinyin_parts[i:i + len(wake_word_parts)] == wake_word_parts:
|
||||
start_idx = i
|
||||
break
|
||||
|
||||
if start_idx == -1:
|
||||
self.get_logger().info(f"[唤醒词] 未检测到唤醒词 '{self.wake_word}',过滤文本")
|
||||
return ""
|
||||
|
||||
removed = 0
|
||||
new_text = ""
|
||||
for c in text:
|
||||
if '\u4e00' <= c <= '\u9fa5':
|
||||
if removed < start_idx or removed >= start_idx + len(wake_word_parts):
|
||||
new_text += c
|
||||
removed += 1
|
||||
else:
|
||||
new_text += c
|
||||
|
||||
self._start_session()
|
||||
return new_text.strip()
|
||||
|
||||
def _check_shutup_command(self, text: str) -> bool:
|
||||
"""检查闭嘴指令"""
|
||||
if not text:
|
||||
return False
|
||||
text_lower = text.lower()
|
||||
text_pinyin = self.intent_router.to_pinyin(text)
|
||||
for keyword in self.shutup_keywords:
|
||||
kw = keyword.lower().strip()
|
||||
if not kw:
|
||||
continue
|
||||
if kw in text_lower or (text_pinyin and kw in text_pinyin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_intent(self, intent_payload: IntentResult):
|
||||
"""按意图路由到不同处理逻辑"""
|
||||
intent = intent_payload.intent
|
||||
text = intent_payload.text
|
||||
need_camera = intent_payload.need_camera
|
||||
system_prompt = intent_payload.system_prompt
|
||||
|
||||
if intent == "kb_qa":
|
||||
answer = None
|
||||
text_pinyin = self.intent_router.to_pinyin(text)
|
||||
if text_pinyin:
|
||||
answer = self.kb_answers_map.get(text_pinyin)
|
||||
if answer:
|
||||
if "{wake_word}" in answer:
|
||||
answer = answer.replace("{wake_word}", self.wake_word or "")
|
||||
self._put_tts_text(answer)
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
if self.use_llm and self.llm_client:
|
||||
if self.history:
|
||||
self.history.start_turn(text)
|
||||
|
||||
reply = self._llm_process_stream_with_camera(
|
||||
text,
|
||||
need_camera=need_camera,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
if reply:
|
||||
if self.history:
|
||||
self.history.commit_turn(reply)
|
||||
else:
|
||||
if self.history:
|
||||
self.history.cancel_turn()
|
||||
else:
|
||||
self.get_logger().warning("[主线程] 未启用LLM,无法处理文本")
|
||||
|
||||
# ==================== 资源清理 ====================
|
||||
def destroy_node(self):
|
||||
"""销毁节点"""
|
||||
self.get_logger().info("语音节点正在关闭...")
|
||||
self.stop_event.set()
|
||||
self.interrupt_event.set()
|
||||
self.get_logger().info("强制停止TTS播放...")
|
||||
self._force_stop_tts()
|
||||
|
||||
self._drain_queue(self.tts_queue)
|
||||
|
||||
threads_to_join = [self.recording_thread, self.asr_thread, self.process_thread, self.tts_thread]
|
||||
if self.sv_thread:
|
||||
threads_to_join.append(self.sv_thread)
|
||||
for thread in threads_to_join:
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
self._force_stop_tts()
|
||||
|
||||
if hasattr(self, 'asr_client') and self.asr_client:
|
||||
self.asr_client.stop()
|
||||
|
||||
if hasattr(self, 'audio_recorder') and self.audio_recorder:
|
||||
self.audio_recorder.cleanup()
|
||||
|
||||
if hasattr(self, 'camera_client') and self.camera_client:
|
||||
self.camera_client.cleanup()
|
||||
|
||||
if hasattr(self, 'sv_client') and self.sv_client:
|
||||
try:
|
||||
self.sv_client.save_speakers()
|
||||
self.sv_client.cleanup()
|
||||
except Exception as e:
|
||||
self.get_logger().warning(f"清理声纹识别资源时出错: {e}")
|
||||
|
||||
super().destroy_node()
|
||||
|
||||
|
||||
def _init_ros(args):
|
||||
rclpy.init(args=args)
|
||||
|
||||
def _create_node():
|
||||
return RobotSpeakerNode()
|
||||
|
||||
def _run_node(node):
|
||||
rclpy.spin(node)
|
||||
|
||||
def _cleanup_node(node):
|
||||
if node:
|
||||
node.destroy_node()
|
||||
|
||||
def _shutdown_ros():
|
||||
if rclpy.ok():
|
||||
rclpy.shutdown()
|
||||
|
||||
# ==================== 入口 ====================
|
||||
def main(args=None):
|
||||
node = None
|
||||
_init_ros(args)
|
||||
node = _create_node()
|
||||
_run_node(node)
|
||||
_cleanup_node(node)
|
||||
_shutdown_ros()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,157 +0,0 @@
|
||||
"""
|
||||
回声消除模块
|
||||
mic = 人声 + 扬声器回声 + 环境噪声
|
||||
ref = 声卡原始播放音频
|
||||
AEC(mic, ref) → 去掉 ref 在 mic 中的那一部分
|
||||
"""
|
||||
import numpy as np
|
||||
import struct
|
||||
import threading
|
||||
from collections import deque
|
||||
import aec_audio_processing
|
||||
|
||||
|
||||
class EchoCanceller:
|
||||
"""回声消除器"""
|
||||
|
||||
def __init__(self, sample_rate: int, frame_size: int, channels: int,
|
||||
ref_channels: int, logger=None):
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.frame_size = frame_size
|
||||
self.channels = channels
|
||||
self.ref_channels = ref_channels
|
||||
self.logger = logger
|
||||
self.aec = None
|
||||
self.aec_frame_size = None # AudioProcessor期望的帧大小(固定10ms=160样本)
|
||||
|
||||
# 初始化aec-audio-processing的AudioProcessor
|
||||
try:
|
||||
self.aec = aec_audio_processing.AudioProcessor(
|
||||
enable_aec=True, # 回声消除
|
||||
enable_ns=False, # 降噪
|
||||
enable_agc=False # 自动增益
|
||||
)
|
||||
|
||||
# 设置流格式(麦克风输入:1声道,16kHz)
|
||||
self.aec.set_stream_format(
|
||||
sample_rate_in=sample_rate,
|
||||
channel_count_in=channels,
|
||||
sample_rate_out=sample_rate,
|
||||
channel_count_out=channels
|
||||
)
|
||||
|
||||
# 设置反向流格式(参考信号:播放是2声道,重采样到16kHz)
|
||||
# 参考信号是声卡播放的音频(2声道),重采样到16kHz用于回声消除
|
||||
self.aec.set_reverse_stream_format(sample_rate, ref_channels)
|
||||
|
||||
# 获取AudioProcessor期望的帧大小(固定10ms)
|
||||
self.aec_frame_size = self.aec.get_frame_size()
|
||||
if logger:
|
||||
logger.info(f"AudioProcessor期望的帧大小: {self.aec_frame_size} 样本 ({self.aec_frame_size / sample_rate * 1000}ms)")
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"aec_audio_processing 初始化失败: {e},将禁用回声消除")
|
||||
self.aec = None
|
||||
|
||||
def process(self, mic_signal: bytes, ref_signal: bytes = None) -> bytes:
|
||||
"""处理音频数据,消除回声(在录音线程中同步处理)"""
|
||||
if self.aec is None or ref_signal is None or self.aec_frame_size is None:
|
||||
return mic_signal
|
||||
|
||||
# 保存原始长度
|
||||
original_mic_len = len(mic_signal)
|
||||
|
||||
# AudioProcessor期望固定10ms的帧,需要将大的chunk分成多个小块处理
|
||||
# 麦克风(1声道):160样本 * 1声道 * 2字节 = 320字节
|
||||
mic_frame_bytes = self.aec_frame_size * self.channels * 2
|
||||
# 参考信号(2声道):160样本 * 2声道 * 2字节 = 640字节
|
||||
ref_frame_bytes = self.aec_frame_size * self.ref_channels * 2
|
||||
|
||||
# 确保输入数据长度是帧大小的整数倍
|
||||
if len(mic_signal) % mic_frame_bytes != 0:
|
||||
padding = mic_frame_bytes - (len(mic_signal) % mic_frame_bytes)
|
||||
mic_signal = mic_signal + b'\x00' * padding
|
||||
|
||||
if len(ref_signal) % ref_frame_bytes != 0:
|
||||
padding = ref_frame_bytes - (len(ref_signal) % ref_frame_bytes)
|
||||
ref_signal = ref_signal + b'\x00' * padding
|
||||
|
||||
# 分块处理:将大的chunk(1024样本)分成多个10ms块(160样本)处理
|
||||
try:
|
||||
num_frames = len(mic_signal) // mic_frame_bytes
|
||||
output_chunks = []
|
||||
for i in range(num_frames):
|
||||
mic_chunk = mic_signal[i * mic_frame_bytes:(i + 1) * mic_frame_bytes]
|
||||
ref_chunk = ref_signal[i * ref_frame_bytes:(i + 1) * ref_frame_bytes]
|
||||
|
||||
self.aec.process_reverse_stream(ref_chunk)
|
||||
output_chunk = self.aec.process_stream(mic_chunk)
|
||||
|
||||
# AudioProcessor.process_stream返回bytes
|
||||
output_chunks.append(output_chunk)
|
||||
|
||||
result = b''.join(output_chunks)
|
||||
return result[:original_mic_len]
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"回声消除处理失败: {e}")
|
||||
return mic_signal[:original_mic_len]
|
||||
|
||||
|
||||
class ReferenceSignalBuffer:
|
||||
"""缓存声卡播放的参考音频(供 AEC 使用)"""
|
||||
|
||||
def __init__(self, max_duration_ms: int, sample_rate: int, channels: int):
|
||||
max_samples = int(sample_rate * max_duration_ms / 1000)
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels # 参考信号声道数(播放声道数,2声道)
|
||||
self.buffer = deque(maxlen=max_samples * channels)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def add_reference(self, audio_data: bytes, source_sample_rate: int = None, source_channels: int = 1):
|
||||
"""
|
||||
添加参考信号
|
||||
"""
|
||||
if not audio_data:
|
||||
return
|
||||
with self.lock:
|
||||
# 重采样:TTS源采样率 -> 麦克风采样率(匹配麦克风采样率)
|
||||
if source_sample_rate and source_sample_rate != self.sample_rate:
|
||||
audio_data = self._resample(audio_data, source_sample_rate, self.sample_rate)
|
||||
|
||||
# 转换声道数:1声道 -> 2声道(匹配播放声道数)
|
||||
samples = struct.unpack(f'<{len(audio_data) // 2}h', audio_data)
|
||||
if source_channels == 1 and self.channels == 2:
|
||||
# 单声道转2声道:复制到左右声道
|
||||
stereo_samples = [s for sample in samples for s in [sample, sample]]
|
||||
samples = stereo_samples
|
||||
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def get_reference(self, num_samples: int) -> bytes:
|
||||
"""获取参考信号(指定样本数,考虑声道数)"""
|
||||
with self.lock:
|
||||
if not self.buffer:
|
||||
return b'\x00' * (num_samples * self.channels * 2)
|
||||
# 需要的总样本数(考虑声道数)
|
||||
total_samples_needed = num_samples * self.channels
|
||||
samples = list(self.buffer)[-total_samples_needed:] if len(self.buffer) >= total_samples_needed else list(self.buffer)
|
||||
if len(samples) < total_samples_needed:
|
||||
samples = [0] * (total_samples_needed - len(samples)) + samples
|
||||
return struct.pack(f'<{len(samples)}h', *samples)
|
||||
|
||||
def clear(self):
|
||||
"""清空缓冲区"""
|
||||
with self.lock:
|
||||
self.buffer.clear()
|
||||
|
||||
def _resample(self, audio_data: bytes, source_rate: int, target_rate: int) -> bytes:
|
||||
"""简单线性重采样"""
|
||||
if source_rate == target_rate:
|
||||
return audio_data
|
||||
samples = np.frombuffer(audio_data, dtype=np.int16)
|
||||
ratio = target_rate / source_rate
|
||||
indices = np.linspace(0, len(samples) - 1, int(len(samples) * ratio))
|
||||
resampled = np.interp(indices, np.arange(len(samples)), samples.astype(np.float32))
|
||||
return resampled.astype(np.int16).tobytes()
|
||||
4
robot_speaker/models/__init__.py
Normal file
4
robot_speaker/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""模型层"""
|
||||
|
||||
|
||||
|
||||
4
robot_speaker/models/asr/__init__.py
Normal file
4
robot_speaker/models/asr/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""ASR模型"""
|
||||
|
||||
|
||||
|
||||
12
robot_speaker/models/asr/base.py
Normal file
12
robot_speaker/models/asr/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class ASRClient:
|
||||
def start(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def send_audio(self, audio_data: bytes) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,10 @@ import threading
|
||||
import dashscope
|
||||
from dashscope.audio.qwen_omni import OmniRealtimeConversation, OmniRealtimeCallback
|
||||
from dashscope.audio.qwen_omni.omni_realtime import TranscriptionParams, MultiModality
|
||||
from robot_speaker.models.asr.base import ASRClient
|
||||
|
||||
|
||||
class DashScopeASR:
|
||||
class DashScopeASR(ASRClient):
|
||||
"""DashScope实时ASR识别器封装"""
|
||||
|
||||
def __init__(self, api_key: str,
|
||||
@@ -65,15 +66,11 @@ class DashScopeASR:
|
||||
callback.conversation = self.conversation
|
||||
|
||||
self.conversation.connect()
|
||||
|
||||
# 自定义文本语料增强识别,听不清的时候高概率说这个词
|
||||
# custom_text = "二狗"
|
||||
|
||||
transcription_params = TranscriptionParams(
|
||||
language='zh',
|
||||
sample_rate=self.sample_rate,
|
||||
input_audio_format="pcm",
|
||||
# corpus_text=custom_text,
|
||||
)
|
||||
|
||||
# 本地 VAD → 只控制 TTS 打断
|
||||
4
robot_speaker/models/llm/__init__.py
Normal file
4
robot_speaker/models/llm/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""LLM模型"""
|
||||
|
||||
|
||||
|
||||
14
robot_speaker/models/llm/base.py
Normal file
14
robot_speaker/models/llm/base.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
interrupt_check=None) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -3,18 +3,9 @@ LLM大语言模型模块
|
||||
支持多模态(文本+图像)
|
||||
"""
|
||||
from openai import OpenAI
|
||||
from .types import LLMMessage
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def chat(self, messages: list[LLMMessage]) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_stream(self, messages: list[LLMMessage],
|
||||
on_token=None,
|
||||
interrupt_check=None) -> str | None:
|
||||
raise NotImplementedError
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
from robot_speaker.models.llm.base import LLMClient
|
||||
|
||||
|
||||
class DashScopeLLM(LLMClient):
|
||||
4
robot_speaker/models/tts/__init__.py
Normal file
4
robot_speaker/models/tts/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""TTS模型"""
|
||||
|
||||
|
||||
|
||||
13
robot_speaker/models/tts/base.py
Normal file
13
robot_speaker/models/tts/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
|
||||
|
||||
class TTSClient:
|
||||
"""TTS客户端抽象基类"""
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -4,16 +4,8 @@ TTS语音合成模块
|
||||
import subprocess
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, ResultCallback, AudioFormat
|
||||
from .types import TTSRequest
|
||||
|
||||
|
||||
class TTSClient:
|
||||
"""TTS客户端抽象基类"""
|
||||
|
||||
def synthesize(self, request: TTSRequest,
|
||||
on_chunk=None,
|
||||
interrupt_check=None) -> bool:
|
||||
raise NotImplementedError
|
||||
from robot_speaker.core.types import TTSRequest
|
||||
from robot_speaker.models.tts.base import TTSClient
|
||||
|
||||
|
||||
class DashScopeTTSClient(TTSClient):
|
||||
@@ -205,8 +197,8 @@ class _TTSCallback(ResultCallback):
|
||||
# 参考信号处理失败不应影响播放
|
||||
self.tts_client._log("warning", f"参考信号处理失败: {e}")
|
||||
|
||||
if self.on_chunk:
|
||||
self.on_chunk(data)
|
||||
if self.on_chunk:
|
||||
self.on_chunk(data)
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
@@ -221,21 +213,26 @@ class _TTSCallback(ResultCallback):
|
||||
except:
|
||||
pass
|
||||
|
||||
# 等待进程自然结束(最多2秒)
|
||||
# 等待进程自然结束(根据文本长度估算,最少10秒,最多30秒)
|
||||
# 假设平均语速:3-4字/秒,加上缓冲时间
|
||||
if self._proc.poll() is None:
|
||||
try:
|
||||
self._proc.wait(timeout=2.0)
|
||||
# 增加等待时间,确保ffmpeg播放完成
|
||||
# 对于长文本,可能需要更长时间
|
||||
self._proc.wait(timeout=30.0)
|
||||
except:
|
||||
# 超时后强制终止
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=0.5)
|
||||
except:
|
||||
# 超时后,如果进程还在运行,说明可能卡住了,强制终止
|
||||
if self._proc.poll() is None:
|
||||
self.tts_client._log("warning", "ffmpeg播放超时,强制终止")
|
||||
try:
|
||||
self._proc.kill()
|
||||
self._proc.wait(timeout=0.1)
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=1.0)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
self._proc.kill()
|
||||
self._proc.wait(timeout=0.1)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 清空PID记录
|
||||
if self.tts_client.current_ffmpeg_pid == self._proc.pid:
|
||||
4
robot_speaker/perception/__init__.py
Normal file
4
robot_speaker/perception/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""感知层"""
|
||||
|
||||
|
||||
|
||||
@@ -18,13 +18,7 @@ class VADDetector:
|
||||
|
||||
|
||||
class AudioRecorder:
|
||||
"""
|
||||
音频录音器 - 录音线程
|
||||
功能:
|
||||
1. VAD + 能量检测
|
||||
2. 检测到人声 → 立即中断TTS
|
||||
3. 音频chunk → 队列(直接发送,不保存文件)
|
||||
"""
|
||||
"""音频录音器 - 录音线程"""
|
||||
|
||||
def __init__(self, device_index: int, sample_rate: int, channels: int,
|
||||
chunk: int, vad_detector: VADDetector,
|
||||
@@ -102,11 +96,7 @@ class AudioRecorder:
|
||||
self.echo_canceller = None
|
||||
|
||||
def record_with_vad(self):
|
||||
"""
|
||||
录音线程:VAD + 能量检测
|
||||
- 检测到人声 → 立即中断TTS
|
||||
- 音频chunk → 队列(直接发送,不保存文件)
|
||||
"""
|
||||
"""录音线程:VAD + 能量检测"""
|
||||
if self.on_heartbeat:
|
||||
self.on_heartbeat()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
相机模块 - RealSense相机封装
|
||||
相机默认一直运行,只在用户说拍照时捕获图像
|
||||
"""
|
||||
import numpy as np
|
||||
import contextlib
|
||||
@@ -35,7 +34,6 @@ class CameraClient:
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
初始化并启动相机管道
|
||||
相机启动后会一直运行,直到调用 cleanup()
|
||||
"""
|
||||
if self._is_initialized:
|
||||
return True
|
||||
@@ -58,7 +56,6 @@ class CameraClient:
|
||||
self.fps
|
||||
)
|
||||
|
||||
# 启动管道,相机开始运行
|
||||
self.pipeline.start(self.config)
|
||||
self._is_initialized = True
|
||||
self._log("info", f"相机已启动并保持运行: {self.width}x{self.height}@{self.fps}fps")
|
||||
@@ -80,7 +77,6 @@ class CameraClient:
|
||||
def capture_rgb(self) -> np.ndarray | None:
|
||||
"""
|
||||
从运行中的相机管道捕获一帧RGB图像
|
||||
相机管道必须已经通过 initialize() 启动
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self._log("error", "相机未初始化,无法捕获图像")
|
||||
98
robot_speaker/perception/echo_cancellation.py
Normal file
98
robot_speaker/perception/echo_cancellation.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ReferenceSignalBuffer:
|
||||
"""参考信号缓冲区"""
|
||||
|
||||
def __init__(self, sample_rate: int, channels: int, max_duration_ms: int | None = None,
|
||||
buffer_seconds: float = 5.0):
|
||||
self.sample_rate = int(sample_rate)
|
||||
self.channels = int(channels)
|
||||
if max_duration_ms is not None:
|
||||
buffer_seconds = max(float(max_duration_ms) / 1000.0, 0.1)
|
||||
self.max_samples = int(self.sample_rate * buffer_seconds)
|
||||
self._buffer = collections.deque(maxlen=self.max_samples * self.channels)
|
||||
|
||||
def add_reference(self, data: bytes, source_sample_rate: int, source_channels: int):
|
||||
if source_sample_rate != self.sample_rate or source_channels != self.channels:
|
||||
return
|
||||
samples = np.frombuffer(data, dtype=np.int16)
|
||||
self._buffer.extend(samples.tolist())
|
||||
|
||||
def get_reference(self, num_samples: int) -> bytes:
|
||||
needed = int(num_samples) * self.channels
|
||||
if needed <= 0:
|
||||
return b""
|
||||
if len(self._buffer) < needed:
|
||||
data = list(self._buffer) + [0] * (needed - len(self._buffer))
|
||||
else:
|
||||
data = list(self._buffer)[-needed:]
|
||||
return np.array(data, dtype=np.int16).tobytes()
|
||||
|
||||
|
||||
class EchoCanceller:
|
||||
"""回声消除器(基于 aec-audio-processing)"""
|
||||
|
||||
def __init__(self, sample_rate: int, frame_size: int, channels: int, ref_channels: int, logger=None):
|
||||
self.sample_rate = int(sample_rate)
|
||||
self.frame_size = int(frame_size)
|
||||
self.channels = int(channels)
|
||||
self.ref_channels = int(ref_channels)
|
||||
self.logger = logger
|
||||
self.aec = None
|
||||
self._process_reverse = None
|
||||
self._frame_bytes = int(self.sample_rate / 100) * self.channels * 2 # 10ms, int16
|
||||
self._ref_frame_bytes = int(self.sample_rate / 100) * self.ref_channels * 2
|
||||
try:
|
||||
from aec_audio_processing import AudioProcessor
|
||||
self.aec = AudioProcessor(enable_aec=True, enable_ns=False, enable_agc=False)
|
||||
self.aec.set_stream_format(self.sample_rate, self.channels)
|
||||
if hasattr(self.aec, "set_reverse_stream_format"):
|
||||
self.aec.set_reverse_stream_format(self.sample_rate, self.ref_channels)
|
||||
if hasattr(self.aec, "set_stream_delay"):
|
||||
self.aec.set_stream_delay(0)
|
||||
if hasattr(self.aec, "process_reverse_stream"):
|
||||
self._process_reverse = self.aec.process_reverse_stream
|
||||
elif hasattr(self.aec, "process_reverse"):
|
||||
self._process_reverse = self.aec.process_reverse
|
||||
except Exception:
|
||||
self.aec = None
|
||||
|
||||
def process(self, mic_data: bytes, ref_data: bytes) -> bytes:
|
||||
if not self.aec:
|
||||
return mic_data
|
||||
if not mic_data:
|
||||
return mic_data
|
||||
|
||||
try:
|
||||
out_chunks = []
|
||||
total_len = len(mic_data)
|
||||
frame_bytes = self._frame_bytes
|
||||
ref_frame_bytes = self._ref_frame_bytes
|
||||
|
||||
frame_count = (total_len + frame_bytes - 1) // frame_bytes
|
||||
for i in range(frame_count):
|
||||
m_start = i * frame_bytes
|
||||
m_end = m_start + frame_bytes
|
||||
mic_frame = mic_data[m_start:m_end]
|
||||
if len(mic_frame) < frame_bytes:
|
||||
mic_frame = mic_frame + b"\x00" * (frame_bytes - len(mic_frame))
|
||||
|
||||
if ref_data:
|
||||
r_start = i * ref_frame_bytes
|
||||
r_end = r_start + ref_frame_bytes
|
||||
ref_frame = ref_data[r_start:r_end]
|
||||
if len(ref_frame) < ref_frame_bytes:
|
||||
ref_frame = ref_frame + b"\x00" * (ref_frame_bytes - len(ref_frame))
|
||||
if self._process_reverse:
|
||||
self._process_reverse(ref_frame)
|
||||
|
||||
processed = self.aec.process_stream(mic_frame)
|
||||
out_chunks.append(processed if processed is not None else mic_frame)
|
||||
|
||||
return b"".join(out_chunks)[:total_len]
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"回声消除处理失败: {e},使用原始音频")
|
||||
return mic_data
|
||||
@@ -29,12 +29,8 @@ class SpeakerVerificationClient:
|
||||
self.logger = logger
|
||||
self.speaker_db = {} # {speaker_id: {"embedding": np.ndarray, "env": str, "threshold": float, "registered_at": float}}
|
||||
self._lock = threading.Lock()
|
||||
# 固定embedding维度,避免第一次异常embedding导致后续全部被拒绝
|
||||
# 常见维度:192 (CAM++), 256 等
|
||||
self._expected_embedding_dim = None # 只存储维度大小,不存储shape元组
|
||||
|
||||
from funasr import AutoModel
|
||||
# 确保模型路径是绝对路径(展开 ~)
|
||||
model_path = os.path.expanduser(self.model_path)
|
||||
self.model = AutoModel(model=model_path, device="cpu")
|
||||
if self.logger:
|
||||
@@ -47,7 +43,6 @@ class SpeakerVerificationClient:
|
||||
"""记录日志 - 修复ROS2 logger在多线程环境中的问题"""
|
||||
if self.logger:
|
||||
try:
|
||||
# 使用映射字典,避免动态获取方法导致ROS2 logger错误
|
||||
log_methods = {
|
||||
"debug": self.logger.debug,
|
||||
"info": self.logger.info,
|
||||
@@ -58,8 +53,6 @@ class SpeakerVerificationClient:
|
||||
log_method = log_methods.get(level.lower(), self.logger.info)
|
||||
log_method(msg)
|
||||
except ValueError as e:
|
||||
# ROS2 logger在多线程环境中可能出现"Logger severity cannot be changed between calls"错误
|
||||
# 如果出现此错误,降级使用info级别
|
||||
if "severity cannot be changed" in str(e):
|
||||
try:
|
||||
self.logger.info(f"[声纹-{level.upper()}] {msg}")
|
||||
@@ -88,9 +81,6 @@ class SpeakerVerificationClient:
|
||||
def extract_embedding(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
"""
|
||||
提取说话人embedding(低频调用,一句话只调用一次)
|
||||
返回: (embedding: np.ndarray | None, success: bool)
|
||||
- 成功返回 (embedding, True)
|
||||
- 失败返回 (None, False)
|
||||
"""
|
||||
if len(audio_data) < int(sample_rate * 0.5):
|
||||
return None, False
|
||||
@@ -100,25 +90,13 @@ class SpeakerVerificationClient:
|
||||
temp_wav_path = self._write_temp_wav(audio_data, sample_rate)
|
||||
result = self.model.generate(input=temp_wav_path)
|
||||
|
||||
# funasr返回格式:list[dict],dict键为'spk_embedding',值为torch.Tensor(shape=[1, 192])
|
||||
import torch
|
||||
embedding = result[0]['spk_embedding'].detach().cpu().numpy()[0] # shape [1, 192] -> [192]
|
||||
|
||||
# 校验embedding维度
|
||||
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, False
|
||||
|
||||
# 如果已有注册的声纹,校验维度是否一致
|
||||
if self._expected_embedding_dim is not None:
|
||||
if embedding_dim != self._expected_embedding_dim:
|
||||
self._log("error", f"embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
|
||||
return None, False
|
||||
else:
|
||||
# 第一次成功提取,固定维度(只在第一次成功时设置,避免异常embedding污染)
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
self._log("info", f"固定embedding维度: {embedding_dim}")
|
||||
|
||||
return embedding, True
|
||||
except Exception as e:
|
||||
self._log("error", f"提取embedding失败: {e}")
|
||||
@@ -134,22 +112,10 @@ class SpeakerVerificationClient:
|
||||
env: str = "near", threshold: float = None) -> bool:
|
||||
"""
|
||||
注册说话人
|
||||
关键修复:注册时对embedding进行归一化,一劳永逸
|
||||
"""
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return False
|
||||
|
||||
# 校验维度一致性
|
||||
if self._expected_embedding_dim is not None:
|
||||
if embedding_dim != self._expected_embedding_dim:
|
||||
self._log("error", f"注册失败:embedding维度不匹配: 期望{self._expected_embedding_dim}, 实际{embedding_dim}")
|
||||
return False
|
||||
else:
|
||||
# 如果还没有固定维度,使用当前embedding的维度
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
|
||||
# 关键修复:注册时归一化embedding
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
self._log("error", f"注册失败:embedding范数为0")
|
||||
@@ -159,26 +125,21 @@ class SpeakerVerificationClient:
|
||||
speaker_threshold = threshold if threshold is not None else self.threshold
|
||||
|
||||
with self._lock:
|
||||
# 使用dict结构存储,便于序列化
|
||||
self.speaker_db[speaker_id] = {
|
||||
"embedding": embedding_normalized, # 已归一化
|
||||
"env": env,
|
||||
"embedding": embedding_normalized,
|
||||
"env": env, # 添加 env 字段
|
||||
"threshold": speaker_threshold,
|
||||
"registered_at": time.time()
|
||||
}
|
||||
self._log("info", f"已注册说话人: {speaker_id}, 阈值: {speaker_threshold:.3f}, 维度: {embedding_dim}")
|
||||
|
||||
# 在锁外调用save_speakers,避免死锁(save_speakers内部会获取锁)
|
||||
save_result = self.save_speakers()
|
||||
if not save_result:
|
||||
# 使用info级别而不是warning,避免ROS2 logger在多线程环境中的问题
|
||||
self._log("info", f"保存声纹数据库失败,但说话人已注册到内存: {speaker_id}")
|
||||
return True
|
||||
|
||||
def match_speaker(self, embedding: np.ndarray):
|
||||
"""
|
||||
匹配说话人(一句话只调用一次)
|
||||
返回: (speaker_id: str | None, state: SpeakerState, score: float, threshold: float)
|
||||
"""
|
||||
if not self.speaker_db:
|
||||
return None, SpeakerState.UNKNOWN, 0.0, self.threshold
|
||||
@@ -186,12 +147,7 @@ class SpeakerVerificationClient:
|
||||
embedding_dim = len(embedding)
|
||||
if embedding_dim == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
|
||||
# 校验维度一致性
|
||||
if self._expected_embedding_dim is not None and embedding_dim != self._expected_embedding_dim:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
|
||||
# 归一化当前embedding(注册时已归一化,这里只需要归一化当前embedding)
|
||||
|
||||
embedding_norm = np.linalg.norm(embedding)
|
||||
if embedding_norm == 0:
|
||||
return None, SpeakerState.ERROR, 0.0, self.threshold
|
||||
@@ -203,13 +159,7 @@ class SpeakerVerificationClient:
|
||||
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
# 从dict结构读取
|
||||
ref_embedding = speaker_data["embedding"] # 注册时已归一化
|
||||
|
||||
if len(ref_embedding) != embedding_dim:
|
||||
continue
|
||||
|
||||
# 直接计算cosine similarity(两个向量都已归一化)
|
||||
ref_embedding = speaker_data["embedding"]
|
||||
score = np.dot(embedding_normalized, ref_embedding)
|
||||
|
||||
if score > best_score:
|
||||
@@ -242,8 +192,6 @@ class SpeakerVerificationClient:
|
||||
def load_speakers(self) -> bool:
|
||||
"""
|
||||
从文件加载已注册的声纹
|
||||
使用JSON格式存储,更安全、可迁移
|
||||
格式: {speaker_id: {"embedding": [list], "env": str, "threshold": float, "registered_at": float}}
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
return False
|
||||
@@ -257,25 +205,14 @@ class SpeakerVerificationClient:
|
||||
data = json.load(f)
|
||||
|
||||
with self._lock:
|
||||
# 转换list回numpy array,并校验维度
|
||||
for speaker_id, speaker_data in data.items():
|
||||
embedding_list = speaker_data["embedding"]
|
||||
embedding_array = np.array(embedding_list, dtype=np.float32)
|
||||
|
||||
# 校验维度
|
||||
embedding_dim = len(embedding_array)
|
||||
if embedding_dim == 0:
|
||||
self._log("warning", f"跳过无效声纹: {speaker_id} (维度为0)")
|
||||
continue
|
||||
|
||||
# 设置期望维度(如果还没有)
|
||||
if self._expected_embedding_dim is None:
|
||||
self._expected_embedding_dim = embedding_dim
|
||||
elif embedding_dim != self._expected_embedding_dim:
|
||||
self._log("warning", f"跳过维度不匹配的声纹: {speaker_id} (期望{self._expected_embedding_dim}, 实际{embedding_dim})")
|
||||
continue
|
||||
|
||||
# 确保embedding已归一化
|
||||
embedding_norm = np.linalg.norm(embedding_array)
|
||||
if embedding_norm > 0:
|
||||
embedding_array = embedding_array / embedding_norm
|
||||
@@ -288,10 +225,7 @@ class SpeakerVerificationClient:
|
||||
}
|
||||
|
||||
count = len(self.speaker_db)
|
||||
if self._expected_embedding_dim:
|
||||
self._log("info", f"已加载 {count} 个已注册说话人,embedding维度: {self._expected_embedding_dim}")
|
||||
else:
|
||||
self._log("info", f"已加载 {count} 个已注册说话人")
|
||||
self._log("info", f"已加载 {count} 个已注册说话人")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log("error", f"加载声纹数据库失败: {e}")
|
||||
@@ -300,7 +234,6 @@ class SpeakerVerificationClient:
|
||||
def save_speakers(self) -> bool:
|
||||
"""
|
||||
保存已注册的声纹到文件
|
||||
使用JSON格式,更安全、可迁移
|
||||
"""
|
||||
if not self.speaker_db_path:
|
||||
self._log("warning", "声纹数据库路径未配置,无法保存到文件(说话人已注册到内存)")
|
||||
@@ -310,24 +243,20 @@ class SpeakerVerificationClient:
|
||||
db_dir = os.path.dirname(self.speaker_db_path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# 转换为JSON可序列化的格式
|
||||
json_data = {}
|
||||
with self._lock:
|
||||
for speaker_id, speaker_data in self.speaker_db.items():
|
||||
json_data[speaker_id] = {
|
||||
"embedding": speaker_data["embedding"].tolist(), # numpy array -> list
|
||||
"env": speaker_data["env"],
|
||||
"env": speaker_data.get("env", "near"), # 兼容旧数据,默认使用 "near"
|
||||
"threshold": speaker_data["threshold"],
|
||||
"registered_at": speaker_data["registered_at"]
|
||||
}
|
||||
|
||||
# 使用临时文件 + 原子替换,避免写入过程中崩溃导致数据丢失
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 原子替换
|
||||
|
||||
os.replace(temp_path, self.speaker_db_path)
|
||||
|
||||
self._log("info", f"已保存 {len(json_data)} 个说话人到: {self.speaker_db_path}")
|
||||
@@ -337,7 +266,6 @@ class SpeakerVerificationClient:
|
||||
self._log("error", f"保存声纹数据库失败: {e}")
|
||||
self._log("error", f"保存路径: {self.speaker_db_path}")
|
||||
self._log("error", f"错误详情: {traceback.format_exc()}")
|
||||
# 清理临时文件
|
||||
temp_path = self.speaker_db_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
File diff suppressed because it is too large
Load Diff
4
robot_speaker/understanding/__init__.py
Normal file
4
robot_speaker/understanding/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""理解层"""
|
||||
|
||||
|
||||
|
||||
@@ -1,22 +1,14 @@
|
||||
"""
|
||||
对话历史管理模块
|
||||
"""
|
||||
from .types import LLMMessage
|
||||
from robot_speaker.core.types import LLMMessage
|
||||
import threading
|
||||
|
||||
|
||||
class ConversationHistory:
|
||||
"""
|
||||
对话历史管理器 - 实时语音友好版本
|
||||
"""对话历史管理器 - 实时语音"""
|
||||
|
||||
使用待确认机制确保历史完整性:
|
||||
1. start_turn() - 开始新轮次,暂存用户消息
|
||||
2. get_messages() - 获取历史(包含待确认的用户消息,用于LLM上下文)
|
||||
3. commit_turn() - 确认轮次完成,写入历史
|
||||
4. cancel_turn() - 取消当前轮次,丢弃待确认消息
|
||||
"""
|
||||
|
||||
def __init__(self, max_history: int = 3, summary_trigger: int = 3):
|
||||
def __init__(self, max_history: int, summary_trigger: int):
|
||||
self.max_history = max_history
|
||||
self.summary_trigger = summary_trigger
|
||||
self.conversation_history: list[LLMMessage] = []
|
||||
@@ -27,53 +19,38 @@ class ConversationHistory:
|
||||
self._lock = threading.Lock() # 线程安全锁
|
||||
|
||||
def start_turn(self, user_content: str):
|
||||
"""
|
||||
开始一个新的对话轮次,暂存用户消息,等待LLM完成后确认写入历史
|
||||
"""
|
||||
"""开始一个新的对话轮次,暂存用户消息,等待LLM完成后确认写入历史"""
|
||||
with self._lock:
|
||||
# 如果有未确认的轮次,新消息会覆盖它(不写入历史,防止半句污染)
|
||||
# 这是正常的场景,比如用户快速连续说话,不需要特殊处理
|
||||
self._pending_user_message = LLMMessage(role="user", content=user_content)
|
||||
|
||||
def commit_turn(self, assistant_content: str) -> bool:
|
||||
"""
|
||||
确认当前轮次完成,将用户和助手消息写入历史
|
||||
"""
|
||||
"""确认当前轮次完成,将usr和assistant消息写入历史"""
|
||||
with self._lock:
|
||||
if self._pending_user_message is None:
|
||||
return False
|
||||
|
||||
# 只有助手回复非空时才写入历史
|
||||
if not assistant_content or not assistant_content.strip():
|
||||
self._pending_user_message = None
|
||||
return False
|
||||
|
||||
# 写入用户消息和助手回复
|
||||
self.conversation_history.append(self._pending_user_message)
|
||||
self.conversation_history.append(
|
||||
LLMMessage(role="assistant", content=assistant_content.strip())
|
||||
)
|
||||
|
||||
# 清空待确认消息
|
||||
|
||||
self._pending_user_message = None
|
||||
|
||||
# 检查是否需要压缩
|
||||
|
||||
self._maybe_compress()
|
||||
return True
|
||||
|
||||
def cancel_turn(self):
|
||||
"""
|
||||
取消当前待确认的轮次,丢弃待确认的用户消息,用于处理中断情况,防止不完整内容污染历史
|
||||
"""
|
||||
"""取消当前待确认的轮次,丢弃待确认的用户消息,用于处理中断情况,防止不完整内容污染历史"""
|
||||
with self._lock:
|
||||
if self._pending_user_message is not None:
|
||||
self._pending_user_message = None
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
"""
|
||||
直接添加消息(向后兼容,但推荐使用 start_turn/commit_turn)
|
||||
注意:此方法会立即写入历史,不会经过待确认机制
|
||||
"""
|
||||
"""直接添加消息"""
|
||||
with self._lock:
|
||||
# 如果有待确认的轮次,先取消它
|
||||
self.cancel_turn()
|
||||
@@ -81,21 +58,16 @@ class ConversationHistory:
|
||||
self._maybe_compress()
|
||||
|
||||
def get_messages(self) -> list[LLMMessage]:
|
||||
"""
|
||||
获取消息列表(包含摘要和待确认的用户消息)
|
||||
"""
|
||||
"""获取消息列表"""
|
||||
with self._lock:
|
||||
messages = []
|
||||
|
||||
# 添加摘要
|
||||
|
||||
if self.summary:
|
||||
messages.append(LLMMessage(role="system", content=self.summary))
|
||||
|
||||
# 添加历史消息
|
||||
if self.max_history > 0:
|
||||
messages.extend(self.conversation_history[-self.max_history * 2:])
|
||||
|
||||
# 添加待确认的用户消息(用于LLM上下文,但不写入历史)
|
||||
if self._pending_user_message is not None:
|
||||
messages.append(self._pending_user_message)
|
||||
|
||||
9
setup.py
9
setup.py
@@ -1,4 +1,4 @@
|
||||
from setuptools import setup
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
@@ -7,14 +7,14 @@ package_name = 'robot_speaker'
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.0.1',
|
||||
packages=[package_name],
|
||||
packages=find_packages(where='.'),
|
||||
package_dir={'': '.'},
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml') + glob('config/*.json')),
|
||||
],
|
||||
install_requires=[
|
||||
'setuptools',
|
||||
@@ -28,7 +28,8 @@ setup(
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'robot_speaker_node = robot_speaker.robot_speaker_node:main',
|
||||
'robot_speaker_node = robot_speaker.core.robot_speaker_node:main',
|
||||
'register_speaker_node = robot_speaker.core.register_speaker_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user