mirror of
https://github.com/HiMeditator/auto-caption.git
synced 2026-02-14 20:02:03 +08:00
feat(engine): 添加GLM-ASR语音识别引擎支持
- 新增GLM-ASR云端语音识别引擎实现 - 扩展配置界面添加GLM相关参数设置 - Ollama支持自定义域名和Apikey以支持云端和其他LLM - 修改音频处理逻辑以支持新引擎 - 更新依赖项和构建配置 - 修复Ollama翻译功能相关问题
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .gummy import GummyRecognizer
|
||||
from .vosk import VoskRecognizer
|
||||
from .sosv import SosvRecognizer
|
||||
from .sosv import SosvRecognizer
|
||||
from .glm import GlmRecognizer
|
||||
|
||||
163
engine/audio2text/glm.py
Normal file
163
engine/audio2text/glm.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import threading
|
||||
import io
|
||||
import wave
|
||||
import struct
|
||||
import math
|
||||
import audioop
|
||||
import requests
|
||||
from datetime import datetime
|
||||
|
||||
from utils import shared_data
|
||||
from utils import stdout_cmd, stdout_obj, google_translate, ollama_translate
|
||||
|
||||
class GlmRecognizer:
|
||||
"""
|
||||
使用 GLM-ASR 引擎处理音频数据,并在标准输出中输出 Auto Caption 软件可读取的 JSON 字符串数据
|
||||
|
||||
初始化参数:
|
||||
url: GLM-ASR API URL
|
||||
model: GLM-ASR 模型名称
|
||||
api_key: GLM-ASR API Key
|
||||
source: 源语言
|
||||
target: 目标语言
|
||||
trans_model: 翻译模型名称
|
||||
ollama_name: Ollama 模型名称
|
||||
"""
|
||||
def __init__(self, url: str, model: str, api_key: str, source: str, target: str | None, trans_model: str, ollama_name: str, ollama_url: str = '', ollama_api_key: str = ''):
|
||||
self.url = url
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.source = source
|
||||
self.target = target
|
||||
if trans_model == 'google':
|
||||
self.trans_func = google_translate
|
||||
else:
|
||||
self.trans_func = ollama_translate
|
||||
self.ollama_name = ollama_name
|
||||
self.ollama_url = ollama_url
|
||||
self.ollama_api_key = ollama_api_key
|
||||
|
||||
self.audio_buffer = []
|
||||
self.is_speech = False
|
||||
self.silence_frames = 0
|
||||
self.speech_start_time = None
|
||||
self.time_str = ''
|
||||
self.cur_id = 0
|
||||
|
||||
# VAD settings (假设 16k 16bit, chunk size 1024 or similar)
|
||||
# 16bit = 2 bytes per sample.
|
||||
# RMS threshold needs tuning. 500 is a conservative guess for silence.
|
||||
self.threshold = 500
|
||||
self.silence_limit = 15 # frames (approx 0.5-1s depending on chunk size)
|
||||
self.min_speech_frames = 10 # frames
|
||||
|
||||
def start(self):
|
||||
"""启动 GLM 引擎"""
|
||||
stdout_cmd('info', 'GLM-ASR recognizer started.')
|
||||
|
||||
def stop(self):
|
||||
"""停止 GLM 引擎"""
|
||||
stdout_cmd('info', 'GLM-ASR recognizer stopped.')
|
||||
|
||||
def process_audio(self, chunk):
|
||||
# chunk is bytes (int16)
|
||||
rms = audioop.rms(chunk, 2)
|
||||
|
||||
if rms > self.threshold:
|
||||
if not self.is_speech:
|
||||
self.is_speech = True
|
||||
self.time_str = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
self.audio_buffer = []
|
||||
self.audio_buffer.append(chunk)
|
||||
self.silence_frames = 0
|
||||
else:
|
||||
if self.is_speech:
|
||||
self.audio_buffer.append(chunk)
|
||||
self.silence_frames += 1
|
||||
if self.silence_frames > self.silence_limit:
|
||||
# Speech ended
|
||||
if len(self.audio_buffer) > self.min_speech_frames:
|
||||
self.recognize(self.audio_buffer, self.time_str)
|
||||
self.is_speech = False
|
||||
self.audio_buffer = []
|
||||
self.silence_frames = 0
|
||||
|
||||
def recognize(self, audio_frames, time_s):
|
||||
audio_bytes = b''.join(audio_frames)
|
||||
|
||||
wav_io = io.BytesIO()
|
||||
with wave.open(wav_io, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.writeframes(audio_bytes)
|
||||
wav_io.seek(0)
|
||||
|
||||
threading.Thread(
|
||||
target=self._do_request,
|
||||
args=(wav_io.read(), time_s, self.cur_id)
|
||||
).start()
|
||||
self.cur_id += 1
|
||||
|
||||
def _do_request(self, audio_content, time_s, index):
|
||||
try:
|
||||
files = {
|
||||
'file': ('audio.wav', audio_content, 'audio/wav')
|
||||
}
|
||||
data = {
|
||||
'model': self.model,
|
||||
'stream': 'false'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
response = requests.post(self.url, headers=headers, data=data, files=files, timeout=15)
|
||||
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
text = res_json.get('text', '')
|
||||
if text:
|
||||
self.output_caption(text, time_s, index)
|
||||
else:
|
||||
try:
|
||||
err_msg = response.json()
|
||||
stdout_cmd('error', f"GLM API Error: {err_msg}")
|
||||
except:
|
||||
stdout_cmd('error', f"GLM API Error: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
stdout_cmd('error', f"GLM Request Failed: {str(e)}")
|
||||
|
||||
def output_caption(self, text, time_s, index):
|
||||
caption = {
|
||||
'command': 'caption',
|
||||
'index': index,
|
||||
'time_s': time_s,
|
||||
'time_t': datetime.now().strftime('%H:%M:%S.%f')[:-3],
|
||||
'text': text,
|
||||
'translation': ''
|
||||
}
|
||||
|
||||
if self.target:
|
||||
if self.trans_func == ollama_translate:
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], time_s, self.ollama_url, self.ollama_api_key),
|
||||
daemon=True
|
||||
)
|
||||
else:
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], time_s),
|
||||
daemon=True
|
||||
)
|
||||
th.start()
|
||||
|
||||
stdout_obj(caption)
|
||||
|
||||
def translate(self):
|
||||
global shared_data
|
||||
while shared_data.status == 'running':
|
||||
chunk = shared_data.chunk_queue.get()
|
||||
self.process_audio(chunk)
|
||||
@@ -29,7 +29,7 @@ class SosvRecognizer:
|
||||
trans_model: 翻译模型名称
|
||||
ollama_name: Ollama 模型名称
|
||||
"""
|
||||
def __init__(self, model_path: str, source: str, target: str | None, trans_model: str, ollama_name: str):
|
||||
def __init__(self, model_path: str, source: str, target: str | None, trans_model: str, ollama_name: str, ollama_url: str = '', ollama_api_key: str = ''):
|
||||
if model_path.startswith('"'):
|
||||
model_path = model_path[1:]
|
||||
if model_path.endswith('"'):
|
||||
@@ -45,6 +45,8 @@ class SosvRecognizer:
|
||||
else:
|
||||
self.trans_func = ollama_translate
|
||||
self.ollama_name = ollama_name
|
||||
self.ollama_url = ollama_url
|
||||
self.ollama_api_key = ollama_api_key
|
||||
self.time_str = ''
|
||||
self.cur_id = 0
|
||||
self.prev_content = ''
|
||||
@@ -152,7 +154,7 @@ class SosvRecognizer:
|
||||
if self.target:
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str),
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str, self.ollama_url, self.ollama_api_key),
|
||||
daemon=True
|
||||
)
|
||||
th.start()
|
||||
|
||||
@@ -18,7 +18,7 @@ class VoskRecognizer:
|
||||
trans_model: 翻译模型名称
|
||||
ollama_name: Ollama 模型名称
|
||||
"""
|
||||
def __init__(self, model_path: str, target: str | None, trans_model: str, ollama_name: str):
|
||||
def __init__(self, model_path: str, target: str | None, trans_model: str, ollama_name: str, ollama_url: str = '', ollama_api_key: str = ''):
|
||||
SetLogLevel(-1)
|
||||
if model_path.startswith('"'):
|
||||
model_path = model_path[1:]
|
||||
@@ -31,6 +31,8 @@ class VoskRecognizer:
|
||||
else:
|
||||
self.trans_func = ollama_translate
|
||||
self.ollama_name = ollama_name
|
||||
self.ollama_url = ollama_url
|
||||
self.ollama_api_key = ollama_api_key
|
||||
self.time_str = ''
|
||||
self.cur_id = 0
|
||||
self.prev_content = ''
|
||||
@@ -66,7 +68,7 @@ class VoskRecognizer:
|
||||
if self.target:
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str),
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str, self.ollama_url, self.ollama_api_key),
|
||||
daemon=True
|
||||
)
|
||||
th.start()
|
||||
|
||||
Reference in New Issue
Block a user