mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-27 14:14:44 +08:00
73 lines
2.5 KiB
Python
Executable File
73 lines
2.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '1'
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
|
os.environ['MKL_NUM_THREADS'] = '1'
|
|
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
|
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
|
|
|
import hydra
|
|
from omegaconf import OmegaConf
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
from pytorch_lightning.plugins import DDPPlugin
|
|
|
|
from saicinpainting.training.trainers import make_training_model
|
|
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
|
|
handle_deterministic_config
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
@handle_ddp_subprocess()
|
|
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
|
|
def main(config: OmegaConf):
|
|
try:
|
|
need_set_deterministic = handle_deterministic_config(config)
|
|
|
|
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
|
|
|
is_in_ddp_subprocess = handle_ddp_parent_process()
|
|
|
|
config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
|
|
if not is_in_ddp_subprocess:
|
|
LOGGER.info(OmegaConf.to_yaml(config))
|
|
OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
|
|
|
|
checkpoints_dir = os.path.join(os.getcwd(), 'models')
|
|
os.makedirs(checkpoints_dir, exist_ok=True)
|
|
|
|
# there is no need to suppress this logger in ddp, because it handles rank on its own
|
|
metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
|
|
metrics_logger.log_hyperparams(config)
|
|
|
|
training_model = make_training_model(config)
|
|
|
|
trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
|
|
if need_set_deterministic:
|
|
trainer_kwargs['deterministic'] = True
|
|
|
|
trainer = Trainer(
|
|
# there is no need to suppress checkpointing in ddp, because it handles rank on its own
|
|
callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
|
|
logger=metrics_logger,
|
|
default_root_dir=os.getcwd(),
|
|
**trainer_kwargs
|
|
)
|
|
trainer.fit(training_model)
|
|
except KeyboardInterrupt:
|
|
LOGGER.warning('Interrupted by user')
|
|
except Exception as ex:
|
|
LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|