mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-15 20:34:45 +08:00
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
import os
|
|
import json
|
|
import argparse
|
|
from shutil import copyfile
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
from backend.tools.train.trainer_sttn import Trainer
|
|
from backend.tools.train.utils_sttn import (
|
|
get_world_size,
|
|
get_local_rank,
|
|
get_global_rank,
|
|
get_master_ip,
|
|
)
|
|
|
|
parser = argparse.ArgumentParser(description='STTN')
|
|
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
|
|
parser.add_argument('-m', '--model', default='sttn', type=str)
|
|
parser.add_argument('-p', '--port', default='23455', type=str)
|
|
parser.add_argument('-e', '--exam', action='store_true')
|
|
args = parser.parse_args()
|
|
|
|
|
|
def main_worker(rank, config):
|
|
if 'local_rank' not in config:
|
|
config['local_rank'] = config['global_rank'] = rank
|
|
if config['distributed']:
|
|
torch.cuda.set_device(int(config['local_rank']))
|
|
torch.distributed.init_process_group(backend='nccl',
|
|
init_method=config['init_method'],
|
|
world_size=config['world_size'],
|
|
rank=config['global_rank'],
|
|
group_name='mtorch'
|
|
)
|
|
print('using GPU {}-{} for training'.format(
|
|
int(config['global_rank']), int(config['local_rank'])))
|
|
|
|
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
|
|
os.path.basename(args.config).split('.')[0]))
|
|
if torch.cuda.is_available():
|
|
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
|
else:
|
|
config['device'] = 'cpu'
|
|
|
|
if (not config['distributed']) or config['global_rank'] == 0:
|
|
os.makedirs(config['save_dir'], exist_ok=True)
|
|
config_path = os.path.join(
|
|
config['save_dir'], config['config'].split('/')[-1])
|
|
if not os.path.isfile(config_path):
|
|
copyfile(config['config'], config_path)
|
|
print('[**] create folder {}'.format(config['save_dir']))
|
|
|
|
trainer = Trainer(config, debug=args.exam)
|
|
trainer.train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# loading configs
|
|
config = json.load(open(args.config))
|
|
config['model'] = args.model
|
|
config['config'] = args.config
|
|
|
|
# setting distributed configurations
|
|
config['world_size'] = get_world_size()
|
|
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
|
|
config['distributed'] = True if config['world_size'] > 1 else False
|
|
|
|
# setup distributed parallel training environments
|
|
if get_master_ip() == "127.0.0.1":
|
|
# manually launch distributed processes
|
|
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
|
else:
|
|
# multiple processes have been launched by openmpi
|
|
config['local_rank'] = get_local_rank()
|
|
config['global_rank'] = get_global_rank()
|
|
main_worker(-1, config)
|