• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for Estimator (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.util import tf_inspect
27
28
29def assert_estimator_contract(tester, estimator_class):
30  """Asserts whether given estimator satisfies the expected contract.
31
32  This doesn't check every details of contract. This test is used for that a
33  function is not forgotten to implement in a precanned Estimator.
34
35  Args:
36    tester: A tf.test.TestCase.
37    estimator_class: 'type' object of pre-canned estimator.
38  """
39  attributes = tf_inspect.getmembers(estimator_class)
40  attribute_names = [a[0] for a in attributes]
41
42  tester.assertTrue('config' in attribute_names)
43  tester.assertTrue('evaluate' in attribute_names)
44  tester.assertTrue('export' in attribute_names)
45  tester.assertTrue('fit' in attribute_names)
46  tester.assertTrue('get_variable_names' in attribute_names)
47  tester.assertTrue('get_variable_value' in attribute_names)
48  tester.assertTrue('model_dir' in attribute_names)
49  tester.assertTrue('predict' in attribute_names)
50
51
52def assert_in_range(min_value, max_value, key, metrics):
53  actual_value = metrics[key]
54  if actual_value < min_value:
55    raise ValueError('%s: %s < %s.' % (key, actual_value, min_value))
56  if actual_value > max_value:
57    raise ValueError('%s: %s > %s.' % (key, actual_value, max_value))
58