• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests and benchmarks for Hessian-vector products with ResNet50."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22import time
23
24from absl.testing import parameterized
25from six.moves import xrange
26import tensorflow as tf
27
28from tensorflow.python.eager import forwardprop
29from tensorflow.python.eager.benchmarks.resnet50 import resnet50
30from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
31
32
33def _forward_over_back_hvp(model, images, labels, vector):
34  with forwardprop.ForwardAccumulator(
35      model.trainable_variables, vector) as acc:
36    with tf.GradientTape() as grad_tape:
37      logits = model(images, training=True)
38      loss = tf.compat.v1.losses.softmax_cross_entropy(
39          logits=logits, onehot_labels=labels)
40    grads = grad_tape.gradient(loss, model.trainable_variables)
41  return acc.jvp(grads)
42
43
44def _back_over_forward_hvp(model, images, labels, vector):
45  with tf.GradientTape() as grad_tape:
46    grad_tape.watch(model.trainable_variables)
47    with forwardprop.ForwardAccumulator(
48        model.trainable_variables, vector) as acc:
49      logits = model(images, training=True)
50      loss = tf.compat.v1.losses.softmax_cross_entropy(
51          logits=logits, onehot_labels=labels)
52  return grad_tape.gradient(acc.jvp(loss), model.trainable_variables)
53
54
55def _tf_gradients_forward_over_back_hvp(model, images, labels, vector):
56  with tf.GradientTape() as grad_tape:
57    logits = model(images, training=True)
58    loss = tf.compat.v1.losses.softmax_cross_entropy(
59        logits=logits, onehot_labels=labels)
60  variables = model.trainable_variables
61  grads = grad_tape.gradient(loss, variables)
62  helpers = tf.nest.map_structure(tf.ones_like, grads)
63  transposing = tf.gradients(grads, variables, helpers)
64  return tf.gradients(transposing, helpers, vector)
65
66
67def _back_over_back_hvp(model, images, labels, vector):
68  with tf.GradientTape() as outer_tape:
69    with tf.GradientTape() as inner_tape:
70      logits = model(images, training=True)
71      loss = tf.compat.v1.losses.softmax_cross_entropy(
72          logits=logits, onehot_labels=labels)
73    grads = inner_tape.gradient(loss, model.trainable_variables)
74  return outer_tape.gradient(
75      grads, model.trainable_variables, output_gradients=vector)
76
77
78class HVPTest(tf.test.TestCase, parameterized.TestCase):
79
80  @parameterized.named_parameters(
81      ("forward_over_back_eager", _forward_over_back_hvp),
82      ("forward_over_back_function", tf.function(_forward_over_back_hvp)),
83      ("tf_gradients", tf.function(_tf_gradients_forward_over_back_hvp)),
84      ("back_over_back_eager", _back_over_back_hvp),
85      ("back_over_back_function", tf.function(_back_over_back_hvp)),
86      ("back_over_forward_eager", _back_over_forward_hvp),
87      ("back_over_forward_function", tf.function(_back_over_forward_hvp)))
88  def test_hvp_shapes(self, hvp_function):
89    device, data_format = resnet50_test_util.device_and_data_format()
90    model = resnet50.ResNet50(data_format)
91    with tf.device(device):
92      images, labels = resnet50_test_util.random_batch(2, data_format)
93      images = tf.constant(images)
94      labels = tf.constant(labels)
95      model.build(images.shape)
96      vector = [tf.ones_like(v) for v in model.trainable_variables]
97
98      # Note that numerical differences build up to quite large differences here
99      # in the final hvp. tensorflow/python/eager:forwardprop_test has a
100      # smaller-scale test that the computations are close on a much smaller but
101      # otherwise similar model.
102      hvp = hvp_function(model, images, labels, vector)
103      for hvp_component, variable in zip(hvp, model.trainable_variables):
104        self.assertEqual(hvp_component.shape, variable.shape)
105        self.assertEqual(hvp_component.dtype, variable.dtype)
106
107
108class HVPBenchmarks(tf.test.Benchmark):
109
110  def _force_device_sync(self):
111    # If this function is called in the context of a non-CPU device
112    # (e.g., inside a 'with tf.device("/gpu:0")' block)
113    # then this will force a copy from CPU->NON_CPU_DEVICE->CPU,
114    # which forces a sync. This is a roundabout way, yes.
115    tf.constant(1.).cpu()
116
117  def _hvp_benchmark(self, hvp_fn, label, batch_sizes,
118                     num_iters=30, num_burn=5):
119    device, data_format = resnet50_test_util.device_and_data_format()
120    model = resnet50.ResNet50(data_format)
121    for batch_size in batch_sizes:
122      with tf.device(device):
123        images, labels = resnet50_test_util.random_batch(
124            batch_size, data_format)
125        images = tf.constant(images)
126        labels = tf.constant(labels)
127        model.build(images.shape)
128        vector = [tf.ones_like(v) for v in model.trainable_variables]
129        for _ in xrange(num_burn):
130          results = hvp_fn(model, images, labels, vector)
131          for result in results:
132            result.cpu()
133        self._force_device_sync()
134        gc.collect()
135        start = time.time()
136        for _ in xrange(num_iters):
137          results = hvp_fn(model, images, labels, vector)
138          for result in results:
139            result.cpu()
140        self._force_device_sync()
141        resnet50_test_util.report(
142            self, label, start, num_iters, device, batch_size, data_format)
143
144  def benchmark_forward_over_backward_hvp_eager(self):
145    self._hvp_benchmark(_forward_over_back_hvp,
146                        "forward_over_backward_hvp_eager",
147                        batch_sizes=[8])
148
149  def benchmark_forward_over_backward_hvp_function(self):
150    self._hvp_benchmark(tf.function(_forward_over_back_hvp),
151                        "forward_over_backward_hvp_function",
152                        batch_sizes=[8])
153
154  def benchmark_tf_gradients_forward_over_backward_hvp_function(self):
155    self._hvp_benchmark(tf.function(_tf_gradients_forward_over_back_hvp),
156                        "tf_gradients_forward_over_backward_hvp_function",
157                        batch_sizes=[8])
158
159  def benchmark_backward_over_backward_hvp_eager(self):
160    self._hvp_benchmark(_back_over_back_hvp,
161                        "backward_over_backward_hvp_eager",
162                        batch_sizes=[8])
163
164  def benchmark_backward_over_backward_hvp_function(self):
165    self._hvp_benchmark(tf.function(_back_over_back_hvp),
166                        "backward_over_backward_hvp_function",
167                        batch_sizes=[8])
168
169  def benchmark_backward_over_forward_hvp_eager(self):
170    self._hvp_benchmark(_back_over_forward_hvp,
171                        "backward_over_forward_hvp_eager",
172                        batch_sizes=[8])
173
174  def benchmark_backward_over_forward_hvp_function(self):
175    self._hvp_benchmark(tf.function(_back_over_forward_hvp),
176                        "backward_over_forward_hvp_function",
177                        batch_sizes=[8])
178
179
180if __name__ == "__main__":
181  tf.compat.v1.enable_v2_behavior()
182  tf.test.main()
183