1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""Utilities for the functionalities.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import time 23import six 24 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.training import training 27 28def check_positive_integer(value, name): 29 """Checks whether `value` is a positive integer.""" 30 if not isinstance(value, six.integer_types): 31 raise TypeError('{} must be int, got {}'.format(name, type(value))) 32 33 if value <= 0: 34 raise ValueError('{} must be positive, got {}'.format(name, value)) 35 36 37# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we 38# release a tensorflow_estimator with MultiHostDatasetInitializerHook in 39# python/estimator/util.py. 40class MultiHostDatasetInitializerHook(training.SessionRunHook): 41 """Creates a SessionRunHook that initializes all passed iterators.""" 42 43 def __init__(self, dataset_initializers): 44 self._initializers = dataset_initializers 45 46 def after_create_session(self, session, coord): 47 del coord 48 start = time.time() 49 session.run(self._initializers) 50 logging.info('Initialized dataset iterators in %d seconds', 51 time.time() - start) 52