mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-11 06:07:33 +08:00
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
try:
|
|
from urllib import urlretrieve
|
|
except ImportError:
|
|
from urllib.request import urlretrieve
|
|
|
|
|
|
def load_url(url, model_dir='./pretrained', map_location=None):
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir)
|
|
filename = url.split('/')[-1]
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if not os.path.exists(cached_file):
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
urlretrieve(url, cached_file)
|
|
return torch.load(cached_file, map_location=map_location)
|
|
|
|
|
|
def color_encode(labelmap, colors, mode='RGB'):
|
|
labelmap = labelmap.astype('int')
|
|
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
|
|
dtype=np.uint8)
|
|
for label in np.unique(labelmap):
|
|
if label < 0:
|
|
continue
|
|
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
|
|
np.tile(colors[label],
|
|
(labelmap.shape[0], labelmap.shape[1], 1))
|
|
|
|
if mode == 'BGR':
|
|
return labelmap_rgb[:, :, ::-1]
|
|
else:
|
|
return labelmap_rgb
|