• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test utils for composite op definition."""
15from tensorflow.python.eager import backprop
16from tensorflow.python.platform import test
17
18
19class OpsDefsTest(test.TestCase):
20  """Test utils."""
21
22  def _assertOpAndComposite(self, vars_, compute_op, compute_composite, kwargs,
23                            op_kwargs=None):
24    if op_kwargs is None:
25      op_kwargs = kwargs
26
27    # compute with op.
28    with backprop.GradientTape() as gt:
29      for var_ in vars_:
30        gt.watch(var_)
31      y = compute_op(**op_kwargs)  # uses op and decomposites by the graph pass.
32      grads = gt.gradient(y, vars_)  # uses registered gradient function.
33
34    # compute with composition
35    with backprop.GradientTape() as gt:
36      for var_ in vars_:
37        gt.watch(var_)
38      re_y = compute_composite(**kwargs)  # uses composite function.
39      re_grads = gt.gradient(re_y, vars_)  # uses gradients compposite function.
40
41    for v, re_v in zip(y, re_y):
42      self.assertAllClose(v, re_v)
43    for g, re_g in zip(grads, re_grads):
44      self.assertAllClose(g, re_g)
45