• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.importer."""
16
17import numpy as np
18
19from google.protobuf import text_format
20
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import op_def_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import device
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import function
27from tensorflow.python.framework import importer
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import test_ops  # pylint: disable=unused-import
31from tensorflow.python.framework import test_util
32from tensorflow.python.framework import versions
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gradients_impl
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import random_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import variables
41import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
42from tensorflow.python.platform import test
43
44
45class ImportGraphDefTest(test.TestCase):
46
47  def _MakeGraphDef(self,
48                    text,
49                    producer=versions.GRAPH_DEF_VERSION,
50                    min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER):
51    text = "versions: { producer: %d min_consumer: %d };\n%s" % (producer,
52                                                                 min_consumer,
53                                                                 text)
54    ret = graph_pb2.GraphDef()
55    text_format.Merge(text, ret)
56    return ret
57
58  def testBasic(self):
59    with ops.Graph().as_default():
60      a, b, c, d = importer.import_graph_def(
61          self._MakeGraphDef("""
62          node { name: 'A' op: 'IntOutputFloatOutput' }
63          node { name: 'B' op: 'ListOutput'
64                 attr { key: 'T'
65                        value { list { type: DT_INT32 type: DT_FLOAT } } } }
66          node { name: 'C' op: 'ListInput'
67                 attr { key: 'N' value { i: 2 } }
68                 attr { key: 'T' value { type: DT_INT32 } }
69                 input: 'A:0' input: 'B:0' }
70          node { name: 'D' op: 'ListInput'
71                 attr { key: 'N' value { i: 2 } }
72                 attr { key: 'T' value { type: DT_FLOAT } }
73                 input: 'A:1' input: 'B:1' }
74          """),
75          return_elements=["A", "B", "C", "D"],
76          name="import")
77
78      # Assert that the import process creates distinct tensors.
79      self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
80      self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
81      self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
82      self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
83      self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
84      self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
85
86      # Assert that the ops are connected according to the GraphDef topology.
87      self.assertEqual(c.inputs[0], a.outputs[0])
88      self.assertEqual(c.inputs[1], b.outputs[0])
89      self.assertEqual(d.inputs[0], a.outputs[1])
90      self.assertEqual(d.inputs[1], b.outputs[1])
91
92      # Check the types of the returned ops and tensors.
93      self.assertEqual(a.type, "IntOutputFloatOutput")
94      self.assertEqual(b.type, "ListOutput")
95      self.assertEqual(c.type, "ListInput")
96      self.assertEqual(d.type, "ListInput")
97      self.assertEqual(a.outputs[0].dtype, dtypes.int32)
98      self.assertEqual(a.outputs[1].dtype, dtypes.float32)
99      self.assertEqual(b.outputs[0].dtype, dtypes.int32)
100      self.assertEqual(b.outputs[1].dtype, dtypes.float32)
101
102      # Check the names of the returned ops.
103      self.assertEqual(a.name, "import/A")
104      self.assertEqual(b.name, "import/B")
105      self.assertEqual(c.name, "import/C")
106      self.assertEqual(d.name, "import/D")
107
108      # Check that the op_def is still available.
109      self.assertNotEqual(None, a.op_def)
110
111  def testMultipleImport(self):
112    graph_def = self._MakeGraphDef("""
113    node { name: 'A' op: 'IntOutput' }
114    node { name: 'B' op: 'IntInput' input: 'A:0' }
115    """)
116
117    with ops.Graph().as_default():
118      # Initial import
119      a, b = importer.import_graph_def(
120          graph_def,
121          return_elements=["A", "B"],
122          name="")
123      self.assertEqual(a.name, "A")
124      self.assertEqual(b.name, "B")
125      self.assertEqual(list(b.inputs), [a.outputs[0]])
126
127      # Repeat the same import
128      a1, b1 = importer.import_graph_def(
129          graph_def,
130          return_elements=["A", "B"],
131          name="")
132      self.assertEqual(a1.name, "A_1")
133      self.assertEqual(b1.name, "B_1")
134      self.assertEqual(list(b1.inputs), [a1.outputs[0]])
135
136      # Repeat the same import again
137      a2, b2 = importer.import_graph_def(
138          graph_def,
139          return_elements=["A", "B"],
140          name="")
141      self.assertEqual(a2.name, "A_2")
142      self.assertEqual(b2.name, "B_2")
143      self.assertEqual(list(b2.inputs), [a2.outputs[0]])
144
145      # Import with an already-used name
146      a3, b3 = importer.import_graph_def(
147          graph_def,
148          return_elements=["A", "B"],
149          name="A")
150      self.assertEqual(a3.name, "A_3/A")
151      self.assertEqual(b3.name, "A_3/B")
152      self.assertEqual(list(b3.inputs), [a3.outputs[0]])
153
154      # Import with an already-used name but with a '/' to indicate an
155      # "absolute" name scope (see the Graph.name_scope docstring).
156      a_a, a_b = importer.import_graph_def(
157          graph_def,
158          return_elements=["A", "B"],
159          name="A/")
160      self.assertEqual(a_a.name, "A/A")
161      self.assertEqual(a_b.name, "A/B")
162      self.assertEqual(list(a_b.inputs), [a_a.outputs[0]])
163
164      # Repeat the same import.
165      a_a1, a_b1 = importer.import_graph_def(
166          graph_def,
167          return_elements=["A", "B"],
168          name="A/")
169      self.assertEqual(a_a1.name, "A/A_1")
170      self.assertEqual(a_b1.name, "A/B_1")
171      self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]])
172
173      # Import with existing de-duped node names
174      a1_1, b1_1 = importer.import_graph_def(
175          self._MakeGraphDef("""
176          node { name: 'A_1' op: 'IntOutput' }
177          node { name: 'B_1' op: 'IntInput' input: 'A_1:0' }
178          """),
179          return_elements=["A_1", "B_1"],
180          name="")
181      self.assertEqual(a1_1.name, "A_1_1")
182      self.assertEqual(b1_1.name, "B_1_1")
183      self.assertEqual(list(b1_1.inputs), [a1_1.outputs[0]])
184
185      # Create a name scope and then import node with same name
186      with ops.name_scope("foo"):
187        constant_op.constant(1)
188      foo, = importer.import_graph_def(
189          self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
190          return_elements=["foo"],
191          name="")
192      self.assertEqual(foo.name, "foo_1")
193
194      # Imported node name can't conflict with intermediate name scope (but can
195      # conflict with outer scope and full name scope)
196      with ops.name_scope("outer"):
197        with ops.name_scope("inner"):
198          c = constant_op.constant(1, name="c")
199          self.assertEqual(c.op.name, "outer/inner/c")
200
201      outer, inner, new_c, outer_inner, outer_inner_c = (
202          importer.import_graph_def(
203              self._MakeGraphDef(
204                  "node { name: 'outer' op: 'IntOutput' }"
205                  "node { name: 'inner' op: 'IntOutput' }"
206                  "node { name: 'c' op: 'IntOutput' }"
207                  "node { name: 'outer/inner' op: 'IntOutput' }"
208                  "node { name: 'outer/inner/c' op: 'IntOutput' }"),
209              return_elements=["outer", "inner", "c", "outer/inner",
210                               "outer/inner/c"],
211              name=""))
212      self.assertEqual(outer.name, "outer_1")
213      self.assertEqual(inner.name, "inner")
214      self.assertEqual(new_c.name, "c")
215      self.assertEqual(outer_inner.name, "outer/inner_1")
216      self.assertEqual(outer_inner_c.name, "outer/inner/c_1")
217
218  def testEmptyNameScope(self):
219    with ops.Graph().as_default():
220      # Create name scope but don't create any ops with it
221      with ops.name_scope("foo"):
222        pass
223
224      # Import graph def that uses name scope name
225      op, = importer.import_graph_def(
226          self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
227          return_elements=["foo"],
228          name="")
229
230      self.assertEqual(op.name, "foo")
231
232  def testInputMap(self):
233    with ops.Graph().as_default():
234      feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
235      feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
236
237      a, b, c, d = importer.import_graph_def(
238          self._MakeGraphDef("""
239          node { name: 'A' op: 'TwoIntOutputs' }
240          node { name: 'B' op: 'TwoIntOutputs' }
241          node { name: 'C' op: 'ListInput'
242                 attr { key: 'N' value { i: 2 } }
243                 attr { key: 'T' value { type: DT_INT32 } }
244                 input: 'A:0' input: 'B:0' }
245          node { name: 'D' op: 'ListInput'
246                 attr { key: 'N' value { i: 2 } }
247                 attr { key: 'T' value { type: DT_INT32 } }
248                 input: 'A:1' input: 'B:1' }
249          """),
250          input_map={"A:0": feed_a_0,
251                     "B:1": feed_b_1},
252          return_elements=["A", "B", "C", "D"])
253
254      self.assertEqual(c.inputs[0], feed_a_0)
255      self.assertEqual(c.inputs[1], b.outputs[0])
256      self.assertEqual(d.inputs[0], a.outputs[1])
257      self.assertEqual(d.inputs[1], feed_b_1)
258
259  def testInputMapBytes(self):
260    with ops.Graph().as_default():
261      feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
262      feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
263
264      a, b, c, d = importer.import_graph_def(
265          self._MakeGraphDef("""
266          node { name: 'A' op: 'TwoIntOutputs' }
267          node { name: 'B' op: 'TwoIntOutputs' }
268          node { name: 'C' op: 'ListInput'
269                 attr { key: 'N' value { i: 2 } }
270                 attr { key: 'T' value { type: DT_INT32 } }
271                 input: 'A:0' input: 'B:0' }
272          node { name: 'D' op: 'ListInput'
273                 attr { key: 'N' value { i: 2 } }
274                 attr { key: 'T' value { type: DT_INT32 } }
275                 input: 'A:1' input: 'B:1' }
276          """),
277          input_map={b"A:0": feed_a_0,
278                     b"B:1": feed_b_1},
279          return_elements=[b"A", b"B", b"C", b"D"])
280
281      self.assertEqual(c.inputs[0], feed_a_0)
282      self.assertEqual(c.inputs[1], b.outputs[0])
283      self.assertEqual(d.inputs[0], a.outputs[1])
284      self.assertEqual(d.inputs[1], feed_b_1)
285
286  def testInputMapUnicode(self):
287    with ops.Graph().as_default():
288      feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
289      feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
290
291      a, b, c, d = importer.import_graph_def(
292          self._MakeGraphDef("""
293          node { name: 'A' op: 'TwoIntOutputs' }
294          node { name: 'B' op: 'TwoIntOutputs' }
295          node { name: 'C' op: 'ListInput'
296                 attr { key: 'N' value { i: 2 } }
297                 attr { key: 'T' value { type: DT_INT32 } }
298                 input: 'A:0' input: 'B:0' }
299          node { name: 'D' op: 'ListInput'
300                 attr { key: 'N' value { i: 2 } }
301                 attr { key: 'T' value { type: DT_INT32 } }
302                 input: 'A:1' input: 'B:1' }
303          """),
304          input_map={u"A:0": feed_a_0,
305                     u"B:1": feed_b_1},
306          return_elements=[u"A", u"B", u"C", u"D"])
307
308      self.assertEqual(c.inputs[0], feed_a_0)
309      self.assertEqual(c.inputs[1], b.outputs[0])
310      self.assertEqual(d.inputs[0], a.outputs[1])
311      self.assertEqual(d.inputs[1], feed_b_1)
312
313  def testImplicitZerothOutput(self):
314    with ops.Graph().as_default():
315      a, b = importer.import_graph_def(
316          self._MakeGraphDef("""
317          node { name: 'A' op: 'TwoIntOutputs' }
318          node { name: 'B' op: 'IntInput' input: 'A' }
319          """),
320          return_elements=["A", "B"])
321
322      self.assertEqual(b.inputs[0], a.outputs[0])
323
324  def testInputMapImplicitZerothOutput(self):
325    with ops.Graph().as_default():
326      feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
327      b, = importer.import_graph_def(
328          self._MakeGraphDef("""
329          node { name: 'A' op: 'TwoIntOutputs' }
330          node { name: 'B' op: 'IntInput' input: 'A:0' }
331          """),
332          input_map={"A": feed_a_0},
333          return_elements=["B"])
334
335      self.assertEqual(b.inputs[0], feed_a_0)
336
337  def testWithControlDependency(self):
338    with ops.Graph().as_default():
339      a, b = importer.import_graph_def(
340          self._MakeGraphDef("""
341          node { name: 'A' op: 'None' }
342          node { name: 'B' op: 'None' input: '^A' }
343          """),
344          return_elements=["A", "B"])
345
346      self.assertEqual(b.control_inputs, [a])
347
348  def testWithRefs(self):
349    with ops.Graph().as_default():
350      a, b, c, d = importer.import_graph_def(
351          self._MakeGraphDef("""
352          node { name: 'A' op: 'RefOutput' }
353          node { name: 'B' op: 'IntOutput' }
354          node { name: 'C' op: 'TwoIntInputs' input: 'A:0' input: 'B:0' }
355          node { name: 'D' op: 'RefInputIntInput' input: 'A:0' input: 'B:0' }
356          """),
357          return_elements=["A", "B", "C", "D"])
358
359      self.assertEqual(c.inputs[0], a.outputs[0])
360      self.assertEqual(c.inputs[1], b.outputs[0])
361      self.assertEqual(d.inputs[0], a.outputs[0])
362      self.assertEqual(d.inputs[1], b.outputs[0])
363
364      self.assertEqual(a.outputs[0].dtype, dtypes.int32_ref)
365      self.assertEqual(c._input_types, [dtypes.int32, dtypes.int32])
366      self.assertEqual(c.outputs, [])
367      self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32])
368      self.assertEqual(d.outputs, [])
369
370  def testResources(self):
371    # Produce GraphDef containing a ops producing and consuming resources.
372    graph = ops.Graph()
373    with graph.as_default():
374      var = resource_variable_ops.ResourceVariable(1.0)
375      var_assign = var.assign(2.0)
376      # Use an op that requires handle shape to be set.
377      var_shape = resource_variable_ops.variable_shape(var.handle)
378      init = variables.global_variables_initializer()
379    graph_def = graph.as_graph_def()
380
381    # Import the GraphDef.
382    with ops.Graph().as_default():
383      # pylint: disable=unused-variable
384      imported_var, imported_assign, imported_shape, imported_init = (
385          importer.import_graph_def(
386              graph_def,
387              return_elements=[var.name, var_assign.name, var_shape.name,
388                               init.name]))
389
390      # Make sure the handle shape is set on the imported variable.
391      new_var_shape = resource_variable_ops.variable_shape(imported_var)
392      # pylint: enable=unused-variable
393
394      # Run the imported graph.
395      # TODO(b/76173421): make this work (currently DCHECKS)
396      # with self.cached_session() as sess:
397      #   self.evaluate(imported_init)
398      #   self.assertEqual(self.evaluate(imported_var), 1.0)
399      #   self.assertEqual(self.evaluate(imported_assign), 2.0)
400      #   self.assertEqual(list(self.evaluate(imported_shape)), [])
401      #   self.assertEqual(list(self.evaluate(new_var_shape)), [])
402
403  def testWhileLoop(self):
404    # Produce GraphDef containing while loop.
405    graph = ops.Graph()
406    with graph.as_default():
407      r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
408      # Add an op that consumes the while loop output.
409      math_ops.add(r, 1)
410    graph_def = graph.as_graph_def()
411
412    # Import the GraphDef and make sure it runs.
413    with ops.Graph().as_default():
414      imported_r, = importer.import_graph_def(graph_def,
415                                              return_elements=[r.name])
416      self.assertEqual(imported_r.name, "import/" + r.name)
417      with self.cached_session() as sess:
418        self.assertEqual(self.evaluate(imported_r), 10)
419
420  def testImportWhileLoopInCond(self):
421    # Produce GraphDef containing while loop.
422    graph = ops.Graph()
423    with graph.as_default():
424      r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
425    graph_def = graph.as_graph_def()
426
427    # Import the GraphDef inside a cond and make sure it runs.
428    with ops.Graph().as_default():
429
430      def ImportFn():
431        return importer.import_graph_def(graph_def, return_elements=[r.name])[0]
432
433      pred = array_ops.placeholder(dtypes.bool)
434      out = control_flow_ops.cond(pred, ImportFn,
435                                  lambda: constant_op.constant(1))
436      with self.cached_session() as sess:
437        self.assertEqual(sess.run(out, {pred: True}), 10)
438        self.assertEqual(sess.run(out, {pred: False}), 1)
439
440  def testImportWhileLoopInWhileLoop(self):
441    self.skipTest("b/111757448")
442    # Produce GraphDef containing while loop.
443    graph = ops.Graph()
444    with graph.as_default():
445      r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
446    graph_def = graph.as_graph_def()
447
448    # Import the GraphDef inside another loop and make sure it runs.
449    with ops.Graph().as_default():
450
451      def ImportFn(_):
452        return importer.import_graph_def(graph_def, return_elements=[r.name])[0]
453
454      out = control_flow_ops.while_loop(
455          lambda i: i < 2, ImportFn, [0],
456          shape_invariants=[tensor_shape.TensorShape(None)])
457      with self.cached_session() as sess:
458        self.assertEqual(self.evaluate(out), 10)
459
460  def testTypeMismatchInGraphDef(self):
461    # TODO(skyewm): improve error message
462    error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 "
463                 "incompatible with expected float.")
464    with ops.Graph().as_default():
465      with self.assertRaisesRegex(ValueError, error_msg):
466        importer.import_graph_def(
467            self._MakeGraphDef("""
468            node { name: 'A' op: 'IntOutput' }
469            node { name: 'B' op: 'FloatInput' input: 'A:0' }
470            """))
471
472  def testShapeAllowlistViolation(self):
473    # L2 loss produces a scalar shape, but the graph
474    # has the wrong shape, so raise an error.
475    with ops.Graph().as_default():
476      with self.assertRaises(ValueError) as e:
477        _ = importer.import_graph_def(
478            self._MakeGraphDef("""
479              node { name: 'A' op: 'FloatOutput' }
480              node { name: 'B' op: 'L2Loss'
481                     input: 'A:0'
482                     attr { key: 'T' value { type: DT_FLOAT } }
483                     attr { key: '_output_shapes'
484                            value { list { shape { dim { size: 43 } } } } } }
485            """),
486            return_elements=["B"],
487            name="import")
488        self.assertTrue(
489            "Shapes () and (43,) are not compatible" in str(e.exception))
490
491  def testInvalidSignatureTooManyInputsInGraphDef(self):
492    with ops.Graph().as_default():
493      # TODO(skyewm): improve error message
494      with self.assertRaisesRegex(
495          ValueError,
496          "NodeDef expected inputs '' do not match 1 inputs specified"):
497        importer.import_graph_def(
498            self._MakeGraphDef("""
499            node { name: 'A' op: 'IntOutput' }
500            node { name: 'B' op: 'None' input: 'A:0' }
501            """))
502
503  def testInvalidSignatureNotEnoughInputsInGraphDef(self):
504    with ops.Graph().as_default():
505      # TODO(skyewm): improve error message
506      with self.assertRaisesRegex(
507          ValueError,
508          "NodeDef expected inputs 'int32, float' do not match 1 inputs "
509          "specified"):
510        importer.import_graph_def(
511            self._MakeGraphDef("""
512            node { name: 'A' op: 'IntOutput' }
513            node { name: 'B' op: 'IntInputFloatInput' input: 'A:0' }
514            """))
515
516  def testMissingInputOpInGraphDef(self):
517    with ops.Graph().as_default():
518      with self.assertRaisesRegex(ValueError,
519                                  "Node 'B': Unknown input node 'A:0'"):
520        importer.import_graph_def(
521            self._MakeGraphDef("""
522            node { name: 'B' op: 'FloatInput' input: 'A:0' }
523            """))
524
525  def testMissingInputOpInGraphDefButAppearsInInputMap(self):
526    with ops.Graph().as_default():
527      feed_a_0 = constant_op.constant(5.0)
528      b, = importer.import_graph_def(
529          self._MakeGraphDef("""
530          node { name: 'B' op: 'FloatInput' input: 'A:0' }
531          """),
532          input_map={"A:0": feed_a_0},
533          return_elements=["B"])
534      self.assertEqual(b.inputs[0], feed_a_0)
535
536  def testMissingInputTensorInGraphDef(self):
537    with ops.Graph().as_default():
538      with self.assertRaisesRegex(
539          ValueError,
540          "Node 'B': Connecting to invalid output 1 of source node A "
541          "which has 1 outputs"):
542        importer.import_graph_def(
543            self._MakeGraphDef("""
544            node { name: 'A' op: 'FloatOutput' }
545            node { name: 'B' op: 'FloatInput' input: 'A:1' }
546            """))
547
548  def testMissingControlInputInGraphDef(self):
549    with ops.Graph().as_default():
550      with self.assertRaisesRegex(ValueError,
551                                  r"Node 'B': Unknown input node '\^A'"):
552        importer.import_graph_def(
553            self._MakeGraphDef("""
554            node { name: 'B' op: 'None' input: '^A' }
555            """))
556
557  def testInvalidTensorNameOutputIndexInGraphDef(self):
558    with ops.Graph().as_default():
559      with self.assertRaisesRegex(ValueError,
560                                  "Node 'B': Unknown input node 'A:B'"):
561        importer.import_graph_def(
562            self._MakeGraphDef("""
563            node { name: 'B' op: 'None' input: 'A:B' }
564            """))
565
566  def testInvalidTensorNameInGraphDef(self):
567    with ops.Graph().as_default():
568      with self.assertRaisesRegex(ValueError,
569                                  "Node 'B': Unknown input node 'A:B:0'"):
570        importer.import_graph_def(
571            self._MakeGraphDef("""
572            node { name: 'B' op: 'None' input: 'A:B:0' }
573            """))
574
575  def testMissingReturnOperation(self):
576    with ops.Graph().as_default():
577      with self.assertRaisesRegex(
578          ValueError, "Requested return node 'B' not found in graph def"):
579        importer.import_graph_def(
580            self._MakeGraphDef("""
581            node { name: 'A' op: 'None' }
582            """),
583            return_elements=["B"])
584
585  def testMissingReturnTensor(self):
586    with ops.Graph().as_default():
587      with self.assertRaisesRegex(
588          ValueError,
589          r"Invalid return output 1 of node 'A', which has 1 output\(s\)"):
590        importer.import_graph_def(
591            self._MakeGraphDef("""
592            node { name: 'A' op: 'IntOutput' }
593            """),
594            return_elements=["A:1"])
595
596      with self.assertRaisesRegex(
597          ValueError, "Requested return tensor 'B:0' not found in graph def"):
598        importer.import_graph_def(
599            self._MakeGraphDef("""
600            node { name: 'A' op: 'IntOutput' }
601            """),
602            return_elements=["B:0"])
603
604      with self.assertRaisesRegex(ValueError,
605                                  "Cannot convert 'A:B:0' to a tensor name."):
606        importer.import_graph_def(
607            self._MakeGraphDef("""
608            node { name: 'A' op: 'IntOutput' }
609            """),
610            return_elements=["A:B:0"])
611
612  def testMissingInputMap(self):
613    with ops.Graph().as_default():
614      with self.assertRaisesRegex(
615          ValueError,
616          r"Attempted to map inputs that were not found in graph_def: \[B:0\]"):
617        importer.import_graph_def(
618            self._MakeGraphDef("""
619            node { name: 'A' op: 'None' }
620            """),
621            input_map={"B:0": constant_op.constant(5.0)})
622
623  def testInputMapUnusedAsInput(self):
624    with ops.Graph().as_default():
625      # Mapping an unused node output should succeed.
626      importer.import_graph_def(
627          self._MakeGraphDef("""
628          node { name: 'A' op: 'IntOutput' }
629          """),
630          input_map={"A:0": constant_op.constant(5.0)})
631
632      # Mapping a non-existent output of an existing node should fail.
633      with self.assertRaisesRegex(
634          ValueError,
635          r"Attempted to map inputs that were not found in graph_def: \[A:2\]"):
636        importer.import_graph_def(
637            self._MakeGraphDef("""
638            node { name: 'A' op: 'IntOutput' }
639            """),
640            input_map={"A:2": constant_op.constant(5.0)})
641
642  def testInputMapTypeMismatch(self):
643    with ops.Graph().as_default():
644      with self.assertRaisesRegex(
645          ValueError, "Input 0 of node import/B was passed float from Const:0 "
646          "incompatible with expected int32."):
647        importer.import_graph_def(
648            self._MakeGraphDef("""
649            node { name: 'A' op: 'IntOutput' }
650            node { name: 'B' op: 'IntInput' input: 'A:0' }
651            """),
652            input_map={"A:0": constant_op.constant(5.0)})
653
654  def testNoReturns(self):
655    with ops.Graph().as_default() as g:
656      ret = importer.import_graph_def(
657          self._MakeGraphDef("""
658          node { name: 'A' op: 'None' }
659          """))
660      self.assertEqual(ret, None)
661
662      a = g.get_operation_by_name("import/A")
663      self.assertEqual(a.type, "None")
664
665  def testOverrideNamePrefix(self):
666    with ops.Graph().as_default():
667      a, = importer.import_graph_def(
668          self._MakeGraphDef("""
669          node { name: 'A' op: 'None' }
670          """),
671          return_elements=["A"],
672          name="imported_graph")
673      self.assertEqual(a.name, "imported_graph/A")
674
675  def testDefaultNamePrefix(self):
676    with ops.Graph().as_default():
677      a, = importer.import_graph_def(
678          self._MakeGraphDef("""
679          node { name: 'A' op: 'None' }
680          """),
681          return_elements=["A"],
682          name=None)
683      self.assertEqual(a.name, "import/A")
684
685  def testNamePrefixColocationAttrs(self):
686    original_graph_def = self._MakeGraphDef("""
687          node { name: 'A' op: 'None' }
688          node { name: 'B' op: 'None'  attr {
689            key: '_class'
690            value { list { s: 'loc:@A' } }
691          } }""")
692
693    with ops.Graph().as_default():
694      b, = importer.import_graph_def(
695          original_graph_def, return_elements=["B"], name="imported_graph")
696      self.assertTrue("_class" in b.node_def.attr)
697      self.assertProtoEquals(
698          "list { s: 'loc:@imported_graph/A' }",
699          b.node_def.attr["_class"])
700
701  def testColocationAndDevice(self):
702    # A and B are colocated, device set on A.
703    original_graph_def = self._MakeGraphDef("""
704          node { name: 'A' op: 'None' device: '/device:CPU:0' attr {
705            key: '_class'
706            value { list { s: 'loc:@A' } }
707          } }
708          node { name: 'B' op: 'None'  attr {
709            key: '_class'
710            value { list { s: 'loc:@A' } }
711          } }""")
712
713    with ops.Graph().as_default():
714      a, b = importer.import_graph_def(original_graph_def,
715                                       return_elements=["A", "B"],
716                                       name="")
717      self.assertEqual(a.device, "/device:CPU:0")
718      self.assertEqual(b.device, "/device:CPU:0")
719      self.assertEqual(a.colocation_groups(), [b"loc:@A"])
720      self.assertEqual(b.colocation_groups(), [b"loc:@A"])
721
722    # A and B are colocated, device set on B.
723    original_graph_def = self._MakeGraphDef("""
724          node { name: 'A' op: 'None' attr {
725            key: '_class'
726            value { list { s: 'loc:@A' } }
727          } }
728          node { name: 'B' op: 'None' device: '/device:CPU:0' attr {
729            key: '_class'
730            value { list { s: 'loc:@A' } }
731          } }""")
732
733    with ops.Graph().as_default():
734      a, b = importer.import_graph_def(original_graph_def,
735                                       return_elements=["A", "B"],
736                                       name="")
737      # TODO(skyewm): this behavior seems inconsistent with the above. Why is
738      # B's device ignored?
739      self.assertEqual(a.device, "")
740      self.assertEqual(b.device, "")
741      self.assertEqual(a.colocation_groups(), [b"loc:@A"])
742      self.assertEqual(b.colocation_groups(), [b"loc:@A"])
743
744  def testColocationWithDeviceFn(self):
745    original_graph_def = self._MakeGraphDef("""
746          node { name: 'A' op: 'None' attr {
747            key: '_class'
748            value { list { s: 'loc:@A' } }
749          } }
750          node { name: 'B' op: 'None'  attr {
751            key: '_class'
752            value { list { s: 'loc:@A' } }
753          } }""")
754
755    # A device function that places "A" on one device and "B" on
756    # another device.  Because B is colocated with A, we test that B's
757    # device function is overridden by A.
758    def CustomDeviceFn(op):
759      if "A" in op.name:
760        return "/device:A:0"
761      else:
762        return "/device:B:0"
763
764    with ops.Graph().as_default():
765      with ops.device(CustomDeviceFn):
766        a, b = importer.import_graph_def(original_graph_def,
767                                         return_elements=["A", "B"],
768                                         name="imported_graph")
769      self.assertEqual(a.device, "/device:A:0")
770      self.assertEqual(b.device, "/device:A:0")
771      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
772      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
773
774    # Test a scenario where 'A' doesn't get a device; 'A' should not have a
775    # device, but during runtime will get colocated with 'B' because of the
776    # colocation attribute. B's device function is still overridden by A.
777    def BDeviceFn(op):
778      if "B" in op.name:
779        return "/device:B:0"
780      return ""
781
782    with ops.Graph().as_default():
783      with ops.device(BDeviceFn):
784        a, b = importer.import_graph_def(original_graph_def,
785                                         return_elements=["A", "B"],
786                                         name="imported_graph")
787      self.assertEqual(a.device, "")
788      self.assertEqual(b.device, "")
789      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
790      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
791
792    # Only A gets a device, so B inherits it implicitly.
793    def ADeviceFn(op):
794      if "A" in op.name:
795        return "/device:A:0"
796      return ""
797
798    with ops.Graph().as_default():
799      with ops.device(ADeviceFn):
800        a, b = importer.import_graph_def(original_graph_def,
801                                         return_elements=["A", "B"],
802                                         name="imported_graph")
803      self.assertEqual(a.device, "/device:A:0")
804      self.assertEqual(b.device, "/device:A:0")
805      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
806      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
807
808  def testMultipleColocationWithDeviceFn(self):
809    original_graph_def = self._MakeGraphDef("""
810          node { name: 'A' op: 'None'}
811          node { name: 'B' op: 'None'}
812          node { name: 'C' op: 'None'  attr {
813            key: '_class'
814            value { list { s: 'loc:@A' s: 'loc:@B' } }
815          } }""")
816
817    # A device function that places "B" on a device, and "A" is empty.
818    #
819    # B and C should contain "/device:B".  A will not right now.  But
820    # because of the colocation property, at runtime it would be
821    # placed with B and C.
822    def CustomDeviceFn(op):
823      if "B" in op.name:
824        return "/device:B:0"
825      return ""
826
827    with ops.Graph().as_default():
828      with ops.device(CustomDeviceFn):
829        a, b, c = importer.import_graph_def(original_graph_def,
830                                            return_elements=["A", "B", "C"],
831                                            name="imported_graph")
832      self.assertEqual(a.device, "")
833      self.assertEqual(b.device, "/device:B:0")
834      self.assertEqual(c.device, "/device:B:0")
835      self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
836      self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/B"])
837      self.assertEqual(c.colocation_groups(),
838                       [b"loc:@imported_graph/A", b"loc:@imported_graph/B"])
839
840  def testNamePrefixColocationAttrsMultipleImport(self):
841    original_graph_def = self._MakeGraphDef("""
842          node { name: 'A' op: 'None' }
843          node { name: 'B' op: 'None'  attr {
844            key: '_class'
845            value { list { s: 'loc:@A' } }
846          } }""")
847
848    with ops.Graph().as_default():
849      a, b = importer.import_graph_def(
850          original_graph_def, return_elements=["A", "B"], name="")
851      a_1, b_1 = importer.import_graph_def(
852          original_graph_def, return_elements=["A", "B"], name="")
853
854      self.assertEqual(a.name, "A")
855      self.assertEqual(b.name, "B")
856      self.assertEqual(b.colocation_groups(), [b"loc:@A"])
857
858      self.assertEqual(a_1.name, "A_1")
859      self.assertEqual(b_1.name, "B_1")
860      self.assertEqual(b_1.colocation_groups(), [b"loc:@A_1"])
861
862  def testNamePrefixColocationAttrsNotFound(self):
863    original_graph_def = self._MakeGraphDef("""
864          node { name: 'B' op: 'None'  attr {
865            key: '_class'
866            value { list { s: 'loc:@A' } }
867          } }""")
868
869    with ops.Graph().as_default():
870      with self.assertRaisesRegex(
871          ValueError, "Node 'B' expects to be colocated with unknown node 'A'"):
872        importer.import_graph_def(
873            original_graph_def, return_elements=["B"], name="imported_graph")
874
875  def testEmptyGraph(self):
876    with ops.Graph().as_default() as g:
877      init_version = g.version
878      importer.import_graph_def(self._MakeGraphDef(""))
879      self.assertEqual(init_version, g.version)
880
881  def testInvalidInputForGraphDef(self):
882    with ops.Graph().as_default():
883      with self.assertRaisesRegex(
884          TypeError, r"Argument `graph_def` must be a GraphDef proto."):
885        importer.import_graph_def("")
886
887  def testInvalidInputForInputMap(self):
888    with ops.Graph().as_default():
889      with self.assertRaisesRegex(
890          TypeError,
891          r"Argument `input_map` must be a dictionary. Obtained list"):
892        importer.import_graph_def(
893            self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
894    graph_def = self._MakeGraphDef("""
895         node { name: 'a' op: 'Placeholder'
896                attr { key: 'dtype' value { type: DT_FLOAT } }}
897         node { name: 'id' op: 'Identity' input: 'a:0'
898                attr { key: 'T' value { type: DT_FLOAT } }}""")
899    with ops.Graph().as_default():
900      with self.assertRaises(ValueError) as e:
901        importer.import_graph_def(
902            graph_def,
903            input_map={"a:0": variables.Variable(5.0)},
904            name="")
905      self.assertStartsWith(str(e.exception),
906                            "tf.import_graph_def() requires a non-empty `name` "
907                            "if `input_map` contains non-Tensor values.")
908    with ops.Graph().as_default():
909      t, = importer.import_graph_def(
910          graph_def,
911          input_map={"a:0": constant_op.constant(5.0)},
912          name="",
913          return_elements=["id:0"])
914      with self.cached_session():
915        self.assertEqual(5.0, self.evaluate(t))
916
917  def testInvalidInputForReturnOperations(self):
918    with ops.Graph().as_default():
919      with self.assertRaisesRegex(
920          TypeError, "Argument `return_elements` must be a list of strings."):
921        importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
922
923      with self.assertRaisesRegex(ValueError,
924                                  "Cannot convert 'a:b:c' to a tensor name."):
925        importer.import_graph_def(
926            self._MakeGraphDef(""), return_elements=["a:b:c"])
927
928  def testDuplicateOperationNames(self):
929    with self.assertRaisesRegex(ValueError, "Node 'A' is not unique"):
930      importer.import_graph_def(
931          self._MakeGraphDef("""
932          node { name: 'A' op: 'IntOutput' }
933          node { name: 'B' op: 'IntOutput' }
934          node { name: 'A' op: 'IntOutput' }
935          """))
936
937  @test_util.run_v1_only("v1 Tensor doesn't have attribute 'numpy'")
938  def testWithExtensionAndAttr(self):
939    with ops.Graph().as_default() as g:
940      c = constant_op.constant(5.0, dtype=dtypes.float32, name="c")
941      array_ops.stack([c, c], name="pack")
942    gdef = g.as_graph_def()
943
944    with self.cached_session():
945      pack, = importer.import_graph_def(gdef, return_elements=["pack"])
946      self.assertAllEqual(pack.outputs[0], [5.0, 5.0])
947
948  def testWithDevice(self):
949    with ops.Graph().as_default() as g:
950      # No device.
951      a = constant_op.constant(3.0, name="a")
952
953      with ops.device("/cpu:0"):
954        b = constant_op.constant(4.0, name="b")
955      with ops.device("/job:worker"):
956        c = constant_op.constant(5.0, name="c")
957
958    gdef = g.as_graph_def()
959
960    with ops.Graph().as_default():
961      a2, b2, c2 = importer.import_graph_def(
962          gdef, return_elements=["a", "b", "c"])
963      self.assertEqual(a.device, a2.device)
964      self.assertEqual(b.device, b2.device)
965      self.assertEqual(c.device, c2.device)
966
967    with ops.Graph().as_default():
968      with ops.device(device.merge_device("/task:0")):
969        a3, b3, c3 = importer.import_graph_def(
970            gdef, return_elements=["a", "b", "c"])
971        self.assertEqual("/task:0", a3.device)
972        self.assertEqual("/task:0/device:CPU:0", b3.device)  # canonicalized.
973        self.assertEqual(c.device + "/task:0", c3.device)
974
975    with ops.Graph().as_default():
976      with ops.device(device.merge_device("/job:ps")):
977        a4, b4, c4 = importer.import_graph_def(
978            gdef, return_elements=["a", "b", "c"])
979        self.assertEqual("/job:ps", a4.device)
980        self.assertEqual("/job:ps/device:CPU:0", b4.device)  # canonicalized.
981        self.assertEqual(c.device, c4.device)  # worker overrides ps.
982
983    with ops.Graph().as_default():
984      with ops.device(device.merge_device("/device:GPU:0")):
985        a5, b5, c5 = importer.import_graph_def(
986            gdef, return_elements=["a", "b", "c"])
987        self.assertEqual("/device:GPU:0", a5.device)
988        self.assertEqual("/device:CPU:0", b5.device)  # cpu overrides gpu.
989        self.assertEqual(c.device + "/device:GPU:0", c5.device)
990
991  def testWithDeviceFunctionDependingOnInputs(self):
992    with ops.Graph().as_default() as g:
993      with ops.device("/job:ps"):
994        v1 = constant_op.constant(1.0)
995        v2 = constant_op.constant(1.0)
996      _ = v1 + v2
997      _ = v1 - v2
998      _ = array_ops.identity(v1)
999    gdef = g.as_graph_def()
1000
1001    # We'll use the following device function to observe ops with two inputs.
1002    ops_with_two_inputs = []
1003
1004    def InputCounter(op):
1005      if len(op.inputs) == 2:
1006        ops_with_two_inputs.append(op)
1007      return ""
1008
1009    with ops.Graph().as_default() as g:
1010      with ops.device(InputCounter):
1011        importer.import_graph_def(gdef)
1012
1013    # We expect to see the add and subtract, but not identity.
1014    self.assertEqual(2, len(ops_with_two_inputs))
1015
1016  def testGradient(self):
1017    with ops.Graph().as_default() as g:
1018      inputs = array_ops.placeholder(
1019          dtypes.float32, shape=[None, 100], name="input")
1020      weights = array_ops.placeholder(
1021          dtypes.float32, shape=[100, 10], name="weights")
1022      biases = array_ops.placeholder(dtypes.float32, shape=[10], name="biases")
1023      activations = nn_ops.relu(
1024          math_ops.matmul(inputs, weights) + biases, name="activations")
1025      loss = math_ops.reduce_mean(activations, name="loss")
1026    gdef = g.as_graph_def()
1027
1028    with ops.Graph().as_default() as g:
1029      input_placeholder = array_ops.placeholder(dtypes.float32, shape=[32, 100])
1030      weights_var = variables.Variable(
1031          random_ops.truncated_normal([100, 10]), name="weights")
1032      biases_var = variables.Variable(array_ops.zeros([10]), name="biases")
1033      activations, loss = importer.import_graph_def(
1034          gdef,
1035          input_map={
1036              "input:0": input_placeholder,
1037              "weights:0": weights_var,
1038              "biases:0": biases_var
1039          },
1040          return_elements=["activations:0", "loss:0"])
1041      self.assertEqual([32, 10], activations.get_shape())
1042      self.assertEqual([], loss.get_shape())
1043      weights_grad, biases_grad = gradients_impl.gradients(
1044          loss, [weights_var, biases_var])
1045      self.assertEqual([100, 10], weights_grad.get_shape())
1046      self.assertEqual([10], biases_grad.get_shape())
1047
1048  def testLargeGraph(self):
1049    with self.cached_session():
1050      # The default message byte limit is 64M. Ours is 2G with a warning at 512.
1051      # Adding a 130M entries float32 tensor should exceed the warning, but not
1052      # the hard limit.
1053      input_shape = [130, 1000, 1000]
1054      tensor_input = np.ones(input_shape, dtype=np.float32)
1055      t = constant_op.constant(tensor_input, shape=input_shape)
1056      g = array_ops.identity(t)
1057      self.evaluate(g)
1058
1059  def testVersion(self):
1060    v0 = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
1061    v2 = versions.GRAPH_DEF_VERSION
1062    v1 = (v0 + v2) // 2
1063    for producer in v0, v1, v2:
1064      for min_consumer in v0, v1, v2:
1065        with ops.Graph().as_default():
1066          a, = importer.import_graph_def(
1067              self._MakeGraphDef(
1068                  "node { name: 'A' op: 'TwoIntOutputs' }",
1069                  producer=producer,
1070                  min_consumer=min_consumer),
1071              return_elements=["A"])
1072          self.assertEqual(a.graph.graph_def_versions.producer, producer)
1073          self.assertEqual(a.graph.graph_def_versions.min_consumer,
1074                           min_consumer)
1075
1076  def testVersionLow(self):
1077    with ops.Graph().as_default():
1078      with self.assertRaisesRegex(
1079          Exception,
1080          r"GraphDef producer version -1 below min producer %d supported "
1081          r"by TensorFlow \S+\.  Please regenerate your graph.$" %
1082          versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
1083        importer.import_graph_def(self._MakeGraphDef("", producer=-1))
1084
1085  def testVersionHigh(self):
1086    with ops.Graph().as_default():
1087      with self.assertRaisesRegex(
1088          ValueError,
1089          r"GraphDef min consumer version %d above current version %d "
1090          r"for TensorFlow \S+\.  Please upgrade TensorFlow\.$" %
1091          (1 << 30, versions.GRAPH_DEF_VERSION)):
1092        importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
1093
1094  def testVersionAppliesToOpConstruction(self):
1095    """These tests rely on shape fns in test_ops.cc."""
1096    with ops.Graph().as_default():
1097      importer.import_graph_def(
1098          self._MakeGraphDef(
1099              "node { name: 'A' op: 'RequiresOlderGraphVersion' }",
1100              producer=versions.GRAPH_DEF_VERSION - 1),
1101          return_elements=["A"])
1102
1103    with ops.Graph().as_default():
1104      with self.assertRaisesWithPredicateMatch(ValueError,
1105                                               "Wrong graph version.*"):
1106        importer.import_graph_def(
1107            self._MakeGraphDef(
1108                "node { name: 'A' op: 'RequiresOlderGraphVersion' }",
1109                producer=versions.GRAPH_DEF_VERSION),
1110            return_elements=["A"])
1111
1112  def testDefaultAttrsAdded(self):
1113    with ops.Graph().as_default():
1114      a = importer.import_graph_def(
1115          self._MakeGraphDef("""
1116          node { name: 'A' op: 'OpWithDefaultAttr' }
1117          """),
1118          return_elements=["A"])
1119      self.assertEqual(123.0, a[0].get_attr("default_float"))
1120
1121  def testDefaultAttrsRemoved(self):
1122    producer_op_list = op_def_pb2.OpList()
1123    text_format.Merge("""
1124      op {
1125        name: 'OpWithFutureDefaultAttr'
1126        attr { name: 'default_int' type: 'int' default_value { i: 456 } }
1127      }
1128    """, producer_op_list)
1129    # Attr only in producer_op_list with default value gets removed.
1130    with ops.Graph().as_default():
1131      a = importer.import_graph_def(
1132          self._MakeGraphDef("""
1133          node { name: 'A' op: 'OpWithFutureDefaultAttr'
1134                 attr { key: 'default_int' value { i: 456 } } }
1135          """),
1136          return_elements=["A"],
1137          producer_op_list=producer_op_list)
1138      with self.assertRaisesRegex(
1139          ValueError, "Operation 'import/A' has no attr named 'default_int'."):
1140        a[0].get_attr("default_int")
1141
1142  def testFunctions(self):
1143    dtype = dtypes.float32
1144
1145    @function.Defun(dtype, dtype, dtype, dtype)
1146    def Grad(x, y, dout1, dout2):  # pylint: disable=unused-argument
1147      # Return the inputs for simplicity of testing. The correct return value
1148      # would be (dout1 + dout2, dout1 - dout2)
1149      return x, y
1150
1151    @function.Defun(dtype, dtype, grad_func=Grad)
1152    def FuncWithGrad(x, y):
1153      return x + y, x - y
1154
1155    @function.Defun(dtypes.int32)
1156    def ExternalTensorFunc(x):
1157      # c must be defined in the containing graph
1158      return x + c
1159
1160    @function.Defun(dtypes.int32, dtypes.int32)
1161    def OuterFunc(x, y):
1162
1163      @function.Defun(dtypes.int32)
1164      def InnerFunc(x):
1165        return x + x
1166
1167      return InnerFunc(x) + y
1168
1169    # Create graph with function calls and export to GraphDef
1170    with ops.Graph().as_default() as g1:
1171      p1 = array_ops.placeholder(dtype, name="p1")
1172      p2 = array_ops.placeholder(dtype, name="p2")
1173      # pylint: disable=unexpected-keyword-arg
1174      a, b = FuncWithGrad(p1, p2, name="f")
1175
1176      c = constant_op.constant(10, dtype=dtypes.int32)
1177      ExternalTensorFunc(1, name="external")
1178
1179      OuterFunc(10, 1, name="outer")
1180      # pylint: enable=unexpected-keyword-arg
1181
1182    gdef = g1.as_graph_def()
1183
1184    # Import GraphDef into new graph, add imported gradients, and test that
1185    # imported functions can be run
1186    with ops.Graph().as_default() as g2:
1187      p1, p2, a, b = importer.import_graph_def(
1188          gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
1189      grad = gradients_impl.gradients([a], [p1, p2])
1190
1191      with self.session(graph=g2) as sess:
1192        feed_dict = {p1: 1, p2: 2}
1193        a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
1194        self.assertEqual(a_val, 3.0)
1195        self.assertEqual(b_val, -1.0)
1196        # Grad function returns inputs values for testing
1197        self.assertEqual(grad_val, [1.0, 2.0])
1198        self.assertEqual(sess.run("external:0"), 11)
1199        self.assertEqual(sess.run("outer:0"), 21)
1200
1201    # Export the new graph and reimport to test that imported functions can be
1202    # successfully exported/imported again
1203    gdef = g2.as_graph_def()
1204    with ops.Graph().as_default() as g3:
1205      p1, p2, a, b = importer.import_graph_def(
1206          gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
1207      # Create new gradient functions (in additional to the imported gradient
1208      # functions created in g2).
1209      grad = gradients_impl.gradients([a], [p1, p2])
1210
1211      with self.session(graph=g3) as sess:
1212        feed_dict = {p1: 1, p2: 2}
1213        a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
1214        self.assertEqual(a_val, 3.0)
1215        self.assertEqual(b_val, -1.0)
1216        self.assertEqual(grad_val, [1.0, 2.0])
1217        self.assertEqual(sess.run("external:0"), 11)
1218        self.assertEqual(sess.run("outer:0"), 21)
1219
1220  @test_util.run_v1_only("import inside defun not supported when eager "
1221                         "execution is enabled.")
1222  def testImportInsideDefun(self):
1223    g = ops.Graph()
1224    with g.as_default():
1225
1226      @function.Defun()
1227      def Add2(x, y):
1228        return math_ops.add(x, y)
1229
1230      x = constant_op.constant(3.0, dtype=dtypes.float32)
1231      y = constant_op.constant(-5.0, dtype=dtypes.float32)
1232      z = Add2(x, y, name="z")  # pylint: disable=unexpected-keyword-arg
1233
1234    gdef = g.as_graph_def()
1235
1236    @function.Defun()
1237    def TestFunc():
1238      return importer.import_graph_def(gdef, return_elements=["z:0"])[0]
1239
1240    z = TestFunc()
1241
1242    with self.cached_session():
1243      z_val = self.evaluate(z)
1244      self.assertEqual(z_val, -2.0)
1245
1246  @test_util.run_v1_only("_as_tf_output not supported when eager execution "
1247                         "is enabled.")
1248  def testImportGraphWithFunctionTwice(self):
1249    g = ops.Graph()
1250    with g.as_default():
1251
1252      @function.Defun()
1253      def Add2(x, y):
1254        return math_ops.add(x, y)
1255
1256      x = array_ops.placeholder(dtype=dtypes.float32, name="x")
1257      y = array_ops.placeholder(dtype=dtypes.float32, name="y")
1258      _ = Add2(x, y, name="z")  # pylint: disable=unexpected-keyword-arg
1259
1260    gdef = g.as_graph_def()
1261
1262    x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
1263    y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
1264    input_map = {"x:0": x, "y:0": y}
1265
1266    with ops.name_scope("first"):
1267      z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
1268                                     input_map=input_map)[0]
1269
1270    with ops.name_scope("second"):
1271      z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
1272                                     input_map=input_map)[0]
1273
1274    with self.cached_session() as sess:
1275      z1_val, z2_val = sess.run((z1, z2))
1276      self.assertAllEqual(z1_val, z2_val)
1277
1278
1279if __name__ == "__main__":
1280  test.main()
1281