563 lines
21 KiB
Python
563 lines
21 KiB
Python
""" This file contains tf.train.SessionRunHooks defined by TWML """
|
|
from datetime import datetime
|
|
import json
|
|
import operator
|
|
import os
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
from tensorflow.python.training.basic_session_run_hooks import NeverTriggerTimer, SecondOrStepTimer
|
|
import twml
|
|
|
|
|
|
class StepProgressHook(tf.train.SessionRunHook):
|
|
"""Hook that displays a progress bar to monitor global step progress """
|
|
|
|
def __init__(self, max_step):
|
|
"""
|
|
Initializes a `StepProgressHook`.
|
|
This hook displays a progress bar for max_steps.
|
|
|
|
Note that this hook only works for training and calibration.
|
|
|
|
Args:
|
|
max_steps:
|
|
maximum steps to monitor in progress bar.
|
|
When this many steps is reached, the progress bar will be full.
|
|
"""
|
|
self._max_step = max_step
|
|
self._start_step = 0
|
|
self._global_step_tensor = None
|
|
self._progress_bar = None
|
|
|
|
def begin(self):
|
|
""" sets the global_step_tensor """
|
|
self._global_step_tensor = tf.train.get_or_create_global_step()
|
|
if self._global_step_tensor is None:
|
|
raise RuntimeError("Global step should be created to use StepProgressHook.")
|
|
|
|
def after_create_session(self, session, coord):
|
|
""" creates the progress bar and keeps track of the first global step upon session creation """
|
|
global_step = session.run(self._global_step_tensor)
|
|
self._start_step = global_step
|
|
self._progress_bar = tf.keras.utils.Progbar(self._max_step)
|
|
|
|
def before_run(self, run_context): # pylint: disable=unused-argument
|
|
""" invoked before calling session.run """
|
|
return tf.train.SessionRunArgs(self._global_step_tensor)
|
|
|
|
def after_run(self, run_context, run_values):
|
|
""" invoked after run is called. Updates the progress bar. """
|
|
step = run_context.session.run(self._global_step_tensor)
|
|
self._progress_bar.update(step - self._start_step)
|
|
|
|
|
|
class GetMetricsHook(tf.train.SessionRunHook):
|
|
"""
|
|
Hook used to obtain evaluation metrics.
|
|
Typically used for early-stopping by obtaining the value of a
|
|
metric at the end of an epoch.
|
|
Note that the metric tensor and its commensurate update Op
|
|
are responsible for aggregating the metric during the session
|
|
(one session per epoch). Used for evaluation.
|
|
"""
|
|
|
|
def __init__(self, get_metrics_fn):
|
|
"""GetMetricsHook constructor.
|
|
|
|
Args:
|
|
get_metrics_fn:
|
|
Function that returns a dict mapping metric keys to
|
|
tensors as a tf.Tensor.
|
|
See Trainer.learn for an example use-case.
|
|
"""
|
|
|
|
self._get_metrics_fn = get_metrics_fn
|
|
self._metric_tensors = None
|
|
self.metric_values = None
|
|
|
|
def begin(self):
|
|
""" sets the global_step_tensor and metric tensor"""
|
|
self._metric_tensors = self._get_metrics_fn()
|
|
assert isinstance(self._metric_tensors, dict)
|
|
|
|
def end(self, session):
|
|
self.metric_values = session.run(self._metric_tensors)
|
|
|
|
|
|
class EarlyStopHook(GetMetricsHook):
|
|
"""
|
|
A GetMetricsHook augmented with early-stopping logic for use
|
|
within the Trainer.learn method.
|
|
"""
|
|
|
|
def __init__(self,
|
|
metric,
|
|
patience,
|
|
minimize,
|
|
get_estimator_spec_fn,
|
|
checkpoint_dir,
|
|
file_path=None,
|
|
exit_on_end=True,
|
|
start_epoch=0,
|
|
tolerance=0):
|
|
"""
|
|
Prepare early-stopping hook and variables.
|
|
|
|
Args:
|
|
metric:
|
|
String specifying the metric to early-stop on. Required with positive
|
|
``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc.
|
|
The string is used to extract the relevant tensor Op from the dict returned by
|
|
the get_eval_metric_ops method. For ``metrics`` pass to the constructor,
|
|
the string is one of those. For multi-class (that is, multi-metric)
|
|
metrics, the string may be appended with a ``_0``, ``_1``, etc. or one
|
|
of the ``multi_metric_names`` (one per class).
|
|
patience:
|
|
Maximum number of epochs to wait for an improvement in the early_stop_metric
|
|
before breaking off training. For example, a patience of 10 means that
|
|
training will have 10 epochs to improve the metric before it is killed.
|
|
Whenever the metric is improved before running out of patience,
|
|
patience is reset to ``early_stop_patience``.
|
|
minimize:
|
|
Set this to True for metrics that need to be minimized
|
|
(like ``loss``). Metrics like ``accuracy`` that need to be maximized
|
|
should set this to False.
|
|
tolerance:
|
|
A non-negative tolerance for comparing early_stop_metric.
|
|
e.g. when maximizing the condition is current_metric > best_metric + tolerance."
|
|
Defaults to 0.
|
|
get_estimator_spec_fn:
|
|
function that returns the current EstimatorSpec.
|
|
The EstimatorSpec is used to obtain the current eval_metric_ops.
|
|
checkpoint_dir:
|
|
path to directory containing the Estimator checkpoints.
|
|
file_path:
|
|
path to file that is used by this hook to communicate early-stopping
|
|
to StopIfExistsHook. This hook would be used for evaluation, while
|
|
the StopIfExistsHooks (the listeners) would be used for training.
|
|
When the file is created, the StopIfExistsHooks detect and terminate training.
|
|
This argument is used by ``Trainer.train_and_evaluate``.
|
|
exit_on_end:
|
|
when the end() method is called to indicate that the session is terminating,
|
|
and exit_on_end is True, twml.errors.EarlyStopError() is triggered to stop the evaluation job.
|
|
This is set to False by the trainer for non distributed jobs.
|
|
start_epoch:
|
|
Specifies the starting epoch number. This is used for logging purposes only.
|
|
"""
|
|
if not isinstance(metric, str):
|
|
raise ValueError("Expecting string for metric arg")
|
|
if not isinstance(patience, int):
|
|
raise ValueError("Expecting positive number for metric arg")
|
|
|
|
self.should_stop = False
|
|
self._metric = metric
|
|
self._patience = patience
|
|
self._current_patience = patience
|
|
self._checkpoint_dir = checkpoint_dir
|
|
self._exit_on_end = exit_on_end
|
|
self._latest_checkpoint_path = None
|
|
# used for distributed training (tf.estimator.train_and_evaluate)
|
|
self._file_path = file_path
|
|
self._epoch = start_epoch
|
|
if self._file_path is not None:
|
|
# TODO try to read epoch from a file that we create
|
|
if tf.io.gfile.exists(self._file_path):
|
|
# delete the file if it exists (not sure this makes sense)
|
|
logging.info("EarlyStopHook: Removing existing file: %s.", self._file_path)
|
|
tf.io.gfile.remove(self._file_path)
|
|
|
|
# best_checkpoint dir will contain the best checkpoint
|
|
self._best_checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint')
|
|
self._eval_checkpoint_path = os.path.join(checkpoint_dir, 'eval_checkpoint')
|
|
self._best_metric_path = os.path.join(self._best_checkpoint_path, self._metric)
|
|
|
|
if tf.io.gfile.exists(self._best_metric_path):
|
|
with tf.io.gfile.GFile(self._best_metric_path, mode="r") as f:
|
|
best_metric_from_file = float(f.read())
|
|
else:
|
|
best_metric_from_file = None
|
|
|
|
if minimize:
|
|
# current < best : is better
|
|
self._is_better_than = operator.lt
|
|
# worse metric possible
|
|
if best_metric_from_file is None:
|
|
self._best_metric = np.inf
|
|
else:
|
|
self._best_metric = best_metric_from_file - tolerance
|
|
# used for printing
|
|
self._early_stop_name = "minimum"
|
|
else:
|
|
# current > best : is better
|
|
self._is_better_than = operator.gt
|
|
# worse metric possible
|
|
if best_metric_from_file is None:
|
|
self._best_metric = -np.inf
|
|
else:
|
|
self._best_metric = best_metric_from_file + tolerance
|
|
# used for printing
|
|
self._early_stop_name = "maximum"
|
|
|
|
def get_metrics_fn():
|
|
""" function to get metric tensors to early-stopping """
|
|
estimator_spec = get_estimator_spec_fn()
|
|
eval_metric_ops = estimator_spec.eval_metric_ops
|
|
if metric not in eval_metric_ops:
|
|
raise ValueError(
|
|
"Expecting early_stop_metric '%s' key in eval_metric_ops dict"
|
|
% (metric))
|
|
# get the value_op from the (value_op, update_op) value
|
|
return {k: v[0] for k, v in eval_metric_ops.items()}
|
|
|
|
# initialize GetMetricsHook to get current value of metric from session
|
|
super(EarlyStopHook, self).__init__(get_metrics_fn=get_metrics_fn)
|
|
|
|
def early_stop(self, epoch):
|
|
"""
|
|
Looks at the current value of the early stopping metric.
|
|
Decrements current patience. If metric improves, patience is reset
|
|
and latest checkpoint is moved to checkpoint_dir/best_checkpoint.
|
|
If current patience reaches zero, returns True.
|
|
|
|
Args:
|
|
epoch:
|
|
The current epoch number.
|
|
|
|
Returns:
|
|
True when early-stopped. False otherwise.
|
|
"""
|
|
# decrement patience
|
|
self._current_patience -= 1
|
|
|
|
# get the current metric value
|
|
current_metric = self.metric_values[self._metric]
|
|
|
|
if self._is_better_than(current_metric, self._best_metric):
|
|
# save best version of model
|
|
self._best_metric = current_metric
|
|
logging.info(
|
|
"Found new %s %s=%f @ epoch %d",
|
|
self._early_stop_name, self._metric, self._best_metric, epoch)
|
|
# backup the file to checkpoint_dir/best_checkpoint
|
|
assert self._latest_checkpoint_path, "expecting latest checkpoint"
|
|
logging.info("Backing up " + self._latest_checkpoint_path)
|
|
|
|
try:
|
|
eval_checkpoint = tf.train.latest_checkpoint(self._eval_checkpoint_path)
|
|
twml.util.backup_checkpoint(
|
|
checkpoint_path_prefix=eval_checkpoint,
|
|
backup_path=self._best_checkpoint_path)
|
|
except twml.errors.CheckpointNotFoundError as ex:
|
|
msg = "Consider increasing 'keep_checkpoint_max' or 'save_checkpoint_secs'"
|
|
raise twml.errors.CheckpointNotFoundError(str(ex) + "\n" + msg)
|
|
|
|
tf.io.gfile.makedirs(os.path.dirname(self._best_metric_path))
|
|
with tf.io.gfile.GFile(self._best_metric_path, mode="w") as f:
|
|
# Write with enough precision
|
|
f.write("%.8f" % self._best_metric)
|
|
|
|
# reset patience
|
|
self._current_patience = self._patience
|
|
|
|
elif self._current_patience > 0:
|
|
logging.info("No new %s found after %d epochs",
|
|
self._early_stop_name, self._patience - self._current_patience)
|
|
elif self._current_patience == 0:
|
|
logging.info(
|
|
"No new %s found after %d epochs. Early-stopping experiment.",
|
|
self._early_stop_name, self._patience)
|
|
return True
|
|
|
|
return False
|
|
|
|
def cleanup_checkpoints(self):
|
|
"""
|
|
makes it so that the best checkpoint is the only checkpoint
|
|
in checkpoint_dir.
|
|
"""
|
|
raise NotImplementedError("cleanup_checkpoints is no longer supported")
|
|
|
|
def end(self, session):
|
|
"""
|
|
This method is called at the end of an evaluation/epoch.
|
|
When file_path constructor argument is provided, this
|
|
will call ``early_stop()``.
|
|
When ``early_stop()`` returns True, it creates the file_path,
|
|
which will be detected by StopIfExistsHooks
|
|
and stop training for all workers and the chief. It will
|
|
also call ``cleanup_checkpoints()``.
|
|
"""
|
|
super(EarlyStopHook, self).end(session)
|
|
|
|
# Checks for early stopping criteria and makes a backup
|
|
self.should_stop = self.early_stop(self._epoch)
|
|
|
|
if self._file_path is not None:
|
|
if self.should_stop:
|
|
# create a file to inform workers
|
|
with tf.io.gfile.GFile(self._file_path, "wb") as gfile:
|
|
gfile.write("early-stop\n")
|
|
# makes the best checkpoint the only checkpoint in save_dir.
|
|
msg = "early-stopping evaluation at epoch %d" % self._epoch
|
|
logging.info(msg)
|
|
if self._exit_on_end:
|
|
raise twml.errors.EarlyStopError(msg)
|
|
else:
|
|
self._latest_checkpoint_path = None
|
|
|
|
self._epoch += 1
|
|
|
|
def begin(self):
|
|
"""
|
|
Saves the latest_checkpoint in case it gets superseded by another checkpoint.
|
|
Remember that when used with train_and_evaluate, the chief saves checkpoints
|
|
continuouly. The chief could save a checkpoint after evaluation started.
|
|
So saving the checkpoint at the beginning of evaluation ensures that we
|
|
later save the correct best checkpoint.
|
|
"""
|
|
super(EarlyStopHook, self).begin()
|
|
self._latest_checkpoint_path = tf.train.latest_checkpoint(self._checkpoint_dir)
|
|
|
|
assert self._latest_checkpoint_path, "expecting latest checkpoint"
|
|
# Backup to temporary directory
|
|
try:
|
|
twml.util.backup_checkpoint(
|
|
checkpoint_path_prefix=self._latest_checkpoint_path,
|
|
backup_path=self._eval_checkpoint_path)
|
|
except twml.errors.CheckpointNotFoundError as ex:
|
|
msg = "Consider increasing 'keep_checkpoint_max' or 'save_checkpoint_secs'"
|
|
raise twml.errors.CheckpointNotFoundError(str(ex) + "\n" + msg)
|
|
|
|
|
|
class MetricsUpdateHook(GetMetricsHook):
|
|
"""
|
|
A GetMetricsHook augmented with logic to map SessionRun events to metrics updates.
|
|
It is mainly used by `TrackRun` to persist model metrics via Model Repo.
|
|
"""
|
|
|
|
def __init__(self,
|
|
get_estimator_spec_fn,
|
|
add_metrics_fn,
|
|
every_n_iter=None,
|
|
every_n_secs=None
|
|
):
|
|
"""
|
|
Args:
|
|
get_estimator_spec_fn:
|
|
function that returns the current EstimatorSpec.
|
|
The EstimatorSpec is used to obtain the current eval_metric_ops.
|
|
add_metrics_fn: `function` callback used to report metrics, called automatically
|
|
at the end of every epoch.
|
|
every_n_iter: `int`, log the metrics once every N local
|
|
steps taken in the current epoch.
|
|
every_n_secs: `int` or `float`, log the metrics once every N
|
|
seconds passed in the current epoch. Exactly one of `every_n_iter` and `every_n_secs`
|
|
should be provided.
|
|
Raises:
|
|
ValueError: if `every_n_iter` is non-positive or if not exactly one of `every_n_iter` and
|
|
`every_n_secs` is set when `add_progress_metrics_fn` is provided.
|
|
"""
|
|
only_log_at_end = (every_n_iter is None) and (every_n_secs is None)
|
|
|
|
if (not only_log_at_end and every_n_iter and every_n_secs):
|
|
raise ValueError(
|
|
'exactly one of every_n_iter and every_n_secs must be provided'
|
|
)
|
|
|
|
# TODO: should have a minimum to avoid too many calls to ModelRepo?
|
|
if every_n_iter is not None and every_n_iter <= 0:
|
|
raise ValueError("invalid every_n_iter=%s." % every_n_iter)
|
|
|
|
self._timer = (
|
|
NeverTriggerTimer() if only_log_at_end else
|
|
SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter)
|
|
)
|
|
|
|
self._should_trigger = False
|
|
self._iter_count = 0
|
|
|
|
self._add_metrics_fn = add_metrics_fn
|
|
|
|
def get_metrics_fn():
|
|
"""
|
|
Function that returns the current EstimatorSpec.
|
|
The EstimatorSpec is used to obtain the current eval_metric_ops.
|
|
"""
|
|
estimator_spec = get_estimator_spec_fn()
|
|
eval_metric_ops = estimator_spec.eval_metric_ops
|
|
# get the value_op from the (value_op, update_op) value
|
|
return {k: v[0] for k, v in eval_metric_ops.items()}
|
|
super(MetricsUpdateHook, self).__init__(get_metrics_fn=get_metrics_fn)
|
|
|
|
def report_metrics(self):
|
|
"""
|
|
Triggers a metrics report.
|
|
"""
|
|
self._timer.update_last_triggered_step(self._iter_count)
|
|
if self.metric_values is not None:
|
|
self._add_metrics_fn(self.metric_values)
|
|
|
|
def begin(self):
|
|
"""
|
|
Triggered before each epoch.
|
|
"""
|
|
self._timer.reset()
|
|
self._iter_count = 0
|
|
return super(MetricsUpdateHook, self).begin()
|
|
|
|
def before_run(self, run_context):
|
|
"""
|
|
Triggered before each step.
|
|
"""
|
|
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
|
|
return super(MetricsUpdateHook, self).before_run(run_context)
|
|
|
|
def after_run(self, run_context, run_values):
|
|
"""
|
|
Triggered after each step.
|
|
"""
|
|
if self._should_trigger:
|
|
self.report_metrics()
|
|
self._iter_count += 1
|
|
return super(MetricsUpdateHook, self).after_run(run_context, run_values)
|
|
|
|
def end(self, session):
|
|
"""
|
|
Triggered after each epoch.
|
|
"""
|
|
self.report_metrics()
|
|
return super(MetricsUpdateHook, self).end(session)
|
|
|
|
|
|
class EarlyStopDuration(tf.train.SessionRunHook):
|
|
"""
|
|
Hook that can be used to terminate a job (training or validation) after a certain duration.
|
|
The hook is fault tolerant, i.e., if a job is allotted 1 hour to run and fails after 45 minutes,
|
|
then it will only run for 15 minutes once restarted.
|
|
|
|
Args:
|
|
max_duration:
|
|
A float. When this argument is defined, the job will automatically terminate after
|
|
`max_duration` seconds if it has not already compeleted.
|
|
|
|
overwrite:
|
|
A boolean. If set to True, this hook will overwrite the file containing the elapsed time
|
|
since the beginning of the job. In a distributed setting, this will be used so only one
|
|
job writes to the file while all others will have read access. In a distributed setting,
|
|
if all executors have this parameter set to False, then it just means that the hook will
|
|
not be fault tolerant. When restarted, the job will restart the clock from 0.
|
|
|
|
save_dir:
|
|
String. A directory (located on a file system that is Tensorflow compatible) where
|
|
we can store the file which contains the record of the elapsed time. This file is what makes
|
|
the hook faul tolerant.
|
|
|
|
exit_on_end:
|
|
when exit_on_end is True, twml.errors.EarlyStopError() is triggered to stop the job.
|
|
This is usually set to True to kill a validation job in a distributed setting.
|
|
"""
|
|
|
|
def __init__(self, max_duration: float, exit_on_end: bool, save_dir: str, overwrite: bool):
|
|
self._overwrite = overwrite
|
|
self._save_dir = save_dir
|
|
self._exit_on_end = exit_on_end
|
|
self._max_duration = max_duration
|
|
self._last_time_check = datetime.now()
|
|
|
|
# Initialize elapse time file
|
|
if overwrite:
|
|
self.elapsed_time()
|
|
|
|
@property
|
|
def elapsed_file_path(self):
|
|
return os.path.join(self._save_dir, "early_stop_duration.txt")
|
|
|
|
def early_stop(self) -> bool:
|
|
return self.elapsed_time() > self._max_duration
|
|
|
|
def elapsed_time(self) -> float:
|
|
# Recorded elapsed time is 0 unless it's been recorded in a file already
|
|
recorded_elapsed_time = 0
|
|
if tf.io.gfile.exists(self.elapsed_file_path):
|
|
with tf.io.gfile.GFile(self.elapsed_file_path, mode="r") as file:
|
|
recorded_elapsed_time = json.loads(file.read())["elapsed_time"]
|
|
|
|
elapsed_time = recorded_elapsed_time + (datetime.now() - self._last_time_check).total_seconds()
|
|
self._last_time_check = datetime.now()
|
|
|
|
if self._overwrite:
|
|
# Record the actualized new elapsed time to the file
|
|
tf.io.gfile.makedirs(os.path.dirname(self.elapsed_file_path))
|
|
with tf.io.gfile.GFile(self.elapsed_file_path, mode="w") as file:
|
|
record = {
|
|
"elapsed_time": elapsed_time,
|
|
"max_duration": self._max_duration
|
|
}
|
|
file.write(json.dumps(record, indent=2))
|
|
|
|
return elapsed_time
|
|
|
|
def before_run(self, run_context: tf.estimator.SessionRunContext) -> None:
|
|
if self.early_stop():
|
|
message = f"""
|
|
Stopping job which now exceeded the maximum duration of {self._max_duration} seconds.
|
|
"""
|
|
logging.info(message)
|
|
run_context.request_stop()
|
|
|
|
if self._exit_on_end:
|
|
raise twml.errors.EarlyStopError(message)
|
|
|
|
|
|
class StopAtStepHook(tf.train.StopAtStepHook):
|
|
"""
|
|
Overrides ``tf.train.StopAtStepHook`` so that
|
|
a ``stop_requested`` property can be accessed to determine
|
|
if this hook requested a stop.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(StopAtStepHook, self).__init__(*args, **kwargs)
|
|
self._stop_requested = False
|
|
|
|
@property
|
|
def stop_requested(self):
|
|
""" true if this hook requested a stop """
|
|
return self._stop_requested
|
|
|
|
def after_run(self, run_context, run_values):
|
|
""" sets self.stop_requested to true when requesting a stop """
|
|
super(StopAtStepHook, self).after_run(run_context, run_values)
|
|
self._stop_requested = run_context.stop_requested
|
|
|
|
|
|
class StopIfExistsHook(tf.train.SessionRunHook):
|
|
"""
|
|
Hook that requests stop if a file exists.
|
|
This hook is used with the EarlyStopHook to implement
|
|
early-stopping for distributed training (tf.estimator.train_and_evaluate).
|
|
"""
|
|
|
|
def __init__(self, file_path):
|
|
"""
|
|
Arguments:
|
|
file_path:
|
|
path to file. When this hook detects that the file exists,
|
|
it requests a stop, which effectively kills this worker.
|
|
"""
|
|
self._file_path = file_path
|
|
self._stop_requested = False
|
|
|
|
def after_run(self, run_context, run_values):
|
|
if tf.io.gfile.exists(self._file_path):
|
|
logging.info("Early-stopping file detected; requesting stop")
|
|
run_context.request_stop()
|
|
self._stop_requested = True
|
|
|
|
@property
|
|
def stop_requested(self):
|
|
""" true if this hook requested a stop """
|
|
return self._stop_requested
|