mirror of
https://github.com/HiMeditator/auto-caption.git
synced 2026-02-16 05:01:07 +08:00
feat(engine): 重构字幕引擎并实现 WebSocket 通信
- 重构了 Gummy 和 Vosk 字幕引擎的代码,提高了可扩展性和可读性 - 合并 Gummy 和 Vosk 引擎为单个可执行文件 - 实现了字幕引擎和主程序之间的 WebSocket 通信,避免了孤儿进程问题
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from dashscope.common.error import InvalidParameter
|
||||
from .gummy import GummyTranslator
|
||||
from .gummy import GummyRecognizer
|
||||
from .vosk import VoskRecognizer
|
||||
@@ -62,7 +62,7 @@ class Callback(TranslationRecognizerCallback):
|
||||
stdout_obj(caption)
|
||||
|
||||
|
||||
class GummyTranslator:
|
||||
class GummyRecognizer:
|
||||
"""
|
||||
使用 Gummy 引擎流式处理的音频数据,并在标准输出中输出与 Auto Caption 软件可读取的 JSON 字符串数据
|
||||
|
||||
@@ -70,6 +70,7 @@ class GummyTranslator:
|
||||
rate: 音频采样率
|
||||
source: 源语言代码字符串(zh, en, ja 等)
|
||||
target: 目标语言代码字符串(zh, en, ja 等)
|
||||
api_key: 阿里云百炼平台 API KEY
|
||||
"""
|
||||
def __init__(self, rate: int, source: str, target: str | None, api_key: str | None):
|
||||
if api_key:
|
||||
|
||||
@@ -2,7 +2,8 @@ import json
|
||||
from datetime import datetime
|
||||
|
||||
from vosk import Model, KaldiRecognizer, SetLogLevel
|
||||
from utils import stdout_obj
|
||||
from utils import stdout_cmd, stdout_obj
|
||||
|
||||
|
||||
class VoskRecognizer:
|
||||
"""
|
||||
@@ -11,7 +12,7 @@ class VoskRecognizer:
|
||||
初始化参数:
|
||||
model_path: Vosk 识别模型路径
|
||||
"""
|
||||
def __int__(self, model_path: str):
|
||||
def __init__(self, model_path: str):
|
||||
SetLogLevel(-1)
|
||||
if model_path.startswith('"'):
|
||||
model_path = model_path[1:]
|
||||
@@ -24,7 +25,11 @@ class VoskRecognizer:
|
||||
|
||||
self.model = Model(self.model_path)
|
||||
self.recognizer = KaldiRecognizer(self.model, 16000)
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动 Vosk 引擎"""
|
||||
stdout_cmd('info', 'Vosk recognizer started.')
|
||||
|
||||
def send_audio_frame(self, data: bytes):
|
||||
"""
|
||||
发送音频帧给 Vosk 引擎,引擎将自动识别并将识别结果输出到标准输出中
|
||||
@@ -57,3 +62,7 @@ class VoskRecognizer:
|
||||
self.prev_content = content
|
||||
|
||||
stdout_obj(caption)
|
||||
|
||||
def stop(self):
|
||||
"""停止 Vosk 引擎"""
|
||||
stdout_cmd('info', 'Vosk recognizer closed.')
|
||||
@@ -1,49 +0,0 @@
|
||||
import sys
|
||||
import argparse
|
||||
from sysaudio import AudioStream
|
||||
from utils import merge_chunk_channels
|
||||
from audio2text import InvalidParameter, GummyTranslator
|
||||
|
||||
|
||||
def convert_audio_to_text(s_lang, t_lang, audio_type, chunk_rate, api_key):
|
||||
stream = AudioStream(audio_type, chunk_rate)
|
||||
|
||||
if t_lang == 'none':
|
||||
gummy = GummyTranslator(stream.RATE, s_lang, None, api_key)
|
||||
else:
|
||||
gummy = GummyTranslator(stream.RATE, s_lang, t_lang, api_key)
|
||||
|
||||
stream.open_stream()
|
||||
gummy.start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = stream.read_chunk()
|
||||
if chunk is None: continue
|
||||
chunk_mono = merge_chunk_channels(chunk, stream.CHANNELS)
|
||||
try:
|
||||
gummy.send_audio_frame(chunk_mono)
|
||||
except InvalidParameter:
|
||||
gummy.start()
|
||||
gummy.send_audio_frame(chunk_mono)
|
||||
except KeyboardInterrupt:
|
||||
stream.close_stream()
|
||||
gummy.stop()
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert system audio stream to text')
|
||||
parser.add_argument('-s', '--source_language', default='en', help='Source language code')
|
||||
parser.add_argument('-t', '--target_language', default='zh', help='Target language code')
|
||||
parser.add_argument('-a', '--audio_type', default=0, help='Audio stream source: 0 for output audio stream, 1 for input audio stream')
|
||||
parser.add_argument('-c', '--chunk_rate', default=20, help='The number of audio stream chunks collected per second.')
|
||||
parser.add_argument('-k', '--api_key', default='', help='API KEY for Gummy model')
|
||||
args = parser.parse_args()
|
||||
convert_audio_to_text(
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
int(args.audio_type),
|
||||
int(args.chunk_rate),
|
||||
args.api_key
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
|
||||
a = Analysis(
|
||||
['main-gummy.py'],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=[],
|
||||
hiddenimports=[],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
noarchive=False,
|
||||
optimize=0,
|
||||
)
|
||||
pyz = PYZ(a.pure)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.datas,
|
||||
[],
|
||||
name='main-gummy',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
onefile=True,
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import numpy.core.multiarray
|
||||
|
||||
from sysaudio import AudioStream
|
||||
from vosk import Model, KaldiRecognizer, SetLogLevel
|
||||
from utils import resample_chunk_mono
|
||||
|
||||
SetLogLevel(-1)
|
||||
|
||||
def convert_audio_to_text(audio_type, chunk_rate, model_path):
|
||||
sys.stdout.reconfigure(line_buffering=True) # type: ignore
|
||||
|
||||
if model_path.startswith('"'):
|
||||
model_path = model_path[1:]
|
||||
if model_path.endswith('"'):
|
||||
model_path = model_path[:-1]
|
||||
|
||||
model = Model(model_path)
|
||||
recognizer = KaldiRecognizer(model, 16000)
|
||||
|
||||
stream = AudioStream(audio_type, chunk_rate)
|
||||
stream.open_stream()
|
||||
|
||||
time_str = ''
|
||||
cur_id = 0
|
||||
prev_content = ''
|
||||
|
||||
while True:
|
||||
chunk = stream.read_chunk()
|
||||
if chunk is None: continue
|
||||
chunk_mono = resample_chunk_mono(chunk, stream.CHANNELS, stream.RATE, 16000)
|
||||
|
||||
caption = {}
|
||||
if recognizer.AcceptWaveform(chunk_mono):
|
||||
content = json.loads(recognizer.Result()).get('text', '')
|
||||
caption['index'] = cur_id
|
||||
caption['text'] = content
|
||||
caption['time_s'] = time_str
|
||||
caption['time_t'] = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
caption['translation'] = ''
|
||||
prev_content = ''
|
||||
cur_id += 1
|
||||
else:
|
||||
content = json.loads(recognizer.PartialResult()).get('partial', '')
|
||||
if content == '' or content == prev_content:
|
||||
continue
|
||||
if prev_content == '':
|
||||
time_str = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
caption['command'] = 'caption'
|
||||
caption['index'] = cur_id
|
||||
caption['text'] = content
|
||||
caption['time_s'] = time_str
|
||||
caption['time_t'] = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
caption['translation'] = ''
|
||||
prev_content = content
|
||||
try:
|
||||
json_str = json.dumps(caption) + '\n'
|
||||
sys.stdout.write(json_str)
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert system audio stream to text')
|
||||
parser.add_argument('-a', '--audio_type', default=0, help='Audio stream source: 0 for output audio stream, 1 for input audio stream')
|
||||
parser.add_argument('-c', '--chunk_rate', default=20, help='The number of audio stream chunks collected per second.')
|
||||
parser.add_argument('-m', '--model_path', default='', help='The path to the vosk model.')
|
||||
args = parser.parse_args()
|
||||
convert_audio_to_text(
|
||||
int(args.audio_type),
|
||||
int(args.chunk_rate),
|
||||
args.model_path
|
||||
)
|
||||
@@ -1,10 +1,59 @@
|
||||
import argparse
|
||||
from utils import stdout_cmd
|
||||
from utils import thread_data, start_server
|
||||
from utils import merge_chunk_channels, resample_chunk_mono
|
||||
from audio2text import InvalidParameter, GummyRecognizer
|
||||
from audio2text import VoskRecognizer
|
||||
from sysaudio import AudioStream
|
||||
|
||||
def gummy_engine(s, t, a, c, k):
|
||||
pass
|
||||
|
||||
def vosk_engine(a, c, m):
|
||||
pass
|
||||
def main_gummy(s: str, t: str, a: int, c: int, k: str):
|
||||
stream = AudioStream(a, c)
|
||||
if t == 'none':
|
||||
engine = GummyRecognizer(stream.RATE, s, None, k)
|
||||
else:
|
||||
engine = GummyRecognizer(stream.RATE, s, t, k)
|
||||
|
||||
stream.open_stream()
|
||||
engine.start()
|
||||
|
||||
while thread_data.status == "running":
|
||||
try:
|
||||
chunk = stream.read_chunk()
|
||||
if chunk is None: continue
|
||||
chunk_mono = merge_chunk_channels(chunk, stream.CHANNELS)
|
||||
try:
|
||||
engine.send_audio_frame(chunk_mono)
|
||||
except InvalidParameter:
|
||||
stdout_cmd('info', 'Gummy engine stopped, restart engine')
|
||||
engine.start()
|
||||
engine.send_audio_frame(chunk_mono)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
stream.close_stream()
|
||||
engine.stop()
|
||||
|
||||
|
||||
def main_vosk(a: int, c: int, m: str):
|
||||
stream = AudioStream(a, c)
|
||||
engine = VoskRecognizer(m)
|
||||
|
||||
stream.open_stream()
|
||||
engine.start()
|
||||
|
||||
while thread_data.status == "running":
|
||||
try:
|
||||
chunk = stream.read_chunk()
|
||||
if chunk is None: continue
|
||||
chunk_mono = resample_chunk_mono(chunk, stream.CHANNELS, stream.RATE, 16000)
|
||||
engine.send_audio_frame(chunk_mono)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
stream.close_stream()
|
||||
engine.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert system audio stream to text')
|
||||
@@ -12,15 +61,23 @@ if __name__ == "__main__":
|
||||
parser.add_argument('-e', '--caption_engine', default='gummy', help='Caption engine: gummy or vosk')
|
||||
parser.add_argument('-a', '--audio_type', default=0, help='Audio stream source: 0 for output, 1 for input')
|
||||
parser.add_argument('-c', '--chunk_rate', default=20, help='Number of audio stream chunks collected per second')
|
||||
parser.add_argument('-p', '--port', default=7070, help='The port to run the server on, 0 for no server')
|
||||
# gummy
|
||||
parser.add_argument('-s', '--source_language', default='en', help='Source language code')
|
||||
parser.add_argument('-t', '--target_language', default='zh', help='Target language code')
|
||||
parser.add_argument('-k', '--api_key', default='', help='API KEY for Gummy model')
|
||||
# vosk
|
||||
parser.add_argument('-m', '--model_path', default='', help='The path to the vosk model.')
|
||||
# for test
|
||||
args = parser.parse_args()
|
||||
|
||||
if int(args.port) == 0:
|
||||
thread_data.status = "running"
|
||||
else:
|
||||
start_server(int(args.port))
|
||||
|
||||
if args.caption_engine == 'gummy':
|
||||
gummy_engine(
|
||||
main_gummy(
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
int(args.audio_type),
|
||||
@@ -28,7 +85,7 @@ if __name__ == "__main__":
|
||||
args.api_key
|
||||
)
|
||||
elif args.caption_engine == 'vosk':
|
||||
vosk_engine(
|
||||
main_vosk(
|
||||
int(args.audio_type),
|
||||
int(args.chunk_rate),
|
||||
args.model_path
|
||||
|
||||
@@ -9,7 +9,7 @@ else:
|
||||
vosk_path = str(Path('./subenv/lib/python3.12/site-packages/vosk').resolve())
|
||||
|
||||
a = Analysis(
|
||||
['main-vosk.py'],
|
||||
['main.py'],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=[(vosk_path, 'vosk')],
|
||||
@@ -30,7 +30,7 @@ exe = EXE(
|
||||
a.binaries,
|
||||
a.datas,
|
||||
[],
|
||||
name='main-vosk',
|
||||
name='main',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
@@ -1,2 +1,4 @@
|
||||
from .process import merge_chunk_channels, resample_chunk_mono, resample_mono_chunk
|
||||
from .sysout import stdout, stdout_cmd, stdout_obj, stderr
|
||||
from .audioprcs import merge_chunk_channels, resample_chunk_mono, resample_mono_chunk
|
||||
from .sysout import stdout, stdout_cmd, stdout_obj, stderr
|
||||
from .thdata import thread_data
|
||||
from .server import start_server
|
||||
@@ -1,6 +1,6 @@
|
||||
import samplerate
|
||||
import numpy as np
|
||||
|
||||
import numpy.core.multiarray
|
||||
|
||||
def merge_chunk_channels(chunk: bytes, channels: int) -> bytes:
|
||||
"""
|
||||
@@ -13,6 +13,7 @@ def merge_chunk_channels(chunk: bytes, channels: int) -> bytes:
|
||||
Returns:
|
||||
单通道音频数据块
|
||||
"""
|
||||
if channels == 1: return chunk
|
||||
# (length * channels,)
|
||||
chunk_np = np.frombuffer(chunk, dtype=np.int16)
|
||||
# (length, channels)
|
||||
@@ -37,13 +38,17 @@ def resample_chunk_mono(chunk: bytes, channels: int, orig_sr: int, target_sr: in
|
||||
Return:
|
||||
单通道音频数据块
|
||||
"""
|
||||
# (length * channels,)
|
||||
chunk_np = np.frombuffer(chunk, dtype=np.int16)
|
||||
# (length, channels)
|
||||
chunk_np = chunk_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
chunk_mono_f = np.mean(chunk_np.astype(np.float32), axis=1)
|
||||
chunk_mono = chunk_mono_f.astype(np.int16)
|
||||
if channels == 1:
|
||||
chunk_mono = chunk
|
||||
else:
|
||||
# (length * channels,)
|
||||
chunk_np = np.frombuffer(chunk, dtype=np.int16)
|
||||
# (length, channels)
|
||||
chunk_np = chunk_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
chunk_mono_f = np.mean(chunk_np.astype(np.float32), axis=1)
|
||||
chunk_mono = chunk_mono_f.astype(np.int16)
|
||||
|
||||
ratio = target_sr / orig_sr
|
||||
chunk_mono_r = samplerate.resample(chunk_mono, ratio, converter_type=mode)
|
||||
chunk_mono_r = np.round(chunk_mono_r).astype(np.int16)
|
||||
37
engine/utils/server.py
Normal file
37
engine/utils/server.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import socket
|
||||
import threading
|
||||
import json
|
||||
from utils import thread_data, stdout_cmd, stderr
|
||||
|
||||
|
||||
def handle_client(client_socket):
|
||||
global thread_data
|
||||
while True:
|
||||
try:
|
||||
data = client_socket.recv(4096).decode('utf-8')
|
||||
if not data:
|
||||
break
|
||||
data = json.loads(data)
|
||||
|
||||
if data['command'] == 'stop':
|
||||
if thread_data.status == 'running':
|
||||
thread_data.status = 'stop'
|
||||
break
|
||||
except Exception as e:
|
||||
stderr(f'Communication error: {e}')
|
||||
break
|
||||
|
||||
thread_data.status = 'stop'
|
||||
client_socket.close()
|
||||
|
||||
|
||||
def start_server(port: int):
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.bind(('localhost', port))
|
||||
server.listen(1)
|
||||
stdout_cmd('ready')
|
||||
|
||||
client, addr = server.accept()
|
||||
client_handler = threading.Thread(target=handle_client, args=(client,))
|
||||
client_handler.daemon = True
|
||||
client_handler.start()
|
||||
5
engine/utils/thdata.py
Normal file
5
engine/utils/thdata.py
Normal file
@@ -0,0 +1,5 @@
|
||||
class ThreadData:
|
||||
def __init__(self):
|
||||
self.status = "running"
|
||||
|
||||
thread_data = ThreadData()
|
||||
Reference in New Issue
Block a user