mirror of
https://github.com/HiMeditator/auto-caption.git
synced 2026-04-08 13:19:39 +08:00
refactor(caption-engine): 重构字幕引擎代码结构
- 重构 GummyTranslator 类,增加启动和停止方法 - 优化 AudioStream 类,添加读取音频数据方法 - 更新 main-gummy.py,使用新的 GummyTranslator 和 AudioStream 接口 - 更新文档和 TODO 列表
This commit is contained in:
2
caption-engine/audio2text/__init__.py
Normal file
2
caption-engine/audio2text/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from dashscope.common.error import InvalidParameter
|
||||
from .gummy import GummyTranslator
|
||||
@@ -69,6 +69,14 @@ class Callback(TranslationRecognizerCallback):
|
||||
print(f"Error sending data to Node.js: {e}", file=sys.stderr)
|
||||
|
||||
class GummyTranslator:
|
||||
"""
|
||||
使用 Gummy 引擎流式处理的音频数据,并在标准输出中输出与 Auto Caption 软件可读取的 JSON 字符串数据
|
||||
|
||||
初始化参数:
|
||||
rate: 音频采样率
|
||||
source: 源语言代码字符串(zh, en, ja 等)
|
||||
target: 目标语言代码字符串(zh, en, ja 等)
|
||||
"""
|
||||
def __init__(self, rate, source, target):
|
||||
self.translator = TranslationRecognizerRealtime(
|
||||
model = "gummy-realtime-v1",
|
||||
@@ -80,3 +88,15 @@ class GummyTranslator:
|
||||
translation_target_languages = [target],
|
||||
callback = Callback()
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""启动 Gummy 引擎"""
|
||||
self.translator.start()
|
||||
|
||||
def send_audio_frame(self, data):
|
||||
"""发送音频帧"""
|
||||
self.translator.send_audio_frame(data)
|
||||
|
||||
def stop(self):
|
||||
"""停止 Gummy 引擎"""
|
||||
self.translator.stop()
|
||||
|
||||
1
caption-engine/audioprcs/__init__.py
Normal file
1
caption-engine/audioprcs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .streamchnl import mergeStreamChannels
|
||||
22
caption-engine/audioprcs/streamchnl.py
Normal file
22
caption-engine/audioprcs/streamchnl.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
|
||||
def mergeStreamChannels(data, channels):
|
||||
"""
|
||||
将当前多通道流数据合并为单通道流数据
|
||||
|
||||
Args:
|
||||
data: 多通道数据
|
||||
channels: 通道数
|
||||
|
||||
Returns:
|
||||
mono_data_bytes: 单通道数据
|
||||
"""
|
||||
# (length * channels,)
|
||||
data_np = np.frombuffer(data, dtype=np.int16)
|
||||
# (length, channels)
|
||||
data_np_r = data_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
mono_data = np.mean(data_np_r.astype(np.float32), axis=1)
|
||||
mono_data = mono_data.astype(np.int16)
|
||||
mono_data_bytes = mono_data.tobytes()
|
||||
return mono_data_bytes
|
||||
@@ -1,40 +1,41 @@
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
if sys.platform == 'win32':
|
||||
from sysaudio.win import AudioStream, mergeStreamChannels
|
||||
from sysaudio.win import AudioStream
|
||||
elif sys.platform == 'linux':
|
||||
from sysaudio.linux import AudioStream, mergeStreamChannels
|
||||
from sysaudio.linux import AudioStream
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
from audio2text.gummy import GummyTranslator
|
||||
import sys
|
||||
import argparse
|
||||
from audioprcs import mergeStreamChannels
|
||||
from audio2text import InvalidParameter, GummyTranslator
|
||||
|
||||
|
||||
def convert_audio_to_text(s_lang, t_lang, audio_type):
|
||||
sys.stdout.reconfigure(line_buffering=True) # type: ignore
|
||||
stream = AudioStream(audio_type)
|
||||
stream.openStream()
|
||||
|
||||
if t_lang == 'none':
|
||||
gummy = GummyTranslator(stream.RATE, s_lang, None)
|
||||
else:
|
||||
gummy = GummyTranslator(stream.RATE, s_lang, t_lang)
|
||||
gummy.translator.start()
|
||||
|
||||
stream.openStream()
|
||||
gummy.start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not stream.stream: continue
|
||||
data = stream.stream.read(stream.CHUNK)
|
||||
data = stream.read_chunk()
|
||||
data = mergeStreamChannels(data, stream.CHANNELS)
|
||||
try:
|
||||
gummy.translator.send_audio_frame(data)
|
||||
except:
|
||||
gummy.translator.start()
|
||||
gummy.translator.send_audio_frame(data)
|
||||
gummy.send_audio_frame(data)
|
||||
except InvalidParameter:
|
||||
gummy.start()
|
||||
gummy.send_audio_frame(data)
|
||||
except KeyboardInterrupt:
|
||||
stream.closeStream()
|
||||
gummy.translator.stop()
|
||||
gummy.stop()
|
||||
break
|
||||
|
||||
|
||||
@@ -47,5 +48,5 @@ if __name__ == "__main__":
|
||||
convert_audio_to_text(
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
0 if args.audio_type == '0' else 1
|
||||
int(args.audio_type)
|
||||
)
|
||||
|
||||
0
caption-engine/sysaudio/__init__.py
Normal file
0
caption-engine/sysaudio/__init__.py
Normal file
@@ -1,30 +1,15 @@
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
|
||||
def mergeStreamChannels(data, channels):
|
||||
"""
|
||||
将当前多通道流数据合并为单通道流数据
|
||||
|
||||
Args:
|
||||
data: 多通道数据
|
||||
channels: 通道数
|
||||
|
||||
Returns:
|
||||
mono_data_bytes: 单通道数据
|
||||
"""
|
||||
# (length * channels,)
|
||||
data_np = np.frombuffer(data, dtype=np.int16)
|
||||
# (length, channels)
|
||||
data_np_r = data_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
mono_data = np.mean(data_np_r.astype(np.float32), axis=1)
|
||||
mono_data = mono_data.astype(np.int16)
|
||||
mono_data_bytes = mono_data.tobytes()
|
||||
return mono_data_bytes
|
||||
|
||||
|
||||
class AudioStream:
|
||||
def __init__(self, audio_type=1):
|
||||
"""
|
||||
获取系统音频流
|
||||
|
||||
初始化参数:
|
||||
audio_type: 0-系统音频输出流(不支持,不会生效),1-系统音频输入流(默认)
|
||||
chunk_rate: 每秒采集音频块的数量,默认为20
|
||||
"""
|
||||
def __init__(self, audio_type=1, chunk_rate=20):
|
||||
self.audio_type = audio_type
|
||||
self.mic = pyaudio.PyAudio()
|
||||
self.device = self.mic.get_default_input_device_info()
|
||||
@@ -33,7 +18,7 @@ class AudioStream:
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = self.device["maxInputChannels"]
|
||||
self.RATE = int(self.device["defaultSampleRate"])
|
||||
self.CHUNK = self.RATE // 20
|
||||
self.CHUNK = self.RATE // chunk_rate
|
||||
self.INDEX = self.device["index"]
|
||||
|
||||
def printInfo(self):
|
||||
@@ -62,13 +47,20 @@ class AudioStream:
|
||||
if self.stream: return self.stream
|
||||
self.stream = self.mic.open(
|
||||
format = self.FORMAT,
|
||||
channels = self.CHANNELS,
|
||||
channels = int(self.CHANNELS),
|
||||
rate = self.RATE,
|
||||
input = True,
|
||||
input_device_index = self.INDEX
|
||||
input_device_index = int(self.INDEX)
|
||||
)
|
||||
return self.stream
|
||||
|
||||
|
||||
def read_chunk(self):
|
||||
"""
|
||||
读取音频数据
|
||||
"""
|
||||
if not self.stream: return None
|
||||
return self.stream.read(self.CHUNK)
|
||||
|
||||
def closeStream(self):
|
||||
"""
|
||||
关闭系统音频输出流
|
||||
@@ -76,4 +68,4 @@ class AudioStream:
|
||||
if self.stream is None: return
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
self.stream = None
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""获取 Windows 系统音频输出流"""
|
||||
|
||||
import pyaudiowpatch as pyaudio
|
||||
import numpy as np
|
||||
|
||||
|
||||
def getDefaultLoopbackDevice(mic: pyaudio.PyAudio, info = True)->dict:
|
||||
@@ -40,35 +39,15 @@ def getDefaultLoopbackDevice(mic: pyaudio.PyAudio, info = True)->dict:
|
||||
return default_speaker
|
||||
|
||||
|
||||
def mergeStreamChannels(data, channels):
|
||||
"""
|
||||
将当前多通道流数据合并为单通道流数据
|
||||
|
||||
Args:
|
||||
data: 多通道数据
|
||||
channels: 通道数
|
||||
|
||||
Returns:
|
||||
mono_data_bytes: 单通道数据
|
||||
"""
|
||||
# (length * channels,)
|
||||
data_np = np.frombuffer(data, dtype=np.int16)
|
||||
# (length, channels)
|
||||
data_np_r = data_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
mono_data = np.mean(data_np_r.astype(np.float32), axis=1)
|
||||
mono_data = mono_data.astype(np.int16)
|
||||
mono_data_bytes = mono_data.tobytes()
|
||||
return mono_data_bytes
|
||||
|
||||
class AudioStream:
|
||||
"""
|
||||
获取系统音频流
|
||||
|
||||
参数:
|
||||
audio_type: (默认)0-系统音频输出流,1-系统音频输入流
|
||||
初始化参数:
|
||||
audio_type: 0-系统音频输出流(默认),1-系统音频输入流
|
||||
chunk_rate: 每秒采集音频块的数量,默认为20
|
||||
"""
|
||||
def __init__(self, audio_type=0):
|
||||
def __init__(self, audio_type=0, chunk_rate=20):
|
||||
self.audio_type = audio_type
|
||||
self.mic = pyaudio.PyAudio()
|
||||
if self.audio_type == 0:
|
||||
@@ -80,7 +59,7 @@ class AudioStream:
|
||||
self.FORMAT = pyaudio.paInt16
|
||||
self.CHANNELS = self.device["maxInputChannels"]
|
||||
self.RATE = int(self.device["defaultSampleRate"])
|
||||
self.CHUNK = self.RATE // 20
|
||||
self.CHUNK = self.RATE // chunk_rate
|
||||
self.INDEX = self.device["index"]
|
||||
|
||||
def printInfo(self):
|
||||
@@ -117,6 +96,13 @@ class AudioStream:
|
||||
)
|
||||
return self.stream
|
||||
|
||||
def read_chunk(self):
|
||||
"""
|
||||
读取音频数据
|
||||
"""
|
||||
if not self.stream: return None
|
||||
return self.stream.read(self.CHUNK)
|
||||
|
||||
def closeStream(self):
|
||||
"""
|
||||
关闭系统音频输出流
|
||||
|
||||
Reference in New Issue
Block a user