1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Basic loop for training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import errors 21from tensorflow.python.util.tf_export import tf_export 22 23 24@tf_export(v1=["train.basic_train_loop"]) 25def basic_train_loop(supervisor, train_step_fn, args=None, 26 kwargs=None, master=""): 27 """Basic loop to train a model. 28 29 Calls `train_step_fn` in a loop to train a model. The function is called as: 30 31 ```python 32 train_step_fn(session, *args, **kwargs) 33 ``` 34 35 It is passed a `tf.Session` in addition to `args` and `kwargs`. The function 36 typically runs one training step in the session. 37 38 Args: 39 supervisor: `tf.train.Supervisor` to run the training services. 40 train_step_fn: Callable to execute one training step. Called 41 repeatedly as `train_step_fn(session, *args **kwargs)`. 42 args: Optional positional arguments passed to `train_step_fn`. 43 kwargs: Optional keyword arguments passed to `train_step_fn`. 44 master: Master to use to create the training session. Defaults to 45 `""` which causes the session to be created in the local process. 46 """ 47 if args is None: 48 args = [] 49 if kwargs is None: 50 kwargs = {} 51 should_retry = True 52 while should_retry: 53 try: 54 should_retry = False 55 with supervisor.managed_session(master) as sess: 56 while not supervisor.should_stop(): 57 train_step_fn(sess, *args, **kwargs) 58 except errors.AbortedError: 59 # Always re-run on AbortedError as it indicates a restart of one of the 60 # distributed tensorflow servers. 61 should_retry = True 62