1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/framework/gradients.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/testutil.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27
28 namespace tensorflow {
29 namespace {
30
31 using ops::Assign;
32 using ops::Const;
33 using ops::Identity;
34 using ops::MatMul;
35 using ops::OnesLike;
36 using ops::Placeholder;
37 using ops::Square;
38 using ops::Stack;
39 using ops::StopGradient;
40 using ops::Unstack;
41 using ops::Variable;
42
43 // TODO(andydavis) Add more unit tests once more gradient functions are ported.
44 class GradientsTest : public ::testing::Test {
45 protected:
GradientsTest()46 GradientsTest()
47 : scope_expected_(Scope::NewRootScope()),
48 scope_test_(Scope::NewRootScope()) {}
49
CompareTestAndExpectedGraphs()50 void CompareTestAndExpectedGraphs() {
51 GraphDef gdef_test;
52 TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test));
53 GraphDef gdef_exp;
54 TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp));
55 TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test);
56 }
57
58 Scope scope_expected_;
59 Scope scope_test_;
60 };
61
62 // Example:
63 // ^ ^
64 // dy| dx| (MatMul Gradient Graph)
65 // | |
66 // MatMul_1 MatMul_2
67 // ^ ^ ^ ^
68 // | |----------| |
69 // | ^ |
70 // | dz| |
71 // | | |
72 // | Const_3 |
73 // | |
74 // | ^ |
75 // | z| | (MatMul Forward Graph)
76 // | | |
77 // | MatMul_0 |
78 // | / \ |
79 // | ^ ^ |
80 // | | | |
81 // |---x| y|---|
82 // | |
83 // | |
84 // Const_0 Const_1
85 //
86
TEST_F(GradientsTest,OneMatMul)87 TEST_F(GradientsTest, OneMatMul) {
88 for (const bool expected : {false, true}) {
89 const Scope& scope = expected ? scope_expected_ : scope_test_;
90 // Construct forward graph.
91 auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
92 auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
93 auto z = MatMul(scope, x, y);
94 TF_ASSERT_OK(scope.status());
95 CHECK_NOTNULL(z.node());
96
97 if (expected) {
98 // Construct backward graph.
99 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
100 auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
101 auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
102 } else {
103 // Call AddSymbolicGradients.
104 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
105 std::vector<Output> grad_outputs;
106 TF_ASSERT_OK(
107 AddSymbolicGradients(scope, {z}, {x, y}, {dz}, &grad_outputs));
108 }
109 }
110 CompareTestAndExpectedGraphs();
111 }
112
TEST_F(GradientsTest,OneMatMul_InferGradInputs)113 TEST_F(GradientsTest, OneMatMul_InferGradInputs) {
114 for (const bool expected : {false, true}) {
115 const Scope& scope = expected ? scope_expected_ : scope_test_;
116 // Construct forward graph.
117 auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
118 auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
119 auto z = MatMul(scope, x, y);
120 TF_ASSERT_OK(scope.status());
121 CHECK_NOTNULL(z.node());
122
123 if (expected) {
124 // Construct backward graph.
125 // The gradients function adds a OnesLike to create a dz of ones with the
126 // shape of z.
127 auto dz = OnesLike(scope, z);
128 auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
129 auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
130 } else {
131 // Call AddSymbolicGradients.
132 std::vector<Output> grad_outputs;
133 TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs));
134 }
135 }
136 CompareTestAndExpectedGraphs();
137 }
138
TEST_F(GradientsTest,TwoMatMuls_Chained)139 TEST_F(GradientsTest, TwoMatMuls_Chained) {
140 for (const bool expected : {false, true}) {
141 const Scope& scope = expected ? scope_expected_ : scope_test_;
142 // Construct forward graph.
143 auto u = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
144 auto v = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
145 auto x = MatMul(scope, u, v);
146
147 auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
148 auto z = MatMul(scope, x, y);
149
150 TF_ASSERT_OK(scope.status());
151 CHECK_NOTNULL(z.node());
152
153 if (expected) {
154 // Construct backward graph.
155 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
156 auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
157 auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
158
159 auto du = MatMul(scope, dx, v, MatMul::TransposeB(true));
160 auto dv = MatMul(scope, u, dx, MatMul::TransposeA(true));
161 } else {
162 // Call AddSymbolicGradients.
163 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
164 std::vector<Output> grad_outputs;
165 TF_ASSERT_OK(
166 AddSymbolicGradients(scope, {z}, {u, v}, {dz}, &grad_outputs));
167 }
168 }
169 CompareTestAndExpectedGraphs();
170 }
171
TEST_F(GradientsTest,TwoMatMuls_Independent)172 TEST_F(GradientsTest, TwoMatMuls_Independent) {
173 for (const bool expected : {false, true}) {
174 const Scope& scope = expected ? scope_expected_ : scope_test_;
175 // Construct forward graph.
176 auto t = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
177 auto u = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
178 auto v = MatMul(scope, t, u);
179 TF_ASSERT_OK(scope.status());
180 CHECK_NOTNULL(v.node());
181
182 auto x = Const(scope, {{5.0, 6.0}, {7.0, 8.0}});
183 auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
184 auto z = MatMul(scope, x, y);
185 TF_ASSERT_OK(scope.status());
186 CHECK_NOTNULL(z.node());
187
188 if (expected) {
189 // Construct backward graph.
190 auto dv = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
191 auto dt = MatMul(scope, dv, u, MatMul::TransposeB(true));
192 auto du = MatMul(scope, t, dv, MatMul::TransposeA(true));
193
194 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
195 auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
196 auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
197 } else {
198 // Call AddSymbolicGradients.
199 auto dv = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
200 auto dz = Const(scope, {{1.0, 1.0}, {1.0, 1.0}});
201 std::vector<Output> grad_outputs;
202 TF_ASSERT_OK(AddSymbolicGradients(scope, {v, z}, {t, u, x, y}, {dv, dz},
203 &grad_outputs));
204 }
205 }
206 CompareTestAndExpectedGraphs();
207 }
208
TEST_F(GradientsTest,StackUnstack_Chained)209 TEST_F(GradientsTest, StackUnstack_Chained) {
210 for (const bool expected : {false, true}) {
211 const Scope& scope = expected ? scope_expected_ : scope_test_;
212 // Construct forward graph.
213 auto a = Const(scope, 1, {4, 2});
214 auto b = Const(scope, 2, {4, 2});
215 auto c = Const(scope, 3, {4, 2});
216
217 auto pack = Stack(scope, {a, b, c});
218 auto unpack = Unstack(scope, pack.output, 3);
219 TF_ASSERT_OK(scope.status());
220
221 // Construct grad inputs.
222 auto dx = Const(scope, 4, {4, 2});
223 auto dy = Const(scope, 5, {4, 2});
224 auto dz = Const(scope, 6, {4, 2});
225
226 if (expected) {
227 // Construct backward graph.
228 auto unpack_grad = Stack(scope, {dx, dy, dz});
229 auto pack_grad = Unstack(scope, unpack_grad.output, 3);
230 } else {
231 // Call AddSymbolicGradients.
232 std::vector<Output> grad_outputs;
233 TF_ASSERT_OK(AddSymbolicGradients(scope, unpack.output, {a, b, c},
234 {dx, dy, dz}, &grad_outputs));
235 }
236 }
237 CompareTestAndExpectedGraphs();
238 }
239
TEST_F(GradientsTest,StackUnstack_StopBackprop)240 TEST_F(GradientsTest, StackUnstack_StopBackprop) {
241 // Tests that backprop stops before calculating gradients for Stack (because
242 // only gradients w.r.t the output of Stack are requested).
243 for (const bool expected : {false, true}) {
244 const Scope& scope = expected ? scope_expected_ : scope_test_;
245 // Construct forward graph.
246 auto a = Const(scope, 1, {4, 2});
247 auto b = Const(scope, 2, {4, 2});
248 auto c = Const(scope, 3, {4, 2});
249
250 auto pack = Stack(scope, {a, b, c});
251 auto unpack = Unstack(scope, pack.output, 3);
252 TF_ASSERT_OK(scope.status());
253
254 // Construct grad inputs.
255 auto dx = Const(scope, 4, {4, 2});
256 auto dy = Const(scope, 5, {4, 2});
257 auto dz = Const(scope, 6, {4, 2});
258
259 if (expected) {
260 // Construct backward graph.
261 // NOTE: We should only expect the grad function for unpack in the
262 // gradients graph, based on the requested grad outputs.
263 auto unpack_grad = Stack(scope, {dx, dy, dz});
264 } else {
265 // Call AddSymbolicGradients.
266 std::vector<Output> grad_outputs;
267 TF_ASSERT_OK(AddSymbolicGradients(scope, unpack.output, {pack},
268 {dx, dy, dz}, &grad_outputs));
269 }
270 }
271 CompareTestAndExpectedGraphs();
272 }
273
TEST_F(GradientsTest,StackUnstack_SubsetOfUnstackOutputs)274 TEST_F(GradientsTest, StackUnstack_SubsetOfUnstackOutputs) {
275 // Constructs an unstack with three outputs, and takes the gradient with
276 // respect to only two of the outputs. Tests that the output gradients are
277 // computed.
278 for (const bool expected : {false, true}) {
279 const Scope& scope = expected ? scope_expected_ : scope_test_;
280 // Construct forward graph.
281 auto c = Const(scope, 1, {3, 4, 2});
282 auto unpack = Unstack(scope, c, 3);
283 auto x = Identity(scope, unpack.output[0]);
284 auto y = Identity(scope, unpack.output[1]);
285 auto z = Identity(scope, unpack.output[2]);
286 TF_ASSERT_OK(scope.status());
287
288 // Construct grad inputs.
289 auto dy = Const(scope, 4, {4, 2});
290 auto dz = Const(scope, 5, {4, 2});
291
292 if (expected) {
293 // Construct backward graph.
294 auto g1 = Identity(scope, dy);
295 auto g2 = Identity(scope, dz);
296 } else {
297 // Call AddSymbolicGradients.
298 std::vector<Output> grad_outputs;
299 TF_ASSERT_OK(AddSymbolicGradients(scope, {y, z},
300 {unpack.output[1], unpack.output[2]},
301 {dy, dz}, &grad_outputs));
302 ASSERT_EQ(grad_outputs.size(), 2);
303 EXPECT_TRUE(grad_outputs[0].node() != nullptr);
304 EXPECT_TRUE(grad_outputs[1].node() != nullptr);
305 }
306 }
307 CompareTestAndExpectedGraphs();
308 }
309
TEST_F(GradientsTest,DependentGradOutputs)310 TEST_F(GradientsTest, DependentGradOutputs) {
311 // Tests that dependent gradients (in this case the gradients w.r.t to the
312 // output and one input of MatMul) are computed properly.
313
314 // Create two chained MatMul ops.
315 auto u = Const(scope_test_, {{2}});
316 auto v = Const(scope_test_, {{3}});
317 auto x = MatMul(scope_test_, u, v);
318
319 auto y = Const(scope_test_, {{4}});
320 auto z = MatMul(scope_test_, x, y);
321
322 TF_ASSERT_OK(scope_test_.status());
323 CHECK_NOTNULL(z.node());
324
325 // Call AddSymbolicGradients with '5' as initial gradients for 'dz'.
326 // The gradient w.r.t to 'v' (returned in grad_outputs[0]) is dependent on
327 // the gradient w.r.t. to 'x' (returned in grad_outputs[1]).
328 auto dz = Const(scope_test_, {{5}});
329 std::vector<Output> grad_outputs;
330 TF_ASSERT_OK(
331 AddSymbolicGradients(scope_test_, {z}, {v, x}, {dz}, &grad_outputs));
332
333 std::vector<Tensor> outputs;
334 test::GetTensors(scope_test_, {grad_outputs[0], grad_outputs[1]}, &outputs);
335
336 // The gradients w.r.t to 'dz' are passed into AddSymbolicGradients as '5'.
337 // Since z = MatMul(x, y), the gradients w.r.t 'x' are computed as:
338 // 'dx' = 5 * 'y' = 5 * 4 = 20.
339 // Since x = MatMul(u, v), the gradients w.r.t. 'v' are computed as:
340 // 'dv' = 'dx' * 'u' = 20 * 2 = 40.
341 test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({40}, {1, 1}));
342 test::ExpectTensorEqual<int>(outputs[1], test::AsTensor<int>({20}, {1, 1}));
343 }
344
TEST_F(GradientsTest,MultipleNodeOutputGrads)345 TEST_F(GradientsTest, MultipleNodeOutputGrads) {
346 // Tests that gradients for multiple outputs of the same node are returned.
347 auto x = Const(scope_test_, 1, {3, 4, 2});
348 auto unpack = Unstack(scope_test_, x, 3);
349 auto pack = Stack(scope_test_, unpack.output);
350
351 // clang-format off
352 auto dx = Const(scope_test_, {40, 41, 42, 43, 44, 45, 46, 47,
353 50, 51, 52, 53, 55, 55, 56, 57,
354 60, 61, 62, 63, 66, 66, 66, 67},
355 {3, 4, 2});
356 // clang-format on
357
358 std::vector<Output> grad_outputs;
359 TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {pack}, unpack.output, {dx},
360 &grad_outputs));
361
362 std::vector<Tensor> outputs;
363 test::GetTensors(scope_test_,
364 {grad_outputs[0], grad_outputs[1], grad_outputs[2]},
365 &outputs);
366
367 test::ExpectTensorEqual<int>(
368 outputs[0],
369 test::AsTensor<int>({40, 41, 42, 43, 44, 45, 46, 47}, {4, 2}));
370 test::ExpectTensorEqual<int>(
371 outputs[1],
372 test::AsTensor<int>({50, 51, 52, 53, 55, 55, 56, 57}, {4, 2}));
373 test::ExpectTensorEqual<int>(
374 outputs[2],
375 test::AsTensor<int>({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2}));
376 }
377
TEST_F(GradientsTest,UnreachableEdgeGradOneOutput)378 TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
379 auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
380 auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
381 auto x_assign = Assign(scope_test_, x, x_const);
382
383 auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
384 auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
385 auto y_assign = Assign(scope_test_, y, y_const);
386
387 auto m = MatMul(scope_test_, x, y);
388
389 auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
390 auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
391 auto z_assign = Assign(scope_test_, z, z_const);
392
393 auto diff_m = Const(scope_test_, {{0.5}, {0.5}});
394
395 std::vector<Output> grad_outputs;
396 TF_ASSERT_OK(
397 AddSymbolicGradients(scope_test_, {m}, {y}, {diff_m}, &grad_outputs));
398
399 std::vector<Tensor> outputs;
400 test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
401 {grad_outputs[0]}, &outputs);
402 // dz/dy = xT * diff_m
403 test::ExpectTensorNear<double>(
404 outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
405 }
406
TEST_F(GradientsTest,UnreachableEdgeGradTwoOutputs)407 TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
408 auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
409 auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
410 auto x_assign = Assign(scope_test_, x, x_const);
411
412 auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
413 auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
414 auto y_assign = Assign(scope_test_, y, y_const);
415
416 auto m1 = MatMul(scope_test_, x, y);
417
418 auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
419 auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
420 auto z_assign = Assign(scope_test_, z, z_const);
421
422 auto m2 = MatMul(scope_test_, y, z);
423
424 auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
425 auto dm2 =
426 Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}});
427
428 std::vector<Output> grad_outputs;
429 TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2},
430 &grad_outputs));
431
432 std::vector<Tensor> outputs;
433 test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
434 {grad_outputs[0]}, &outputs);
435
436 // The gradients from m1 and m2 will be summed to compute the gradient
437 // w.r.t y:
438 // dz/dy = xT * dm1 + dm2 * zT
439 test::ExpectTensorNear<double>(
440 outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
441 }
442
TEST_F(GradientsTest,UnreachableInput)443 TEST_F(GradientsTest, UnreachableInput) {
444 auto x = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
445 auto y = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
446 auto z = Const(scope_test_.WithOpName("z"), {{9.0, 10.0, 11.0}});
447
448 auto m1 = MatMul(scope_test_, x, y);
449 auto m2 = MatMul(scope_test_, y, z);
450 auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
451
452 // From m1, z is unreachable, so an error status should be returned.
453 // m2 m1
454 // | |
455 // * *
456 // / \ / \
457 // z y x
458 std::vector<Output> grad_outputs;
459 Status status =
460 AddSymbolicGradients(scope_test_, {m1}, {z}, {dm1}, &grad_outputs);
461 EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
462 EXPECT_EQ(status.error_message(),
463 "Cannot compute the partial derivative"
464 " for node 'z' as it's unreachable from the output node(s).");
465 }
466
TEST_F(GradientsTest,DependentOutputs)467 TEST_F(GradientsTest, DependentOutputs) {
468 auto x = Placeholder(scope_test_, DT_FLOAT);
469 auto y0 = Square(scope_test_, x);
470 auto y1 = Square(scope_test_, y0);
471 auto y2 = Square(scope_test_, y1);
472 // Requesting the gradients for y0 and y2 should return the sum of their
473 // individual gradients.
474 std::vector<Output> grad_outputs;
475 TF_EXPECT_OK(AddSymbolicGradients(scope_test_, {y0, y2}, {x}, &grad_outputs));
476 ClientSession session(scope_test_);
477 std::vector<Tensor> grad_result;
478 TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
479 EXPECT_EQ(grad_result.size(), 1);
480 EXPECT_EQ(grad_result[0].NumElements(), 1);
481 EXPECT_EQ(grad_result[0].flat<float>()(0), 17502.0f);
482 }
483
TEST_F(GradientsTest,MultiOutputNodeDependentOutputs)484 TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
485 auto x = Placeholder(scope_test_, DT_FLOAT);
486 auto y0 = Square(scope_test_, x);
487 // y1, y2, and y3 all use y0. This means the backwards pass will need to wait
488 // for the gradient for all three.
489 auto y1 = Square(scope_test_, y0);
490 auto y2 = Square(scope_test_, y0);
491 auto y3 = Square(scope_test_, y2);
492 std::vector<Output> grad_outputs;
493 // By requesting y0, y1, and y3 we test that the computation correctly waits
494 // for all the points in backprop where gradients need to be summed from
495 // multiple branches.
496 TF_EXPECT_OK(
497 AddSymbolicGradients(scope_test_, {y0, y1, y3}, {x}, &grad_outputs));
498 ClientSession session(scope_test_);
499 std::vector<Tensor> grad_result;
500 TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
501 EXPECT_EQ(grad_result.size(), 1);
502 EXPECT_EQ(grad_result[0].NumElements(), 1);
503 EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
504 }
505
506 // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
507 // 'NoGradient' (induced by StopGradient op) returned along multiple edges from
508 // a single nodes output.
509 class StopGradientSingleOutputMultiEdgeTest : public ::testing::Test {
510 protected:
StopGradientSingleOutputMultiEdgeTest()511 StopGradientSingleOutputMultiEdgeTest() : scope_(Scope::NewRootScope()) {}
512
CheckGrad(const std::vector<bool> & stop_outputs,const Tensor & expected_grad)513 void CheckGrad(const std::vector<bool>& stop_outputs,
514 const Tensor& expected_grad) {
515 CHECK_EQ(3, stop_outputs.size());
516
517 auto x = Const(scope_, {{1, 0}, {0, 1}});
518 auto y = Const(scope_, {{1, 0}, {0, 1}});
519 auto z = MatMul(scope_, x, y);
520
521 // Create three output going edges from 'z'.
522 // Add StopGradients according to 'stop_outputs'.
523 auto out0 = stop_outputs[0]
524 ? StopGradient(scope_, (Identity(scope_, z))).output
525 : Identity(scope_, z).output;
526 auto out1 = stop_outputs[1]
527 ? StopGradient(scope_, (Identity(scope_, z))).output
528 : Identity(scope_, z).output;
529 auto out2 = stop_outputs[2]
530 ? StopGradient(scope_, (Identity(scope_, z))).output
531 : Identity(scope_, z).output;
532
533 auto g0 = Const(scope_, {{1, 2}, {3, 4}});
534 auto g1 = Const(scope_, {{5, 6}, {7, 8}});
535 auto g2 = Const(scope_, {{9, 10}, {11, 12}});
536
537 // Call AddSymbolicGradients and compare against 'expected_grad'.
538 std::vector<Output> grad_outputs;
539 TF_EXPECT_OK(AddSymbolicGradients(scope_, {out0, out1, out2}, {z},
540 {g0, g1, g2}, &grad_outputs));
541
542 if (expected_grad.NumElements() > 0) {
543 Tensor output;
544 test::GetTensor(scope_, grad_outputs[0], &output);
545 test::ExpectTensorEqual<int>(output, expected_grad);
546 } else {
547 EXPECT_EQ(NoGradient(), grad_outputs[0]);
548 }
549 }
550
551 Scope scope_;
552 };
553
TEST_F(StopGradientSingleOutputMultiEdgeTest,ValidGradAllEdges)554 TEST_F(StopGradientSingleOutputMultiEdgeTest, ValidGradAllEdges) {
555 CheckGrad({false, false, false},
556 test::AsTensor<int>({15, 18, 21, 24}, {2, 2}));
557 }
558
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstEdge)559 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstEdge) {
560 CheckGrad({true, false, false},
561 test::AsTensor<int>({14, 16, 18, 20}, {2, 2}));
562 }
563
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradSecondEdge)564 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradSecondEdge) {
565 CheckGrad({false, true, false},
566 test::AsTensor<int>({10, 12, 14, 16}, {2, 2}));
567 }
568
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradThirdEdge)569 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradThirdEdge) {
570 CheckGrad({false, false, true}, test::AsTensor<int>({6, 8, 10, 12}, {2, 2}));
571 }
572
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstAndSecondEdges)573 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstAndSecondEdges) {
574 CheckGrad({true, true, false}, test::AsTensor<int>({9, 10, 11, 12}, {2, 2}));
575 }
576
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradSecondAndThirdEdges)577 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradSecondAndThirdEdges) {
578 CheckGrad({false, true, true}, test::AsTensor<int>({1, 2, 3, 4}, {2, 2}));
579 }
580
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradFirstAndThirdEdges)581 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradFirstAndThirdEdges) {
582 CheckGrad({true, false, true}, test::AsTensor<int>({5, 6, 7, 8}, {2, 2}));
583 }
584
TEST_F(StopGradientSingleOutputMultiEdgeTest,StopGradAllEdges)585 TEST_F(StopGradientSingleOutputMultiEdgeTest, StopGradAllEdges) {
586 CheckGrad({true, true, true}, Tensor());
587 }
588
589 // StopGradientMultiOutputTest tests combinations of valid and 'NoGradient'
590 // (induced by StopGradient op) returned along a single nodes multiple outputs.
591 class StopGradientMultiOutputTest : public ::testing::Test {
592 protected:
StopGradientMultiOutputTest()593 StopGradientMultiOutputTest() : scope_(Scope::NewRootScope()) {}
594
CheckGrad(const std::vector<bool> & stop_outputs,const Tensor & expected_grad)595 void CheckGrad(const std::vector<bool>& stop_outputs,
596 const Tensor& expected_grad) {
597 CHECK_EQ(3, stop_outputs.size());
598 auto x = ops::Const(scope_, 1, {3, 2, 4});
599 auto y = Unstack(scope_, x, 3);
600 TF_ASSERT_OK(scope_.status());
601
602 // Add StopGradients according to 'stop_outputs'.
603 auto out0 =
604 stop_outputs[0] ? StopGradient(scope_, y.output[0]) : y.output[0];
605 auto out1 =
606 stop_outputs[1] ? StopGradient(scope_, y.output[1]) : y.output[1];
607 auto out2 =
608 stop_outputs[2] ? StopGradient(scope_, y.output[2]) : y.output[2];
609
610 auto g0 = Const(scope_, {1, 2, 3, 4, 5, 6, 7, 8}, {2, 4});
611 auto g1 = Const(scope_, {9, 10, 11, 12, 13, 14, 15, 16}, {2, 4});
612 auto g2 = Const(scope_, {17, 18, 19, 20, 21, 22, 23, 24}, {2, 4});
613
614 // Call AddSymbolicGradients and compare against 'expected_grad'.
615 std::vector<Output> grad_outputs;
616 TF_EXPECT_OK(AddSymbolicGradients(scope_, {out0, out1, out2}, {x},
617 {g0, g1, g2}, &grad_outputs));
618
619 if (expected_grad.NumElements() > 0) {
620 Tensor output;
621 test::GetTensor(scope_, grad_outputs[0], &output);
622 test::ExpectTensorEqual<int>(output, expected_grad);
623 } else {
624 EXPECT_EQ(NoGradient(), grad_outputs[0]);
625 }
626 }
627
628 Scope scope_;
629 };
630
TEST_F(StopGradientMultiOutputTest,ValidGradAllOutputs)631 TEST_F(StopGradientMultiOutputTest, ValidGradAllOutputs) {
632 // clang-format off
633 CheckGrad({false, false, false}, test::AsTensor<int>(
634 {1, 2, 3, 4, 5, 6, 7, 8,
635 9, 10, 11, 12, 13, 14, 15, 16,
636 17, 18, 19, 20, 21, 22, 23, 24},
637 {3, 2, 4}));
638 // clang-format on
639 }
640
TEST_F(StopGradientMultiOutputTest,StopGradFirstOutput)641 TEST_F(StopGradientMultiOutputTest, StopGradFirstOutput) {
642 // clang-format off
643 CheckGrad({true, false, false}, test::AsTensor<int>(
644 {0, 0, 0, 0, 0, 0, 0, 0,
645 9, 10, 11, 12, 13, 14, 15, 16,
646 17, 18, 19, 20, 21, 22, 23, 24},
647 {3, 2, 4}));
648 // clang-format on
649 }
650
TEST_F(StopGradientMultiOutputTest,StopGradSecondOutput)651 TEST_F(StopGradientMultiOutputTest, StopGradSecondOutput) {
652 // clang-format off
653 CheckGrad({false, true, false}, test::AsTensor<int>(
654 {1, 2, 3, 4, 5, 6, 7, 8,
655 0, 0, 0, 0, 0, 0, 0, 0,
656 17, 18, 19, 20, 21, 22, 23, 24},
657 {3, 2, 4}));
658 // clang-format on
659 }
660
TEST_F(StopGradientMultiOutputTest,StopGradThirdOutput)661 TEST_F(StopGradientMultiOutputTest, StopGradThirdOutput) {
662 // clang-format off
663 CheckGrad({false, false, true}, test::AsTensor<int>(
664 {1, 2, 3, 4, 5, 6, 7, 8,
665 9, 10, 11, 12, 13, 14, 15, 16,
666 0, 0, 0, 0, 0, 0, 0, 0},
667 {3, 2, 4}));
668 // clang-format on
669 }
670
TEST_F(StopGradientMultiOutputTest,StopGradFirstAndThirdOutputs)671 TEST_F(StopGradientMultiOutputTest, StopGradFirstAndThirdOutputs) {
672 // clang-format off
673 CheckGrad({true, false, true}, test::AsTensor<int>(
674 {0, 0, 0, 0, 0, 0, 0, 0,
675 9, 10, 11, 12, 13, 14, 15, 16,
676 0, 0, 0, 0, 0, 0, 0, 0},
677 {3, 2, 4}));
678 // clang-format on
679 }
680
TEST_F(StopGradientMultiOutputTest,StopAllOutputs)681 TEST_F(StopGradientMultiOutputTest, StopAllOutputs) {
682 CheckGrad({true, true, true}, Tensor());
683 }
684
685 } // namespace
686 } // namespace tensorflow
687