mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-21 05:24:51 +08:00
添加sttn训练代码
This commit is contained in:
77
backend/tools/train/train_sttn.py
Normal file
77
backend/tools/train/train_sttn.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from shutil import copyfile
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from backend.tools.train.trainer_sttn import Trainer
|
||||
from backend.tools.train.utils_sttn import (
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
get_global_rank,
|
||||
get_master_ip,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='STTN')
|
||||
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
|
||||
parser.add_argument('-m', '--model', default='sttn', type=str)
|
||||
parser.add_argument('-p', '--port', default='23455', type=str)
|
||||
parser.add_argument('-e', '--exam', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main_worker(rank, config):
|
||||
if 'local_rank' not in config:
|
||||
config['local_rank'] = config['global_rank'] = rank
|
||||
if config['distributed']:
|
||||
torch.cuda.set_device(int(config['local_rank']))
|
||||
torch.distributed.init_process_group(backend='nccl',
|
||||
init_method=config['init_method'],
|
||||
world_size=config['world_size'],
|
||||
rank=config['global_rank'],
|
||||
group_name='mtorch'
|
||||
)
|
||||
print('using GPU {}-{} for training'.format(
|
||||
int(config['global_rank']), int(config['local_rank'])))
|
||||
|
||||
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
|
||||
os.path.basename(args.config).split('.')[0]))
|
||||
if torch.cuda.is_available():
|
||||
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
||||
else:
|
||||
config['device'] = 'cpu'
|
||||
|
||||
if (not config['distributed']) or config['global_rank'] == 0:
|
||||
os.makedirs(config['save_dir'], exist_ok=True)
|
||||
config_path = os.path.join(
|
||||
config['save_dir'], config['config'].split('/')[-1])
|
||||
if not os.path.isfile(config_path):
|
||||
copyfile(config['config'], config_path)
|
||||
print('[**] create folder {}'.format(config['save_dir']))
|
||||
|
||||
trainer = Trainer(config, debug=args.exam)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# loading configs
|
||||
config = json.load(open(args.config))
|
||||
config['model'] = args.model
|
||||
config['config'] = args.config
|
||||
|
||||
# setting distributed configurations
|
||||
config['world_size'] = get_world_size()
|
||||
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
|
||||
config['distributed'] = True if config['world_size'] > 1 else False
|
||||
|
||||
# setup distributed parallel training environments
|
||||
if get_master_ip() == "127.0.0.1":
|
||||
# manually launch distributed processes
|
||||
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
||||
else:
|
||||
# multiple processes have been launched by openmpi
|
||||
config['local_rank'] = get_local_rank()
|
||||
config['global_rank'] = get_global_rank()
|
||||
main_worker(-1, config)
|
||||
Reference in New Issue
Block a user