mirror of
https://github.com/HiMeditator/auto-caption.git
synced 2026-02-04 04:14:42 +08:00
feat(engine): 替换重采样模型、SOSV 添加标点恢复模型
- 将 samplerate 库替换为 resampy 库,提高重采样质量 - Shepra-ONNX SenseVoice 添加中文和英语标点恢复模型
This commit is contained in:
@@ -9,6 +9,7 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/simulate-s
|
||||
import time
|
||||
from datetime import datetime
|
||||
import sherpa_onnx
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from utils import shared_data
|
||||
@@ -23,23 +24,27 @@ class SosvRecognizer:
|
||||
初始化参数:
|
||||
model_path: Shepra ONNX Sense Voice 识别模型路径
|
||||
vad_model: Silero VAD 模型路径
|
||||
source: 识别源语言(auto, zh, en, ja, ko, yue)
|
||||
target: 翻译目标语言
|
||||
trans_model: 翻译模型名称
|
||||
ollama_name: Ollama 模型名称
|
||||
"""
|
||||
def __init__(self, model_path: str, target: str | None, trans_model: str, ollama_name: str):
|
||||
def __init__(self, model_path: str, source: str, target: str | None, trans_model: str, ollama_name: str):
|
||||
if model_path.startswith('"'):
|
||||
model_path = model_path[1:]
|
||||
if model_path.endswith('"'):
|
||||
model_path = model_path[:-1]
|
||||
self.model_path = model_path
|
||||
self.ext = ""
|
||||
if self.model_path[-4:] == "int8":
|
||||
self.ext = ".int8"
|
||||
self.source = source
|
||||
self.target = target
|
||||
if trans_model == 'google':
|
||||
self.trans_func = google_translate
|
||||
else:
|
||||
self.trans_func = ollama_translate
|
||||
self.ollama_name = ollama_name
|
||||
|
||||
self.time_str = ''
|
||||
self.cur_id = 0
|
||||
self.prev_content = ''
|
||||
@@ -47,19 +52,39 @@ class SosvRecognizer:
|
||||
def start(self):
|
||||
"""启动 Sense Voice 模型"""
|
||||
self.recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
||||
model=f"{self.model_path}/model.onnx",
|
||||
tokens=f"{self.model_path}/tokens.txt",
|
||||
model=f"{self.model_path}/sensevoice/model{self.ext}.onnx",
|
||||
tokens=f"{self.model_path}/sensevoice/tokens.txt",
|
||||
language=self.source,
|
||||
num_threads = 2,
|
||||
)
|
||||
config = sherpa_onnx.VadModelConfig()
|
||||
config.silero_vad.model = f"{self.model_path}/silero_vad.onnx"
|
||||
config.silero_vad.threshold = 0.5
|
||||
config.silero_vad.min_silence_duration = 0.1
|
||||
config.silero_vad.min_speech_duration = 0.25
|
||||
config.silero_vad.max_speech_duration = 8
|
||||
config.sample_rate = 16000
|
||||
self.window_size = config.silero_vad.window_size
|
||||
self.vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100)
|
||||
|
||||
vad_config = sherpa_onnx.VadModelConfig()
|
||||
vad_config.silero_vad.model = f"{self.model_path}/silero_vad.onnx"
|
||||
vad_config.silero_vad.threshold = 0.5
|
||||
vad_config.silero_vad.min_silence_duration = 0.1
|
||||
vad_config.silero_vad.min_speech_duration = 0.25
|
||||
vad_config.silero_vad.max_speech_duration = 8
|
||||
vad_config.sample_rate = 16000
|
||||
self.window_size = vad_config.silero_vad.window_size
|
||||
self.vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100)
|
||||
|
||||
if self.source == 'en':
|
||||
model_config = sherpa_onnx.OnlinePunctuationModelConfig(
|
||||
cnn_bilstm=f"{self.model_path}/punct-en/model{self.ext}.onnx",
|
||||
bpe_vocab=f"{self.model_path}/punct-en/bpe.vocab"
|
||||
)
|
||||
punct_config = sherpa_onnx.OnlinePunctuationConfig(
|
||||
model_config=model_config,
|
||||
)
|
||||
self.punct = sherpa_onnx.OnlinePunctuation(punct_config)
|
||||
else:
|
||||
punct_config = sherpa_onnx.OfflinePunctuationConfig(
|
||||
model=sherpa_onnx.OfflinePunctuationModelConfig(
|
||||
ct_transformer=f"{self.model_path}/punct/model{self.ext}.onnx"
|
||||
),
|
||||
)
|
||||
self.punct = sherpa_onnx.OfflinePunctuation(punct_config)
|
||||
|
||||
self.buffer = []
|
||||
self.offset = 0
|
||||
self.started = False
|
||||
@@ -112,15 +137,27 @@ class SosvRecognizer:
|
||||
self.vad.pop()
|
||||
self.recognizer.decode_stream(stream)
|
||||
text = stream.result.text.strip()
|
||||
|
||||
|
||||
if self.source == 'en':
|
||||
text_with_punct = self.punct.add_punctuation_with_case(text)
|
||||
else:
|
||||
text_with_punct = self.punct.add_punctuation(text)
|
||||
|
||||
caption['index'] = self.cur_id
|
||||
caption['text'] = text
|
||||
caption['text'] = text_with_punct
|
||||
caption['time_s'] = self.time_str
|
||||
caption['time_t'] = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
if text:
|
||||
stdout_obj(caption)
|
||||
if self.target:
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str),
|
||||
daemon=True
|
||||
)
|
||||
th.start()
|
||||
self.cur_id += 1
|
||||
self.prev_content = ''
|
||||
stdout_obj(caption)
|
||||
|
||||
self.cur_id += 1
|
||||
self.time_str = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||||
self.buffer = []
|
||||
self.offset = 0
|
||||
|
||||
@@ -62,11 +62,12 @@ class VoskRecognizer:
|
||||
self.prev_content = ''
|
||||
if content == '': return
|
||||
self.cur_id += 1
|
||||
|
||||
if self.target:
|
||||
self.trans_time = time.time()
|
||||
th = threading.Thread(
|
||||
target=self.trans_func,
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str)
|
||||
args=(self.ollama_name, self.target, caption['text'], self.time_str),
|
||||
daemon=True
|
||||
)
|
||||
th.start()
|
||||
else:
|
||||
|
||||
@@ -10,23 +10,26 @@ from audio2text import SosvRecognizer
|
||||
from sysaudio import AudioStream
|
||||
|
||||
|
||||
def audio_recording(stream: AudioStream, resample: bool, save = False):
|
||||
def audio_recording(stream: AudioStream, resample: bool, save = False, path = ''):
|
||||
global shared_data
|
||||
stream.open_stream()
|
||||
wf = None
|
||||
if save:
|
||||
wf = wave.open(f'record.wav', 'wb')
|
||||
wf.setnchannels(1)
|
||||
if path != '':
|
||||
path += '/'
|
||||
wf = wave.open(f'{path}record.wav', 'wb')
|
||||
wf.setnchannels(stream.CHANNELS)
|
||||
wf.setsampwidth(stream.SAMP_WIDTH)
|
||||
wf.setframerate(16000)
|
||||
wf.setframerate(stream.CHUNK_RATE)
|
||||
while shared_data.status == 'running':
|
||||
raw_chunk = stream.read_chunk()
|
||||
if save: wf.writeframes(raw_chunk) # type: ignore
|
||||
if raw_chunk is None: continue
|
||||
if resample:
|
||||
chunk = resample_chunk_mono(raw_chunk, stream.CHANNELS, stream.RATE, 16000)
|
||||
else:
|
||||
chunk = merge_chunk_channels(raw_chunk, stream.CHANNELS)
|
||||
shared_data.chunk_queue.put(chunk)
|
||||
if save: wf.writeframes(chunk) # type: ignore
|
||||
if save: wf.close() # type: ignore
|
||||
stream.close_stream_signal()
|
||||
|
||||
@@ -88,21 +91,22 @@ def main_vosk(a: int, c: int, vosk: str, t: str, tm: str, omn: str):
|
||||
engine.stop()
|
||||
|
||||
|
||||
def main_sosv(a: int, c: int, sosv: str, t: str, tm: str, omn: str):
|
||||
def main_sosv(a: int, c: int, sosv: str, s: str, t: str, tm: str, omn: str):
|
||||
"""
|
||||
Parameters:
|
||||
a: Audio source: 0 for output, 1 for input
|
||||
c: Chunk number in 1 second
|
||||
sosv: Sherpa-ONNX SenseVoice model path
|
||||
s: Source language
|
||||
t: Target language
|
||||
tm: Translation model type, ollama or google
|
||||
omn: Ollama model name
|
||||
"""
|
||||
stream = AudioStream(a, c)
|
||||
if t == 'none':
|
||||
engine = SosvRecognizer(sosv, None, tm, omn)
|
||||
engine = SosvRecognizer(sosv, s, None, tm, omn)
|
||||
else:
|
||||
engine = SosvRecognizer(sosv, t, tm, omn)
|
||||
engine = SosvRecognizer(sosv, s, t, tm, omn)
|
||||
|
||||
engine.start()
|
||||
stream_thread = threading.Thread(
|
||||
@@ -120,14 +124,15 @@ def main_sosv(a: int, c: int, sosv: str, t: str, tm: str, omn: str):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert system audio stream to text')
|
||||
# both
|
||||
# all
|
||||
parser.add_argument('-e', '--caption_engine', default='gummy', help='Caption engine: gummy or vosk')
|
||||
parser.add_argument('-a', '--audio_type', default=0, help='Audio stream source: 0 for output, 1 for input')
|
||||
parser.add_argument('-c', '--chunk_rate', default=10, help='Number of audio stream chunks collected per second')
|
||||
parser.add_argument('-p', '--port', default=0, help='The port to run the server on, 0 for no server')
|
||||
parser.add_argument('-t', '--target_language', default='zh', help='Target language code, "none" for no translation')
|
||||
# gummy and sosv
|
||||
parser.add_argument('-s', '--source_language', default='auto', help='Source language code')
|
||||
# gummy only
|
||||
parser.add_argument('-s', '--source_language', default='en', help='Source language code')
|
||||
parser.add_argument('-k', '--api_key', default='', help='API KEY for Gummy model')
|
||||
# vosk and sosv
|
||||
parser.add_argument('-tm', '--translation_model', default='ollama', help='Model for translation: ollama or google')
|
||||
@@ -165,6 +170,7 @@ if __name__ == "__main__":
|
||||
int(args.audio_type),
|
||||
int(args.chunk_rate),
|
||||
args.sosv_model,
|
||||
args.source_language,
|
||||
args.target_language,
|
||||
args.translation_model,
|
||||
args.ollama_name
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
dashscope
|
||||
numpy
|
||||
samplerate
|
||||
resampy
|
||||
vosk
|
||||
pyinstaller
|
||||
pyaudio; sys_platform == 'darwin'
|
||||
pyaudiowpatch; sys_platform == 'win32'
|
||||
googletrans
|
||||
ollama
|
||||
sherpa_onnx
|
||||
@@ -1,9 +1,4 @@
|
||||
from .audioprcs import (
|
||||
merge_chunk_channels,
|
||||
resample_chunk_mono,
|
||||
resample_chunk_mono_np,
|
||||
resample_mono_chunk
|
||||
)
|
||||
from .audioprcs import merge_chunk_channels, resample_chunk_mono
|
||||
from .sysout import stdout, stdout_err, stdout_cmd, stdout_obj, stderr
|
||||
from .shared import shared_data
|
||||
from .server import start_server
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import samplerate
|
||||
import resampy
|
||||
import numpy as np
|
||||
import numpy.core.multiarray # do not remove
|
||||
|
||||
@@ -24,16 +24,15 @@ def merge_chunk_channels(chunk: bytes, channels: int) -> bytes:
|
||||
return chunk_mono.tobytes()
|
||||
|
||||
|
||||
def resample_chunk_mono(chunk: bytes, channels: int, orig_sr: int, target_sr: int, mode="sinc_best") -> bytes:
|
||||
def resample_chunk_mono(chunk: bytes, channels: int, orig_sr: int, target_sr: int) -> bytes:
|
||||
"""
|
||||
将当前多通道音频数据块转换成单通道音频数据块,然后进行重采样
|
||||
将当前多通道音频数据块转换成单通道音频数据块,并进行重采样
|
||||
|
||||
Args:
|
||||
chunk: 多通道音频数据块
|
||||
channels: 通道数
|
||||
orig_sr: 原始采样率
|
||||
target_sr: 目标采样率
|
||||
mode: 重采样模式,可选:'sinc_best' | 'sinc_medium' | 'sinc_fastest' | 'zero_order_hold' | 'linear'
|
||||
|
||||
Return:
|
||||
单通道音频数据块
|
||||
@@ -52,82 +51,14 @@ def resample_chunk_mono(chunk: bytes, channels: int, orig_sr: int, target_sr: in
|
||||
if orig_sr == target_sr:
|
||||
return chunk_mono.astype(np.int16).tobytes()
|
||||
|
||||
ratio = target_sr / orig_sr
|
||||
chunk_mono_r = samplerate.resample(chunk_mono, ratio, converter_type=mode)
|
||||
chunk_mono_r = resampy.resample(chunk_mono, orig_sr, target_sr)
|
||||
chunk_mono_r = np.round(chunk_mono_r).astype(np.int16)
|
||||
real_len = round(chunk_mono.shape[0] * ratio)
|
||||
real_len = round(chunk_mono.shape[0] * target_sr / orig_sr)
|
||||
if(chunk_mono_r.shape[0] != real_len):
|
||||
print(chunk_mono_r.shape[0], real_len)
|
||||
if(chunk_mono_r.shape[0] > real_len):
|
||||
chunk_mono_r = chunk_mono_r[:real_len]
|
||||
else:
|
||||
while chunk_mono_r.shape[0] < real_len:
|
||||
chunk_mono_r = np.append(chunk_mono_r, chunk_mono_r[-1])
|
||||
return chunk_mono_r.tobytes()
|
||||
|
||||
|
||||
def resample_chunk_mono_np(chunk: bytes, channels: int, orig_sr: int, target_sr: int, mode="sinc_best", dtype=np.float32) -> np.ndarray:
|
||||
"""
|
||||
将当前多通道音频数据块转换成单通道音频数据块,然后进行重采样,返回 Numpy 数组
|
||||
|
||||
Args:
|
||||
chunk: 多通道音频数据块
|
||||
channels: 通道数
|
||||
orig_sr: 原始采样率
|
||||
target_sr: 目标采样率
|
||||
mode: 重采样模式,可选:'sinc_best' | 'sinc_medium' | 'sinc_fastest' | 'zero_order_hold' | 'linear'
|
||||
dtype: 返回 Numpy 数组的数据类型
|
||||
|
||||
Return:
|
||||
单通道音频数据块
|
||||
"""
|
||||
if channels == 1:
|
||||
chunk_mono = np.frombuffer(chunk, dtype=np.int16)
|
||||
chunk_mono = chunk_mono.astype(np.float32)
|
||||
else:
|
||||
# (length * channels,)
|
||||
chunk_np = np.frombuffer(chunk, dtype=np.int16)
|
||||
# (length, channels)
|
||||
chunk_np = chunk_np.reshape(-1, channels)
|
||||
# (length,)
|
||||
chunk_mono = np.mean(chunk_np.astype(np.float32), axis=1)
|
||||
|
||||
if orig_sr == target_sr:
|
||||
return chunk_mono.astype(dtype)
|
||||
|
||||
ratio = target_sr / orig_sr
|
||||
chunk_mono_r = samplerate.resample(chunk_mono, ratio, converter_type=mode)
|
||||
chunk_mono_r = chunk_mono_r.astype(dtype)
|
||||
real_len = round(chunk_mono.shape[0] * ratio)
|
||||
if(chunk_mono_r.shape[0] > real_len):
|
||||
chunk_mono_r = chunk_mono_r[:real_len]
|
||||
else:
|
||||
while chunk_mono_r.shape[0] < real_len:
|
||||
chunk_mono_r = np.append(chunk_mono_r, chunk_mono_r[-1])
|
||||
return chunk_mono_r
|
||||
|
||||
|
||||
def resample_mono_chunk(chunk: bytes, orig_sr: int, target_sr: int, mode="sinc_best") -> bytes:
|
||||
"""
|
||||
将当前单通道音频块进行重采样
|
||||
|
||||
Args:
|
||||
chunk: 单通道音频数据块
|
||||
orig_sr: 原始采样率
|
||||
target_sr: 目标采样率
|
||||
mode: 重采样模式,可选:'sinc_best' | 'sinc_medium' | 'sinc_fastest' | 'zero_order_hold' | 'linear'
|
||||
|
||||
Return:
|
||||
单通道音频数据块
|
||||
"""
|
||||
if orig_sr == target_sr: return chunk
|
||||
chunk_np = np.frombuffer(chunk, dtype=np.int16)
|
||||
chunk_np = chunk_np.astype(np.float32)
|
||||
ratio = target_sr / orig_sr
|
||||
chunk_r = samplerate.resample(chunk_np, ratio, converter_type=mode)
|
||||
chunk_r = np.round(chunk_r).astype(np.int16)
|
||||
real_len = round(chunk_np.shape[0] * ratio)
|
||||
if(chunk_r.shape[0] > real_len):
|
||||
chunk_r = chunk_r[:real_len]
|
||||
else:
|
||||
while chunk_r.shape[0] < real_len:
|
||||
chunk_r = np.append(chunk_r, chunk_r[-1])
|
||||
return chunk_r.tobytes()
|
||||
|
||||
@@ -13,6 +13,7 @@ lang_map = {
|
||||
'ru': 'Russian',
|
||||
'ja': 'Japanese',
|
||||
'ko': 'Korean',
|
||||
'zh': 'Chinese',
|
||||
'zh-cn': 'Chinese'
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user