mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 22:27:33 +08:00
vsr v1.0.0
This commit is contained in:
78
backend/ppocr/utils/loggers/wandb_logger.py
Normal file
78
backend/ppocr/utils/loggers/wandb_logger.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
from .base_logger import BaseLogger
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
def __init__(self,
|
||||
project=None,
|
||||
name=None,
|
||||
id=None,
|
||||
entity=None,
|
||||
save_dir=None,
|
||||
config=None,
|
||||
**kwargs):
|
||||
try:
|
||||
import wandb
|
||||
self.wandb = wandb
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install wandb using `pip install wandb`"
|
||||
)
|
||||
|
||||
self.project = project
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.save_dir = save_dir
|
||||
self.config = config
|
||||
self.kwargs = kwargs
|
||||
self.entity = entity
|
||||
self._run = None
|
||||
self._wandb_init = dict(
|
||||
project=self.project,
|
||||
name=self.name,
|
||||
id=self.id,
|
||||
entity=self.entity,
|
||||
dir=self.save_dir,
|
||||
resume="allow"
|
||||
)
|
||||
self._wandb_init.update(**kwargs)
|
||||
|
||||
_ = self.run
|
||||
|
||||
if self.config:
|
||||
self.run.settings_config.update(self.config)
|
||||
|
||||
@property
|
||||
def run(self):
|
||||
if self._run is None:
|
||||
if self.wandb.run is not None:
|
||||
logger.info(
|
||||
"There is a wandb run already in progress "
|
||||
"and newly created instances of `WandbLogger` will reuse"
|
||||
" this run. If this is not desired, call `wandb.finish()`"
|
||||
"before instantiating `WandbLogger`."
|
||||
)
|
||||
self._run = self.wandb.run
|
||||
else:
|
||||
self._run = self.wandb.init(**self._wandb_init)
|
||||
return self._run
|
||||
|
||||
def log_metrics(self, metrics, prefix=None, step=None):
|
||||
if not prefix:
|
||||
prefix = ""
|
||||
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
|
||||
|
||||
self.run.log(updated_metrics, step=step)
|
||||
|
||||
def log_model(self, is_best, prefix, metadata=None):
|
||||
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
|
||||
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
|
||||
artifact.add_file(model_path, name="model_ckpt.pdparams")
|
||||
|
||||
aliases = [prefix]
|
||||
if is_best:
|
||||
aliases.append("best")
|
||||
|
||||
self.run.log_artifact(artifact, aliases=aliases)
|
||||
|
||||
def close(self):
|
||||
self.run.finish()
|
||||
Reference in New Issue
Block a user