添加注释

This commit is contained in:
YaoFANGUK
2024-01-09 11:05:07 +08:00
parent 6b353455a0
commit a3dd7b797d
11 changed files with 271 additions and 154 deletions

View File

@@ -14,7 +14,7 @@ from backend.tools.train.utils_sttn import (
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
@@ -22,56 +22,75 @@ args = parser.parse_args()
def main_worker(rank, config):
# 如果配置中没有提到局部排序local_rank就给它和全局排序global_rank赋值为传入的排序rank
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank'])))
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
os.path.basename(args.config).split('.')[0]))
# 如果配置指定为分布式训练
if config['distributed']:
# 设置使用的CUDA设备为当前的本地排名对应的GPU
torch.cuda.set_device(int(config['local_rank']))
# 初始化分布式进程组通过nccl后端
torch.distributed.init_process_group(
backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
# 打印当前GPU的使用情况输出全球排名和本地排名
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank']))
)
# 创建模型保存的目录路径,包括模型名和配置文件名
config['save_dir'] = os.path.join(
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
)
# 如果CUDA可用则设置设备为相应的CUDA设备否则为CPU
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
# 如果不是分布式训练或者是分布式训练的主节点rank 0
if (not config['distributed']) or config['global_rank'] == 0:
# 创建模型保存目录并允许如果该目录存在则忽略创建exist_ok=True
os.makedirs(config['save_dir'], exist_ok=True)
# 设置配置文件的保存路径
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1])
config['save_dir'], config['config'].split('/')[-1]
)
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
# 打印创建目录的信息
print('[**] create folder {}'.format(config['save_dir']))
# 初始化训练器传入配置参数和debug标记
trainer = Trainer(config, debug=args.exam)
# 开始训练
trainer.train()
if __name__ == "__main__":
# loading configs
# 加载配置文件
config = json.load(open(args.config))
config['model'] = args.model
config['config'] = args.config
config['model'] = args.model # 设置模型名称
config['config'] = args.config # 设置配置文件路径
# setting distributed configurations
config['world_size'] = get_world_size()
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
config['distributed'] = True if config['world_size'] > 1 else False
# 设置分布式训练的相关配置
config['world_size'] = get_world_size() # 获取全局进程数即训练过程中参与计算的总GPU数量
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法包括主节点IP和端口
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
# setup distributed parallel training environments
# 设置分布式并行训练环境
if get_master_ip() == "127.0.0.1":
# manually launch distributed processes
# 如果主节点IP是本机地址那么手动启动多个分布式训练进程
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# multiple processes have been launched by openmpi
config['local_rank'] = get_local_rank()
config['global_rank'] = get_global_rank()
main_worker(-1, config)
# 如果是由其他工具如OpenMPI启动的多个进程不需手动创建进程。
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
config['global_rank'] = get_global_rank() # 获取全局排名
main_worker(-1, config) # 启动主工作函数