Files
video-subtitle-remover/backend/tools/train/train_sttn.py
2024-01-09 11:05:07 +08:00

97 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import argparse
from shutil import copyfile
import torch
import torch.multiprocessing as mp
from backend.tools.train.trainer_sttn import Trainer
from backend.tools.train.utils_sttn import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
args = parser.parse_args()
def main_worker(rank, config):
# 如果配置中没有提到局部排序local_rank就给它和全局排序global_rank赋值为传入的排序rank
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
# 如果配置指定为分布式训练
if config['distributed']:
# 设置使用的CUDA设备为当前的本地排名对应的GPU
torch.cuda.set_device(int(config['local_rank']))
# 初始化分布式进程组通过nccl后端
torch.distributed.init_process_group(
backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
# 打印当前GPU的使用情况输出全球排名和本地排名
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank']))
)
# 创建模型保存的目录路径,包括模型名和配置文件名
config['save_dir'] = os.path.join(
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
)
# 如果CUDA可用则设置设备为相应的CUDA设备否则为CPU
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
# 如果不是分布式训练或者是分布式训练的主节点rank 0
if (not config['distributed']) or config['global_rank'] == 0:
# 创建模型保存目录并允许如果该目录存在则忽略创建exist_ok=True
os.makedirs(config['save_dir'], exist_ok=True)
# 设置配置文件的保存路径
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1]
)
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
# 打印创建目录的信息
print('[**] create folder {}'.format(config['save_dir']))
# 初始化训练器传入配置参数和debug标记
trainer = Trainer(config, debug=args.exam)
# 开始训练
trainer.train()
if __name__ == "__main__":
# 加载配置文件
config = json.load(open(args.config))
config['model'] = args.model # 设置模型名称
config['config'] = args.config # 设置配置文件路径
# 设置分布式训练的相关配置
config['world_size'] = get_world_size() # 获取全局进程数即训练过程中参与计算的总GPU数量
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法包括主节点IP和端口
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
# 设置分布式并行训练环境
if get_master_ip() == "127.0.0.1":
# 如果主节点IP是本机地址那么手动启动多个分布式训练进程
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# 如果是由其他工具如OpenMPI启动的多个进程不需手动创建进程。
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
config['global_rank'] = get_global_rank() # 获取全局排名
main_worker(-1, config) # 启动主工作函数