1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Estimator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import itertools 23import json 24import os 25import tempfile 26 27import numpy as np 28import six 29from six.moves import xrange # pylint: disable=redefined-builtin 30 31from google.protobuf import text_format 32 33from tensorflow.contrib import learn 34from tensorflow.contrib import lookup 35from tensorflow.python.training import training_util 36from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib 37from tensorflow.contrib.layers.python.layers import optimizers 38from tensorflow.contrib.learn.python.learn import experiment 39from tensorflow.contrib.learn.python.learn import models 40from tensorflow.contrib.learn.python.learn import monitors as monitors_lib 41from tensorflow.contrib.learn.python.learn.datasets import base 42from tensorflow.contrib.learn.python.learn.estimators import _sklearn 43from tensorflow.contrib.learn.python.learn.estimators import constants 44from tensorflow.contrib.learn.python.learn.estimators import estimator 45from tensorflow.contrib.learn.python.learn.estimators import linear 46from tensorflow.contrib.learn.python.learn.estimators import model_fn 47from tensorflow.contrib.learn.python.learn.estimators import run_config 48from tensorflow.contrib.learn.python.learn.utils import input_fn_utils 49from tensorflow.contrib.metrics.python.ops import metric_ops 50from tensorflow.contrib.testing.python.framework import util_test 51from tensorflow.python.client import session as session_lib 52from tensorflow.python.framework import constant_op 53from tensorflow.python.framework import dtypes 54from tensorflow.python.framework import ops 55from tensorflow.python.lib.io import file_io 56from tensorflow.python.ops import array_ops 57from tensorflow.python.ops import check_ops 58from tensorflow.python.ops import control_flow_ops 59from tensorflow.python.ops import math_ops 60from tensorflow.python.ops import parsing_ops 61from tensorflow.python.ops import variables as variables_lib 62from tensorflow.python.platform import gfile 63from tensorflow.python.platform import test 64from tensorflow.python.saved_model import loader 65from tensorflow.python.saved_model import tag_constants 66from tensorflow.python.summary import summary 67from tensorflow.python.training import basic_session_run_hooks 68from tensorflow.python.training import checkpoint_state_pb2 69from tensorflow.python.training import input as input_lib 70from tensorflow.python.training import monitored_session 71from tensorflow.python.training import saver as saver_lib 72from tensorflow.python.training import session_run_hook 73from tensorflow.python.util import compat 74 75_BOSTON_INPUT_DIM = 13 76_IRIS_INPUT_DIM = 4 77 78 79def boston_input_fn(num_epochs=None): 80 boston = base.load_boston() 81 features = input_lib.limit_epochs( 82 array_ops.reshape( 83 constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]), 84 num_epochs=num_epochs) 85 labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1]) 86 return features, labels 87 88 89def iris_input_fn(): 90 iris = base.load_iris() 91 features = array_ops.reshape( 92 constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM]) 93 labels = array_ops.reshape(constant_op.constant(iris.target), [-1]) 94 return features, labels 95 96 97def iris_input_fn_labels_dict(): 98 iris = base.load_iris() 99 features = array_ops.reshape( 100 constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM]) 101 labels = { 102 'labels': array_ops.reshape(constant_op.constant(iris.target), [-1]) 103 } 104 return features, labels 105 106 107def boston_eval_fn(): 108 boston = base.load_boston() 109 n_examples = len(boston.target) 110 features = array_ops.reshape( 111 constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) 112 labels = array_ops.reshape( 113 constant_op.constant(boston.target), [n_examples, 1]) 114 return array_ops.concat([features, features], 115 0), array_ops.concat([labels, labels], 0) 116 117 118def extract(data, key): 119 if isinstance(data, dict): 120 assert key in data 121 return data[key] 122 else: 123 return data 124 125 126def linear_model_params_fn(features, labels, mode, params): 127 features = extract(features, 'input') 128 labels = extract(labels, 'labels') 129 130 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 131 model_fn.ModeKeys.INFER) 132 prediction, loss = (models.linear_regression_zero_init(features, labels)) 133 train_op = optimizers.optimize_loss( 134 loss, 135 training_util.get_global_step(), 136 optimizer='Adagrad', 137 learning_rate=params['learning_rate']) 138 return prediction, loss, train_op 139 140 141def linear_model_fn(features, labels, mode): 142 features = extract(features, 'input') 143 labels = extract(labels, 'labels') 144 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 145 model_fn.ModeKeys.INFER) 146 if isinstance(features, dict): 147 (_, features), = features.items() 148 prediction, loss = (models.linear_regression_zero_init(features, labels)) 149 train_op = optimizers.optimize_loss( 150 loss, 151 training_util.get_global_step(), 152 optimizer='Adagrad', 153 learning_rate=0.1) 154 return prediction, loss, train_op 155 156 157def linear_model_fn_with_model_fn_ops(features, labels, mode): 158 """Same as linear_model_fn, but returns `ModelFnOps`.""" 159 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 160 model_fn.ModeKeys.INFER) 161 prediction, loss = (models.linear_regression_zero_init(features, labels)) 162 train_op = optimizers.optimize_loss( 163 loss, 164 training_util.get_global_step(), 165 optimizer='Adagrad', 166 learning_rate=0.1) 167 return model_fn.ModelFnOps( 168 mode=mode, predictions=prediction, loss=loss, train_op=train_op) 169 170 171def logistic_model_no_mode_fn(features, labels): 172 features = extract(features, 'input') 173 labels = extract(labels, 'labels') 174 labels = array_ops.one_hot(labels, 3, 1, 0) 175 prediction, loss = (models.logistic_regression_zero_init(features, labels)) 176 train_op = optimizers.optimize_loss( 177 loss, 178 training_util.get_global_step(), 179 optimizer='Adagrad', 180 learning_rate=0.1) 181 return { 182 'class': math_ops.argmax(prediction, 1), 183 'prob': prediction 184 }, loss, train_op 185 186 187VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' 188EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' 189 190 191def _build_estimator_for_export_tests(tmpdir): 192 193 def _input_fn(): 194 iris = base.load_iris() 195 return { 196 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) 197 }, constant_op.constant( 198 iris.target, shape=[150], dtype=dtypes.int32) 199 200 feature_columns = [ 201 feature_column_lib.real_valued_column('feature', dimension=4) 202 ] 203 204 est = linear.LinearRegressor(feature_columns) 205 est.fit(input_fn=_input_fn, steps=20) 206 207 feature_spec = feature_column_lib.create_feature_spec_for_parsing( 208 feature_columns) 209 serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) 210 211 # hack in an op that uses an asset, in order to test asset export. 212 # this is not actually valid, of course. 213 def serving_input_fn_with_asset(): 214 features, labels, inputs = serving_input_fn() 215 216 vocab_file_name = os.path.join(tmpdir, 'my_vocab_file') 217 vocab_file = gfile.GFile(vocab_file_name, mode='w') 218 vocab_file.write(VOCAB_FILE_CONTENT) 219 vocab_file.close() 220 hashtable = lookup.HashTable( 221 lookup.TextFileStringTableInitializer(vocab_file_name), 'x') 222 features['bogus_lookup'] = hashtable.lookup( 223 math_ops.cast(features['feature'], dtypes.int64)) 224 225 return input_fn_utils.InputFnOps(features, labels, inputs) 226 227 return est, serving_input_fn_with_asset 228 229 230def _build_estimator_for_resource_export_test(): 231 232 def _input_fn(): 233 iris = base.load_iris() 234 return { 235 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) 236 }, constant_op.constant( 237 iris.target, shape=[150], dtype=dtypes.int32) 238 239 feature_columns = [ 240 feature_column_lib.real_valued_column('feature', dimension=4) 241 ] 242 243 def resource_constant_model_fn(unused_features, unused_labels, mode): 244 """A model_fn that loads a constant from a resource and serves it.""" 245 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 246 model_fn.ModeKeys.INFER) 247 248 const = constant_op.constant(-1, dtype=dtypes.int64) 249 table = lookup.MutableHashTable( 250 dtypes.string, dtypes.int64, const, name='LookupTableModel') 251 update_global_step = training_util.get_global_step().assign_add(1) 252 if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): 253 key = constant_op.constant(['key']) 254 value = constant_op.constant([42], dtype=dtypes.int64) 255 train_op_1 = table.insert(key, value) 256 training_state = lookup.MutableHashTable( 257 dtypes.string, dtypes.int64, const, name='LookupTableTrainingState') 258 training_op_2 = training_state.insert(key, value) 259 return (const, const, 260 control_flow_ops.group(train_op_1, training_op_2, 261 update_global_step)) 262 if mode == model_fn.ModeKeys.INFER: 263 key = constant_op.constant(['key']) 264 prediction = table.lookup(key) 265 return prediction, const, update_global_step 266 267 est = estimator.Estimator(model_fn=resource_constant_model_fn) 268 est.fit(input_fn=_input_fn, steps=1) 269 270 feature_spec = feature_column_lib.create_feature_spec_for_parsing( 271 feature_columns) 272 serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) 273 return est, serving_input_fn 274 275 276class CheckCallsMonitor(monitors_lib.BaseMonitor): 277 278 def __init__(self, expect_calls): 279 super(CheckCallsMonitor, self).__init__() 280 self.begin_calls = None 281 self.end_calls = None 282 self.expect_calls = expect_calls 283 284 def begin(self, max_steps): 285 self.begin_calls = 0 286 self.end_calls = 0 287 288 def step_begin(self, step): 289 self.begin_calls += 1 290 return {} 291 292 def step_end(self, step, outputs): 293 self.end_calls += 1 294 return False 295 296 def end(self): 297 assert (self.end_calls == self.expect_calls and 298 self.begin_calls == self.expect_calls) 299 300 301def _model_fn_ops(expected_features, expected_labels, actual_features, 302 actual_labels, mode): 303 assert_ops = tuple([ 304 check_ops.assert_equal( 305 expected_features[k], actual_features[k], name='assert_%s' % k) 306 for k in expected_features 307 ] + [ 308 check_ops.assert_equal( 309 expected_labels, actual_labels, name='assert_labels') 310 ]) 311 with ops.control_dependencies(assert_ops): 312 return model_fn.ModelFnOps( 313 mode=mode, 314 predictions=constant_op.constant(0.), 315 loss=constant_op.constant(0.), 316 train_op=training_util.get_global_step().assign_add(1)) 317 318 319def _make_input_fn(features, labels): 320 321 def _input_fn(): 322 return {k: constant_op.constant(v) 323 for k, v in six.iteritems(features)}, constant_op.constant(labels) 324 325 return _input_fn 326 327 328class EstimatorModelFnTest(test.TestCase): 329 330 def testModelFnArgs(self): 331 features = {'x': 42., 'y': 43.} 332 labels = 44. 333 expected_params = {'some_param': 'some_value'} 334 expected_config = run_config.RunConfig() 335 expected_config.i_am_test = True 336 337 # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments 338 # doesn't work with mock fns. 339 model_fn_call_count = [0] 340 341 # `features` and `labels` are passed by position, `arg0` and `arg1` here. 342 def _model_fn(arg0, arg1, mode, params, config): 343 model_fn_call_count[0] += 1 344 self.assertItemsEqual(features.keys(), arg0.keys()) 345 self.assertEqual(model_fn.ModeKeys.TRAIN, mode) 346 self.assertEqual(expected_params, params) 347 self.assertTrue(config.i_am_test) 348 return _model_fn_ops(features, labels, arg0, arg1, mode) 349 350 est = estimator.Estimator( 351 model_fn=_model_fn, params=expected_params, config=expected_config) 352 self.assertEqual(0, model_fn_call_count[0]) 353 est.fit(input_fn=_make_input_fn(features, labels), steps=1) 354 self.assertEqual(1, model_fn_call_count[0]) 355 356 def testPartialModelFnArgs(self): 357 features = {'x': 42., 'y': 43.} 358 labels = 44. 359 expected_params = {'some_param': 'some_value'} 360 expected_config = run_config.RunConfig() 361 expected_config.i_am_test = True 362 expected_foo = 45. 363 expected_bar = 46. 364 365 # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments 366 # doesn't work with mock fns. 367 model_fn_call_count = [0] 368 369 # `features` and `labels` are passed by position, `arg0` and `arg1` here. 370 def _model_fn(arg0, arg1, foo, mode, params, config, bar): 371 model_fn_call_count[0] += 1 372 self.assertEqual(expected_foo, foo) 373 self.assertEqual(expected_bar, bar) 374 self.assertItemsEqual(features.keys(), arg0.keys()) 375 self.assertEqual(model_fn.ModeKeys.TRAIN, mode) 376 self.assertEqual(expected_params, params) 377 self.assertTrue(config.i_am_test) 378 return _model_fn_ops(features, labels, arg0, arg1, mode) 379 380 partial_model_fn = functools.partial( 381 _model_fn, foo=expected_foo, bar=expected_bar) 382 383 est = estimator.Estimator( 384 model_fn=partial_model_fn, 385 params=expected_params, 386 config=expected_config) 387 self.assertEqual(0, model_fn_call_count[0]) 388 est.fit(input_fn=_make_input_fn(features, labels), steps=1) 389 self.assertEqual(1, model_fn_call_count[0]) 390 391 def testModelFnWithModelDir(self): 392 expected_param = {'some_param': 'some_value'} 393 expected_model_dir = tempfile.mkdtemp() 394 395 def _argument_checker(features, 396 labels, 397 mode, 398 params, 399 config=None, 400 model_dir=None): 401 _, _, _ = features, labels, config 402 self.assertEqual(model_fn.ModeKeys.TRAIN, mode) 403 self.assertEqual(expected_param, params) 404 self.assertEqual(model_dir, expected_model_dir) 405 return (constant_op.constant(0.), constant_op.constant(0.), 406 training_util.get_global_step().assign_add(1)) 407 408 est = estimator.Estimator( 409 model_fn=_argument_checker, 410 params=expected_param, 411 model_dir=expected_model_dir) 412 est.fit(input_fn=boston_input_fn, steps=1) 413 414 def testInvalidModelFn_no_train_op(self): 415 416 def _invalid_model_fn(features, labels): 417 # pylint: disable=unused-argument 418 w = variables_lib.Variable(42.0, 'weight') 419 update_global_step = training_util.get_global_step().assign_add(1) 420 with ops.control_dependencies([update_global_step]): 421 loss = 100.0 - w 422 return None, loss, None 423 424 est = estimator.Estimator(model_fn=_invalid_model_fn) 425 with self.assertRaisesRegexp(ValueError, 'Missing train_op'): 426 est.fit(input_fn=boston_input_fn, steps=1) 427 428 def testInvalidModelFn_no_loss(self): 429 430 def _invalid_model_fn(features, labels, mode): 431 # pylint: disable=unused-argument 432 w = variables_lib.Variable(42.0, 'weight') 433 loss = 100.0 - w 434 update_global_step = training_util.get_global_step().assign_add(1) 435 with ops.control_dependencies([update_global_step]): 436 train_op = w.assign_add(loss / 100.0) 437 predictions = loss 438 if mode == model_fn.ModeKeys.EVAL: 439 loss = None 440 return predictions, loss, train_op 441 442 est = estimator.Estimator(model_fn=_invalid_model_fn) 443 est.fit(input_fn=boston_input_fn, steps=1) 444 with self.assertRaisesRegexp(ValueError, 'Missing loss'): 445 est.evaluate(input_fn=boston_eval_fn, steps=1) 446 447 def testInvalidModelFn_no_prediction(self): 448 449 def _invalid_model_fn(features, labels): 450 # pylint: disable=unused-argument 451 w = variables_lib.Variable(42.0, 'weight') 452 loss = 100.0 - w 453 update_global_step = training_util.get_global_step().assign_add(1) 454 with ops.control_dependencies([update_global_step]): 455 train_op = w.assign_add(loss / 100.0) 456 return None, loss, train_op 457 458 est = estimator.Estimator(model_fn=_invalid_model_fn) 459 est.fit(input_fn=boston_input_fn, steps=1) 460 with self.assertRaisesRegexp(ValueError, 'Missing prediction'): 461 est.evaluate(input_fn=boston_eval_fn, steps=1) 462 with self.assertRaisesRegexp(ValueError, 'Missing prediction'): 463 est.predict(input_fn=boston_input_fn) 464 with self.assertRaisesRegexp(ValueError, 'Missing prediction'): 465 est.predict( 466 input_fn=functools.partial(boston_input_fn, num_epochs=1), 467 as_iterable=True) 468 469 def testModelFnScaffoldInTraining(self): 470 self.is_init_fn_called = False 471 472 def _init_fn(scaffold, session): 473 _, _ = scaffold, session 474 self.is_init_fn_called = True 475 476 def _model_fn_scaffold(features, labels, mode): 477 _, _ = features, labels 478 return model_fn.ModelFnOps( 479 mode=mode, 480 predictions=constant_op.constant(0.), 481 loss=constant_op.constant(0.), 482 train_op=training_util.get_global_step().assign_add(1), 483 scaffold=monitored_session.Scaffold(init_fn=_init_fn)) 484 485 est = estimator.Estimator(model_fn=_model_fn_scaffold) 486 est.fit(input_fn=boston_input_fn, steps=1) 487 self.assertTrue(self.is_init_fn_called) 488 489 def testModelFnScaffoldSaverUsage(self): 490 491 def _model_fn_scaffold(features, labels, mode): 492 _, _ = features, labels 493 variables_lib.Variable(1., 'weight') 494 real_saver = saver_lib.Saver() 495 self.mock_saver = test.mock.Mock( 496 wraps=real_saver, saver_def=real_saver.saver_def) 497 return model_fn.ModelFnOps( 498 mode=mode, 499 predictions=constant_op.constant([[1.]]), 500 loss=constant_op.constant(0.), 501 train_op=training_util.get_global_step().assign_add(1), 502 scaffold=monitored_session.Scaffold(saver=self.mock_saver)) 503 504 def input_fn(): 505 return { 506 'x': constant_op.constant([[1.]]), 507 }, constant_op.constant([[1.]]) 508 509 est = estimator.Estimator(model_fn=_model_fn_scaffold) 510 est.fit(input_fn=input_fn, steps=1) 511 self.assertTrue(self.mock_saver.save.called) 512 est.evaluate(input_fn=input_fn, steps=1) 513 self.assertTrue(self.mock_saver.restore.called) 514 est.predict(input_fn=input_fn) 515 self.assertTrue(self.mock_saver.restore.called) 516 517 def serving_input_fn(): 518 serialized_tf_example = array_ops.placeholder( 519 dtype=dtypes.string, shape=[None], name='input_example_tensor') 520 features, labels = input_fn() 521 return input_fn_utils.InputFnOps(features, labels, { 522 'examples': serialized_tf_example 523 }) 524 525 est.export_savedmodel( 526 os.path.join(est.model_dir, 'export'), serving_input_fn) 527 self.assertTrue(self.mock_saver.restore.called) 528 529 530class EstimatorTest(test.TestCase): 531 532 def testExperimentIntegration(self): 533 exp = experiment.Experiment( 534 estimator=estimator.Estimator(model_fn=linear_model_fn), 535 train_input_fn=boston_input_fn, 536 eval_input_fn=boston_input_fn) 537 exp.test() 538 539 def testCheckpointSaverHookSuppressesTheDefaultOne(self): 540 saver_hook = test.mock.Mock( 541 spec=basic_session_run_hooks.CheckpointSaverHook) 542 saver_hook.before_run.return_value = None 543 est = estimator.Estimator(model_fn=linear_model_fn) 544 est.fit(input_fn=boston_input_fn, steps=1, monitors=[saver_hook]) 545 # test nothing is saved, due to suppressing default saver 546 with self.assertRaises(learn.NotFittedError): 547 est.evaluate(input_fn=boston_input_fn, steps=1) 548 549 def testCustomConfig(self): 550 test_random_seed = 5783452 551 552 class TestInput(object): 553 554 def __init__(self): 555 self.random_seed = 0 556 557 def config_test_input_fn(self): 558 self.random_seed = ops.get_default_graph().seed 559 return constant_op.constant([[1.]]), constant_op.constant([1.]) 560 561 config = run_config.RunConfig(tf_random_seed=test_random_seed) 562 test_input = TestInput() 563 est = estimator.Estimator(model_fn=linear_model_fn, config=config) 564 est.fit(input_fn=test_input.config_test_input_fn, steps=1) 565 # If input_fn ran, it will have given us the random seed set on the graph. 566 self.assertEquals(test_random_seed, test_input.random_seed) 567 568 def testRunConfigModelDir(self): 569 config = run_config.RunConfig(model_dir='test_dir') 570 est = estimator.Estimator(model_fn=linear_model_fn, config=config) 571 self.assertEqual('test_dir', est.config.model_dir) 572 self.assertEqual('test_dir', est.model_dir) 573 574 def testModelDirAndRunConfigModelDir(self): 575 config = run_config.RunConfig(model_dir='test_dir') 576 est = estimator.Estimator( 577 model_fn=linear_model_fn, config=config, model_dir='test_dir') 578 self.assertEqual('test_dir', est.config.model_dir) 579 580 with self.assertRaisesRegexp( 581 ValueError, 'model_dir are set both in constructor and RunConfig, ' 582 'but with different'): 583 estimator.Estimator( 584 model_fn=linear_model_fn, config=config, model_dir='different_dir') 585 586 def testModelDirIsCopiedToRunConfig(self): 587 config = run_config.RunConfig() 588 self.assertIsNone(config.model_dir) 589 590 est = estimator.Estimator( 591 model_fn=linear_model_fn, model_dir='test_dir', config=config) 592 self.assertEqual('test_dir', est.config.model_dir) 593 self.assertEqual('test_dir', est.model_dir) 594 595 def testModelDirAsTempDir(self): 596 with test.mock.patch.object(tempfile, 'mkdtemp', return_value='temp_dir'): 597 est = estimator.Estimator(model_fn=linear_model_fn) 598 self.assertEqual('temp_dir', est.config.model_dir) 599 self.assertEqual('temp_dir', est.model_dir) 600 601 def testCheckInputs(self): 602 est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn)) 603 # Lambdas so we have to different objects to compare 604 right_features = lambda: np.ones(shape=[7, 8], dtype=np.float32) 605 right_labels = lambda: np.ones(shape=[7, 10], dtype=np.int32) 606 est.fit(right_features(), right_labels(), steps=1) 607 # TODO(wicke): This does not fail for np.int32 because of data_feeder magic. 608 wrong_type_features = np.ones(shape=[7, 8], dtype=np.int64) 609 wrong_size_features = np.ones(shape=[7, 10]) 610 wrong_type_labels = np.ones(shape=[7, 10], dtype=np.float32) 611 wrong_size_labels = np.ones(shape=[7, 11]) 612 est.fit(x=right_features(), y=right_labels(), steps=1) 613 with self.assertRaises(ValueError): 614 est.fit(x=wrong_type_features, y=right_labels(), steps=1) 615 with self.assertRaises(ValueError): 616 est.fit(x=wrong_size_features, y=right_labels(), steps=1) 617 with self.assertRaises(ValueError): 618 est.fit(x=right_features(), y=wrong_type_labels, steps=1) 619 with self.assertRaises(ValueError): 620 est.fit(x=right_features(), y=wrong_size_labels, steps=1) 621 622 def testBadInput(self): 623 est = estimator.Estimator(model_fn=linear_model_fn) 624 self.assertRaisesRegexp( 625 ValueError, 626 'Either x or input_fn must be provided.', 627 est.fit, 628 x=None, 629 input_fn=None, 630 steps=1) 631 self.assertRaisesRegexp( 632 ValueError, 633 'Can not provide both input_fn and x or y', 634 est.fit, 635 x='X', 636 input_fn=iris_input_fn, 637 steps=1) 638 self.assertRaisesRegexp( 639 ValueError, 640 'Can not provide both input_fn and x or y', 641 est.fit, 642 y='Y', 643 input_fn=iris_input_fn, 644 steps=1) 645 self.assertRaisesRegexp( 646 ValueError, 647 'Can not provide both input_fn and batch_size', 648 est.fit, 649 input_fn=iris_input_fn, 650 batch_size=100, 651 steps=1) 652 self.assertRaisesRegexp( 653 ValueError, 654 'Inputs cannot be tensors. Please provide input_fn.', 655 est.fit, 656 x=constant_op.constant(1.), 657 steps=1) 658 659 def testUntrained(self): 660 boston = base.load_boston() 661 est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn)) 662 with self.assertRaises(learn.NotFittedError): 663 _ = est.score(x=boston.data, y=boston.target.astype(np.float64)) 664 with self.assertRaises(learn.NotFittedError): 665 est.predict(x=boston.data) 666 667 def testContinueTraining(self): 668 boston = base.load_boston() 669 output_dir = tempfile.mkdtemp() 670 est = estimator.SKCompat( 671 estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) 672 float64_labels = boston.target.astype(np.float64) 673 est.fit(x=boston.data, y=float64_labels, steps=50) 674 scores = est.score( 675 x=boston.data, 676 y=float64_labels, 677 metrics={ 678 'MSE': metric_ops.streaming_mean_squared_error 679 }) 680 del est 681 # Create another estimator object with the same output dir. 682 est2 = estimator.SKCompat( 683 estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) 684 685 # Check we can evaluate and predict. 686 scores2 = est2.score( 687 x=boston.data, 688 y=float64_labels, 689 metrics={ 690 'MSE': metric_ops.streaming_mean_squared_error 691 }) 692 self.assertAllClose(scores['MSE'], scores2['MSE']) 693 predictions = np.array(list(est2.predict(x=boston.data))) 694 other_score = _sklearn.mean_squared_error(predictions, float64_labels) 695 self.assertAllClose(scores['MSE'], other_score) 696 697 # Check we can keep training. 698 est2.fit(x=boston.data, y=float64_labels, steps=100) 699 scores3 = est2.score( 700 x=boston.data, 701 y=float64_labels, 702 metrics={ 703 'MSE': metric_ops.streaming_mean_squared_error 704 }) 705 self.assertLess(scores3['MSE'], scores['MSE']) 706 707 def test_checkpoint_contains_relative_paths(self): 708 tmpdir = tempfile.mkdtemp() 709 est = estimator.Estimator( 710 model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops) 711 est.fit(input_fn=boston_input_fn, steps=5) 712 713 checkpoint_file_content = file_io.read_file_to_string( 714 os.path.join(tmpdir, 'checkpoint')) 715 ckpt = checkpoint_state_pb2.CheckpointState() 716 text_format.Merge(checkpoint_file_content, ckpt) 717 self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') 718 # TODO(b/78461127): Please modify tests to not directly rely on names of 719 # checkpoints. 720 self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'], 721 ckpt.all_model_checkpoint_paths) 722 723 def test_train_save_copy_reload(self): 724 tmpdir = tempfile.mkdtemp() 725 model_dir1 = os.path.join(tmpdir, 'model_dir1') 726 est1 = estimator.Estimator( 727 model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops) 728 est1.fit(input_fn=boston_input_fn, steps=5) 729 730 model_dir2 = os.path.join(tmpdir, 'model_dir2') 731 os.renames(model_dir1, model_dir2) 732 est2 = estimator.Estimator( 733 model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops) 734 self.assertEqual(5, est2.get_variable_value('global_step')) 735 est2.fit(input_fn=boston_input_fn, steps=5) 736 self.assertEqual(10, est2.get_variable_value('global_step')) 737 738 def testEstimatorParams(self): 739 boston = base.load_boston() 740 est = estimator.SKCompat( 741 estimator.Estimator( 742 model_fn=linear_model_params_fn, params={ 743 'learning_rate': 0.01 744 })) 745 est.fit(x=boston.data, y=boston.target, steps=100) 746 747 def testHooksNotChanged(self): 748 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 749 # We pass empty array and expect it to remain empty after calling 750 # fit and evaluate. Requires inside to copy this array if any hooks were 751 # added. 752 my_array = [] 753 est.fit(input_fn=iris_input_fn, steps=100, monitors=my_array) 754 _ = est.evaluate(input_fn=iris_input_fn, steps=1, hooks=my_array) 755 self.assertEqual(my_array, []) 756 757 def testIrisIterator(self): 758 iris = base.load_iris() 759 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 760 x_iter = itertools.islice(iris.data, 100) 761 y_iter = itertools.islice(iris.target, 100) 762 estimator.SKCompat(est).fit(x_iter, y_iter, steps=20) 763 eval_result = est.evaluate(input_fn=iris_input_fn, steps=1) 764 x_iter_eval = itertools.islice(iris.data, 100) 765 y_iter_eval = itertools.islice(iris.target, 100) 766 score_result = estimator.SKCompat(est).score(x_iter_eval, y_iter_eval) 767 print(score_result) 768 self.assertItemsEqual(eval_result.keys(), score_result.keys()) 769 self.assertItemsEqual(['global_step', 'loss'], score_result.keys()) 770 predictions = estimator.SKCompat(est).predict(x=iris.data)['class'] 771 self.assertEqual(len(predictions), iris.target.shape[0]) 772 773 def testIrisIteratorArray(self): 774 iris = base.load_iris() 775 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 776 x_iter = itertools.islice(iris.data, 100) 777 y_iter = (np.array(x) for x in iris.target) 778 est.fit(x_iter, y_iter, steps=100) 779 _ = est.evaluate(input_fn=iris_input_fn, steps=1) 780 _ = six.next(est.predict(x=iris.data))['class'] 781 782 def testIrisIteratorPlainInt(self): 783 iris = base.load_iris() 784 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 785 x_iter = itertools.islice(iris.data, 100) 786 y_iter = (v for v in iris.target) 787 est.fit(x_iter, y_iter, steps=100) 788 _ = est.evaluate(input_fn=iris_input_fn, steps=1) 789 _ = six.next(est.predict(x=iris.data))['class'] 790 791 def testIrisTruncatedIterator(self): 792 iris = base.load_iris() 793 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 794 x_iter = itertools.islice(iris.data, 50) 795 y_iter = ([np.int32(v)] for v in iris.target) 796 est.fit(x_iter, y_iter, steps=100) 797 798 def testTrainStepsIsIncremental(self): 799 est = estimator.Estimator(model_fn=linear_model_fn) 800 est.fit(input_fn=boston_input_fn, steps=10) 801 self.assertEqual(10, est.get_variable_value('global_step')) 802 est.fit(input_fn=boston_input_fn, steps=15) 803 self.assertEqual(25, est.get_variable_value('global_step')) 804 805 def testTrainMaxStepsIsNotIncremental(self): 806 est = estimator.Estimator(model_fn=linear_model_fn) 807 est.fit(input_fn=boston_input_fn, max_steps=10) 808 self.assertEqual(10, est.get_variable_value('global_step')) 809 est.fit(input_fn=boston_input_fn, max_steps=15) 810 self.assertEqual(15, est.get_variable_value('global_step')) 811 812 def testPredict(self): 813 est = estimator.Estimator(model_fn=linear_model_fn) 814 boston = base.load_boston() 815 est.fit(input_fn=boston_input_fn, steps=1) 816 output = list(est.predict(x=boston.data, batch_size=10)) 817 self.assertEqual(len(output), boston.target.shape[0]) 818 819 def testWithModelFnOps(self): 820 """Test for model_fn that returns `ModelFnOps`.""" 821 est = estimator.Estimator(model_fn=linear_model_fn_with_model_fn_ops) 822 boston = base.load_boston() 823 est.fit(input_fn=boston_input_fn, steps=1) 824 input_fn = functools.partial(boston_input_fn, num_epochs=1) 825 scores = est.evaluate(input_fn=input_fn, steps=1) 826 self.assertIn('loss', scores.keys()) 827 output = list(est.predict(input_fn=input_fn)) 828 self.assertEqual(len(output), boston.target.shape[0]) 829 830 def testWrongInput(self): 831 832 def other_input_fn(): 833 return { 834 'other': constant_op.constant([0, 0, 0]) 835 }, constant_op.constant([0, 0, 0]) 836 837 est = estimator.Estimator(model_fn=linear_model_fn) 838 est.fit(input_fn=boston_input_fn, steps=1) 839 with self.assertRaises(ValueError): 840 est.fit(input_fn=other_input_fn, steps=1) 841 842 def testMonitorsForFit(self): 843 est = estimator.Estimator(model_fn=linear_model_fn) 844 est.fit( 845 input_fn=boston_input_fn, 846 steps=21, 847 monitors=[CheckCallsMonitor(expect_calls=21)]) 848 849 def testHooksForEvaluate(self): 850 851 class CheckCallHook(session_run_hook.SessionRunHook): 852 853 def __init__(self): 854 self.run_count = 0 855 856 def after_run(self, run_context, run_values): 857 self.run_count += 1 858 859 est = learn.Estimator(model_fn=linear_model_fn) 860 est.fit(input_fn=boston_input_fn, steps=1) 861 hook = CheckCallHook() 862 est.evaluate(input_fn=boston_eval_fn, steps=3, hooks=[hook]) 863 864 self.assertEqual(3, hook.run_count) 865 866 def testSummaryWriting(self): 867 est = estimator.Estimator(model_fn=linear_model_fn) 868 est.fit(input_fn=boston_input_fn, steps=200) 869 est.evaluate(input_fn=boston_input_fn, steps=200) 870 loss_summary = util_test.simple_values_from_events( 871 util_test.latest_events(est.model_dir), ['OptimizeLoss/loss']) 872 self.assertEqual(1, len(loss_summary)) 873 874 def testSummaryWritingWithSummaryProto(self): 875 876 def _streaming_mean_squared_error_histogram(predictions, 877 labels, 878 weights=None, 879 metrics_collections=None, 880 updates_collections=None, 881 name=None): 882 metrics, update_ops = metric_ops.streaming_mean_squared_error( 883 predictions, 884 labels, 885 weights=weights, 886 metrics_collections=metrics_collections, 887 updates_collections=updates_collections, 888 name=name) 889 return summary.histogram('histogram', metrics), update_ops 890 891 est = estimator.Estimator(model_fn=linear_model_fn) 892 est.fit(input_fn=boston_input_fn, steps=200) 893 est.evaluate( 894 input_fn=boston_input_fn, 895 steps=200, 896 metrics={ 897 'MSE': _streaming_mean_squared_error_histogram 898 }) 899 events = util_test.latest_events(est.model_dir + '/eval') 900 output_values = {} 901 for e in events: 902 if e.HasField('summary'): 903 for v in e.summary.value: 904 output_values[v.tag] = v 905 self.assertTrue('MSE' in output_values) 906 self.assertTrue(output_values['MSE'].HasField('histo')) 907 908 def testSummaryWritingWithTensor(self): 909 910 def _streaming_precition_mean_tensor(predictions, 911 weights=None, 912 metrics_collections=None, 913 updates_collections=None, 914 name=None): 915 return metric_ops.streaming_mean_tensor( 916 predictions, 917 weights=weights, 918 metrics_collections=metrics_collections, 919 updates_collections=updates_collections, 920 name=name) 921 922 est = estimator.Estimator(model_fn=linear_model_fn) 923 est.fit(input_fn=boston_input_fn, steps=200) 924 est.evaluate( 925 input_fn=boston_input_fn, 926 steps=200, 927 metrics={ 928 'PMT': _streaming_precition_mean_tensor 929 }) 930 events = util_test.latest_events(est.model_dir + '/eval') 931 output_values = {} 932 for e in events: 933 if e.HasField('summary'): 934 for v in e.summary.value: 935 output_values[v.tag] = v 936 self.assertTrue('PMT' in output_values) 937 self.assertTrue(output_values['PMT'].HasField('tensor')) 938 939 def testLossInGraphCollection(self): 940 941 class _LossCheckerHook(session_run_hook.SessionRunHook): 942 943 def begin(self): 944 self.loss_collection = ops.get_collection(ops.GraphKeys.LOSSES) 945 946 hook = _LossCheckerHook() 947 est = estimator.Estimator(model_fn=linear_model_fn) 948 est.fit(input_fn=boston_input_fn, steps=200, monitors=[hook]) 949 self.assertTrue(hook.loss_collection) 950 951 def test_export_returns_exported_dirname(self): 952 expected = '/path/to/some_dir' 953 with test.mock.patch.object(estimator, 'export') as mock_export_module: 954 mock_export_module._export_estimator.return_value = expected 955 956 est = estimator.Estimator(model_fn=linear_model_fn) 957 actual = est.export('/path/to') 958 959 self.assertEquals(expected, actual) 960 961 def test_export_savedmodel(self): 962 tmpdir = tempfile.mkdtemp() 963 est, serving_input_fn = _build_estimator_for_export_tests(tmpdir) 964 965 extra_file_name = os.path.join( 966 compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file')) 967 extra_file = gfile.GFile(extra_file_name, mode='w') 968 extra_file.write(EXTRA_FILE_CONTENT) 969 extra_file.close() 970 assets_extra = {'some/sub/directory/my_extra_file': extra_file_name} 971 972 export_dir_base = os.path.join( 973 compat.as_bytes(tmpdir), compat.as_bytes('export')) 974 export_dir = est.export_savedmodel( 975 export_dir_base, serving_input_fn, assets_extra=assets_extra) 976 977 self.assertTrue(gfile.Exists(export_dir_base)) 978 self.assertTrue(gfile.Exists(export_dir)) 979 self.assertTrue( 980 gfile.Exists( 981 os.path.join( 982 compat.as_bytes(export_dir), 983 compat.as_bytes('saved_model.pb')))) 984 self.assertTrue( 985 gfile.Exists( 986 os.path.join( 987 compat.as_bytes(export_dir), compat.as_bytes('variables')))) 988 self.assertTrue( 989 gfile.Exists( 990 os.path.join( 991 compat.as_bytes(export_dir), 992 compat.as_bytes('variables/variables.index')))) 993 self.assertTrue( 994 gfile.Exists( 995 os.path.join( 996 compat.as_bytes(export_dir), 997 compat.as_bytes('variables/variables.data-00000-of-00001')))) 998 999 self.assertTrue( 1000 gfile.Exists( 1001 os.path.join( 1002 compat.as_bytes(export_dir), compat.as_bytes('assets')))) 1003 self.assertTrue( 1004 gfile.Exists( 1005 os.path.join( 1006 compat.as_bytes(export_dir), 1007 compat.as_bytes('assets/my_vocab_file')))) 1008 self.assertEqual( 1009 compat.as_bytes(VOCAB_FILE_CONTENT), 1010 compat.as_bytes( 1011 gfile.GFile( 1012 os.path.join( 1013 compat.as_bytes(export_dir), 1014 compat.as_bytes('assets/my_vocab_file'))).read())) 1015 1016 expected_extra_path = os.path.join( 1017 compat.as_bytes(export_dir), 1018 compat.as_bytes('assets.extra/some/sub/directory/my_extra_file')) 1019 self.assertTrue( 1020 gfile.Exists( 1021 os.path.join( 1022 compat.as_bytes(export_dir), compat.as_bytes('assets.extra')))) 1023 self.assertTrue(gfile.Exists(expected_extra_path)) 1024 self.assertEqual( 1025 compat.as_bytes(EXTRA_FILE_CONTENT), 1026 compat.as_bytes(gfile.GFile(expected_extra_path).read())) 1027 1028 expected_vocab_file = os.path.join( 1029 compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file')) 1030 # Restore, to validate that the export was well-formed. 1031 with ops.Graph().as_default() as graph: 1032 with session_lib.Session(graph=graph) as sess: 1033 loader.load(sess, [tag_constants.SERVING], export_dir) 1034 assets = [ 1035 x.eval() 1036 for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 1037 ] 1038 self.assertItemsEqual([expected_vocab_file], assets) 1039 graph_ops = [x.name for x in graph.get_operations()] 1040 self.assertTrue('input_example_tensor' in graph_ops) 1041 self.assertTrue('ParseExample/ParseExample' in graph_ops) 1042 self.assertTrue('linear/linear/feature/matmul' in graph_ops) 1043 self.assertItemsEqual(['bogus_lookup', 'feature'], [ 1044 compat.as_str_any(x) 1045 for x in graph.get_collection( 1046 constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS) 1047 ]) 1048 1049 # cleanup 1050 gfile.DeleteRecursively(tmpdir) 1051 1052 def test_export_savedmodel_with_resource(self): 1053 tmpdir = tempfile.mkdtemp() 1054 est, serving_input_fn = _build_estimator_for_resource_export_test() 1055 1056 export_dir_base = os.path.join( 1057 compat.as_bytes(tmpdir), compat.as_bytes('export')) 1058 export_dir = est.export_savedmodel(export_dir_base, serving_input_fn) 1059 1060 self.assertTrue(gfile.Exists(export_dir_base)) 1061 self.assertTrue(gfile.Exists(export_dir)) 1062 self.assertTrue( 1063 gfile.Exists( 1064 os.path.join( 1065 compat.as_bytes(export_dir), 1066 compat.as_bytes('saved_model.pb')))) 1067 self.assertTrue( 1068 gfile.Exists( 1069 os.path.join( 1070 compat.as_bytes(export_dir), compat.as_bytes('variables')))) 1071 self.assertTrue( 1072 gfile.Exists( 1073 os.path.join( 1074 compat.as_bytes(export_dir), 1075 compat.as_bytes('variables/variables.index')))) 1076 self.assertTrue( 1077 gfile.Exists( 1078 os.path.join( 1079 compat.as_bytes(export_dir), 1080 compat.as_bytes('variables/variables.data-00000-of-00001')))) 1081 1082 # Restore, to validate that the export was well-formed. 1083 with ops.Graph().as_default() as graph: 1084 with session_lib.Session(graph=graph) as sess: 1085 loader.load(sess, [tag_constants.SERVING], export_dir) 1086 graph_ops = [x.name for x in graph.get_operations()] 1087 self.assertTrue('input_example_tensor' in graph_ops) 1088 self.assertTrue('ParseExample/ParseExample' in graph_ops) 1089 self.assertTrue('LookupTableModel' in graph_ops) 1090 self.assertFalse('LookupTableTrainingState' in graph_ops) 1091 1092 # cleanup 1093 gfile.DeleteRecursively(tmpdir) 1094 1095 def test_export_savedmodel_with_graph_transforms(self): 1096 tmpdir = tempfile.mkdtemp() 1097 est, serving_input_fn = _build_estimator_for_export_tests(tmpdir) 1098 1099 extra_file_name = os.path.join( 1100 compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file')) 1101 extra_file = gfile.GFile(extra_file_name, mode='w') 1102 extra_file.write(EXTRA_FILE_CONTENT) 1103 extra_file.close() 1104 assets_extra = {'some/sub/directory/my_extra_file': extra_file_name} 1105 1106 export_dir_base = os.path.join( 1107 compat.as_bytes(tmpdir), compat.as_bytes('export')) 1108 export_dir = est.export_savedmodel( 1109 export_dir_base, 1110 serving_input_fn, 1111 assets_extra=assets_extra, 1112 graph_rewrite_specs=[ 1113 estimator.GraphRewriteSpec(['tag_1'], []), 1114 estimator.GraphRewriteSpec(['tag_2', 'tag_3'], 1115 ['strip_unused_nodes']) 1116 ]) 1117 1118 self.assertTrue(gfile.Exists(export_dir_base)) 1119 self.assertTrue(gfile.Exists(export_dir)) 1120 self.assertTrue( 1121 gfile.Exists( 1122 os.path.join( 1123 compat.as_bytes(export_dir), 1124 compat.as_bytes('saved_model.pb')))) 1125 self.assertTrue( 1126 gfile.Exists( 1127 os.path.join( 1128 compat.as_bytes(export_dir), compat.as_bytes('variables')))) 1129 self.assertTrue( 1130 gfile.Exists( 1131 os.path.join( 1132 compat.as_bytes(export_dir), 1133 compat.as_bytes('variables/variables.index')))) 1134 self.assertTrue( 1135 gfile.Exists( 1136 os.path.join( 1137 compat.as_bytes(export_dir), 1138 compat.as_bytes('variables/variables.data-00000-of-00001')))) 1139 1140 self.assertTrue( 1141 gfile.Exists( 1142 os.path.join( 1143 compat.as_bytes(export_dir), compat.as_bytes('assets')))) 1144 self.assertTrue( 1145 gfile.Exists( 1146 os.path.join( 1147 compat.as_bytes(export_dir), 1148 compat.as_bytes('assets/my_vocab_file')))) 1149 self.assertEqual( 1150 compat.as_bytes(VOCAB_FILE_CONTENT), 1151 compat.as_bytes( 1152 gfile.GFile( 1153 os.path.join( 1154 compat.as_bytes(export_dir), 1155 compat.as_bytes('assets/my_vocab_file'))).read())) 1156 1157 expected_extra_path = os.path.join( 1158 compat.as_bytes(export_dir), 1159 compat.as_bytes('assets.extra/some/sub/directory/my_extra_file')) 1160 self.assertTrue( 1161 gfile.Exists( 1162 os.path.join( 1163 compat.as_bytes(export_dir), compat.as_bytes('assets.extra')))) 1164 self.assertTrue(gfile.Exists(expected_extra_path)) 1165 self.assertEqual( 1166 compat.as_bytes(EXTRA_FILE_CONTENT), 1167 compat.as_bytes(gfile.GFile(expected_extra_path).read())) 1168 1169 expected_vocab_file = os.path.join( 1170 compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file')) 1171 1172 # Restore, to validate that the export was well-formed. 1173 # tag_1 is untransformed. 1174 tags = ['tag_1'] 1175 with ops.Graph().as_default() as graph: 1176 with session_lib.Session(graph=graph) as sess: 1177 loader.load(sess, tags, export_dir) 1178 assets = [ 1179 x.eval() 1180 for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 1181 ] 1182 self.assertItemsEqual([expected_vocab_file], assets) 1183 graph_ops = [x.name for x in graph.get_operations()] 1184 self.assertIn('input_example_tensor', graph_ops) 1185 self.assertIn('ParseExample/ParseExample', graph_ops) 1186 self.assertIn('linear/linear/feature/matmul', graph_ops) 1187 # Since there were no transforms, both save ops are still present. 1188 self.assertIn('save/SaveV2/tensor_names', graph_ops) 1189 self.assertIn('save_1/SaveV2/tensor_names', graph_ops) 1190 # Since there were no transforms, the hash table lookup is still there. 1191 self.assertIn('hash_table_Lookup/LookupTableFindV2', graph_ops) 1192 1193 # Restore, to validate that the export was well-formed. 1194 # tag_2, tag_3 was subjected to strip_unused_nodes. 1195 tags = ['tag_2', 'tag_3'] 1196 with ops.Graph().as_default() as graph: 1197 with session_lib.Session(graph=graph) as sess: 1198 loader.load(sess, tags, export_dir) 1199 assets = [ 1200 x.eval() 1201 for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 1202 ] 1203 self.assertItemsEqual([expected_vocab_file], assets) 1204 graph_ops = [x.name for x in graph.get_operations()] 1205 self.assertTrue('input_example_tensor' in graph_ops) 1206 self.assertTrue('ParseExample/ParseExample' in graph_ops) 1207 self.assertTrue('linear/linear/feature/matmul' in graph_ops) 1208 # The Saver used to restore the checkpoint into the export Session 1209 # was not added to the SAVERS collection, so strip_unused_nodes removes 1210 # it. The one explicitly created in export_savedmodel is tracked in 1211 # the MetaGraphDef saver_def field, so that one is retained. 1212 # TODO(soergel): Make Savers sane again. I understand this is all a bit 1213 # nuts but for now the test demonstrates what actually happens. 1214 self.assertFalse('save/SaveV2/tensor_names' in graph_ops) 1215 self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops) 1216 # The fake hash table lookup wasn't connected to anything; stripped. 1217 self.assertFalse('hash_table_Lookup' in graph_ops) 1218 1219 # cleanup 1220 gfile.DeleteRecursively(tmpdir) 1221 1222 1223class InferRealValuedColumnsTest(test.TestCase): 1224 1225 def testInvalidArgs(self): 1226 with self.assertRaisesRegexp(ValueError, 'x or input_fn must be provided'): 1227 estimator.infer_real_valued_columns_from_input(None) 1228 1229 with self.assertRaisesRegexp(ValueError, 'cannot be tensors'): 1230 estimator.infer_real_valued_columns_from_input(constant_op.constant(1.0)) 1231 1232 def _assert_single_feature_column(self, expected_shape, expected_dtype, 1233 feature_columns): 1234 self.assertEqual(1, len(feature_columns)) 1235 feature_column = feature_columns[0] 1236 self.assertEqual('', feature_column.name) 1237 self.assertEqual({ 1238 '': 1239 parsing_ops.FixedLenFeature( 1240 shape=expected_shape, dtype=expected_dtype) 1241 }, feature_column.config) 1242 1243 def testInt32Input(self): 1244 feature_columns = estimator.infer_real_valued_columns_from_input( 1245 np.ones(shape=[7, 8], dtype=np.int32)) 1246 self._assert_single_feature_column([8], dtypes.int32, feature_columns) 1247 1248 def testInt32InputFn(self): 1249 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1250 lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.int32), None)) 1251 self._assert_single_feature_column([8], dtypes.int32, feature_columns) 1252 1253 def testInt64Input(self): 1254 feature_columns = estimator.infer_real_valued_columns_from_input( 1255 np.ones(shape=[7, 8], dtype=np.int64)) 1256 self._assert_single_feature_column([8], dtypes.int64, feature_columns) 1257 1258 def testInt64InputFn(self): 1259 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1260 lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.int64), None)) 1261 self._assert_single_feature_column([8], dtypes.int64, feature_columns) 1262 1263 def testFloat32Input(self): 1264 feature_columns = estimator.infer_real_valued_columns_from_input( 1265 np.ones(shape=[7, 8], dtype=np.float32)) 1266 self._assert_single_feature_column([8], dtypes.float32, feature_columns) 1267 1268 def testFloat32InputFn(self): 1269 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1270 lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.float32), None)) 1271 self._assert_single_feature_column([8], dtypes.float32, feature_columns) 1272 1273 def testFloat64Input(self): 1274 feature_columns = estimator.infer_real_valued_columns_from_input( 1275 np.ones(shape=[7, 8], dtype=np.float64)) 1276 self._assert_single_feature_column([8], dtypes.float64, feature_columns) 1277 1278 def testFloat64InputFn(self): 1279 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1280 lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.float64), None)) 1281 self._assert_single_feature_column([8], dtypes.float64, feature_columns) 1282 1283 def testBoolInput(self): 1284 with self.assertRaisesRegexp( 1285 ValueError, 'on integer or non floating types are not supported'): 1286 estimator.infer_real_valued_columns_from_input( 1287 np.array([[False for _ in xrange(8)] for _ in xrange(7)])) 1288 1289 def testBoolInputFn(self): 1290 with self.assertRaisesRegexp( 1291 ValueError, 'on integer or non floating types are not supported'): 1292 # pylint: disable=g-long-lambda 1293 estimator.infer_real_valued_columns_from_input_fn( 1294 lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None) 1295 ) 1296 1297 def testStringInput(self): 1298 with self.assertRaisesRegexp( 1299 ValueError, 'on integer or non floating types are not supported'): 1300 # pylint: disable=g-long-lambda 1301 estimator.infer_real_valued_columns_from_input( 1302 np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)])) 1303 1304 def testStringInputFn(self): 1305 with self.assertRaisesRegexp( 1306 ValueError, 'on integer or non floating types are not supported'): 1307 # pylint: disable=g-long-lambda 1308 estimator.infer_real_valued_columns_from_input_fn( 1309 lambda: ( 1310 constant_op.constant([['%d.0' % i 1311 for i in xrange(8)] 1312 for _ in xrange(7)]), 1313 None)) 1314 1315 def testBostonInputFn(self): 1316 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1317 boston_input_fn) 1318 self._assert_single_feature_column([_BOSTON_INPUT_DIM], dtypes.float64, 1319 feature_columns) 1320 1321 def testIrisInputFn(self): 1322 feature_columns = estimator.infer_real_valued_columns_from_input_fn( 1323 iris_input_fn) 1324 self._assert_single_feature_column([_IRIS_INPUT_DIM], dtypes.float64, 1325 feature_columns) 1326 1327 1328class ReplicaDeviceSetterTest(test.TestCase): 1329 1330 def testVariablesAreOnPs(self): 1331 tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} 1332 with test.mock.patch.dict('os.environ', { 1333 'TF_CONFIG': json.dumps(tf_config) 1334 }): 1335 config = run_config.RunConfig() 1336 1337 with ops.device(estimator._get_replica_device_setter(config)): 1338 v = variables_lib.Variable([1, 2]) 1339 w = variables_lib.Variable([2, 1]) 1340 a = v + w 1341 self.assertDeviceEqual('/job:ps/task:0', v.device) 1342 self.assertDeviceEqual('/job:ps/task:0', v.initializer.device) 1343 self.assertDeviceEqual('/job:ps/task:0', w.device) 1344 self.assertDeviceEqual('/job:ps/task:0', w.initializer.device) 1345 self.assertDeviceEqual('/job:worker', a.device) 1346 1347 def testVariablesAreLocal(self): 1348 with ops.device( 1349 estimator._get_replica_device_setter(run_config.RunConfig())): 1350 v = variables_lib.Variable([1, 2]) 1351 w = variables_lib.Variable([2, 1]) 1352 a = v + w 1353 self.assertDeviceEqual('', v.device) 1354 self.assertDeviceEqual('', v.initializer.device) 1355 self.assertDeviceEqual('', w.device) 1356 self.assertDeviceEqual('', w.initializer.device) 1357 self.assertDeviceEqual('', a.device) 1358 1359 def testMutableHashTableIsOnPs(self): 1360 tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} 1361 with test.mock.patch.dict('os.environ', { 1362 'TF_CONFIG': json.dumps(tf_config) 1363 }): 1364 config = run_config.RunConfig() 1365 1366 with ops.device(estimator._get_replica_device_setter(config)): 1367 default_val = constant_op.constant([-1, -1], dtypes.int64) 1368 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) 1369 input_string = constant_op.constant(['brain', 'salad', 'tank']) 1370 output = table.lookup(input_string) 1371 self.assertDeviceEqual('/job:ps/task:0', table.resource_handle.device) 1372 self.assertDeviceEqual('/job:ps/task:0', output.device) 1373 1374 def testMutableHashTableIsLocal(self): 1375 with ops.device( 1376 estimator._get_replica_device_setter(run_config.RunConfig())): 1377 default_val = constant_op.constant([-1, -1], dtypes.int64) 1378 table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) 1379 input_string = constant_op.constant(['brain', 'salad', 'tank']) 1380 output = table.lookup(input_string) 1381 self.assertDeviceEqual('', table.resource_handle.device) 1382 self.assertDeviceEqual('', output.device) 1383 1384 def testTaskIsSetOnWorkerWhenJobNameIsSet(self): 1385 tf_config = { 1386 'cluster': { 1387 run_config.TaskType.PS: ['fake_ps_0'] 1388 }, 1389 'task': { 1390 'type': run_config.TaskType.WORKER, 1391 'index': 3 1392 } 1393 } 1394 with test.mock.patch.dict('os.environ', { 1395 'TF_CONFIG': json.dumps(tf_config) 1396 }): 1397 config = run_config.RunConfig() 1398 1399 with ops.device(estimator._get_replica_device_setter(config)): 1400 v = variables_lib.Variable([1, 2]) 1401 w = variables_lib.Variable([2, 1]) 1402 a = v + w 1403 self.assertDeviceEqual('/job:ps/task:0', v.device) 1404 self.assertDeviceEqual('/job:ps/task:0', v.initializer.device) 1405 self.assertDeviceEqual('/job:ps/task:0', w.device) 1406 self.assertDeviceEqual('/job:ps/task:0', w.initializer.device) 1407 self.assertDeviceEqual('/job:worker/task:3', a.device) 1408 1409 1410if __name__ == '__main__': 1411 test.main() 1412