• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/framework/gradients.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/testutil.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 using ops::Assign;
32 using ops::Const;
33 using ops::Identity;
34 using ops::MatMul;
35 using ops::OnesLike;
36 using ops::Placeholder;
37 using ops::Square;
38 using ops::Stack;
39 using ops::StopGradient;
40 using ops::Unstack;
41 using ops::Variable;
42 
43 // TODO(andydavis) Add more unit tests once more gradient functions are ported.
44 class GradientsTest : public ::testing::Test {
45  protected:
GradientsTest()46   GradientsTest()
47       : scope_expected_(Scope::NewRootScope()),
48         scope_test_(Scope::NewRootScope()) {}
49 
CompareTestAndExpectedGraphs()50   void CompareTestAndExpectedGraphs() {
51     GraphDef gdef_test;
52     TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test));
53     GraphDef gdef_exp;
54     TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp));
55     TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test);
56   }
57 
58   Scope scope_expected_;
59   Scope scope_test_;
60 };
61 
62 // Example:
63 //      ^             ^
64 //    dy|           dx|        (MatMul Gradient Graph)
65 //      |             |
66 //   MatMul_1      MatMul_2
67 //   ^   ^          ^    ^
68 //   |   |----------|    |
69 //   |        ^          |
70 //   |      dz|          |
71 //   |        |          |
72 //   |     Const_3       |
73 //   |                   |
74 //   |        ^          |
75 //   |       z|          |     (MatMul Forward Graph)
76 //   |        |          |
77 //   |      MatMul_0     |
78 //   |     /        \    |
79 //   |    ^          ^   |
80 //   |    |          |   |
81 //   |---x|         y|---|
82 //        |          |
83 //        |          |
84 //      Const_0   Const_1
85 //
86 
TEST_F(GradientsTest,OneMatMul)87 TEST_F(GradientsTest, OneMatMul) {
88   for (const bool expected : {false, true}) {
89     const Scope& scope = expected ? scope_expected_ : scope_test_;
90     // Construct forward graph.
91     auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
92     auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
93     auto z = MatMul(scope, x, y);
94     TF_ASSERT_OK(scope.status());
95     CHECK_NOTNULL(z.node());
96 
97     if (expected) {
98       // Construct backward graph.
99       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
100       auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
101       auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
102     } else {
103       // Call AddSymbolicGradients.
104       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
105       std::vector<Output> grad_outputs;
106       TF_ASSERT_OK(
107           AddSymbolicGradients(scope, {z}, {x, y}, {dz}, &grad_outputs));
108     }
109   }
110   CompareTestAndExpectedGraphs();
111 }
112 
TEST_F(GradientsTest,OneMatMul_InferGradInputs)113 TEST_F(GradientsTest, OneMatMul_InferGradInputs) {
114   for (const bool expected : {false, true}) {
115     const Scope& scope = expected ? scope_expected_ : scope_test_;
116     // Construct forward graph.
117     auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
118     auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
119     auto z = MatMul(scope, x, y);
120     TF_ASSERT_OK(scope.status());
121     CHECK_NOTNULL(z.node());
122 
123     if (expected) {
124       // Construct backward graph.
125       // The gradients function adds a OnesLike to create a dz of ones with the
126       // shape of z.
127       auto dz = OnesLike(scope, z);
128       auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
129       auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
130     } else {
131       // Call AddSymbolicGradients.
132       std::vector<Output> grad_outputs;
133       TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs));
134     }
135   }
136   CompareTestAndExpectedGraphs();
137 }
138 
TEST_F(GradientsTest,TwoMatMuls_Chained)139 TEST_F(GradientsTest, TwoMatMuls_Chained) {
140   for (const bool expected : {false, true}) {
141     const Scope& scope = expected ? scope_expected_ : scope_test_;
142     // Construct forward graph.
143     auto u = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
144     auto v = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
145     auto x = MatMul(scope, u, v);
146 
147     auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
148     auto z = MatMul(scope, x, y);
149 
150     TF_ASSERT_OK(scope.status());
151     CHECK_NOTNULL(z.node());
152 
153     if (expected) {
154       // Construct backward graph.
155       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
156       auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
157       auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
158 
159       auto du = MatMul(scope, dx, v, MatMul::TransposeB(true));
160       auto dv = MatMul(scope, u, dx, MatMul::TransposeA(true));
161     } else {
162       // Call AddSymbolicGradients.
163       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
164       std::vector<Output> grad_outputs;
165       TF_ASSERT_OK(
166           AddSymbolicGradients(scope, {z}, {u, v}, {dz}, &grad_outputs));
167     }
168   }
169   CompareTestAndExpectedGraphs();
170 }
171 
TEST_F(GradientsTest,TwoMatMuls_Independent)172 TEST_F(GradientsTest, TwoMatMuls_Independent) {
173   for (const bool expected : {false, true}) {
174     const Scope& scope = expected ? scope_expected_ : scope_test_;
175     // Construct forward graph.
176     auto t = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
177     auto u = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
178     auto v = MatMul(scope, t, u);
179     TF_ASSERT_OK(scope.status());
180     CHECK_NOTNULL(v.node());
181 
182     auto x = Const(scope, {{5.0, 6.0}, {7.0, 8.0}});
183     auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
184     auto z = MatMul(scope, x, y);
185     TF_ASSERT_OK(scope.status());
186     CHECK_NOTNULL(z.node());
187 
188     if (expected) {
189       // Construct backward graph.
190       auto dv = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
191       auto dt = MatMul(scope, dv, u, MatMul::TransposeB(true));
192       auto du = MatMul(scope, t, dv, MatMul::TransposeA(true));
193 
194       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
195       auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
196       auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
197     } else {
198       // Call AddSymbolicGradients.
199       auto dv = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
200       auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
201       std::vector<Output> grad_outputs;
202       TF_ASSERT_OK(AddSymbolicGradients(scope, {v, z}, {t, u, x, y}, {dv, dz},
203                                         &grad_outputs));
204     }
205   }
206   CompareTestAndExpectedGraphs();
207 }
208 
TEST_F(GradientsTest,StackUnstack_Chained)209 TEST_F(GradientsTest, StackUnstack_Chained) {
210   for (const bool expected : {false, true}) {
211     const Scope& scope = expected ? scope_expected_ : scope_test_;
212     // Construct forward graph.
213     auto a = Const(scope, 1, {4, 2});
214     auto b = Const(scope, 2, {4, 2});
215     auto c = Const(scope, 3, {4, 2});
216 
217     auto pack = Stack(scope, {a, b, c});
218     auto unpack = Unstack(scope, pack.output, 3);
219     TF_ASSERT_OK(scope.status());
220 
221     // Construct grad inputs.
222     auto dx = Const(scope, 4, {4, 2});
223     auto dy = Const(scope, 5, {4, 2});
224     auto dz = Const(scope, 6, {4, 2});
225 
226     if (expected) {
227       // Construct backward graph.
228       auto unpack_grad = Stack(scope, {dx, dy, dz});
229       auto pack_grad = Unstack(scope, unpack_grad.output, 3);
230     } else {
231       // Call AddSymbolicGradients.
232       std::vector<Output> grad_outputs;
233       TF_ASSERT_OK(AddSymbolicGradients(scope, unpack.output, {a, b, c},
234                                         {dx, dy, dz}, &grad_outputs));
235     }
236   }
237   CompareTestAndExpectedGraphs();
238 }
239 
TEST_F(GradientsTest,StackUnstack_StopBackprop)240 TEST_F(GradientsTest, StackUnstack_StopBackprop) {
241   // Tests that backprop stops before calculating gradients for Stack (because
242   // only gradients w.r.t the output of Stack are requested).
243   for (const bool expected : {false, true}) {
244     const Scope& scope = expected ? scope_expected_ : scope_test_;
245     // Construct forward graph.
246     auto a = Const(scope, 1, {4, 2});
247     auto b = Const(scope, 2, {4, 2});
248     auto c = Const(scope, 3, {4, 2});
249 
250     auto pack = Stack(scope, {a, b, c});
251     auto unpack = Unstack(scope, pack.output, 3);
252     TF_ASSERT_OK(scope.status());
253 
254     // Construct grad inputs.
255     auto dx = Const(scope, 4, {4, 2});
256     auto dy = Const(scope, 5, {4, 2});
257     auto dz = Const(scope, 6, {4, 2});
258 
259     if (expected) {
260       // Construct backward graph.
261       // NOTE: We should only expect the grad function for unpack in the
262       // gradients graph, based on the requested grad outputs.
263       auto unpack_grad = Stack(scope, {dx, dy, dz});
264     } else {
265       // Call AddSymbolicGradients.
266       std::vector<Output> grad_outputs;
267       TF_ASSERT_OK(AddSymbolicGradients(scope, unpack.output, {pack},
268                                         {dx, dy, dz}, &grad_outputs));
269     }
270   }
271   CompareTestAndExpectedGraphs();
272 }
273 
TEST_F(GradientsTest,StackUnstack_SubsetOfUnstackOutputs)274 TEST_F(GradientsTest, StackUnstack_SubsetOfUnstackOutputs) {
275   // Constructs an unstack with three outputs, and takes the gradient with
276   // respect to only two of the outputs. Tests that the output gradients are
277   // computed.
278   for (const bool expected : {false, true}) {
279     const Scope& scope = expected ? scope_expected_ : scope_test_;
280     // Construct forward graph.
281     auto c = Const(scope, 1, {3, 4, 2});
282     auto unpack = Unstack(scope, c, 3);
283     auto x = Identity(scope, unpack.output[0]);
284     auto y = Identity(scope, unpack.output[1]);
285     auto z = Identity(scope, unpack.output[2]);
286     TF_ASSERT_OK(scope.status());
287 
288     // Construct grad inputs.
289     auto dy = Const(scope, 4, {4, 2});
290     auto dz = Const(scope, 5, {4, 2});
291 
292     if (expected) {
293       // Construct backward graph.
294       auto g1 = Identity(scope, dy);
295       auto g2 = Identity(scope, dz);
296     } else {
297       // Call AddSymbolicGradients.
298       std::vector<Output> grad_outputs;
299       TF_ASSERT_OK(AddSymbolicGradients(scope, {y, z},
300                                         {unpack.output[1], unpack.output[2]},
301                                         {dy, dz}, &grad_outputs));
302       ASSERT_EQ(grad_outputs.size(), 2);
303       EXPECT_TRUE(grad_outputs[0].node() != nullptr);
304       EXPECT_TRUE(grad_outputs[1].node() != nullptr);
305     }
306   }
307   CompareTestAndExpectedGraphs();
308 }
309 
TEST_F(GradientsTest,DependentGradOutputs)310 TEST_F(GradientsTest, DependentGradOutputs) {
311   // Tests that dependent gradients (in this case the gradients w.r.t to the
312   // output and one input of MatMul) are computed properly.
313 
314   // Create two chained MatMul ops.
315   auto u = Const(scope_test_, {{2}});
316   auto v = Const(scope_test_, {{3}});
317   auto x = MatMul(scope_test_, u, v);
318 
319   auto y = Const(scope_test_, {{4}});
320   auto z = MatMul(scope_test_, x, y);
321 
322   TF_ASSERT_OK(scope_test_.status());
323   CHECK_NOTNULL(z.node());
324 
325   // Call AddSymbolicGradients with '5' as initial gradients for 'dz'.
326   // The gradient w.r.t to 'v' (returned in grad_outputs[0]) is dependent on
327   // the gradient w.r.t. to 'x' (returned in grad_outputs[1]).
328   auto dz = Const(scope_test_, {{5}});
329   std::vector<Output> grad_outputs;
330   TF_ASSERT_OK(
331       AddSymbolicGradients(scope_test_, {z}, {v, x}, {dz}, &grad_outputs));
332 
333   std::vector<Tensor> outputs;
334   test::GetTensors(scope_test_, {grad_outputs[0], grad_outputs[1]}, &outputs);
335 
336   // The gradients w.r.t to 'dz' are passed into AddSymbolicGradients as '5'.
337   // Since z = MatMul(x, y), the gradients w.r.t 'x' are computed as:
338   //   'dx' = 5 * 'y' = 5 * 4 = 20.
339   // Since x = MatMul(u, v), the gradients w.r.t. 'v' are computed as:
340   //   'dv' = 'dx' * 'u' = 20 * 2 = 40.
341   test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({40}, {1, 1}));
342   test::ExpectTensorEqual<int>(outputs[1], test::AsTensor<int>({20}, {1, 1}));
343 }
344 
TEST_F(GradientsTest,MultipleNodeOutputGrads)345 TEST_F(GradientsTest, MultipleNodeOutputGrads) {
346   // Tests that gradients for multiple outputs of the same node are returned.
347   auto x = Const(scope_test_, 1, {3, 4, 2});
348   auto unpack = Unstack(scope_test_, x, 3);
349   auto pack = Stack(scope_test_, unpack.output);
350 
351   // clang-format off
352   auto dx = Const(scope_test_, {40, 41, 42, 43, 44, 45, 46, 47,
353                                 50, 51, 52, 53, 55, 55, 56, 57,
354                                 60, 61, 62, 63, 66, 66, 66, 67},
355                                {3, 4, 2});
356   // clang-format on
357 
358   std::vector<Output> grad_outputs;
359   TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {pack}, unpack.output, {dx},
360                                     &grad_outputs));
361 
362   std::vector<Tensor> outputs;
363   test::GetTensors(scope_test_,
364                    {grad_outputs[0], grad_outputs[1], grad_outputs[2]},
365                    &outputs);
366 
367   test::ExpectTensorEqual<int>(
368       outputs[0],
369       test::AsTensor<int>({40, 41, 42, 43, 44, 45, 46, 47}, {4, 2}));
370   test::ExpectTensorEqual<int>(
371       outputs[1],
372       test::AsTensor<int>({50, 51, 52, 53, 55, 55, 56, 57}, {4, 2}));
373   test::ExpectTensorEqual<int>(
374       outputs[2],
375       test::AsTensor<int>({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2}));
376 }
377 
TEST_F(GradientsTest,UnreachableEdgeGradOneOutput)378 TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
379   auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
380   auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
381   auto x_assign = Assign(scope_test_, x, x_const);
382 
383   auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
384   auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
385   auto y_assign = Assign(scope_test_, y, y_const);
386 
387   auto m = MatMul(scope_test_, x, y);
388 
389   auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
390   auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
391   auto z_assign = Assign(scope_test_, z, z_const);
392 
393   auto diff_m = Const(scope_test_, {{0.5}, {0.5}});
394 
395   std::vector<Output> grad_outputs;
396   TF_ASSERT_OK(
397       AddSymbolicGradients(scope_test_, {m}, {y}, {diff_m}, &grad_outputs));
398 
399   std::vector<Tensor> outputs;
400   test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
401                    {grad_outputs[0]}, &outputs);
402   // dz/dy = xT * diff_m
403   test::ExpectTensorNear<double>(
404       outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
405 }
406 
TEST_F(GradientsTest,UnreachableEdgeGradTwoOutputs)407 TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
408   auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
409   auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
410   auto x_assign = Assign(scope_test_, x, x_const);
411 
412   auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
413   auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
414   auto y_assign = Assign(scope_test_, y, y_const);
415 
416   auto m1 = MatMul(scope_test_, x, y);
417 
418   auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
419   auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
420   auto z_assign = Assign(scope_test_, z, z_const);
421 
422   auto m2 = MatMul(scope_test_, y, z);
423 
424   auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
425   auto dm2 =
426       Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}});
427 
428   std::vector<Output> grad_outputs;
429   TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2},
430                                     &grad_outputs));
431 
432   std::vector<Tensor> outputs;
433   test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
434                    {grad_outputs[0]}, &outputs);
435 
436   // The gradients from m1 and m2 will be summed to compute the gradient
437   // w.r.t y:
438   // dz/dy = xT * dm1 + dm2 * zT
439   test::ExpectTensorNear<double>(
440       outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
441 }
442 
TEST_F(GradientsTest,UnreachableInput)443 TEST_F(GradientsTest, UnreachableInput) {
444   auto x = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
445   auto y = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
446   auto z = Const(scope_test_.WithOpName("z"), {{9.0, 10.0, 11.0}});
447 
448   auto m1 = MatMul(scope_test_, x, y);
449   auto m2 = MatMul(scope_test_, y, z);
450   auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
451 
452   // From m1, z is unreachable, so an error status should be returned.
453   //   m2  m1
454   //   |   |
455   //   *   *
456   //  / \ / \
457   // z   y   x
458   std::vector<Output> grad_outputs;
459   Status status =
460       AddSymbolicGradients(scope_test_, {m1}, {z}, {dm1}, &grad_outputs);
461   EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
462   EXPECT_EQ(status.error_message(),
463             "Cannot compute the partial derivative"
464             " for node 'z' as it's unreachable from the output node(s).");
465 }
466 
TEST_F(GradientsTest,DependentOutputs)467 TEST_F(GradientsTest, DependentOutputs) {
468   auto x = Placeholder(scope_test_, DT_FLOAT);
469   auto y0 = Square(scope_test_, x);
470   auto y1 = Square(scope_test_, y0);
471   auto y2 = Square(scope_test_, y1);
472   // Requesting the gradients for y0 and y2 should return the sum of their
473   // individual gradients.
474   std::vector<Output> grad_outputs;
475   TF_EXPECT_OK(AddSymbolicGradients(scope_test_, {y0, y2}, {x}, &grad_outputs));
476   ClientSession session(scope_test_);
477   std::vector<Tensor> grad_result;
478   TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
479   EXPECT_EQ(grad_result.size(), 1);
480   EXPECT_EQ(grad_result[0].NumElements(), 1);
481   EXPECT_EQ(grad_result[0].flat<float>()(0), 17502.0f);
482 }
483 
TEST_F(GradientsTest,MultiOutputNodeDependentOutputs)484 TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
485   auto x = Placeholder(scope_test_, DT_FLOAT);
486   auto y0 = Square(scope_test_, x);
487   // y1, y2, and y3 all use y0. This means the backwards pass will need to wait
488   // for the gradient for all three.
489   auto y1 = Square(scope_test_, y0);
490   auto y2 = Square(scope_test_, y0);
491   auto y3 = Square(scope_test_, y2);
492   std::vector<Output> grad_outputs;
493   // By requesting y0, y1, and y3 we test that the computation correctly waits
494   // for all the points in backprop where gradients need to be summed from
495   // multiple branches.
496   TF_EXPECT_OK(
497       AddSymbolicGradients(scope_test_, {y0, y1, y3}, {x}, &grad_outputs));
498   ClientSession session(scope_test_);
499   std::vector<Tensor> grad_result;
500   TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
501   EXPECT_EQ(grad_result.size(), 1);
502   EXPECT_EQ(grad_result[0].NumElements(), 1);
503   EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
504 }
505 
506 // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
507 // 'NoGradient' (induced by StopGradient op) returned along multiple edges from
508 // a single nodes output.
509 class StopGradientSingleOutputMultiEdgeTest : public ::testing::Test {
510  protected:
StopGradientSingleOutputMultiEdgeTest()511   StopGradientSingleOutputMultiEdgeTest() : scope_(Scope::NewRootScope()) {}
512 
CheckGrad(const std::vector<bool> & stop_outputs,const Tensor & expected_grad)513   void CheckGrad(const std::vector<bool>& stop_outputs,
514                  const Tensor& expected_grad) {
515     CHECK_EQ(3, stop_outputs.size());
516 
517     auto x = Const(scope_, {{1, 0}, {0, 1}});
518     auto y = Const(scope_, {{1, 0}, {0, 1}});
519     auto z = MatMul(scope_, x, y);
520 
521     // Create three output going edges from 'z'.
522     // Add StopGradients according to 'stop_outputs'.
523     auto out0 = stop_outputs[0]
524                     ? StopGradient(scope_, (Identity(scope_, z))).output
525                     : Identity(scope_, z).output;
526     auto out1 = stop_outputs[1]
527                     ? StopGradient(scope_, (Identity(scope_, z))).output
528                     : Identity(scope_, z).output;
529     auto out2 = stop_outputs[2]
530                     ? StopGradient(scope_, (Identity(scope_, z))).output
531                     : Identity(scope_, z).output;
532 
533     auto g0 = Const(scope_, {{1, 2}, {3, 4}});
534     auto g1 = Const(scope_, {{5, 6}, {7, 8}});
535     auto g2 = Const(scope_, {{9, 10}, {11, 12}});
536 
537     // Call AddSymbolicGradients and compare against 'expected_grad'.
538     std::vector<Output> grad_outputs;
539     TF_EXPECT_OK(AddSymbolicGradients(scope_, {out0, out1, out2}, {z},
540                                       {g0, g1, g2}, &grad_outputs));
541 
542     if (expected_grad.NumElements() > 0) {
543       Tensor output;
544       test::GetTensor(scope_, grad_outputs[0], &output);
545       test::ExpectTensorEqual<int>(output, expected_grad);
546     } else {
547       EXPECT_EQ(NoGradient(), grad_outputs[0]);
548     }
549   }
550 
551   Scope scope_;
552 };
553 
TEST_F(StopGradientSingleOutputMultiEdgeTest,ValidGradAllEdges)554 TEST_F(StopGradientSingleOutputMultiEdgeTest, ValidGradAllEdges) {
555   CheckGrad({false, false, false},
556             test::AsTensor<int>({15, 18, 21, 24}, {2, 2}));
557 }
558 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstEdge)559 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstEdge) {
560   CheckGrad({true, false, false},
561             test::AsTensor<int>({14, 16, 18, 20}, {2, 2}));
562 }
563 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradSecondEdge)564 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradSecondEdge) {
565   CheckGrad({false, true, false},
566             test::AsTensor<int>({10, 12, 14, 16}, {2, 2}));
567 }
568 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradThirdEdge)569 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradThirdEdge) {
570   CheckGrad({false, false, true}, test::AsTensor<int>({6, 8, 10, 12}, {2, 2}));
571 }
572 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstAndSecondEdges)573 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstAndSecondEdges) {
574   CheckGrad({true, true, false}, test::AsTensor<int>({9, 10, 11, 12}, {2, 2}));
575 }
576 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradSecondAndThirdEdges)577 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradSecondAndThirdEdges) {
578   CheckGrad({false, true, true}, test::AsTensor<int>({1, 2, 3, 4}, {2, 2}));
579 }
580 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstAndThirdEdges)581 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstAndThirdEdges) {
582   CheckGrad({true, false, true}, test::AsTensor<int>({5, 6, 7, 8}, {2, 2}));
583 }
584 
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradAllEdges)585 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradAllEdges) {
586   CheckGrad({true, true, true}, Tensor());
587 }
588 
589 // StopGradientMultiOutputTest tests combinations of valid and 'NoGradient'
590 // (induced by StopGradient op) returned along a single nodes multiple outputs.
591 class StopGradientMultiOutputTest : public ::testing::Test {
592  protected:
StopGradientMultiOutputTest()593   StopGradientMultiOutputTest() : scope_(Scope::NewRootScope()) {}
594 
CheckGrad(const std::vector<bool> & stop_outputs,const Tensor & expected_grad)595   void CheckGrad(const std::vector<bool>& stop_outputs,
596                  const Tensor& expected_grad) {
597     CHECK_EQ(3, stop_outputs.size());
598     auto x = ops::Const(scope_, 1, {3, 2, 4});
599     auto y = Unstack(scope_, x, 3);
600     TF_ASSERT_OK(scope_.status());
601 
602     // Add StopGradients according to 'stop_outputs'.
603     auto out0 =
604         stop_outputs[0] ? StopGradient(scope_, y.output[0]) : y.output[0];
605     auto out1 =
606         stop_outputs[1] ? StopGradient(scope_, y.output[1]) : y.output[1];
607     auto out2 =
608         stop_outputs[2] ? StopGradient(scope_, y.output[2]) : y.output[2];
609 
610     auto g0 = Const(scope_, {1, 2, 3, 4, 5, 6, 7, 8}, {2, 4});
611     auto g1 = Const(scope_, {9, 10, 11, 12, 13, 14, 15, 16}, {2, 4});
612     auto g2 = Const(scope_, {17, 18, 19, 20, 21, 22, 23, 24}, {2, 4});
613 
614     // Call AddSymbolicGradients and compare against 'expected_grad'.
615     std::vector<Output> grad_outputs;
616     TF_EXPECT_OK(AddSymbolicGradients(scope_, {out0, out1, out2}, {x},
617                                       {g0, g1, g2}, &grad_outputs));
618 
619     if (expected_grad.NumElements() > 0) {
620       Tensor output;
621       test::GetTensor(scope_, grad_outputs[0], &output);
622       test::ExpectTensorEqual<int>(output, expected_grad);
623     } else {
624       EXPECT_EQ(NoGradient(), grad_outputs[0]);
625     }
626   }
627 
628   Scope scope_;
629 };
630 
TEST_F(StopGradientMultiOutputTest,ValidGradAllOutputs)631 TEST_F(StopGradientMultiOutputTest, ValidGradAllOutputs) {
632   // clang-format off
633   CheckGrad({false, false, false}, test::AsTensor<int>(
634     {1, 2, 3, 4, 5, 6, 7, 8,
635      9, 10, 11, 12, 13, 14, 15, 16,
636      17, 18, 19, 20, 21, 22, 23, 24},
637     {3, 2, 4}));
638   // clang-format on
639 }
640 
TEST_F(StopGradientMultiOutputTest,StopGradFirstOutput)641 TEST_F(StopGradientMultiOutputTest, StopGradFirstOutput) {
642   // clang-format off
643   CheckGrad({true, false, false}, test::AsTensor<int>(
644     {0, 0, 0, 0, 0, 0, 0, 0,
645      9, 10, 11, 12, 13, 14, 15, 16,
646      17, 18, 19, 20, 21, 22, 23, 24},
647     {3, 2, 4}));
648   // clang-format on
649 }
650 
TEST_F(StopGradientMultiOutputTest,StopGradSecondOutput)651 TEST_F(StopGradientMultiOutputTest, StopGradSecondOutput) {
652   // clang-format off
653   CheckGrad({false, true, false}, test::AsTensor<int>(
654     {1, 2, 3, 4, 5, 6, 7, 8,
655      0, 0, 0, 0, 0, 0, 0, 0,
656      17, 18, 19, 20, 21, 22, 23, 24},
657     {3, 2, 4}));
658   // clang-format on
659 }
660 
TEST_F(StopGradientMultiOutputTest,StopGradThirdOutput)661 TEST_F(StopGradientMultiOutputTest, StopGradThirdOutput) {
662   // clang-format off
663   CheckGrad({false, false, true}, test::AsTensor<int>(
664     {1, 2, 3, 4, 5, 6, 7, 8,
665      9, 10, 11, 12, 13, 14, 15, 16,
666      0, 0, 0, 0, 0, 0, 0, 0},
667     {3, 2, 4}));
668   // clang-format on
669 }
670 
TEST_F(StopGradientMultiOutputTest,StopGradFirstAndThirdOutputs)671 TEST_F(StopGradientMultiOutputTest, StopGradFirstAndThirdOutputs) {
672   // clang-format off
673   CheckGrad({true, false, true}, test::AsTensor<int>(
674     {0, 0, 0, 0, 0, 0, 0, 0,
675      9, 10, 11, 12, 13, 14, 15, 16,
676      0, 0, 0, 0, 0, 0, 0, 0},
677     {3, 2, 4}));
678   // clang-format on
679 }
680 
TEST_F(StopGradientMultiOutputTest,StopAllOutputs)681 TEST_F(StopGradientMultiOutputTest, StopAllOutputs) {
682   CheckGrad({true, true, true}, Tensor());
683 }
684 
685 }  // namespace
686 }  // namespace tensorflow
687