1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/grad_op_registry.h"
18 #include "tensorflow/cc/framework/gradient_checker.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/framework/testutil.h"
21 #include "tensorflow/cc/gradients/grad_testutil.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/random/random.h"
26
27 namespace tensorflow {
28 namespace {
29
30 using ops::Abs;
31 using ops::Add;
32 using ops::AddN;
33 using ops::AddV2;
34 using ops::Atan2;
35 using ops::BatchMatMul;
36 using ops::Cast;
37 using ops::Const;
38 using ops::Cumsum;
39 using ops::Div;
40 using ops::DivNoNan;
41 using ops::MatMul;
42 using ops::Max;
43 using ops::Maximum;
44 using ops::Mean;
45 using ops::Min;
46 using ops::Minimum;
47 using ops::Mul;
48 using ops::Placeholder;
49 using ops::Pow;
50 using ops::Prod;
51 using ops::RealDiv;
52 using ops::SegmentSum;
53 using ops::SelectV2;
54 using ops::SquaredDifference;
55 using ops::Sub;
56 using ops::Sum;
57 using ops::UnsortedSegmentMax;
58 using ops::UnsortedSegmentMin;
59 using ops::UnsortedSegmentSum;
60 using ops::Where3;
61
62 // TODO(andydavis) Test gradient function against numeric gradients output.
63 // TODO(andydavis) As more gradients are added move common test functions
64 // to a testutil library.
65
66 class CWiseUnaryGradTest : public ::testing::Test {
67 protected:
CWiseUnaryGradTest()68 CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
69
70 enum UnaryOpType {
71 ABS,
72 NEG,
73 INV,
74 SQUARE,
75 SQRT,
76 RSQRT,
77 EXP,
78 EXPM1,
79 LOG,
80 LOG1P,
81 SINH,
82 COSH,
83 TANH,
84 ASINH,
85 ACOSH,
86 ATANH,
87 SIGMOID,
88 SIGN,
89 SIN,
90 COS,
91 ASIN,
92 ACOS,
93 TAN,
94 ATAN,
95 REAL,
96 IMAG,
97 CONJ,
98 COMPLEX,
99 ANGLE,
100 LGAMMA,
101 ERF,
102 ERFINV,
103 NDTRI
104 };
105
106 template <typename X_T, typename Y_T>
TestCWiseGrad(UnaryOpType op_type,const std::function<X_T (int)> & x_fn)107 void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
108 TF_ASSERT_OK(scope_.status());
109 DataType x_type = DataTypeToEnum<X_T>::v();
110 TensorShape shape({2, 3, 2});
111 auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
112 Tensor x_data(x_type, shape);
113 auto x_data_flat = x_data.flat<X_T>();
114 for (int i = 0; i < x_data_flat.size(); ++i) {
115 x_data_flat(i) = x_fn(i);
116 }
117
118 Output y;
119 switch (op_type) {
120 using namespace ops; // NOLINT(build/namespaces)
121 case ABS:
122 y = Abs(scope_, x);
123 break;
124 case NEG:
125 y = Neg(scope_, x);
126 break;
127 case INV:
128 y = Reciprocal(scope_, x);
129 break;
130 case SQUARE:
131 y = Square(scope_, x);
132 break;
133 case SQRT:
134 y = Sqrt(scope_, x);
135 break;
136 case RSQRT:
137 y = Rsqrt(scope_, x);
138 break;
139 case EXP:
140 y = Exp(scope_, x);
141 break;
142 case EXPM1:
143 y = Expm1(scope_, x);
144 break;
145 case LOG:
146 y = Log(scope_, x);
147 break;
148 case LOG1P:
149 y = Log1p(scope_, x);
150 break;
151 case SINH:
152 y = Sinh(scope_, x);
153 break;
154 case COSH:
155 y = Cosh(scope_, x);
156 break;
157 case TANH:
158 y = Tanh(scope_, x);
159 break;
160 case ASINH:
161 y = Asinh(scope_, x);
162 break;
163 case ACOSH:
164 y = Acosh(scope_, x);
165 break;
166 case ATANH:
167 y = Atanh(scope_, x);
168 break;
169 case SIGMOID:
170 y = Sigmoid(scope_, x);
171 break;
172 case SIGN:
173 y = Sign(scope_, x);
174 break;
175 case SIN:
176 y = Sin(scope_, x);
177 break;
178 case COS:
179 y = Cos(scope_, x);
180 break;
181 case ASIN:
182 y = Asin(scope_, x);
183 break;
184 case ACOS:
185 y = Acos(scope_, x);
186 break;
187 case TAN:
188 y = Tan(scope_, x);
189 break;
190 case ATAN:
191 y = Atan(scope_, x);
192 break;
193 case REAL:
194 y = Real(scope_, x);
195 break;
196 case IMAG:
197 y = Imag(scope_, x);
198 break;
199 case CONJ:
200 y = Conj(scope_, x);
201 break;
202 case COMPLEX:
203 y = Complex(scope_, x, x);
204 break;
205 case ANGLE:
206 y = Angle(scope_, x);
207 break;
208 case LGAMMA:
209 y = Lgamma(scope_, x);
210 break;
211 case ERF:
212 y = Erf(scope_, x);
213 break;
214 case ERFINV:
215 y = Erfinv(scope_, x);
216 break;
217 case NDTRI:
218 y = Ndtri(scope_, x);
219 break;
220 }
221
222 float max_error;
223 TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
224 shape, &max_error)));
225 EXPECT_LT(max_error, 1e-3f);
226 }
227
RV(const std::vector<float> & v)228 float RV(const std::vector<float>& v) {
229 return v[random::New64() % v.size()];
230 }
231
CRV(const std::vector<complex64> & v)232 complex64 CRV(const std::vector<complex64>& v) {
233 return v[random::New64() % v.size()];
234 }
235
conjugate(const complex64 & val)236 complex64 conjugate(const complex64& val) {
237 return complex64(val.real(), -val.imag());
238 }
239
240 Scope scope_;
241 };
242
TEST_F(CWiseUnaryGradTest,Abs)243 TEST_F(CWiseUnaryGradTest, Abs) {
244 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
245 TestCWiseGrad<float, float>(ABS, x_fn);
246 }
247
TEST_F(CWiseUnaryGradTest,Neg)248 TEST_F(CWiseUnaryGradTest, Neg) {
249 auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
250 TestCWiseGrad<float, float>(NEG, x_fn);
251 }
252
TEST_F(CWiseUnaryGradTest,Reciprocal)253 TEST_F(CWiseUnaryGradTest, Reciprocal) {
254 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
255 TestCWiseGrad<float, float>(INV, x_fn);
256 }
257
TEST_F(CWiseUnaryGradTest,Reciprocal_Complex)258 TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
259 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
260 TestCWiseGrad<complex64, complex64>(INV, x_fn);
261 }
262
TEST_F(CWiseUnaryGradTest,Square)263 TEST_F(CWiseUnaryGradTest, Square) {
264 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
265 TestCWiseGrad<float, float>(SQUARE, x_fn);
266 }
267
TEST_F(CWiseUnaryGradTest,Square_Complex)268 TEST_F(CWiseUnaryGradTest, Square_Complex) {
269 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
270 TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
271 }
272
TEST_F(CWiseUnaryGradTest,Sqrt)273 TEST_F(CWiseUnaryGradTest, Sqrt) {
274 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
275 TestCWiseGrad<float, float>(SQRT, x_fn);
276 }
277
TEST_F(CWiseUnaryGradTest,Sqrt_Complex)278 TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
279 auto x_fn = [this](const int i) {
280 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
281 };
282 TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
283 }
284
TEST_F(CWiseUnaryGradTest,Rsqrt)285 TEST_F(CWiseUnaryGradTest, Rsqrt) {
286 auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
287 TestCWiseGrad<float, float>(RSQRT, x_fn);
288 }
289
TEST_F(CWiseUnaryGradTest,Rsqrt_Complex)290 TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
291 auto x_fn = [this](const int i) {
292 return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
293 };
294 TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
295 }
296
TEST_F(CWiseUnaryGradTest,Exp)297 TEST_F(CWiseUnaryGradTest, Exp) {
298 auto x_fn = [this](const int i) {
299 return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
300 };
301 TestCWiseGrad<float, float>(EXP, x_fn);
302 }
303
TEST_F(CWiseUnaryGradTest,Exp_Complex)304 TEST_F(CWiseUnaryGradTest, Exp_Complex) {
305 auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
306 TestCWiseGrad<complex64, complex64>(EXP, x_fn);
307 }
308
TEST_F(CWiseUnaryGradTest,Expm1)309 TEST_F(CWiseUnaryGradTest, Expm1) {
310 auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
311 TestCWiseGrad<float, float>(EXPM1, x_fn);
312 }
313
TEST_F(CWiseUnaryGradTest,Expm1_Complex)314 TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
315 auto x_fn = [this](const int i) {
316 return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
317 };
318 TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
319 }
320
TEST_F(CWiseUnaryGradTest,Log)321 TEST_F(CWiseUnaryGradTest, Log) {
322 auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
323 TestCWiseGrad<float, float>(LOG, x_fn);
324 }
325
TEST_F(CWiseUnaryGradTest,Log_Complex)326 TEST_F(CWiseUnaryGradTest, Log_Complex) {
327 auto x_fn = [this](const int i) {
328 return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
329 };
330 TestCWiseGrad<complex64, complex64>(LOG, x_fn);
331 }
332
TEST_F(CWiseUnaryGradTest,Log1p)333 TEST_F(CWiseUnaryGradTest, Log1p) {
334 auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
335 TestCWiseGrad<float, float>(LOG1P, x_fn);
336 }
337
TEST_F(CWiseUnaryGradTest,Log1p_Complex)338 TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
339 auto x_fn = [this](const int i) {
340 return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
341 };
342 TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
343 }
344
TEST_F(CWiseUnaryGradTest,Sinh)345 TEST_F(CWiseUnaryGradTest, Sinh) {
346 auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
347 TestCWiseGrad<float, float>(SINH, x_fn);
348 }
349
TEST_F(CWiseUnaryGradTest,Sinh_Complex)350 TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
351 auto x_fn = [this](const int i) {
352 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
353 };
354 TestCWiseGrad<complex64, complex64>(SINH, x_fn);
355 }
356
TEST_F(CWiseUnaryGradTest,Cosh)357 TEST_F(CWiseUnaryGradTest, Cosh) {
358 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
359 TestCWiseGrad<float, float>(COSH, x_fn);
360 }
361
TEST_F(CWiseUnaryGradTest,Cosh_Complex)362 TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
363 auto x_fn = [this](const int i) {
364 return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
365 };
366 TestCWiseGrad<complex64, complex64>(COSH, x_fn);
367 }
368
TEST_F(CWiseUnaryGradTest,Tanh)369 TEST_F(CWiseUnaryGradTest, Tanh) {
370 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
371 TestCWiseGrad<float, float>(TANH, x_fn);
372 }
373
TEST_F(CWiseUnaryGradTest,Tanh_Complex)374 TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
375 auto x_fn = [this](const int i) {
376 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
377 };
378 TestCWiseGrad<complex64, complex64>(TANH, x_fn);
379 }
380
TEST_F(CWiseUnaryGradTest,Asinh)381 TEST_F(CWiseUnaryGradTest, Asinh) {
382 auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
383 TestCWiseGrad<float, float>(ASINH, x_fn);
384 }
385
TEST_F(CWiseUnaryGradTest,Asinh_Complex)386 TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
387 auto x_fn = [this](const int i) {
388 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
389 };
390 TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
391 }
392
TEST_F(CWiseUnaryGradTest,Acosh)393 TEST_F(CWiseUnaryGradTest, Acosh) {
394 auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
395 TestCWiseGrad<float, float>(ACOSH, x_fn);
396 }
397
TEST_F(CWiseUnaryGradTest,Acosh_Complex)398 TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
399 auto x_fn = [this](const int i) {
400 return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
401 };
402 TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
403 }
404
TEST_F(CWiseUnaryGradTest,Atanh)405 TEST_F(CWiseUnaryGradTest, Atanh) {
406 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
407 TestCWiseGrad<float, float>(ATANH, x_fn);
408 }
409
TEST_F(CWiseUnaryGradTest,Atanh_Complex)410 TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
411 auto x_fn = [this](const int i) {
412 return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
413 };
414 TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
415 }
416
TEST_F(CWiseUnaryGradTest,Sigmoid)417 TEST_F(CWiseUnaryGradTest, Sigmoid) {
418 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
419 TestCWiseGrad<float, float>(SIGMOID, x_fn);
420 }
421
TEST_F(CWiseUnaryGradTest,Sigmoid_Complex)422 TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
423 auto x_fn = [this](const int i) {
424 return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
425 };
426 TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
427 }
428
TEST_F(CWiseUnaryGradTest,Sign)429 TEST_F(CWiseUnaryGradTest, Sign) {
430 auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
431 TestCWiseGrad<float, float>(SIGN, x_fn);
432 }
433
TEST_F(CWiseUnaryGradTest,Sin)434 TEST_F(CWiseUnaryGradTest, Sin) {
435 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
436 TestCWiseGrad<float, float>(SIN, x_fn);
437 }
438
TEST_F(CWiseUnaryGradTest,Sin_Complex)439 TEST_F(CWiseUnaryGradTest, Sin_Complex) {
440 auto x_fn = [this](const int i) {
441 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
442 };
443 TestCWiseGrad<complex64, complex64>(SIN, x_fn);
444 }
445
TEST_F(CWiseUnaryGradTest,Cos)446 TEST_F(CWiseUnaryGradTest, Cos) {
447 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
448 TestCWiseGrad<float, float>(COS, x_fn);
449 }
450
TEST_F(CWiseUnaryGradTest,Cos_Complex)451 TEST_F(CWiseUnaryGradTest, Cos_Complex) {
452 auto x_fn = [this](const int i) {
453 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
454 };
455 TestCWiseGrad<complex64, complex64>(COS, x_fn);
456 }
457
TEST_F(CWiseUnaryGradTest,Asin)458 TEST_F(CWiseUnaryGradTest, Asin) {
459 auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
460 TestCWiseGrad<float, float>(ASIN, x_fn);
461 }
462
TEST_F(CWiseUnaryGradTest,Asin_Complex)463 TEST_F(CWiseUnaryGradTest, Asin_Complex) {
464 auto x_fn = [this](const int i) {
465 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
466 };
467 // TODO(kbsriram)
468 // Enable test when the asin kernel supports complex numbers
469 if (false) {
470 TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
471 }
472 }
473
TEST_F(CWiseUnaryGradTest,Acos)474 TEST_F(CWiseUnaryGradTest, Acos) {
475 auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
476 TestCWiseGrad<float, float>(ACOS, x_fn);
477 }
478
TEST_F(CWiseUnaryGradTest,Acos_Complex)479 TEST_F(CWiseUnaryGradTest, Acos_Complex) {
480 auto x_fn = [this](const int i) {
481 return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
482 };
483 // TODO(kbsriram)
484 // Add test when the acos kernel supports complex numbers
485 if (false) {
486 TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
487 }
488 }
489
TEST_F(CWiseUnaryGradTest,Tan)490 TEST_F(CWiseUnaryGradTest, Tan) {
491 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
492 TestCWiseGrad<float, float>(TAN, x_fn);
493 }
494
TEST_F(CWiseUnaryGradTest,Tan_Complex)495 TEST_F(CWiseUnaryGradTest, Tan_Complex) {
496 auto x_fn = [this](const int i) {
497 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
498 };
499 TestCWiseGrad<complex64, complex64>(TAN, x_fn);
500 }
501
TEST_F(CWiseUnaryGradTest,Atan)502 TEST_F(CWiseUnaryGradTest, Atan) {
503 auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
504 TestCWiseGrad<float, float>(ATAN, x_fn);
505 }
506
TEST_F(CWiseUnaryGradTest,Atan_Complex)507 TEST_F(CWiseUnaryGradTest, Atan_Complex) {
508 auto x_fn = [this](const int i) {
509 return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
510 };
511 // TODO(kbsriram)
512 // Add test when the atan kernel supports complex numbers
513 if (false) {
514 TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
515 }
516 }
517
TEST_F(CWiseUnaryGradTest,Real)518 TEST_F(CWiseUnaryGradTest, Real) {
519 auto x_fn = [this](const int i) {
520 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
521 };
522 TestCWiseGrad<complex64, float>(REAL, x_fn);
523 }
524
TEST_F(CWiseUnaryGradTest,Imag)525 TEST_F(CWiseUnaryGradTest, Imag) {
526 auto x_fn = [this](const int i) {
527 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
528 };
529 TestCWiseGrad<complex64, float>(IMAG, x_fn);
530 }
531
TEST_F(CWiseUnaryGradTest,Conj)532 TEST_F(CWiseUnaryGradTest, Conj) {
533 auto x_fn = [this](const int i) {
534 return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
535 };
536 TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
537 }
538
TEST_F(CWiseUnaryGradTest,Complex)539 TEST_F(CWiseUnaryGradTest, Complex) {
540 auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
541 TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
542 }
543
TEST_F(CWiseUnaryGradTest,Angle)544 TEST_F(CWiseUnaryGradTest, Angle) {
545 auto x_fn = [this](const int i) {
546 return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
547 };
548 TestCWiseGrad<complex64, float>(ANGLE, x_fn);
549 }
550
TEST_F(CWiseUnaryGradTest,Lgamma)551 TEST_F(CWiseUnaryGradTest, Lgamma) {
552 auto x_fn = [this](const int i) {
553 return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
554 };
555 TestCWiseGrad<float, float>(LGAMMA, x_fn);
556 }
557
TEST_F(CWiseUnaryGradTest,Lgamma_Complex)558 TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
559 auto x_fn = [this](const int i) {
560 return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
561 };
562 // TODO(kbsriram)
563 // Add test when the lgamma kernel supports complex numbers
564 if (false) {
565 TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
566 }
567 }
568
TEST_F(CWiseUnaryGradTest,Erf)569 TEST_F(CWiseUnaryGradTest, Erf) {
570 auto x_fn = [this](const int i) {
571 return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
572 };
573 TestCWiseGrad<float, float>(ERF, x_fn);
574 }
575
TEST_F(CWiseUnaryGradTest,Erf_Complex)576 TEST_F(CWiseUnaryGradTest, Erf_Complex) {
577 auto x_fn = [this](const int i) {
578 return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
579 };
580 // TODO(kbsriram)
581 // Add test when the erf kernel supports complex numbers
582 if (false) {
583 TestCWiseGrad<complex64, complex64>(ERF, x_fn);
584 }
585 }
586
TEST_F(CWiseUnaryGradTest,Ndtri)587 TEST_F(CWiseUnaryGradTest, Ndtri) {
588 auto x_fn = [this](const int i) {
589 return RV({0.1, 0.2, 0.3, 0.5, 0.7, 0.9});
590 };
591 TestCWiseGrad<float, float>(NDTRI, x_fn);
592 }
593
TEST_F(CWiseUnaryGradTest,Erfinv)594 TEST_F(CWiseUnaryGradTest, Erfinv) {
595 auto x_fn = [this](const int i) {
596 return RV({-0.9, -0.3, -0.1, 0.2, 0.6, 0.8});
597 };
598 TestCWiseGrad<float, float>(ERFINV, x_fn);
599 }
600
601 class MathGradTest : public ::testing::Test {
602 protected:
MathGradTest()603 MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
604
605 template <typename T>
TestMatMulGrad(const bool is_batch,const bool t_x,const bool t_y)606 void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
607 TF_ASSERT_OK(root_.status());
608 // Generate random (but compatible) shapes for matrix multiplication.
609 std::vector<TensorShape> shapes;
610 RandMatMulShapes(is_batch, t_x, t_y, &shapes);
611 TensorShape x_shape = shapes[0];
612 TensorShape y_shape = shapes[1];
613 TensorShape z_shape = shapes[2];
614 auto x =
615 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
616 auto y =
617 Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
618 Output z;
619 if (is_batch) {
620 z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
621 } else {
622 z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
623 }
624
625 float max_error;
626 TF_ASSERT_OK((ComputeGradientError<T, T, float>(
627 root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
628 EXPECT_LT(max_error, 1e-3);
629 }
630
RandMatMulShapes(const bool is_batch,const bool tx,const bool ty,std::vector<TensorShape> * shapes)631 void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
632 std::vector<TensorShape>* shapes) {
633 // Choose a random batch size in [1, 4]
634 const int b = 1 + (random::New64() % 4);
635 // z = MatMul(x, y)
636 const int m = Rand();
637 const int k = Rand();
638 const int n = Rand();
639
640 TensorShape x_shape;
641 if (is_batch) {
642 // x.shape = [b, m, k]
643 x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
644 } else {
645 // x.shape = [m, k]
646 x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
647 }
648 shapes->push_back(x_shape);
649
650 TensorShape y_shape;
651 if (is_batch) {
652 // y.shape = [b, k, n]
653 y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
654 } else {
655 // y.shape = [k, n]
656 y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
657 }
658 shapes->push_back(y_shape);
659
660 TensorShape z_shape;
661 if (is_batch) {
662 // z.shape = [b, m, n]
663 z_shape = TensorShape({b, m, n});
664 } else {
665 // z.shape = [m, n]
666 z_shape = TensorShape({m, n});
667 }
668 shapes->push_back(z_shape);
669 }
670
Rand()671 int Rand() { return 1 + (random::New64() % 10); }
672
673 Scope root_;
674 };
675
TEST_F(MathGradTest,MatMulGrad_NoTranspose)676 TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
677 TestMatMulGrad<float>(false, false, false);
678 }
679
TEST_F(MathGradTest,MatMulComplexGrad_NoTranspose)680 TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
681 TestMatMulGrad<complex64>(false, false, false);
682 }
683
TEST_F(MathGradTest,MatMulGrad_TransposeX)684 TEST_F(MathGradTest, MatMulGrad_TransposeX) {
685 TestMatMulGrad<float>(false, true, false);
686 }
687
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX)688 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
689 TestMatMulGrad<complex64>(false, true, false);
690 }
691
TEST_F(MathGradTest,MatMulGrad_TransposeY)692 TEST_F(MathGradTest, MatMulGrad_TransposeY) {
693 TestMatMulGrad<float>(false, false, true);
694 }
695
TEST_F(MathGradTest,MatMulComplexGrad_TransposeY)696 TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
697 TestMatMulGrad<complex64>(false, false, true);
698 }
699
TEST_F(MathGradTest,MatMulGrad_TransposeX_TransposeY)700 TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
701 TestMatMulGrad<float>(false, true, true);
702 }
703
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX_TransposeY)704 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
705 TestMatMulGrad<complex64>(false, true, true);
706 }
707
TEST_F(MathGradTest,BatchMatMulGrad_NoTranspose)708 TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
709 TestMatMulGrad<float>(true, false, false);
710 }
711
TEST_F(MathGradTest,BatchMatMulComplexGrad_NoTranspose)712 TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
713 TestMatMulGrad<complex64>(true, false, false);
714 }
715
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX)716 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
717 TestMatMulGrad<float>(true, true, false);
718 }
719
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX)720 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
721 TestMatMulGrad<complex64>(true, true, false);
722 }
723
TEST_F(MathGradTest,BatchMatMulGrad_TransposeY)724 TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
725 TestMatMulGrad<float>(true, false, true);
726 }
727
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeY)728 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
729 TestMatMulGrad<complex64>(true, false, true);
730 }
731
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX_TransposeY)732 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
733 TestMatMulGrad<float>(true, true, true);
734 }
735
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX_TransposeY)736 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
737 TestMatMulGrad<complex64>(true, true, true);
738 }
739
740 class NaryGradTest : public ::testing::Test {
741 protected:
NaryGradTest()742 NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
743
RunTest(const OutputList & xs,const std::vector<TensorShape> & x_shapes,const OutputList & ys,const std::vector<TensorShape> & y_shapes)744 void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
745 const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
746 TF_ASSERT_OK(scope_.status());
747 float max_error;
748 TF_ASSERT_OK((ComputeGradientError<float, float, float>(
749 scope_, xs, x_shapes, ys, y_shapes, &max_error)));
750 EXPECT_LT(max_error, 1e-3);
751 }
752
RunTest(const Output & x,const Tensor & x_init_value,const Output & y,const TensorShape & y_shape)753 void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
754 const TensorShape& y_shape) {
755 TF_ASSERT_OK(scope_.status());
756 float max_error;
757 TF_ASSERT_OK((ComputeGradientError<float, float, float>(
758 scope_, x, x_init_value, y, y_shape, &max_error)));
759 EXPECT_LT(max_error, 1e-3);
760 }
761
762 Scope scope_;
763 };
764
TEST_F(NaryGradTest,Sum)765 TEST_F(NaryGradTest, Sum) {
766 TensorShape x_shape({2, 3, 5, 7});
767 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
768 auto y = Sum(scope_, x, {1, -1});
769 // y's shape is the result of reducing x along axes 1 and -1 (= 3)
770 TensorShape y_shape({2, 5});
771 RunTest({x}, {x_shape}, {y}, {y_shape});
772 }
773
TEST_F(NaryGradTest,Mean)774 TEST_F(NaryGradTest, Mean) {
775 TensorShape x_shape({2, 3, 5, 7});
776 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
777 auto y = Mean(scope_, x, {1, -1});
778 // y's shape is the result of reducing x along axes 1 and -1 (= 3)
779 TensorShape y_shape({2, 5});
780 RunTest({x}, {x_shape}, {y}, {y_shape});
781 }
782
TEST_F(NaryGradTest,Min)783 TEST_F(NaryGradTest, Min) {
784 TensorShape x_shape({2, 3});
785 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
786 auto y = Min(scope_, x, {-1});
787 // y's shape is the result of reducing x along axes -1 (= 1)
788 TensorShape y_shape({2});
789 Tensor x_init_value =
790 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
791 RunTest(x, x_init_value, y, y_shape);
792 }
793
TEST_F(NaryGradTest,Max)794 TEST_F(NaryGradTest, Max) {
795 TensorShape x_shape({2, 3});
796 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
797 auto y = Max(scope_, x, {-1});
798 // y's shape is the result of reducing x along axes -1 (= 1)
799 TensorShape y_shape({2});
800 Tensor x_init_value =
801 test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
802 RunTest(x, x_init_value, y, y_shape);
803 }
804
TEST_F(NaryGradTest,MinMulti)805 TEST_F(NaryGradTest, MinMulti) {
806 // Test gradient when there are multiple minima.
807 // Note that we cannot directly use a test Tensor with multiple
808 // minima, as the numeric estimator will calculate incorrect
809 // gradients when perturbing each entry in the Tensor (which then
810 // changes how many minima exist.)
811 // Instead, we use a single input that broadcast-multiplies a larger
812 // tensor with equal values, and apply reduce_min to the multiplied
813 // result.
814 TensorShape x_shape({1});
815 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
816 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
817 auto y = Min(scope_, all_same, {0});
818 // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
819 TensorShape y_shape({1});
820 RunTest({x}, {x_shape}, {y}, {y_shape});
821 }
822
TEST_F(NaryGradTest,MaxMulti)823 TEST_F(NaryGradTest, MaxMulti) {
824 TensorShape x_shape({1});
825 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
826 auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
827 auto y = Max(scope_, all_same, {0});
828 TensorShape y_shape({1});
829 RunTest({x}, {x_shape}, {y}, {y_shape});
830 }
831
TEST_F(NaryGradTest,AddN)832 TEST_F(NaryGradTest, AddN) {
833 TensorShape shape({3, 2, 5});
834 std::vector<Output> xs;
835 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
836 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
837 xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
838 auto y = AddN(scope_, xs);
839 RunTest(xs, {shape, shape, shape}, {y}, {shape});
840 }
841
TEST_F(NaryGradTest,Add)842 TEST_F(NaryGradTest, Add) {
843 TensorShape x1_shape({3, 2, 5});
844 TensorShape x2_shape({2, 5});
845 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
846 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
847 auto y = Add(scope_, x1, x2);
848 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
849 }
850
TEST_F(NaryGradTest,AddV2)851 TEST_F(NaryGradTest, AddV2) {
852 TensorShape x1_shape({3, 2, 5});
853 TensorShape x2_shape({2, 5});
854 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
855 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
856 auto y = AddV2(scope_, x1, x2);
857 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
858 }
859
TEST_F(NaryGradTest,Sub)860 TEST_F(NaryGradTest, Sub) {
861 TensorShape x1_shape({3, 2, 5});
862 TensorShape x2_shape({2, 5});
863 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
864 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
865 auto y = Sub(scope_, x1, x2);
866 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
867 }
868
TEST_F(NaryGradTest,Mul)869 TEST_F(NaryGradTest, Mul) {
870 TensorShape x1_shape({3, 2, 5});
871 TensorShape x2_shape({2, 5});
872 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
873 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
874 auto y = Mul(scope_, x1, x2);
875 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
876 }
877
TEST_F(NaryGradTest,Div)878 TEST_F(NaryGradTest, Div) {
879 TensorShape x_shape({3, 2, 5});
880 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
881 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
882 // division errors in the numeric estimator used by the gradient checker.
883 auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
884 RunTest({x}, {x_shape}, {y}, {x_shape});
885 }
886
TEST_F(NaryGradTest,RealDiv)887 TEST_F(NaryGradTest, RealDiv) {
888 TensorShape x_shape({3, 2, 5});
889 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
890 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
891 // division errors in the numeric estimator used by the gradient checker.
892 auto y =
893 RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
894 RunTest({x}, {x_shape}, {y}, {x_shape});
895 }
896
TEST_F(NaryGradTest,DivNoNan)897 TEST_F(NaryGradTest, DivNoNan) {
898 {
899 TensorShape x_shape({3, 2, 5});
900 const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
901 // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
902 // division errors in the numeric estimator used by the gradient checker.
903 const auto y = DivNoNan(
904 scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
905 RunTest({x}, {x_shape}, {y}, {x_shape});
906 }
907 {
908 // Return 0 gradient (rather than NaN) for division by zero.
909 const auto x = Placeholder(scope_, DT_FLOAT);
910 const auto zero = Const<float>(scope_, 0.0);
911 const auto y = DivNoNan(scope_, x, zero);
912
913 std::vector<Output> grad_outputs;
914 TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
915 ClientSession session(scope_);
916 std::vector<Tensor> grad_result;
917 TF_EXPECT_OK(
918 session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
919 EXPECT_EQ(grad_result.size(), 1);
920 EXPECT_EQ(grad_result[0].NumElements(), 3);
921 EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
922 EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
923 EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
924 }
925 }
926
TEST_F(NaryGradTest,SquaredDifference)927 TEST_F(NaryGradTest, SquaredDifference) {
928 TensorShape x1_shape({3, 2, 5});
929 TensorShape x2_shape({2, 5});
930 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
931 auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
932 auto y = SquaredDifference(scope_, x1, x2);
933 RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
934 }
935
TEST_F(NaryGradTest,Pow)936 TEST_F(NaryGradTest, Pow) {
937 TensorShape shape({3});
938 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
939 // fix exponent to avoid overflow
940 auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
941 RunTest({x}, {shape}, {y}, {shape});
942 }
943
TEST_F(NaryGradTest,Maximum)944 TEST_F(NaryGradTest, Maximum) {
945 TensorShape shape({3, 2});
946 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
947 auto y = Maximum(scope_, x, Const(scope_, 1.0f));
948 // Select values away from 1.0f to avoid instability when computing
949 // finite differences.
950 Tensor x_init_value =
951 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
952 RunTest(x, x_init_value, y, shape);
953 }
954
TEST_F(NaryGradTest,Minimum)955 TEST_F(NaryGradTest, Minimum) {
956 TensorShape shape({3, 2});
957 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
958 auto y = Minimum(scope_, x, Const(scope_, 1.0f));
959 // Select values away from 1.0f to avoid instability when computing
960 // finite differences.
961 Tensor x_init_value =
962 test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
963 RunTest(x, x_init_value, y, shape);
964 }
965
TEST_F(NaryGradTest,Prod)966 TEST_F(NaryGradTest, Prod) {
967 TensorShape x_shape({2, 3, 2});
968 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
969 auto y = Prod(scope_, x, {1});
970 // y's shape is the result of reducing x along axes 1
971 TensorShape y_shape({2, 1, 2});
972 RunTest({x}, {x_shape}, {y}, {y_shape});
973 }
974
TEST_F(NaryGradTest,SegmentSum)975 TEST_F(NaryGradTest, SegmentSum) {
976 TensorShape x_shape({3, 4});
977 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
978 auto y = SegmentSum(scope_, x, {0, 0, 1});
979 // the sum is always on the first dimension
980 TensorShape y_shape({2, 4});
981 RunTest({x}, {x_shape}, {y}, {y_shape});
982 }
983
984 class CumsumGradTest
985 : public NaryGradTest,
986 public ::testing::WithParamInterface<std::tuple<bool, bool, int>> {};
987
TEST_P(CumsumGradTest,CumsumGrad)988 TEST_P(CumsumGradTest, CumsumGrad) {
989 int axis = std::get<2>(GetParam());
990
991 TensorShape shape({2, 3, 2});
992 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
993 Cumsum::Attrs attrs;
994 attrs.exclusive_ = std::get<0>(GetParam());
995 attrs.reverse_ = std::get<1>(GetParam());
996 auto y = Cumsum(scope_, x, axis, attrs);
997 RunTest({x}, {shape}, {y}, {shape});
998 }
999
1000 INSTANTIATE_TEST_SUITE_P(CumsumGrad, CumsumGradTest,
1001 ::testing::Combine(::testing::Bool(),
1002 ::testing::Bool(),
1003 ::testing::Range(0, 2)));
1004
TEST_F(NaryGradTest,CastGrad)1005 TEST_F(NaryGradTest, CastGrad) {
1006 TensorShape shape({2, 3, 2});
1007 auto x = Placeholder(scope_, DT_DOUBLE, Placeholder::Shape(shape));
1008 auto y = Cast(scope_, x, DT_FLOAT);
1009 TF_ASSERT_OK(scope_.status());
1010 double max_error;
1011 TF_ASSERT_OK((ComputeGradientError<double, float, double>(
1012 scope_, {x}, {shape}, {y}, {shape}, &max_error)));
1013 EXPECT_LT(max_error, 1e-3);
1014 }
1015
TEST_F(NaryGradTest,Select)1016 TEST_F(NaryGradTest, Select) {
1017 TensorShape shape({1, 3});
1018 auto cond = Const<bool>(scope_, {{false, true, true}});
1019 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1020 auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1021 auto z = Where3(scope_, cond, x, y);
1022 RunTest({x, y}, {shape, shape}, {z}, {shape});
1023 }
1024
TEST_F(NaryGradTest,SelectV2_Basic)1025 TEST_F(NaryGradTest, SelectV2_Basic) {
1026 TensorShape shape({1, 3});
1027 auto cond = Const<bool>(scope_, {{false, true, true}});
1028 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1029 auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1030 auto z = SelectV2(scope_, cond, x, y);
1031 RunTest({x, y}, {shape, shape}, {z}, {shape});
1032 }
1033
TEST_F(NaryGradTest,SelectV2_Broadcast)1034 TEST_F(NaryGradTest, SelectV2_Broadcast) {
1035 TensorShape x_shape({2, 3});
1036 TensorShape y_shape({});
1037 auto cond = Const<bool>(scope_, {{false, true, true}, {true, true, false}});
1038 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1039 auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(y_shape));
1040 auto z = SelectV2(scope_, cond, x, y);
1041 RunTest({x, y}, {x_shape, y_shape}, {z}, {x_shape});
1042 }
1043
TEST_F(NaryGradTest,SelectV2_Broadcast2)1044 TEST_F(NaryGradTest, SelectV2_Broadcast2) {
1045 TensorShape x_shape({2, 3});
1046 auto cond = Const<bool>(scope_, {{false}, {true}});
1047 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1048 auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1049 auto z = SelectV2(scope_, cond, x, y);
1050 RunTest({x, y}, {x_shape, x_shape}, {z}, {x_shape});
1051 }
1052
TEST_F(NaryGradTest,Atan2Grad)1053 TEST_F(NaryGradTest, Atan2Grad) {
1054 TensorShape shape({3, 2, 5});
1055 auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1056 // Test with x2 = 1 + |x1| to avoid regions where the gradient
1057 // is unstable and might cause problems for the numeric estimator.
1058 auto x2 =
1059 Div(scope_, x1, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x1)));
1060 auto y = Atan2(scope_, x1, x2);
1061 RunTest({x1}, {shape}, {y}, {shape});
1062 }
1063
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad)1064 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad) {
1065 TensorShape shape({3, 2, 5});
1066 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1067 auto segment_ids = Const(scope_, {0, 0, 1});
1068 auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1069 TensorShape y_shape({2, 2, 5});
1070 RunTest({x}, {shape}, {y}, {y_shape});
1071 }
1072
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad_Int64Ids)1073 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_Int64Ids) {
1074 TensorShape shape({3, 2, 5});
1075 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1076 auto segment_ids = Const(scope_, {0ll, 0ll, 1ll});
1077 auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1078 TensorShape y_shape({2, 2, 5});
1079 RunTest({x}, {shape}, {y}, {y_shape});
1080 }
1081
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad_NegativeIds)1082 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_NegativeIds) {
1083 TensorShape shape({3, 2, 5});
1084 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1085 auto segment_ids = Const(scope_, {0, 0, -1});
1086 auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/1);
1087 TensorShape y_shape({1, 2, 5});
1088 RunTest({x}, {shape}, {y}, {y_shape});
1089 }
1090
TEST_F(NaryGradTest,UnsortedSegmentMinGrad)1091 TEST_F(NaryGradTest, UnsortedSegmentMinGrad) {
1092 TensorShape shape({3, 2, 5});
1093 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1094 auto segment_ids = Const(scope_, {0, 0, 1});
1095 auto y = UnsortedSegmentMin(scope_, x, segment_ids, /*num_segments=*/2);
1096 TensorShape y_shape({2, 2, 5});
1097 RunTest({x}, {shape}, {y}, {y_shape});
1098 }
1099
TEST_F(NaryGradTest,UnsortedSegmentSumGrad)1100 TEST_F(NaryGradTest, UnsortedSegmentSumGrad) {
1101 TensorShape shape({3, 2, 5});
1102 auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1103 auto segment_ids = Const(scope_, {0, 0, 1});
1104 auto y = UnsortedSegmentSum(scope_, x, segment_ids, /*num_segments=*/2);
1105 TensorShape y_shape({2, 2, 5});
1106 RunTest({x}, {shape}, {y}, {y_shape});
1107 }
1108
1109 } // namespace
1110 } // namespace tensorflow
1111