• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import def_function
23from tensorflow.python.eager import wrap_function
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33
34
35class WrapFunctionTest(test.TestCase):
36
37  def testDocString(self):
38
39    def f(x, do_add):
40      v = variables.Variable(5.0)
41      if do_add:
42        op = v.assign_add(x)
43      else:
44        op = v.assign_sub(x)
45      with ops.control_dependencies([op]):
46        return v.read_value()
47
48    f_add = wrap_function.wrap_function(
49        f, [tensor_spec.TensorSpec((), dtypes.float32), True])
50
51    self.assertAllEqual(f_add(1.0), 6.0)
52    self.assertAllEqual(f_add(1.0), 7.0)
53
54    # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
55    # of variables, and possibly different non-template arguments.
56    f_sub = wrap_function.wrap_function(
57        f, [tensor_spec.TensorSpec((), dtypes.float32), False])
58
59    self.assertAllEqual(f_sub(1.0), 4.0)
60    self.assertAllEqual(f_sub(1.0), 3.0)
61
62  def testPrune(self):
63
64    x_in = []
65    x_out = []
66
67    def f(x, y):
68      x_in.append(x)
69      xx = x * x
70      x_out.append(xx)
71      return xx, 2 * y*y
72
73    f_wrapped = wrap_function.wrap_function(
74        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2)
75
76    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
77    self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
78
79  def testNoArguments(self):
80
81    def f():
82      return constant_op.constant(1.)
83
84    f_wrapped = wrap_function.wrap_function(f, [])
85    self.assertAllEqual(1.0, f_wrapped())
86
87  def testPruneCaptures(self):
88
89    v1 = variables.Variable(2.)
90
91    def f():
92      v2 = variables.Variable(3.)
93      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')
94
95    f_wrapped = wrap_function.wrap_function(f, [])
96    self.assertAllEqual(6.0, f_wrapped())
97
98    # Test pruning directly on the inputs
99    pruned = f_wrapped.prune(
100        feeds=f_wrapped.inputs,
101        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
102    self.assertAllEqual(6.0, pruned())
103
104    # Test pruning with no inputs
105    pruned = f_wrapped.prune(
106        feeds=(),
107        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
108    self.assertAllEqual(6.0, pruned())
109
110  def testCollectionsIsolation(self):
111
112    v1 = variables.Variable(2.)
113    v2_holder = []
114    def f():
115      v2 = variables.Variable(3.)
116      v2_holder.append(v2)
117      ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.))
118      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')
119
120    f_wrapped = wrap_function.wrap_function(f, [])
121    self.assertAllEqual(6.0, f_wrapped())
122    self.assertEqual(
123        len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
124    f_var_collection = f_wrapped.graph.get_collection(
125        ops.GraphKeys.TRAINABLE_VARIABLES)
126    self.assertEqual(len(f_var_collection), 1)
127    self.assertIs(f_var_collection[0], v2_holder[0])
128
129    v3_holder = []
130    def g():
131      v3 = variables.Variable(4.)
132      v3_holder.append(v3)
133      ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.))
134      return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch')
135
136    g_wrapped = wrap_function.wrap_function(g, [])
137    self.assertAllEqual(8.0, g_wrapped())
138    self.assertEqual(
139        len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
140    g_var_collection = g_wrapped.graph.get_collection(
141        ops.GraphKeys.TRAINABLE_VARIABLES)
142    self.assertEqual(len(g_var_collection), 1)
143    self.assertIs(g_var_collection[0], v3_holder[0])
144
145    # Both have only one value, and their values aren't equal. So no sharing.
146    self.assertNotEqual(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES),
147                        f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES))
148
149  def testGradientsOfPrune(self):
150
151    v1 = variables.Variable(2.)
152    v2_holder = []
153
154    def f(z):
155      v2 = variables.Variable(3.)
156      v2_holder.append(v2)
157      return array_ops.identity(v1 * v2 * z, 'fetch')
158
159    f_wrapped = wrap_function.wrap_function(
160        f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)])
161
162    x = constant_op.constant(1.)
163    with backprop.GradientTape() as tape:
164      tape.watch(x)
165      out = f_wrapped(x)
166    grads = tape.gradient(out, [x, v1, v2_holder[0]])
167
168    self.assertAllEqual(6.0, out)
169    self.assertAllEqual([6.0, 3.0, 2.0], grads)
170
171    pruned = f_wrapped.prune(
172        feeds=f_wrapped.inputs,
173        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
174
175    x = constant_op.constant(1.)
176    with backprop.GradientTape() as tape:
177      tape.watch(x)
178      out = pruned(x)
179    grads = tape.gradient(out, [x, v1, v2_holder[0]])
180
181    self.assertAllEqual(6.0, out)
182    self.assertAllEqual([6.0, 3.0, 2.0], grads)
183
184  def testPruneOperations(self):
185
186    v = variables.Variable(0)
187
188    def f():
189      v.assign_add(1, name='increment', read_value=False)
190
191    f_wrapped = wrap_function.wrap_function(f, [])
192    pruned = f_wrapped.prune(
193        feeds=(),
194        fetches=(f_wrapped.graph.get_operation_by_name('increment'),))
195    self.assertEqual((None,), pruned())
196    self.assertEqual(1, self.evaluate(v))
197
198    del f, f_wrapped
199
200    def f1():
201      v.assign_add(
202          array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'),
203          name='increment', read_value=False)
204      return constant_op.constant(1, name='other')
205
206    f_wrapped = wrap_function.wrap_function(f1, [])
207    increments = f_wrapped.prune(
208        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
209        fetches=(f_wrapped.graph.get_operation_by_name('increment'),
210                 f_wrapped.graph.get_tensor_by_name('other:0')))
211    first_output, second_output = increments(constant_op.constant(2))
212    self.assertEqual(['step:0', 'increment/resource:0'],
213                     [t.name for t in increments.inputs])
214    self.assertIs(None, first_output)
215    self.assertEqual(1, second_output.numpy())
216    self.assertEqual(3, v.numpy())
217    does_not_increment = f_wrapped.prune(
218        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
219        fetches=f_wrapped.graph.get_tensor_by_name('other:0'))
220    self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy())
221    self.assertEqual(3, v.numpy())
222
223  def testPruneStatefulOpsFromWrappedFunc(self):
224
225    v0 = variables.Variable(0)
226    v1 = variables.Variable(0)
227
228    # When we wrap a function, we expect it to be executed with 'tf.Graph`
229    # rules: it's allowed to prune all ops that are not in transitive fanin of
230    # the fetches.
231    def f(x):
232      v0.assign_add(1, name='increment_v0')
233      v1.assign_add(1, name='increment_v1')
234      return x
235
236    f_wrapped = wrap_function.wrap_function(f, [1])
237
238    self.assertEqual(1, f_wrapped().numpy())
239    self.assertEqual(0, v0.numpy())
240    self.assertEqual(0, v1.numpy())
241
242    f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func')
243
244    self.assertEqual(2, f_wrapped_with_name().numpy())
245    self.assertEqual(0, v0.numpy())
246    self.assertEqual(0, v1.numpy())
247
248  def test_function_from_graph_def(self):
249    @def_function.function
250    def make_graph_def(x):
251      return x + 1.
252
253    original_func_graph = make_graph_def.get_concrete_function(
254        tensor_spec.TensorSpec([None, 2], dtypes.float32)).graph
255    graph_def = original_func_graph.as_graph_def()
256    revived_function = wrap_function.function_from_graph_def(
257        graph_def, inputs=original_func_graph.inputs[0].name,
258        outputs=original_func_graph.outputs[0].name)
259    self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy())
260
261
262class WrappedGraphTest(test.TestCase):
263
264  def testAddFunction(self):
265
266    def fn(x):
267      v = variables.Variable(3, name='v')
268      v2 = variable_scope.get_variable(
269          'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
270      return v + v2 + x
271
272    with self.cached_session() as sess:
273      result = fn(constant_op.constant(5))
274      sess.run(variables.global_variables_initializer())
275      expected = sess.run(result)
276
277    g = wrap_function.WrappedGraph()
278    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
279    wrapped_fn = g.wrap_function(fn, signature)
280    self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())
281
282  def testCollections(self):
283
284    def fn(x):
285      v = variables.VariableV1(3, name='v', trainable=False, collections=['a'])
286      v2 = variable_scope.get_variable(
287          'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32,
288          collections=['a', 'b'])
289      return v + v2 + x
290
291    def assert_collections(graph):
292      self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1)
293      self.assertLen(graph.get_collection('a'), 2)
294      self.assertLen(graph.get_collection('b'), 1)
295
296    g = wrap_function.WrappedGraph()
297    g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)])
298    assert_collections(g.graph)
299
300    def assert_fn():
301      assert_collections(ops.get_default_graph())
302      return 1  # Return is required
303
304    # Assert that collections are accessible within a wrapped function.
305    g.wrap_function(assert_fn, [])
306
307  def testShareVariablesSameGraph(self):
308
309    def add_v1(x):
310      with variable_scope.variable_scope(
311          'reuse', reuse=variable_scope.AUTO_REUSE):
312        v = variable_scope.get_variable(
313            'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32)
314      return v + x
315
316    def subtract_v1(x):
317      with variable_scope.variable_scope(
318          'reuse', reuse=variable_scope.AUTO_REUSE):
319        v = variable_scope.get_variable(
320            'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
321      return v - x
322
323    def different_variable_fn_v1(x):
324      with variable_scope.variable_scope(
325          'no_reuse', reuse=variable_scope.AUTO_REUSE):
326        v = variable_scope.get_variable(
327            'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32)
328      return v * x
329
330    def increment_variable_v1(x):
331      with variable_scope.variable_scope(
332          'reuse', reuse=variable_scope.AUTO_REUSE):
333        v = variable_scope.get_variable(
334            'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32)
335      return v.assign_add(x)
336
337    g = wrap_function.WrappedGraph()
338    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
339    add = g.wrap_function(add_v1, signature)
340    subtract = g.wrap_function(subtract_v1, signature)
341    different_variable_fn = g.wrap_function(different_variable_fn_v1, signature)
342    increment_variable = g.wrap_function(increment_variable_v1, signature)
343
344    self.assertEqual(10, add(constant_op.constant(7)).numpy())
345    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
346
347    # The shared variable has a starting value of 3 because add_v1 was wrapped
348    # first.
349    self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
350    self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
351
352    # Check that variable updates
353    self.assertEqual(17, add(constant_op.constant(7)).numpy())
354    self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
355
356    # Sanity check - result from this function shouldn't change.
357    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
358
359    self.assertAllEqual({'reuse/v:0', 'no_reuse/v:0'},
360                        set([v.name for v in g.variables]))
361
362  def testShareVariablesDifferentGraphs(self):
363
364    def add_v1(x):
365      v = variables.Variable(3, name='v')
366      return v + x
367
368    def subtract_v1(x):
369      v = variables.Variable(4, name='v')
370      return v - x
371
372    def different_variable_fn_v1(x):
373      with ops.name_scope('different_scope'):
374        v = variables.Variable(5, name='v')
375      return v * x
376
377    def increment_variable_v1(x):
378      v = variables.Variable(6, name='v')
379      return v.assign_add(x)
380
381    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
382    vh = wrap_function.VariableHolder(share_variables=True)
383    new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)
384
385    add = new_graph().wrap_function(add_v1, signature)
386    subtract = new_graph().wrap_function(subtract_v1, signature)
387    different_variable_fn = new_graph().wrap_function(
388        different_variable_fn_v1, signature)
389    increment_variable = new_graph().wrap_function(
390        increment_variable_v1, signature)
391
392    self.assertEqual(10, add(constant_op.constant(7)).numpy())
393    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
394
395    # Because the variable in add_v1 was created first, its starting value is 3
396    # instead of the values defined in subtract_v1 or increment_variable_v1.
397    self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
398    self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
399
400    # Check that variable updates
401    self.assertEqual(17, add(constant_op.constant(7)).numpy())
402    self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
403
404    # Sanity check - result from this function shouldn't change.
405    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
406
407    self.assertAllEqual({'v:0', 'different_scope/v:0'},
408                        set([v.name for v in vh.variables]))
409
410if __name__ == '__main__':
411  ops.enable_eager_execution()
412  test.main()
413