1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import def_function 23from tensorflow.python.eager import wrap_function 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33 34 35class WrapFunctionTest(test.TestCase): 36 37 def testDocString(self): 38 39 def f(x, do_add): 40 v = variables.Variable(5.0) 41 if do_add: 42 op = v.assign_add(x) 43 else: 44 op = v.assign_sub(x) 45 with ops.control_dependencies([op]): 46 return v.read_value() 47 48 f_add = wrap_function.wrap_function( 49 f, [tensor_spec.TensorSpec((), dtypes.float32), True]) 50 51 self.assertAllEqual(f_add(1.0), 6.0) 52 self.assertAllEqual(f_add(1.0), 7.0) 53 54 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 55 # of variables, and possibly different non-template arguments. 56 f_sub = wrap_function.wrap_function( 57 f, [tensor_spec.TensorSpec((), dtypes.float32), False]) 58 59 self.assertAllEqual(f_sub(1.0), 4.0) 60 self.assertAllEqual(f_sub(1.0), 3.0) 61 62 def testPrune(self): 63 64 x_in = [] 65 x_out = [] 66 67 def f(x, y): 68 x_in.append(x) 69 xx = x * x 70 x_out.append(xx) 71 return xx, 2 * y*y 72 73 f_wrapped = wrap_function.wrap_function( 74 f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2) 75 76 f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) 77 self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0]) 78 79 def testNoArguments(self): 80 81 def f(): 82 return constant_op.constant(1.) 83 84 f_wrapped = wrap_function.wrap_function(f, []) 85 self.assertAllEqual(1.0, f_wrapped()) 86 87 def testPruneCaptures(self): 88 89 v1 = variables.Variable(2.) 90 91 def f(): 92 v2 = variables.Variable(3.) 93 return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') 94 95 f_wrapped = wrap_function.wrap_function(f, []) 96 self.assertAllEqual(6.0, f_wrapped()) 97 98 # Test pruning directly on the inputs 99 pruned = f_wrapped.prune( 100 feeds=f_wrapped.inputs, 101 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 102 self.assertAllEqual(6.0, pruned()) 103 104 # Test pruning with no inputs 105 pruned = f_wrapped.prune( 106 feeds=(), 107 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 108 self.assertAllEqual(6.0, pruned()) 109 110 def testCollectionsIsolation(self): 111 112 v1 = variables.Variable(2.) 113 v2_holder = [] 114 def f(): 115 v2 = variables.Variable(3.) 116 v2_holder.append(v2) 117 ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.)) 118 return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') 119 120 f_wrapped = wrap_function.wrap_function(f, []) 121 self.assertAllEqual(6.0, f_wrapped()) 122 self.assertEqual( 123 len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) 124 f_var_collection = f_wrapped.graph.get_collection( 125 ops.GraphKeys.TRAINABLE_VARIABLES) 126 self.assertEqual(len(f_var_collection), 1) 127 self.assertIs(f_var_collection[0], v2_holder[0]) 128 129 v3_holder = [] 130 def g(): 131 v3 = variables.Variable(4.) 132 v3_holder.append(v3) 133 ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.)) 134 return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch') 135 136 g_wrapped = wrap_function.wrap_function(g, []) 137 self.assertAllEqual(8.0, g_wrapped()) 138 self.assertEqual( 139 len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) 140 g_var_collection = g_wrapped.graph.get_collection( 141 ops.GraphKeys.TRAINABLE_VARIABLES) 142 self.assertEqual(len(g_var_collection), 1) 143 self.assertIs(g_var_collection[0], v3_holder[0]) 144 145 # Both have only one value, and their values aren't equal. So no sharing. 146 self.assertNotEqual(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES), 147 f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)) 148 149 def testGradientsOfPrune(self): 150 151 v1 = variables.Variable(2.) 152 v2_holder = [] 153 154 def f(z): 155 v2 = variables.Variable(3.) 156 v2_holder.append(v2) 157 return array_ops.identity(v1 * v2 * z, 'fetch') 158 159 f_wrapped = wrap_function.wrap_function( 160 f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)]) 161 162 x = constant_op.constant(1.) 163 with backprop.GradientTape() as tape: 164 tape.watch(x) 165 out = f_wrapped(x) 166 grads = tape.gradient(out, [x, v1, v2_holder[0]]) 167 168 self.assertAllEqual(6.0, out) 169 self.assertAllEqual([6.0, 3.0, 2.0], grads) 170 171 pruned = f_wrapped.prune( 172 feeds=f_wrapped.inputs, 173 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 174 175 x = constant_op.constant(1.) 176 with backprop.GradientTape() as tape: 177 tape.watch(x) 178 out = pruned(x) 179 grads = tape.gradient(out, [x, v1, v2_holder[0]]) 180 181 self.assertAllEqual(6.0, out) 182 self.assertAllEqual([6.0, 3.0, 2.0], grads) 183 184 def testPruneOperations(self): 185 186 v = variables.Variable(0) 187 188 def f(): 189 v.assign_add(1, name='increment', read_value=False) 190 191 f_wrapped = wrap_function.wrap_function(f, []) 192 pruned = f_wrapped.prune( 193 feeds=(), 194 fetches=(f_wrapped.graph.get_operation_by_name('increment'),)) 195 self.assertEqual((None,), pruned()) 196 self.assertEqual(1, self.evaluate(v)) 197 198 del f, f_wrapped 199 200 def f1(): 201 v.assign_add( 202 array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'), 203 name='increment', read_value=False) 204 return constant_op.constant(1, name='other') 205 206 f_wrapped = wrap_function.wrap_function(f1, []) 207 increments = f_wrapped.prune( 208 feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), 209 fetches=(f_wrapped.graph.get_operation_by_name('increment'), 210 f_wrapped.graph.get_tensor_by_name('other:0'))) 211 first_output, second_output = increments(constant_op.constant(2)) 212 self.assertEqual(['step:0', 'increment/resource:0'], 213 [t.name for t in increments.inputs]) 214 self.assertIs(None, first_output) 215 self.assertEqual(1, second_output.numpy()) 216 self.assertEqual(3, v.numpy()) 217 does_not_increment = f_wrapped.prune( 218 feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), 219 fetches=f_wrapped.graph.get_tensor_by_name('other:0')) 220 self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy()) 221 self.assertEqual(3, v.numpy()) 222 223 def testPruneStatefulOpsFromWrappedFunc(self): 224 225 v0 = variables.Variable(0) 226 v1 = variables.Variable(0) 227 228 # When we wrap a function, we expect it to be executed with 'tf.Graph` 229 # rules: it's allowed to prune all ops that are not in transitive fanin of 230 # the fetches. 231 def f(x): 232 v0.assign_add(1, name='increment_v0') 233 v1.assign_add(1, name='increment_v1') 234 return x 235 236 f_wrapped = wrap_function.wrap_function(f, [1]) 237 238 self.assertEqual(1, f_wrapped().numpy()) 239 self.assertEqual(0, v0.numpy()) 240 self.assertEqual(0, v1.numpy()) 241 242 f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func') 243 244 self.assertEqual(2, f_wrapped_with_name().numpy()) 245 self.assertEqual(0, v0.numpy()) 246 self.assertEqual(0, v1.numpy()) 247 248 def test_function_from_graph_def(self): 249 @def_function.function 250 def make_graph_def(x): 251 return x + 1. 252 253 original_func_graph = make_graph_def.get_concrete_function( 254 tensor_spec.TensorSpec([None, 2], dtypes.float32)).graph 255 graph_def = original_func_graph.as_graph_def() 256 revived_function = wrap_function.function_from_graph_def( 257 graph_def, inputs=original_func_graph.inputs[0].name, 258 outputs=original_func_graph.outputs[0].name) 259 self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy()) 260 261 262class WrappedGraphTest(test.TestCase): 263 264 def testAddFunction(self): 265 266 def fn(x): 267 v = variables.Variable(3, name='v') 268 v2 = variable_scope.get_variable( 269 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) 270 return v + v2 + x 271 272 with self.cached_session() as sess: 273 result = fn(constant_op.constant(5)) 274 sess.run(variables.global_variables_initializer()) 275 expected = sess.run(result) 276 277 g = wrap_function.WrappedGraph() 278 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 279 wrapped_fn = g.wrap_function(fn, signature) 280 self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy()) 281 282 def testCollections(self): 283 284 def fn(x): 285 v = variables.VariableV1(3, name='v', trainable=False, collections=['a']) 286 v2 = variable_scope.get_variable( 287 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32, 288 collections=['a', 'b']) 289 return v + v2 + x 290 291 def assert_collections(graph): 292 self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1) 293 self.assertLen(graph.get_collection('a'), 2) 294 self.assertLen(graph.get_collection('b'), 1) 295 296 g = wrap_function.WrappedGraph() 297 g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)]) 298 assert_collections(g.graph) 299 300 def assert_fn(): 301 assert_collections(ops.get_default_graph()) 302 return 1 # Return is required 303 304 # Assert that collections are accessible within a wrapped function. 305 g.wrap_function(assert_fn, []) 306 307 def testShareVariablesSameGraph(self): 308 309 def add_v1(x): 310 with variable_scope.variable_scope( 311 'reuse', reuse=variable_scope.AUTO_REUSE): 312 v = variable_scope.get_variable( 313 'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32) 314 return v + x 315 316 def subtract_v1(x): 317 with variable_scope.variable_scope( 318 'reuse', reuse=variable_scope.AUTO_REUSE): 319 v = variable_scope.get_variable( 320 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) 321 return v - x 322 323 def different_variable_fn_v1(x): 324 with variable_scope.variable_scope( 325 'no_reuse', reuse=variable_scope.AUTO_REUSE): 326 v = variable_scope.get_variable( 327 'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32) 328 return v * x 329 330 def increment_variable_v1(x): 331 with variable_scope.variable_scope( 332 'reuse', reuse=variable_scope.AUTO_REUSE): 333 v = variable_scope.get_variable( 334 'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32) 335 return v.assign_add(x) 336 337 g = wrap_function.WrappedGraph() 338 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 339 add = g.wrap_function(add_v1, signature) 340 subtract = g.wrap_function(subtract_v1, signature) 341 different_variable_fn = g.wrap_function(different_variable_fn_v1, signature) 342 increment_variable = g.wrap_function(increment_variable_v1, signature) 343 344 self.assertEqual(10, add(constant_op.constant(7)).numpy()) 345 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 346 347 # The shared variable has a starting value of 3 because add_v1 was wrapped 348 # first. 349 self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) 350 self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) 351 352 # Check that variable updates 353 self.assertEqual(17, add(constant_op.constant(7)).numpy()) 354 self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) 355 356 # Sanity check - result from this function shouldn't change. 357 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 358 359 self.assertAllEqual({'reuse/v:0', 'no_reuse/v:0'}, 360 set([v.name for v in g.variables])) 361 362 def testShareVariablesDifferentGraphs(self): 363 364 def add_v1(x): 365 v = variables.Variable(3, name='v') 366 return v + x 367 368 def subtract_v1(x): 369 v = variables.Variable(4, name='v') 370 return v - x 371 372 def different_variable_fn_v1(x): 373 with ops.name_scope('different_scope'): 374 v = variables.Variable(5, name='v') 375 return v * x 376 377 def increment_variable_v1(x): 378 v = variables.Variable(6, name='v') 379 return v.assign_add(x) 380 381 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 382 vh = wrap_function.VariableHolder(share_variables=True) 383 new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh) 384 385 add = new_graph().wrap_function(add_v1, signature) 386 subtract = new_graph().wrap_function(subtract_v1, signature) 387 different_variable_fn = new_graph().wrap_function( 388 different_variable_fn_v1, signature) 389 increment_variable = new_graph().wrap_function( 390 increment_variable_v1, signature) 391 392 self.assertEqual(10, add(constant_op.constant(7)).numpy()) 393 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 394 395 # Because the variable in add_v1 was created first, its starting value is 3 396 # instead of the values defined in subtract_v1 or increment_variable_v1. 397 self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) 398 self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) 399 400 # Check that variable updates 401 self.assertEqual(17, add(constant_op.constant(7)).numpy()) 402 self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) 403 404 # Sanity check - result from this function shouldn't change. 405 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 406 407 self.assertAllEqual({'v:0', 'different_scope/v:0'}, 408 set([v.name for v in vh.variables])) 409 410if __name__ == '__main__': 411 ops.enable_eager_execution() 412 test.main() 413