• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the swig wrapper tf_optimizer."""
16
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.core.protobuf import config_pb2
19from tensorflow.core.protobuf import rewriter_config_pb2
20from tensorflow.python.client import session
21from tensorflow.python.framework import meta_graph
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import random_seed
24from tensorflow.python.framework import test_util
25from tensorflow.python.grappler import tf_optimizer
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import test
31from tensorflow.python.training import training as train
32
33
34class MemoryOptimizerSwapTest(test.TestCase):
35  """Tests the Grappler memory optimizer."""
36
37  @test_util.run_deprecated_v1
38  def testNoSwapping(self):
39    """Make sure the graph is preserved when there is nothing to swap."""
40    a = variables.VariableV1(10, name='a')
41    b = variables.VariableV1(20, name='b')
42    c = math_ops.add_n([a, b], name='c')
43    d = math_ops.add_n([b, c], name='d')
44    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
45    train_op.append(d)
46    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
47    graph_size = len(mg.graph_def.node)
48    nodes = [node.name for node in mg.graph_def.node]
49
50    config = config_pb2.ConfigProto()
51    config.graph_options.rewrite_options.CopyFrom(
52        rewriter_config_pb2.RewriterConfig(
53            disable_model_pruning=True,
54            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
55            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
56            memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
57    graph = tf_optimizer.OptimizeGraph(config, mg)
58
59    self.assertEqual(len(graph.node), graph_size)
60    self.assertItemsEqual([node.name for node in graph.node], nodes)
61
62  @test_util.run_v1_only('b/120545219')
63  def testSimpleSwap(self):
64    """Check that the swap annotations are followed."""
65    with ops.device('/gpu:0'):
66      a = variables.VariableV1(10, name='a')
67      b = variables.VariableV1(20, name='b')
68      c = math_ops.add_n([a, b], name='c')
69      d = math_ops.add_n([b, c], name='d')
70      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
71      train_op.append(d)
72
73      d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))
74
75      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
76      graph_size = len(mg.graph_def.node)
77
78      config = config_pb2.ConfigProto()
79      config.graph_options.rewrite_options.CopyFrom(
80          rewriter_config_pb2.RewriterConfig(
81              disable_model_pruning=True,
82              meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE,
83              constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
84              memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL,
85              min_graph_nodes=-1))
86      graph = tf_optimizer.OptimizeGraph(config, mg)
87
88      self.assertEqual(len(graph.node), graph_size + 2)
89      self.assertTrue(
90          set(node.name for node in graph.node) > set(
91              ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
92      for node in graph.node:
93        if node.name == 'swap_in_d_0':
94          self.assertEqual('swap_out_d_0', node.input[0])
95          self.assertEqual('^b/read', node.input[1])
96        elif node.name == 'swap_out_d_0':
97          self.assertEqual('b/read', node.input[0])
98        elif node.name == 'd':
99          self.assertEqual('swap_in_d_0', node.input[0])
100          self.assertEqual('c', node.input[1])
101
102
103class MemoryOptimizerRecomputeTest(test.TestCase):
104  """Tests the Python interface to recomputation rewrites.
105
106  See core/grappler/optimizers/memory_optimizer_test.cc for functional tests.
107  """
108
109  def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''):
110    """A simple layered graph with conv, an intermediate op, and a ReLU."""
111    graph = ops.Graph()
112    with graph.as_default():
113      random_seed.set_random_seed(1)
114      current_activation = variable_scope.get_variable(
115          name='start', shape=[batch_size, image_dim, image_dim, 5])
116      conv_filter = variable_scope.get_variable(
117          name='filter', shape=[5, 5, 5, 5])
118      for layer_number in range(10):
119        with variable_scope.variable_scope('layer_{}'.format(layer_number)):
120          after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
121                                 'SAME')
122          current_activation = 2. * after_conv
123          current_activation = nn.relu(current_activation)
124      loss = math_ops.reduce_mean(current_activation)
125      with ops.name_scope(optimizer_scope_name):
126        optimizer = train.AdamOptimizer(0.001)
127        train_op = optimizer.minimize(loss)
128      init_op = variables.global_variables_initializer()
129      metagraph = train.export_meta_graph()
130    return (metagraph, init_op.name, train_op.name, loss.name)
131
132  def testRewritingDefaultGradientNames(self):
133    """Tests that rewriting occurs with default gradient names."""
134    (original_metagraph, _, _, _) = self._GetMetaGraph()
135    config = config_pb2.ConfigProto()
136    config.graph_options.rewrite_options.CopyFrom(
137        rewriter_config_pb2.RewriterConfig(
138            disable_model_pruning=True,
139            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
140            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
141            layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
142            arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
143            min_graph_nodes=-1,
144            memory_optimization=(
145                rewriter_config_pb2.RewriterConfig.RECOMPUTATION_HEURISTICS)))
146    rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
147    self.assertGreater(
148        len(rewritten_graph_def.node),
149        len(original_metagraph.graph_def.node))
150    self.assertEqual(
151        0,
152        len([node for node in original_metagraph.graph_def.node
153             if 'Recomputed/' in node.name]))
154    self.assertEqual(
155        20,  # Two per layer
156        len([node for node in rewritten_graph_def.node
157             if 'Recomputed/' in node.name]))
158
159  def testRewritingNameScopedGradientNames(self):
160    """Tests that rewriting occurs with non-standard gradient names."""
161    (original_metagraph, _, _, _) = self._GetMetaGraph(
162        optimizer_scope_name='optimizer')
163    config = config_pb2.ConfigProto()
164    config.graph_options.rewrite_options.CopyFrom(
165        rewriter_config_pb2.RewriterConfig(
166            disable_model_pruning=True,
167            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
168            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
169            layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
170            arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
171            min_graph_nodes=-1,
172            memory_optimization=rewriter_config_pb2.RewriterConfig
173            .RECOMPUTATION_HEURISTICS,
174            # Checks that name scope "gradients/" also match sub-scope.
175            memory_optimizer_target_node_name_scope='gradients/'))
176    rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
177    self.assertGreater(
178        len(rewritten_graph_def.node),
179        len(original_metagraph.graph_def.node))
180    self.assertEqual(
181        0,
182        len([node for node in original_metagraph.graph_def.node
183             if 'Recomputed/' in node.name]))
184    self.assertEqual(
185        20,  # Two per layer
186        len([node for node in rewritten_graph_def.node
187             if 'Recomputed/' in node.name]))
188
189  def testRewritingNameScopedGradientNamesScope(self):
190    """Tests that rewriting occurs with non-standard gradient names."""
191    (original_metagraph, _, _,
192     _) = self._GetMetaGraph(optimizer_scope_name='foo/bar')
193    config = config_pb2.ConfigProto()
194    config.graph_options.rewrite_options.CopyFrom(
195        rewriter_config_pb2.RewriterConfig(
196            disable_model_pruning=True,
197            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
198            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
199            layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
200            arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
201            memory_optimization=rewriter_config_pb2.RewriterConfig
202            .RECOMPUTATION_HEURISTICS,
203            # This should not match anything.
204            memory_optimizer_target_node_name_scope='r/gradients/'))
205    rewritten_graph_def = tf_optimizer.OptimizeGraph(config, original_metagraph)
206    self.assertEqual(
207        len(rewritten_graph_def.node), len(original_metagraph.graph_def.node))
208    self.assertEqual(0,
209                     len([
210                         node for node in original_metagraph.graph_def.node
211                         if 'Recomputed/' in node.name
212                     ]))
213    self.assertEqual(0,
214                     len([
215                         node for node in rewritten_graph_def.node
216                         if 'Recomputed/' in node.name
217                     ]))
218
219  def _GetMemoryOptimizerSessionConfig(self):
220    rewrite_options = rewriter_config_pb2.RewriterConfig(
221        disable_model_pruning=True,
222        memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
223    graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
224    return config_pb2.ConfigProto(graph_options=graph_options)
225
226  def _RunMetaGraphWithConfig(
227      self, config, metagraph, init_op_name, train_op_name, loss_op_name):
228    graph = ops.Graph()
229    with graph.as_default():
230      train.import_meta_graph(metagraph)
231      init_op = graph.get_operation_by_name(init_op_name)
232      train_op = graph.get_operation_by_name(train_op_name)
233      loss_op = graph.get_tensor_by_name(loss_op_name)
234      with session.Session(config=config, graph=graph) as sess:
235        self.evaluate(init_op)
236        self.evaluate(train_op)
237        self.evaluate(train_op)
238        return self.evaluate(loss_op)
239
240  def testRecomputationRewritingNoErrors(self):
241    """Tests that graph output is not significantly different with rewriting."""
242    (original_metagraph, init_op_name, train_op_name, loss_op_name
243    ) = self._GetMetaGraph()
244    original_loss = self._RunMetaGraphWithConfig(
245        config=config_pb2.ConfigProto(),
246        metagraph=original_metagraph,
247        init_op_name=init_op_name,
248        train_op_name=train_op_name,
249        loss_op_name=loss_op_name)
250    memory_optimized_loss = self._RunMetaGraphWithConfig(
251        config=self._GetMemoryOptimizerSessionConfig(),
252        metagraph=original_metagraph,
253        init_op_name=init_op_name,
254        train_op_name=train_op_name,
255        loss_op_name=loss_op_name)
256    self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
257
258  def _annotated_graph(self):
259    graph = ops.Graph()
260    with graph.as_default():
261      random_seed.set_random_seed(2)
262      current_activation = variable_scope.get_variable(
263          name='start', shape=[1, 2, 2, 5])
264      conv_filter = variable_scope.get_variable(
265          name='filter', shape=[5, 5, 5, 5])
266      for layer_number in range(3):
267        with variable_scope.variable_scope('layer_{}'.format(layer_number)):
268          after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
269                                 'SAME')
270          current_activation = 2. * after_conv
271          current_activation.op._set_attr(
272              '_recompute_hint',
273              # The value of the attribute does not matter; just that the key
274              # exists in the op's attributes.
275              attr_value_pb2.AttrValue(i=1))
276          current_activation += 5.
277          current_activation.op._set_attr(
278              '_recompute_hint', attr_value_pb2.AttrValue(i=0))
279          current_activation = nn.relu(current_activation)
280          current_activation.op._set_attr(
281              '_recompute_hint', attr_value_pb2.AttrValue(i=1))
282      loss = math_ops.reduce_mean(current_activation)
283      optimizer = train.AdamOptimizer(0.001)
284      train_op = optimizer.minimize(loss)
285      init_op = variables.global_variables_initializer()
286    return graph, init_op, train_op
287
288  def testHintNoMetaGraph(self):
289    # Closer to expected usage, but does not check that a re-write actually
290    # happens; see testHintDoesRewrite.
291    graph, init_op, train_op = self._annotated_graph()
292    with graph.as_default():
293      manual_memory_config = rewriter_config_pb2.RewriterConfig(
294          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
295      graph_options = config_pb2.GraphOptions(
296          rewrite_options=manual_memory_config)
297      session_config = config_pb2.ConfigProto(graph_options=graph_options)
298      with session.Session(config=session_config) as sess:
299        self.evaluate(init_op)
300        self.evaluate(train_op)
301
302  @test_util.run_v1_only('b/120545219')
303  def testHintDoesRewrite(self):
304    graph = self._annotated_graph()[0]
305    with graph.as_default():
306      metagraph = train.export_meta_graph()
307    self.assertEqual(
308        0,
309        len([node for node in metagraph.graph_def.node
310             if 'Recomputed/' in node.name]))
311    config = config_pb2.ConfigProto()
312    config.graph_options.rewrite_options.CopyFrom(
313        rewriter_config_pb2.RewriterConfig(
314            min_graph_nodes=-1,
315            memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
316    rewritten_graph_def = tf_optimizer.OptimizeGraph(config, metagraph)
317    self.assertEqual(
318        9,
319        len([
320            node for node in rewritten_graph_def.node
321            if 'Recomputed/' in node.name
322        ]))
323
324if __name__ == '__main__':
325  test.main()
326