1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.importer.""" 16 17import numpy as np 18 19from google.protobuf import text_format 20 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.framework import op_def_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import device 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import function 27from tensorflow.python.framework import importer 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import test_ops # pylint: disable=unused-import 31from tensorflow.python.framework import test_util 32from tensorflow.python.framework import versions 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gradients_impl 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import nn_ops 38from tensorflow.python.ops import random_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import variables 41import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 42from tensorflow.python.platform import test 43 44 45class ImportGraphDefTest(test.TestCase): 46 47 def _MakeGraphDef(self, 48 text, 49 producer=versions.GRAPH_DEF_VERSION, 50 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER): 51 text = "versions: { producer: %d min_consumer: %d };\n%s" % (producer, 52 min_consumer, 53 text) 54 ret = graph_pb2.GraphDef() 55 text_format.Merge(text, ret) 56 return ret 57 58 def testBasic(self): 59 with ops.Graph().as_default(): 60 a, b, c, d = importer.import_graph_def( 61 self._MakeGraphDef(""" 62 node { name: 'A' op: 'IntOutputFloatOutput' } 63 node { name: 'B' op: 'ListOutput' 64 attr { key: 'T' 65 value { list { type: DT_INT32 type: DT_FLOAT } } } } 66 node { name: 'C' op: 'ListInput' 67 attr { key: 'N' value { i: 2 } } 68 attr { key: 'T' value { type: DT_INT32 } } 69 input: 'A:0' input: 'B:0' } 70 node { name: 'D' op: 'ListInput' 71 attr { key: 'N' value { i: 2 } } 72 attr { key: 'T' value { type: DT_FLOAT } } 73 input: 'A:1' input: 'B:1' } 74 """), 75 return_elements=["A", "B", "C", "D"], 76 name="import") 77 78 # Assert that the import process creates distinct tensors. 79 self.assertNotEqual(a.outputs[0].name, a.outputs[1].name) 80 self.assertNotEqual(b.outputs[0].name, b.outputs[1].name) 81 self.assertNotEqual(a.outputs[0].name, b.outputs[0].name) 82 self.assertNotEqual(a.outputs[0].name, b.outputs[1].name) 83 self.assertNotEqual(a.outputs[1].name, b.outputs[0].name) 84 self.assertNotEqual(a.outputs[1].name, b.outputs[1].name) 85 86 # Assert that the ops are connected according to the GraphDef topology. 87 self.assertEqual(c.inputs[0], a.outputs[0]) 88 self.assertEqual(c.inputs[1], b.outputs[0]) 89 self.assertEqual(d.inputs[0], a.outputs[1]) 90 self.assertEqual(d.inputs[1], b.outputs[1]) 91 92 # Check the types of the returned ops and tensors. 93 self.assertEqual(a.type, "IntOutputFloatOutput") 94 self.assertEqual(b.type, "ListOutput") 95 self.assertEqual(c.type, "ListInput") 96 self.assertEqual(d.type, "ListInput") 97 self.assertEqual(a.outputs[0].dtype, dtypes.int32) 98 self.assertEqual(a.outputs[1].dtype, dtypes.float32) 99 self.assertEqual(b.outputs[0].dtype, dtypes.int32) 100 self.assertEqual(b.outputs[1].dtype, dtypes.float32) 101 102 # Check the names of the returned ops. 103 self.assertEqual(a.name, "import/A") 104 self.assertEqual(b.name, "import/B") 105 self.assertEqual(c.name, "import/C") 106 self.assertEqual(d.name, "import/D") 107 108 # Check that the op_def is still available. 109 self.assertNotEqual(None, a.op_def) 110 111 def testMultipleImport(self): 112 graph_def = self._MakeGraphDef(""" 113 node { name: 'A' op: 'IntOutput' } 114 node { name: 'B' op: 'IntInput' input: 'A:0' } 115 """) 116 117 with ops.Graph().as_default(): 118 # Initial import 119 a, b = importer.import_graph_def( 120 graph_def, 121 return_elements=["A", "B"], 122 name="") 123 self.assertEqual(a.name, "A") 124 self.assertEqual(b.name, "B") 125 self.assertEqual(list(b.inputs), [a.outputs[0]]) 126 127 # Repeat the same import 128 a1, b1 = importer.import_graph_def( 129 graph_def, 130 return_elements=["A", "B"], 131 name="") 132 self.assertEqual(a1.name, "A_1") 133 self.assertEqual(b1.name, "B_1") 134 self.assertEqual(list(b1.inputs), [a1.outputs[0]]) 135 136 # Repeat the same import again 137 a2, b2 = importer.import_graph_def( 138 graph_def, 139 return_elements=["A", "B"], 140 name="") 141 self.assertEqual(a2.name, "A_2") 142 self.assertEqual(b2.name, "B_2") 143 self.assertEqual(list(b2.inputs), [a2.outputs[0]]) 144 145 # Import with an already-used name 146 a3, b3 = importer.import_graph_def( 147 graph_def, 148 return_elements=["A", "B"], 149 name="A") 150 self.assertEqual(a3.name, "A_3/A") 151 self.assertEqual(b3.name, "A_3/B") 152 self.assertEqual(list(b3.inputs), [a3.outputs[0]]) 153 154 # Import with an already-used name but with a '/' to indicate an 155 # "absolute" name scope (see the Graph.name_scope docstring). 156 a_a, a_b = importer.import_graph_def( 157 graph_def, 158 return_elements=["A", "B"], 159 name="A/") 160 self.assertEqual(a_a.name, "A/A") 161 self.assertEqual(a_b.name, "A/B") 162 self.assertEqual(list(a_b.inputs), [a_a.outputs[0]]) 163 164 # Repeat the same import. 165 a_a1, a_b1 = importer.import_graph_def( 166 graph_def, 167 return_elements=["A", "B"], 168 name="A/") 169 self.assertEqual(a_a1.name, "A/A_1") 170 self.assertEqual(a_b1.name, "A/B_1") 171 self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]]) 172 173 # Import with existing de-duped node names 174 a1_1, b1_1 = importer.import_graph_def( 175 self._MakeGraphDef(""" 176 node { name: 'A_1' op: 'IntOutput' } 177 node { name: 'B_1' op: 'IntInput' input: 'A_1:0' } 178 """), 179 return_elements=["A_1", "B_1"], 180 name="") 181 self.assertEqual(a1_1.name, "A_1_1") 182 self.assertEqual(b1_1.name, "B_1_1") 183 self.assertEqual(list(b1_1.inputs), [a1_1.outputs[0]]) 184 185 # Create a name scope and then import node with same name 186 with ops.name_scope("foo"): 187 constant_op.constant(1) 188 foo, = importer.import_graph_def( 189 self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"), 190 return_elements=["foo"], 191 name="") 192 self.assertEqual(foo.name, "foo_1") 193 194 # Imported node name can't conflict with intermediate name scope (but can 195 # conflict with outer scope and full name scope) 196 with ops.name_scope("outer"): 197 with ops.name_scope("inner"): 198 c = constant_op.constant(1, name="c") 199 self.assertEqual(c.op.name, "outer/inner/c") 200 201 outer, inner, new_c, outer_inner, outer_inner_c = ( 202 importer.import_graph_def( 203 self._MakeGraphDef( 204 "node { name: 'outer' op: 'IntOutput' }" 205 "node { name: 'inner' op: 'IntOutput' }" 206 "node { name: 'c' op: 'IntOutput' }" 207 "node { name: 'outer/inner' op: 'IntOutput' }" 208 "node { name: 'outer/inner/c' op: 'IntOutput' }"), 209 return_elements=["outer", "inner", "c", "outer/inner", 210 "outer/inner/c"], 211 name="")) 212 self.assertEqual(outer.name, "outer_1") 213 self.assertEqual(inner.name, "inner") 214 self.assertEqual(new_c.name, "c") 215 self.assertEqual(outer_inner.name, "outer/inner_1") 216 self.assertEqual(outer_inner_c.name, "outer/inner/c_1") 217 218 def testEmptyNameScope(self): 219 with ops.Graph().as_default(): 220 # Create name scope but don't create any ops with it 221 with ops.name_scope("foo"): 222 pass 223 224 # Import graph def that uses name scope name 225 op, = importer.import_graph_def( 226 self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"), 227 return_elements=["foo"], 228 name="") 229 230 self.assertEqual(op.name, "foo") 231 232 def testInputMap(self): 233 with ops.Graph().as_default(): 234 feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) 235 feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) 236 237 a, b, c, d = importer.import_graph_def( 238 self._MakeGraphDef(""" 239 node { name: 'A' op: 'TwoIntOutputs' } 240 node { name: 'B' op: 'TwoIntOutputs' } 241 node { name: 'C' op: 'ListInput' 242 attr { key: 'N' value { i: 2 } } 243 attr { key: 'T' value { type: DT_INT32 } } 244 input: 'A:0' input: 'B:0' } 245 node { name: 'D' op: 'ListInput' 246 attr { key: 'N' value { i: 2 } } 247 attr { key: 'T' value { type: DT_INT32 } } 248 input: 'A:1' input: 'B:1' } 249 """), 250 input_map={"A:0": feed_a_0, 251 "B:1": feed_b_1}, 252 return_elements=["A", "B", "C", "D"]) 253 254 self.assertEqual(c.inputs[0], feed_a_0) 255 self.assertEqual(c.inputs[1], b.outputs[0]) 256 self.assertEqual(d.inputs[0], a.outputs[1]) 257 self.assertEqual(d.inputs[1], feed_b_1) 258 259 def testInputMapBytes(self): 260 with ops.Graph().as_default(): 261 feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) 262 feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) 263 264 a, b, c, d = importer.import_graph_def( 265 self._MakeGraphDef(""" 266 node { name: 'A' op: 'TwoIntOutputs' } 267 node { name: 'B' op: 'TwoIntOutputs' } 268 node { name: 'C' op: 'ListInput' 269 attr { key: 'N' value { i: 2 } } 270 attr { key: 'T' value { type: DT_INT32 } } 271 input: 'A:0' input: 'B:0' } 272 node { name: 'D' op: 'ListInput' 273 attr { key: 'N' value { i: 2 } } 274 attr { key: 'T' value { type: DT_INT32 } } 275 input: 'A:1' input: 'B:1' } 276 """), 277 input_map={b"A:0": feed_a_0, 278 b"B:1": feed_b_1}, 279 return_elements=[b"A", b"B", b"C", b"D"]) 280 281 self.assertEqual(c.inputs[0], feed_a_0) 282 self.assertEqual(c.inputs[1], b.outputs[0]) 283 self.assertEqual(d.inputs[0], a.outputs[1]) 284 self.assertEqual(d.inputs[1], feed_b_1) 285 286 def testInputMapUnicode(self): 287 with ops.Graph().as_default(): 288 feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) 289 feed_b_1 = constant_op.constant(1, dtype=dtypes.int32) 290 291 a, b, c, d = importer.import_graph_def( 292 self._MakeGraphDef(""" 293 node { name: 'A' op: 'TwoIntOutputs' } 294 node { name: 'B' op: 'TwoIntOutputs' } 295 node { name: 'C' op: 'ListInput' 296 attr { key: 'N' value { i: 2 } } 297 attr { key: 'T' value { type: DT_INT32 } } 298 input: 'A:0' input: 'B:0' } 299 node { name: 'D' op: 'ListInput' 300 attr { key: 'N' value { i: 2 } } 301 attr { key: 'T' value { type: DT_INT32 } } 302 input: 'A:1' input: 'B:1' } 303 """), 304 input_map={u"A:0": feed_a_0, 305 u"B:1": feed_b_1}, 306 return_elements=[u"A", u"B", u"C", u"D"]) 307 308 self.assertEqual(c.inputs[0], feed_a_0) 309 self.assertEqual(c.inputs[1], b.outputs[0]) 310 self.assertEqual(d.inputs[0], a.outputs[1]) 311 self.assertEqual(d.inputs[1], feed_b_1) 312 313 def testImplicitZerothOutput(self): 314 with ops.Graph().as_default(): 315 a, b = importer.import_graph_def( 316 self._MakeGraphDef(""" 317 node { name: 'A' op: 'TwoIntOutputs' } 318 node { name: 'B' op: 'IntInput' input: 'A' } 319 """), 320 return_elements=["A", "B"]) 321 322 self.assertEqual(b.inputs[0], a.outputs[0]) 323 324 def testInputMapImplicitZerothOutput(self): 325 with ops.Graph().as_default(): 326 feed_a_0 = constant_op.constant(0, dtype=dtypes.int32) 327 b, = importer.import_graph_def( 328 self._MakeGraphDef(""" 329 node { name: 'A' op: 'TwoIntOutputs' } 330 node { name: 'B' op: 'IntInput' input: 'A:0' } 331 """), 332 input_map={"A": feed_a_0}, 333 return_elements=["B"]) 334 335 self.assertEqual(b.inputs[0], feed_a_0) 336 337 def testWithControlDependency(self): 338 with ops.Graph().as_default(): 339 a, b = importer.import_graph_def( 340 self._MakeGraphDef(""" 341 node { name: 'A' op: 'None' } 342 node { name: 'B' op: 'None' input: '^A' } 343 """), 344 return_elements=["A", "B"]) 345 346 self.assertEqual(b.control_inputs, [a]) 347 348 def testWithRefs(self): 349 with ops.Graph().as_default(): 350 a, b, c, d = importer.import_graph_def( 351 self._MakeGraphDef(""" 352 node { name: 'A' op: 'RefOutput' } 353 node { name: 'B' op: 'IntOutput' } 354 node { name: 'C' op: 'TwoIntInputs' input: 'A:0' input: 'B:0' } 355 node { name: 'D' op: 'RefInputIntInput' input: 'A:0' input: 'B:0' } 356 """), 357 return_elements=["A", "B", "C", "D"]) 358 359 self.assertEqual(c.inputs[0], a.outputs[0]) 360 self.assertEqual(c.inputs[1], b.outputs[0]) 361 self.assertEqual(d.inputs[0], a.outputs[0]) 362 self.assertEqual(d.inputs[1], b.outputs[0]) 363 364 self.assertEqual(a.outputs[0].dtype, dtypes.int32_ref) 365 self.assertEqual(c._input_types, [dtypes.int32, dtypes.int32]) 366 self.assertEqual(c.outputs, []) 367 self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32]) 368 self.assertEqual(d.outputs, []) 369 370 def testResources(self): 371 # Produce GraphDef containing a ops producing and consuming resources. 372 graph = ops.Graph() 373 with graph.as_default(): 374 var = resource_variable_ops.ResourceVariable(1.0) 375 var_assign = var.assign(2.0) 376 # Use an op that requires handle shape to be set. 377 var_shape = resource_variable_ops.variable_shape(var.handle) 378 init = variables.global_variables_initializer() 379 graph_def = graph.as_graph_def() 380 381 # Import the GraphDef. 382 with ops.Graph().as_default(): 383 # pylint: disable=unused-variable 384 imported_var, imported_assign, imported_shape, imported_init = ( 385 importer.import_graph_def( 386 graph_def, 387 return_elements=[var.name, var_assign.name, var_shape.name, 388 init.name])) 389 390 # Make sure the handle shape is set on the imported variable. 391 new_var_shape = resource_variable_ops.variable_shape(imported_var) 392 # pylint: enable=unused-variable 393 394 # Run the imported graph. 395 # TODO(b/76173421): make this work (currently DCHECKS) 396 # with self.cached_session() as sess: 397 # self.evaluate(imported_init) 398 # self.assertEqual(self.evaluate(imported_var), 1.0) 399 # self.assertEqual(self.evaluate(imported_assign), 2.0) 400 # self.assertEqual(list(self.evaluate(imported_shape)), []) 401 # self.assertEqual(list(self.evaluate(new_var_shape)), []) 402 403 def testWhileLoop(self): 404 # Produce GraphDef containing while loop. 405 graph = ops.Graph() 406 with graph.as_default(): 407 r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) 408 # Add an op that consumes the while loop output. 409 math_ops.add(r, 1) 410 graph_def = graph.as_graph_def() 411 412 # Import the GraphDef and make sure it runs. 413 with ops.Graph().as_default(): 414 imported_r, = importer.import_graph_def(graph_def, 415 return_elements=[r.name]) 416 self.assertEqual(imported_r.name, "import/" + r.name) 417 with self.cached_session() as sess: 418 self.assertEqual(self.evaluate(imported_r), 10) 419 420 def testImportWhileLoopInCond(self): 421 # Produce GraphDef containing while loop. 422 graph = ops.Graph() 423 with graph.as_default(): 424 r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) 425 graph_def = graph.as_graph_def() 426 427 # Import the GraphDef inside a cond and make sure it runs. 428 with ops.Graph().as_default(): 429 430 def ImportFn(): 431 return importer.import_graph_def(graph_def, return_elements=[r.name])[0] 432 433 pred = array_ops.placeholder(dtypes.bool) 434 out = control_flow_ops.cond(pred, ImportFn, 435 lambda: constant_op.constant(1)) 436 with self.cached_session() as sess: 437 self.assertEqual(sess.run(out, {pred: True}), 10) 438 self.assertEqual(sess.run(out, {pred: False}), 1) 439 440 def testImportWhileLoopInWhileLoop(self): 441 self.skipTest("b/111757448") 442 # Produce GraphDef containing while loop. 443 graph = ops.Graph() 444 with graph.as_default(): 445 r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) 446 graph_def = graph.as_graph_def() 447 448 # Import the GraphDef inside another loop and make sure it runs. 449 with ops.Graph().as_default(): 450 451 def ImportFn(_): 452 return importer.import_graph_def(graph_def, return_elements=[r.name])[0] 453 454 out = control_flow_ops.while_loop( 455 lambda i: i < 2, ImportFn, [0], 456 shape_invariants=[tensor_shape.TensorShape(None)]) 457 with self.cached_session() as sess: 458 self.assertEqual(self.evaluate(out), 10) 459 460 def testTypeMismatchInGraphDef(self): 461 # TODO(skyewm): improve error message 462 error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 " 463 "incompatible with expected float.") 464 with ops.Graph().as_default(): 465 with self.assertRaisesRegex(ValueError, error_msg): 466 importer.import_graph_def( 467 self._MakeGraphDef(""" 468 node { name: 'A' op: 'IntOutput' } 469 node { name: 'B' op: 'FloatInput' input: 'A:0' } 470 """)) 471 472 def testShapeAllowlistViolation(self): 473 # L2 loss produces a scalar shape, but the graph 474 # has the wrong shape, so raise an error. 475 with ops.Graph().as_default(): 476 with self.assertRaises(ValueError) as e: 477 _ = importer.import_graph_def( 478 self._MakeGraphDef(""" 479 node { name: 'A' op: 'FloatOutput' } 480 node { name: 'B' op: 'L2Loss' 481 input: 'A:0' 482 attr { key: 'T' value { type: DT_FLOAT } } 483 attr { key: '_output_shapes' 484 value { list { shape { dim { size: 43 } } } } } } 485 """), 486 return_elements=["B"], 487 name="import") 488 self.assertTrue( 489 "Shapes () and (43,) are not compatible" in str(e.exception)) 490 491 def testInvalidSignatureTooManyInputsInGraphDef(self): 492 with ops.Graph().as_default(): 493 # TODO(skyewm): improve error message 494 with self.assertRaisesRegex( 495 ValueError, 496 "NodeDef expected inputs '' do not match 1 inputs specified"): 497 importer.import_graph_def( 498 self._MakeGraphDef(""" 499 node { name: 'A' op: 'IntOutput' } 500 node { name: 'B' op: 'None' input: 'A:0' } 501 """)) 502 503 def testInvalidSignatureNotEnoughInputsInGraphDef(self): 504 with ops.Graph().as_default(): 505 # TODO(skyewm): improve error message 506 with self.assertRaisesRegex( 507 ValueError, 508 "NodeDef expected inputs 'int32, float' do not match 1 inputs " 509 "specified"): 510 importer.import_graph_def( 511 self._MakeGraphDef(""" 512 node { name: 'A' op: 'IntOutput' } 513 node { name: 'B' op: 'IntInputFloatInput' input: 'A:0' } 514 """)) 515 516 def testMissingInputOpInGraphDef(self): 517 with ops.Graph().as_default(): 518 with self.assertRaisesRegex(ValueError, 519 "Node 'B': Unknown input node 'A:0'"): 520 importer.import_graph_def( 521 self._MakeGraphDef(""" 522 node { name: 'B' op: 'FloatInput' input: 'A:0' } 523 """)) 524 525 def testMissingInputOpInGraphDefButAppearsInInputMap(self): 526 with ops.Graph().as_default(): 527 feed_a_0 = constant_op.constant(5.0) 528 b, = importer.import_graph_def( 529 self._MakeGraphDef(""" 530 node { name: 'B' op: 'FloatInput' input: 'A:0' } 531 """), 532 input_map={"A:0": feed_a_0}, 533 return_elements=["B"]) 534 self.assertEqual(b.inputs[0], feed_a_0) 535 536 def testMissingInputTensorInGraphDef(self): 537 with ops.Graph().as_default(): 538 with self.assertRaisesRegex( 539 ValueError, 540 "Node 'B': Connecting to invalid output 1 of source node A " 541 "which has 1 outputs"): 542 importer.import_graph_def( 543 self._MakeGraphDef(""" 544 node { name: 'A' op: 'FloatOutput' } 545 node { name: 'B' op: 'FloatInput' input: 'A:1' } 546 """)) 547 548 def testMissingControlInputInGraphDef(self): 549 with ops.Graph().as_default(): 550 with self.assertRaisesRegex(ValueError, 551 r"Node 'B': Unknown input node '\^A'"): 552 importer.import_graph_def( 553 self._MakeGraphDef(""" 554 node { name: 'B' op: 'None' input: '^A' } 555 """)) 556 557 def testInvalidTensorNameOutputIndexInGraphDef(self): 558 with ops.Graph().as_default(): 559 with self.assertRaisesRegex(ValueError, 560 "Node 'B': Unknown input node 'A:B'"): 561 importer.import_graph_def( 562 self._MakeGraphDef(""" 563 node { name: 'B' op: 'None' input: 'A:B' } 564 """)) 565 566 def testInvalidTensorNameInGraphDef(self): 567 with ops.Graph().as_default(): 568 with self.assertRaisesRegex(ValueError, 569 "Node 'B': Unknown input node 'A:B:0'"): 570 importer.import_graph_def( 571 self._MakeGraphDef(""" 572 node { name: 'B' op: 'None' input: 'A:B:0' } 573 """)) 574 575 def testMissingReturnOperation(self): 576 with ops.Graph().as_default(): 577 with self.assertRaisesRegex( 578 ValueError, "Requested return node 'B' not found in graph def"): 579 importer.import_graph_def( 580 self._MakeGraphDef(""" 581 node { name: 'A' op: 'None' } 582 """), 583 return_elements=["B"]) 584 585 def testMissingReturnTensor(self): 586 with ops.Graph().as_default(): 587 with self.assertRaisesRegex( 588 ValueError, 589 r"Invalid return output 1 of node 'A', which has 1 output\(s\)"): 590 importer.import_graph_def( 591 self._MakeGraphDef(""" 592 node { name: 'A' op: 'IntOutput' } 593 """), 594 return_elements=["A:1"]) 595 596 with self.assertRaisesRegex( 597 ValueError, "Requested return tensor 'B:0' not found in graph def"): 598 importer.import_graph_def( 599 self._MakeGraphDef(""" 600 node { name: 'A' op: 'IntOutput' } 601 """), 602 return_elements=["B:0"]) 603 604 with self.assertRaisesRegex(ValueError, 605 "Cannot convert 'A:B:0' to a tensor name."): 606 importer.import_graph_def( 607 self._MakeGraphDef(""" 608 node { name: 'A' op: 'IntOutput' } 609 """), 610 return_elements=["A:B:0"]) 611 612 def testMissingInputMap(self): 613 with ops.Graph().as_default(): 614 with self.assertRaisesRegex( 615 ValueError, 616 r"Attempted to map inputs that were not found in graph_def: \[B:0\]"): 617 importer.import_graph_def( 618 self._MakeGraphDef(""" 619 node { name: 'A' op: 'None' } 620 """), 621 input_map={"B:0": constant_op.constant(5.0)}) 622 623 def testInputMapUnusedAsInput(self): 624 with ops.Graph().as_default(): 625 # Mapping an unused node output should succeed. 626 importer.import_graph_def( 627 self._MakeGraphDef(""" 628 node { name: 'A' op: 'IntOutput' } 629 """), 630 input_map={"A:0": constant_op.constant(5.0)}) 631 632 # Mapping a non-existent output of an existing node should fail. 633 with self.assertRaisesRegex( 634 ValueError, 635 r"Attempted to map inputs that were not found in graph_def: \[A:2\]"): 636 importer.import_graph_def( 637 self._MakeGraphDef(""" 638 node { name: 'A' op: 'IntOutput' } 639 """), 640 input_map={"A:2": constant_op.constant(5.0)}) 641 642 def testInputMapTypeMismatch(self): 643 with ops.Graph().as_default(): 644 with self.assertRaisesRegex( 645 ValueError, "Input 0 of node import/B was passed float from Const:0 " 646 "incompatible with expected int32."): 647 importer.import_graph_def( 648 self._MakeGraphDef(""" 649 node { name: 'A' op: 'IntOutput' } 650 node { name: 'B' op: 'IntInput' input: 'A:0' } 651 """), 652 input_map={"A:0": constant_op.constant(5.0)}) 653 654 def testNoReturns(self): 655 with ops.Graph().as_default() as g: 656 ret = importer.import_graph_def( 657 self._MakeGraphDef(""" 658 node { name: 'A' op: 'None' } 659 """)) 660 self.assertEqual(ret, None) 661 662 a = g.get_operation_by_name("import/A") 663 self.assertEqual(a.type, "None") 664 665 def testOverrideNamePrefix(self): 666 with ops.Graph().as_default(): 667 a, = importer.import_graph_def( 668 self._MakeGraphDef(""" 669 node { name: 'A' op: 'None' } 670 """), 671 return_elements=["A"], 672 name="imported_graph") 673 self.assertEqual(a.name, "imported_graph/A") 674 675 def testDefaultNamePrefix(self): 676 with ops.Graph().as_default(): 677 a, = importer.import_graph_def( 678 self._MakeGraphDef(""" 679 node { name: 'A' op: 'None' } 680 """), 681 return_elements=["A"], 682 name=None) 683 self.assertEqual(a.name, "import/A") 684 685 def testNamePrefixColocationAttrs(self): 686 original_graph_def = self._MakeGraphDef(""" 687 node { name: 'A' op: 'None' } 688 node { name: 'B' op: 'None' attr { 689 key: '_class' 690 value { list { s: 'loc:@A' } } 691 } }""") 692 693 with ops.Graph().as_default(): 694 b, = importer.import_graph_def( 695 original_graph_def, return_elements=["B"], name="imported_graph") 696 self.assertTrue("_class" in b.node_def.attr) 697 self.assertProtoEquals( 698 "list { s: 'loc:@imported_graph/A' }", 699 b.node_def.attr["_class"]) 700 701 def testColocationAndDevice(self): 702 # A and B are colocated, device set on A. 703 original_graph_def = self._MakeGraphDef(""" 704 node { name: 'A' op: 'None' device: '/device:CPU:0' attr { 705 key: '_class' 706 value { list { s: 'loc:@A' } } 707 } } 708 node { name: 'B' op: 'None' attr { 709 key: '_class' 710 value { list { s: 'loc:@A' } } 711 } }""") 712 713 with ops.Graph().as_default(): 714 a, b = importer.import_graph_def(original_graph_def, 715 return_elements=["A", "B"], 716 name="") 717 self.assertEqual(a.device, "/device:CPU:0") 718 self.assertEqual(b.device, "/device:CPU:0") 719 self.assertEqual(a.colocation_groups(), [b"loc:@A"]) 720 self.assertEqual(b.colocation_groups(), [b"loc:@A"]) 721 722 # A and B are colocated, device set on B. 723 original_graph_def = self._MakeGraphDef(""" 724 node { name: 'A' op: 'None' attr { 725 key: '_class' 726 value { list { s: 'loc:@A' } } 727 } } 728 node { name: 'B' op: 'None' device: '/device:CPU:0' attr { 729 key: '_class' 730 value { list { s: 'loc:@A' } } 731 } }""") 732 733 with ops.Graph().as_default(): 734 a, b = importer.import_graph_def(original_graph_def, 735 return_elements=["A", "B"], 736 name="") 737 # TODO(skyewm): this behavior seems inconsistent with the above. Why is 738 # B's device ignored? 739 self.assertEqual(a.device, "") 740 self.assertEqual(b.device, "") 741 self.assertEqual(a.colocation_groups(), [b"loc:@A"]) 742 self.assertEqual(b.colocation_groups(), [b"loc:@A"]) 743 744 def testColocationWithDeviceFn(self): 745 original_graph_def = self._MakeGraphDef(""" 746 node { name: 'A' op: 'None' attr { 747 key: '_class' 748 value { list { s: 'loc:@A' } } 749 } } 750 node { name: 'B' op: 'None' attr { 751 key: '_class' 752 value { list { s: 'loc:@A' } } 753 } }""") 754 755 # A device function that places "A" on one device and "B" on 756 # another device. Because B is colocated with A, we test that B's 757 # device function is overridden by A. 758 def CustomDeviceFn(op): 759 if "A" in op.name: 760 return "/device:A:0" 761 else: 762 return "/device:B:0" 763 764 with ops.Graph().as_default(): 765 with ops.device(CustomDeviceFn): 766 a, b = importer.import_graph_def(original_graph_def, 767 return_elements=["A", "B"], 768 name="imported_graph") 769 self.assertEqual(a.device, "/device:A:0") 770 self.assertEqual(b.device, "/device:A:0") 771 self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) 772 self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) 773 774 # Test a scenario where 'A' doesn't get a device; 'A' should not have a 775 # device, but during runtime will get colocated with 'B' because of the 776 # colocation attribute. B's device function is still overridden by A. 777 def BDeviceFn(op): 778 if "B" in op.name: 779 return "/device:B:0" 780 return "" 781 782 with ops.Graph().as_default(): 783 with ops.device(BDeviceFn): 784 a, b = importer.import_graph_def(original_graph_def, 785 return_elements=["A", "B"], 786 name="imported_graph") 787 self.assertEqual(a.device, "") 788 self.assertEqual(b.device, "") 789 self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) 790 self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) 791 792 # Only A gets a device, so B inherits it implicitly. 793 def ADeviceFn(op): 794 if "A" in op.name: 795 return "/device:A:0" 796 return "" 797 798 with ops.Graph().as_default(): 799 with ops.device(ADeviceFn): 800 a, b = importer.import_graph_def(original_graph_def, 801 return_elements=["A", "B"], 802 name="imported_graph") 803 self.assertEqual(a.device, "/device:A:0") 804 self.assertEqual(b.device, "/device:A:0") 805 self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) 806 self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"]) 807 808 def testMultipleColocationWithDeviceFn(self): 809 original_graph_def = self._MakeGraphDef(""" 810 node { name: 'A' op: 'None'} 811 node { name: 'B' op: 'None'} 812 node { name: 'C' op: 'None' attr { 813 key: '_class' 814 value { list { s: 'loc:@A' s: 'loc:@B' } } 815 } }""") 816 817 # A device function that places "B" on a device, and "A" is empty. 818 # 819 # B and C should contain "/device:B". A will not right now. But 820 # because of the colocation property, at runtime it would be 821 # placed with B and C. 822 def CustomDeviceFn(op): 823 if "B" in op.name: 824 return "/device:B:0" 825 return "" 826 827 with ops.Graph().as_default(): 828 with ops.device(CustomDeviceFn): 829 a, b, c = importer.import_graph_def(original_graph_def, 830 return_elements=["A", "B", "C"], 831 name="imported_graph") 832 self.assertEqual(a.device, "") 833 self.assertEqual(b.device, "/device:B:0") 834 self.assertEqual(c.device, "/device:B:0") 835 self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"]) 836 self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/B"]) 837 self.assertEqual(c.colocation_groups(), 838 [b"loc:@imported_graph/A", b"loc:@imported_graph/B"]) 839 840 def testNamePrefixColocationAttrsMultipleImport(self): 841 original_graph_def = self._MakeGraphDef(""" 842 node { name: 'A' op: 'None' } 843 node { name: 'B' op: 'None' attr { 844 key: '_class' 845 value { list { s: 'loc:@A' } } 846 } }""") 847 848 with ops.Graph().as_default(): 849 a, b = importer.import_graph_def( 850 original_graph_def, return_elements=["A", "B"], name="") 851 a_1, b_1 = importer.import_graph_def( 852 original_graph_def, return_elements=["A", "B"], name="") 853 854 self.assertEqual(a.name, "A") 855 self.assertEqual(b.name, "B") 856 self.assertEqual(b.colocation_groups(), [b"loc:@A"]) 857 858 self.assertEqual(a_1.name, "A_1") 859 self.assertEqual(b_1.name, "B_1") 860 self.assertEqual(b_1.colocation_groups(), [b"loc:@A_1"]) 861 862 def testNamePrefixColocationAttrsNotFound(self): 863 original_graph_def = self._MakeGraphDef(""" 864 node { name: 'B' op: 'None' attr { 865 key: '_class' 866 value { list { s: 'loc:@A' } } 867 } }""") 868 869 with ops.Graph().as_default(): 870 with self.assertRaisesRegex( 871 ValueError, "Node 'B' expects to be colocated with unknown node 'A'"): 872 importer.import_graph_def( 873 original_graph_def, return_elements=["B"], name="imported_graph") 874 875 def testEmptyGraph(self): 876 with ops.Graph().as_default() as g: 877 init_version = g.version 878 importer.import_graph_def(self._MakeGraphDef("")) 879 self.assertEqual(init_version, g.version) 880 881 def testInvalidInputForGraphDef(self): 882 with ops.Graph().as_default(): 883 with self.assertRaisesRegex( 884 TypeError, r"Argument `graph_def` must be a GraphDef proto."): 885 importer.import_graph_def("") 886 887 def testInvalidInputForInputMap(self): 888 with ops.Graph().as_default(): 889 with self.assertRaisesRegex( 890 TypeError, 891 r"Argument `input_map` must be a dictionary. Obtained list"): 892 importer.import_graph_def( 893 self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)]) 894 graph_def = self._MakeGraphDef(""" 895 node { name: 'a' op: 'Placeholder' 896 attr { key: 'dtype' value { type: DT_FLOAT } }} 897 node { name: 'id' op: 'Identity' input: 'a:0' 898 attr { key: 'T' value { type: DT_FLOAT } }}""") 899 with ops.Graph().as_default(): 900 with self.assertRaises(ValueError) as e: 901 importer.import_graph_def( 902 graph_def, 903 input_map={"a:0": variables.Variable(5.0)}, 904 name="") 905 self.assertStartsWith(str(e.exception), 906 "tf.import_graph_def() requires a non-empty `name` " 907 "if `input_map` contains non-Tensor values.") 908 with ops.Graph().as_default(): 909 t, = importer.import_graph_def( 910 graph_def, 911 input_map={"a:0": constant_op.constant(5.0)}, 912 name="", 913 return_elements=["id:0"]) 914 with self.cached_session(): 915 self.assertEqual(5.0, self.evaluate(t)) 916 917 def testInvalidInputForReturnOperations(self): 918 with ops.Graph().as_default(): 919 with self.assertRaisesRegex( 920 TypeError, "Argument `return_elements` must be a list of strings."): 921 importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7]) 922 923 with self.assertRaisesRegex(ValueError, 924 "Cannot convert 'a:b:c' to a tensor name."): 925 importer.import_graph_def( 926 self._MakeGraphDef(""), return_elements=["a:b:c"]) 927 928 def testDuplicateOperationNames(self): 929 with self.assertRaisesRegex(ValueError, "Node 'A' is not unique"): 930 importer.import_graph_def( 931 self._MakeGraphDef(""" 932 node { name: 'A' op: 'IntOutput' } 933 node { name: 'B' op: 'IntOutput' } 934 node { name: 'A' op: 'IntOutput' } 935 """)) 936 937 @test_util.run_v1_only("v1 Tensor doesn't have attribute 'numpy'") 938 def testWithExtensionAndAttr(self): 939 with ops.Graph().as_default() as g: 940 c = constant_op.constant(5.0, dtype=dtypes.float32, name="c") 941 array_ops.stack([c, c], name="pack") 942 gdef = g.as_graph_def() 943 944 with self.cached_session(): 945 pack, = importer.import_graph_def(gdef, return_elements=["pack"]) 946 self.assertAllEqual(pack.outputs[0], [5.0, 5.0]) 947 948 def testWithDevice(self): 949 with ops.Graph().as_default() as g: 950 # No device. 951 a = constant_op.constant(3.0, name="a") 952 953 with ops.device("/cpu:0"): 954 b = constant_op.constant(4.0, name="b") 955 with ops.device("/job:worker"): 956 c = constant_op.constant(5.0, name="c") 957 958 gdef = g.as_graph_def() 959 960 with ops.Graph().as_default(): 961 a2, b2, c2 = importer.import_graph_def( 962 gdef, return_elements=["a", "b", "c"]) 963 self.assertEqual(a.device, a2.device) 964 self.assertEqual(b.device, b2.device) 965 self.assertEqual(c.device, c2.device) 966 967 with ops.Graph().as_default(): 968 with ops.device(device.merge_device("/task:0")): 969 a3, b3, c3 = importer.import_graph_def( 970 gdef, return_elements=["a", "b", "c"]) 971 self.assertEqual("/task:0", a3.device) 972 self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized. 973 self.assertEqual(c.device + "/task:0", c3.device) 974 975 with ops.Graph().as_default(): 976 with ops.device(device.merge_device("/job:ps")): 977 a4, b4, c4 = importer.import_graph_def( 978 gdef, return_elements=["a", "b", "c"]) 979 self.assertEqual("/job:ps", a4.device) 980 self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized. 981 self.assertEqual(c.device, c4.device) # worker overrides ps. 982 983 with ops.Graph().as_default(): 984 with ops.device(device.merge_device("/device:GPU:0")): 985 a5, b5, c5 = importer.import_graph_def( 986 gdef, return_elements=["a", "b", "c"]) 987 self.assertEqual("/device:GPU:0", a5.device) 988 self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu. 989 self.assertEqual(c.device + "/device:GPU:0", c5.device) 990 991 def testWithDeviceFunctionDependingOnInputs(self): 992 with ops.Graph().as_default() as g: 993 with ops.device("/job:ps"): 994 v1 = constant_op.constant(1.0) 995 v2 = constant_op.constant(1.0) 996 _ = v1 + v2 997 _ = v1 - v2 998 _ = array_ops.identity(v1) 999 gdef = g.as_graph_def() 1000 1001 # We'll use the following device function to observe ops with two inputs. 1002 ops_with_two_inputs = [] 1003 1004 def InputCounter(op): 1005 if len(op.inputs) == 2: 1006 ops_with_two_inputs.append(op) 1007 return "" 1008 1009 with ops.Graph().as_default() as g: 1010 with ops.device(InputCounter): 1011 importer.import_graph_def(gdef) 1012 1013 # We expect to see the add and subtract, but not identity. 1014 self.assertEqual(2, len(ops_with_two_inputs)) 1015 1016 def testGradient(self): 1017 with ops.Graph().as_default() as g: 1018 inputs = array_ops.placeholder( 1019 dtypes.float32, shape=[None, 100], name="input") 1020 weights = array_ops.placeholder( 1021 dtypes.float32, shape=[100, 10], name="weights") 1022 biases = array_ops.placeholder(dtypes.float32, shape=[10], name="biases") 1023 activations = nn_ops.relu( 1024 math_ops.matmul(inputs, weights) + biases, name="activations") 1025 loss = math_ops.reduce_mean(activations, name="loss") 1026 gdef = g.as_graph_def() 1027 1028 with ops.Graph().as_default() as g: 1029 input_placeholder = array_ops.placeholder(dtypes.float32, shape=[32, 100]) 1030 weights_var = variables.Variable( 1031 random_ops.truncated_normal([100, 10]), name="weights") 1032 biases_var = variables.Variable(array_ops.zeros([10]), name="biases") 1033 activations, loss = importer.import_graph_def( 1034 gdef, 1035 input_map={ 1036 "input:0": input_placeholder, 1037 "weights:0": weights_var, 1038 "biases:0": biases_var 1039 }, 1040 return_elements=["activations:0", "loss:0"]) 1041 self.assertEqual([32, 10], activations.get_shape()) 1042 self.assertEqual([], loss.get_shape()) 1043 weights_grad, biases_grad = gradients_impl.gradients( 1044 loss, [weights_var, biases_var]) 1045 self.assertEqual([100, 10], weights_grad.get_shape()) 1046 self.assertEqual([10], biases_grad.get_shape()) 1047 1048 def testLargeGraph(self): 1049 with self.cached_session(): 1050 # The default message byte limit is 64M. Ours is 2G with a warning at 512. 1051 # Adding a 130M entries float32 tensor should exceed the warning, but not 1052 # the hard limit. 1053 input_shape = [130, 1000, 1000] 1054 tensor_input = np.ones(input_shape, dtype=np.float32) 1055 t = constant_op.constant(tensor_input, shape=input_shape) 1056 g = array_ops.identity(t) 1057 self.evaluate(g) 1058 1059 def testVersion(self): 1060 v0 = versions.GRAPH_DEF_VERSION_MIN_CONSUMER 1061 v2 = versions.GRAPH_DEF_VERSION 1062 v1 = (v0 + v2) // 2 1063 for producer in v0, v1, v2: 1064 for min_consumer in v0, v1, v2: 1065 with ops.Graph().as_default(): 1066 a, = importer.import_graph_def( 1067 self._MakeGraphDef( 1068 "node { name: 'A' op: 'TwoIntOutputs' }", 1069 producer=producer, 1070 min_consumer=min_consumer), 1071 return_elements=["A"]) 1072 self.assertEqual(a.graph.graph_def_versions.producer, producer) 1073 self.assertEqual(a.graph.graph_def_versions.min_consumer, 1074 min_consumer) 1075 1076 def testVersionLow(self): 1077 with ops.Graph().as_default(): 1078 with self.assertRaisesRegex( 1079 Exception, 1080 r"GraphDef producer version -1 below min producer %d supported " 1081 r"by TensorFlow \S+\. Please regenerate your graph.$" % 1082 versions.GRAPH_DEF_VERSION_MIN_PRODUCER): 1083 importer.import_graph_def(self._MakeGraphDef("", producer=-1)) 1084 1085 def testVersionHigh(self): 1086 with ops.Graph().as_default(): 1087 with self.assertRaisesRegex( 1088 ValueError, 1089 r"GraphDef min consumer version %d above current version %d " 1090 r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" % 1091 (1 << 30, versions.GRAPH_DEF_VERSION)): 1092 importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30)) 1093 1094 def testVersionAppliesToOpConstruction(self): 1095 """These tests rely on shape fns in test_ops.cc.""" 1096 with ops.Graph().as_default(): 1097 importer.import_graph_def( 1098 self._MakeGraphDef( 1099 "node { name: 'A' op: 'RequiresOlderGraphVersion' }", 1100 producer=versions.GRAPH_DEF_VERSION - 1), 1101 return_elements=["A"]) 1102 1103 with ops.Graph().as_default(): 1104 with self.assertRaisesWithPredicateMatch(ValueError, 1105 "Wrong graph version.*"): 1106 importer.import_graph_def( 1107 self._MakeGraphDef( 1108 "node { name: 'A' op: 'RequiresOlderGraphVersion' }", 1109 producer=versions.GRAPH_DEF_VERSION), 1110 return_elements=["A"]) 1111 1112 def testDefaultAttrsAdded(self): 1113 with ops.Graph().as_default(): 1114 a = importer.import_graph_def( 1115 self._MakeGraphDef(""" 1116 node { name: 'A' op: 'OpWithDefaultAttr' } 1117 """), 1118 return_elements=["A"]) 1119 self.assertEqual(123.0, a[0].get_attr("default_float")) 1120 1121 def testDefaultAttrsRemoved(self): 1122 producer_op_list = op_def_pb2.OpList() 1123 text_format.Merge(""" 1124 op { 1125 name: 'OpWithFutureDefaultAttr' 1126 attr { name: 'default_int' type: 'int' default_value { i: 456 } } 1127 } 1128 """, producer_op_list) 1129 # Attr only in producer_op_list with default value gets removed. 1130 with ops.Graph().as_default(): 1131 a = importer.import_graph_def( 1132 self._MakeGraphDef(""" 1133 node { name: 'A' op: 'OpWithFutureDefaultAttr' 1134 attr { key: 'default_int' value { i: 456 } } } 1135 """), 1136 return_elements=["A"], 1137 producer_op_list=producer_op_list) 1138 with self.assertRaisesRegex( 1139 ValueError, "Operation 'import/A' has no attr named 'default_int'."): 1140 a[0].get_attr("default_int") 1141 1142 def testFunctions(self): 1143 dtype = dtypes.float32 1144 1145 @function.Defun(dtype, dtype, dtype, dtype) 1146 def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument 1147 # Return the inputs for simplicity of testing. The correct return value 1148 # would be (dout1 + dout2, dout1 - dout2) 1149 return x, y 1150 1151 @function.Defun(dtype, dtype, grad_func=Grad) 1152 def FuncWithGrad(x, y): 1153 return x + y, x - y 1154 1155 @function.Defun(dtypes.int32) 1156 def ExternalTensorFunc(x): 1157 # c must be defined in the containing graph 1158 return x + c 1159 1160 @function.Defun(dtypes.int32, dtypes.int32) 1161 def OuterFunc(x, y): 1162 1163 @function.Defun(dtypes.int32) 1164 def InnerFunc(x): 1165 return x + x 1166 1167 return InnerFunc(x) + y 1168 1169 # Create graph with function calls and export to GraphDef 1170 with ops.Graph().as_default() as g1: 1171 p1 = array_ops.placeholder(dtype, name="p1") 1172 p2 = array_ops.placeholder(dtype, name="p2") 1173 # pylint: disable=unexpected-keyword-arg 1174 a, b = FuncWithGrad(p1, p2, name="f") 1175 1176 c = constant_op.constant(10, dtype=dtypes.int32) 1177 ExternalTensorFunc(1, name="external") 1178 1179 OuterFunc(10, 1, name="outer") 1180 # pylint: enable=unexpected-keyword-arg 1181 1182 gdef = g1.as_graph_def() 1183 1184 # Import GraphDef into new graph, add imported gradients, and test that 1185 # imported functions can be run 1186 with ops.Graph().as_default() as g2: 1187 p1, p2, a, b = importer.import_graph_def( 1188 gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="") 1189 grad = gradients_impl.gradients([a], [p1, p2]) 1190 1191 with self.session(graph=g2) as sess: 1192 feed_dict = {p1: 1, p2: 2} 1193 a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict) 1194 self.assertEqual(a_val, 3.0) 1195 self.assertEqual(b_val, -1.0) 1196 # Grad function returns inputs values for testing 1197 self.assertEqual(grad_val, [1.0, 2.0]) 1198 self.assertEqual(sess.run("external:0"), 11) 1199 self.assertEqual(sess.run("outer:0"), 21) 1200 1201 # Export the new graph and reimport to test that imported functions can be 1202 # successfully exported/imported again 1203 gdef = g2.as_graph_def() 1204 with ops.Graph().as_default() as g3: 1205 p1, p2, a, b = importer.import_graph_def( 1206 gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="") 1207 # Create new gradient functions (in additional to the imported gradient 1208 # functions created in g2). 1209 grad = gradients_impl.gradients([a], [p1, p2]) 1210 1211 with self.session(graph=g3) as sess: 1212 feed_dict = {p1: 1, p2: 2} 1213 a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict) 1214 self.assertEqual(a_val, 3.0) 1215 self.assertEqual(b_val, -1.0) 1216 self.assertEqual(grad_val, [1.0, 2.0]) 1217 self.assertEqual(sess.run("external:0"), 11) 1218 self.assertEqual(sess.run("outer:0"), 21) 1219 1220 @test_util.run_v1_only("import inside defun not supported when eager " 1221 "execution is enabled.") 1222 def testImportInsideDefun(self): 1223 g = ops.Graph() 1224 with g.as_default(): 1225 1226 @function.Defun() 1227 def Add2(x, y): 1228 return math_ops.add(x, y) 1229 1230 x = constant_op.constant(3.0, dtype=dtypes.float32) 1231 y = constant_op.constant(-5.0, dtype=dtypes.float32) 1232 z = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg 1233 1234 gdef = g.as_graph_def() 1235 1236 @function.Defun() 1237 def TestFunc(): 1238 return importer.import_graph_def(gdef, return_elements=["z:0"])[0] 1239 1240 z = TestFunc() 1241 1242 with self.cached_session(): 1243 z_val = self.evaluate(z) 1244 self.assertEqual(z_val, -2.0) 1245 1246 @test_util.run_v1_only("_as_tf_output not supported when eager execution " 1247 "is enabled.") 1248 def testImportGraphWithFunctionTwice(self): 1249 g = ops.Graph() 1250 with g.as_default(): 1251 1252 @function.Defun() 1253 def Add2(x, y): 1254 return math_ops.add(x, y) 1255 1256 x = array_ops.placeholder(dtype=dtypes.float32, name="x") 1257 y = array_ops.placeholder(dtype=dtypes.float32, name="y") 1258 _ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg 1259 1260 gdef = g.as_graph_def() 1261 1262 x = random_ops.random_uniform(dtype=dtypes.float32, shape=()) 1263 y = random_ops.random_uniform(dtype=dtypes.float32, shape=()) 1264 input_map = {"x:0": x, "y:0": y} 1265 1266 with ops.name_scope("first"): 1267 z1 = importer.import_graph_def(gdef, return_elements=["z:0"], 1268 input_map=input_map)[0] 1269 1270 with ops.name_scope("second"): 1271 z2 = importer.import_graph_def(gdef, return_elements=["z:0"], 1272 input_map=input_map)[0] 1273 1274 with self.cached_session() as sess: 1275 z1_val, z2_val = sess.run((z1, z2)) 1276 self.assertAllEqual(z1_val, z2_val) 1277 1278 1279if __name__ == "__main__": 1280 test.main() 1281