• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/grad_op_registry.h"
18 #include "tensorflow/cc/framework/gradient_checker.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/framework/testutil.h"
21 #include "tensorflow/cc/gradients/grad_testutil.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/random/random.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 using ops::Abs;
31 using ops::Add;
32 using ops::AddN;
33 using ops::AddV2;
34 using ops::Atan2;
35 using ops::BatchMatMul;
36 using ops::Cast;
37 using ops::Const;
38 using ops::Cumsum;
39 using ops::Div;
40 using ops::DivNoNan;
41 using ops::MatMul;
42 using ops::Max;
43 using ops::Maximum;
44 using ops::Mean;
45 using ops::Min;
46 using ops::Minimum;
47 using ops::Mul;
48 using ops::Placeholder;
49 using ops::Pow;
50 using ops::Prod;
51 using ops::RealDiv;
52 using ops::SegmentSum;
53 using ops::SelectV2;
54 using ops::SquaredDifference;
55 using ops::Sub;
56 using ops::Sum;
57 using ops::UnsortedSegmentMax;
58 using ops::UnsortedSegmentMin;
59 using ops::UnsortedSegmentSum;
60 using ops::Where3;
61 
62 // TODO(andydavis) Test gradient function against numeric gradients output.
63 // TODO(andydavis) As more gradients are added move common test functions
64 // to a testutil library.
65 
66 class CWiseUnaryGradTest : public ::testing::Test {
67  protected:
CWiseUnaryGradTest()68   CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
69 
70   enum UnaryOpType {
71     ABS,
72     NEG,
73     INV,
74     SQUARE,
75     SQRT,
76     RSQRT,
77     EXP,
78     EXPM1,
79     LOG,
80     LOG1P,
81     SINH,
82     COSH,
83     TANH,
84     ASINH,
85     ACOSH,
86     ATANH,
87     SIGMOID,
88     SIGN,
89     SIN,
90     COS,
91     ASIN,
92     ACOS,
93     TAN,
94     ATAN,
95     REAL,
96     IMAG,
97     CONJ,
98     COMPLEX,
99     ANGLE,
100     LGAMMA,
101     ERF,
102     ERFINV,
103     NDTRI
104   };
105 
106   template <typename X_T, typename Y_T>
TestCWiseGrad(UnaryOpType op_type,const std::function<X_T (int)> & x_fn)107   void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
108     TF_ASSERT_OK(scope_.status());
109     DataType x_type = DataTypeToEnum<X_T>::v();
110     TensorShape shape({2, 3, 2});
111     auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
112     Tensor x_data(x_type, shape);
113     auto x_data_flat = x_data.flat<X_T>();
114     for (int i = 0; i < x_data_flat.size(); ++i) {
115       x_data_flat(i) = x_fn(i);
116     }
117 
118     Output y;
119     switch (op_type) {
120       using namespace ops;  // NOLINT(build/namespaces)
121       case ABS:
122         y = Abs(scope_, x);
123         break;
124       case NEG:
125         y = Neg(scope_, x);
126         break;
127       case INV:
128         y = Reciprocal(scope_, x);
129         break;
130       case SQUARE:
131         y = Square(scope_, x);
132         break;
133       case SQRT:
134         y = Sqrt(scope_, x);
135         break;
136       case RSQRT:
137         y = Rsqrt(scope_, x);
138         break;
139       case EXP:
140         y = Exp(scope_, x);
141         break;
142       case EXPM1:
143         y = Expm1(scope_, x);
144         break;
145       case LOG:
146         y = Log(scope_, x);
147         break;
148       case LOG1P:
149         y = Log1p(scope_, x);
150         break;
151       case SINH:
152         y = Sinh(scope_, x);
153         break;
154       case COSH:
155         y = Cosh(scope_, x);
156         break;
157       case TANH:
158         y = Tanh(scope_, x);
159         break;
160       case ASINH:
161         y = Asinh(scope_, x);
162         break;
163       case ACOSH:
164         y = Acosh(scope_, x);
165         break;
166       case ATANH:
167         y = Atanh(scope_, x);
168         break;
169       case SIGMOID:
170         y = Sigmoid(scope_, x);
171         break;
172       case SIGN:
173         y = Sign(scope_, x);
174         break;
175       case SIN:
176         y = Sin(scope_, x);
177         break;
178       case COS:
179         y = Cos(scope_, x);
180         break;
181       case ASIN:
182         y = Asin(scope_, x);
183         break;
184       case ACOS:
185         y = Acos(scope_, x);
186         break;
187       case TAN:
188         y = Tan(scope_, x);
189         break;
190       case ATAN:
191         y = Atan(scope_, x);
192         break;
193       case REAL:
194         y = Real(scope_, x);
195         break;
196       case IMAG:
197         y = Imag(scope_, x);
198         break;
199       case CONJ:
200         y = Conj(scope_, x);
201         break;
202       case COMPLEX:
203         y = Complex(scope_, x, x);
204         break;
205       case ANGLE:
206         y = Angle(scope_, x);
207         break;
208       case LGAMMA:
209         y = Lgamma(scope_, x);
210         break;
211       case ERF:
212         y = Erf(scope_, x);
213         break;
214       case ERFINV:
215         y = Erfinv(scope_, x);
216         break;
217       case NDTRI:
218         y = Ndtri(scope_, x);
219         break;
220     }
221 
222     float max_error;
223     TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
224                                                         shape, &max_error)));
225     EXPECT_LT(max_error, 1e-3f);
226   }
227 
RV(const std::vector<float> & v)228   float RV(const std::vector<float>& v) {
229     return v[random::New64() % v.size()];
230   }
231 
CRV(const std::vector<complex64> & v)232   complex64 CRV(const std::vector<complex64>& v) {
233     return v[random::New64() % v.size()];
234   }
235 
conjugate(const complex64 & val)236   complex64 conjugate(const complex64& val) {
237     return complex64(val.real(), -val.imag());
238   }
239 
240   Scope scope_;
241 };
242 
TEST_F(CWiseUnaryGradTest,Abs)243 TEST_F(CWiseUnaryGradTest, Abs) {
244   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
245   TestCWiseGrad<float, float>(ABS, x_fn);
246 }
247 
TEST_F(CWiseUnaryGradTest,Neg)248 TEST_F(CWiseUnaryGradTest, Neg) {
249   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
250   TestCWiseGrad<float, float>(NEG, x_fn);
251 }
252 
TEST_F(CWiseUnaryGradTest,Reciprocal)253 TEST_F(CWiseUnaryGradTest, Reciprocal) {
254   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
255   TestCWiseGrad<float, float>(INV, x_fn);
256 }
257 
TEST_F(CWiseUnaryGradTest,Reciprocal_Complex)258 TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
259   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
260   TestCWiseGrad<complex64, complex64>(INV, x_fn);
261 }
262 
TEST_F(CWiseUnaryGradTest,Square)263 TEST_F(CWiseUnaryGradTest, Square) {
264   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
265   TestCWiseGrad<float, float>(SQUARE, x_fn);
266 }
267 
TEST_F(CWiseUnaryGradTest,Square_Complex)268 TEST_F(CWiseUnaryGradTest, Square_Complex) {
269   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
270   TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
271 }
272 
TEST_F(CWiseUnaryGradTest,Sqrt)273 TEST_F(CWiseUnaryGradTest, Sqrt) {
274   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
275   TestCWiseGrad<float, float>(SQRT, x_fn);
276 }
277 
TEST_F(CWiseUnaryGradTest,Sqrt_Complex)278 TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
279   auto x_fn = [this](const int i) {
280     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
281   };
282   TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
283 }
284 
TEST_F(CWiseUnaryGradTest,Rsqrt)285 TEST_F(CWiseUnaryGradTest, Rsqrt) {
286   auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
287   TestCWiseGrad<float, float>(RSQRT, x_fn);
288 }
289 
TEST_F(CWiseUnaryGradTest,Rsqrt_Complex)290 TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
291   auto x_fn = [this](const int i) {
292     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
293   };
294   TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
295 }
296 
TEST_F(CWiseUnaryGradTest,Exp)297 TEST_F(CWiseUnaryGradTest, Exp) {
298   auto x_fn = [this](const int i) {
299     return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
300   };
301   TestCWiseGrad<float, float>(EXP, x_fn);
302 }
303 
TEST_F(CWiseUnaryGradTest,Exp_Complex)304 TEST_F(CWiseUnaryGradTest, Exp_Complex) {
305   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
306   TestCWiseGrad<complex64, complex64>(EXP, x_fn);
307 }
308 
TEST_F(CWiseUnaryGradTest,Expm1)309 TEST_F(CWiseUnaryGradTest, Expm1) {
310   auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
311   TestCWiseGrad<float, float>(EXPM1, x_fn);
312 }
313 
TEST_F(CWiseUnaryGradTest,Expm1_Complex)314 TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
315   auto x_fn = [this](const int i) {
316     return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
317   };
318   TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
319 }
320 
TEST_F(CWiseUnaryGradTest,Log)321 TEST_F(CWiseUnaryGradTest, Log) {
322   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
323   TestCWiseGrad<float, float>(LOG, x_fn);
324 }
325 
TEST_F(CWiseUnaryGradTest,Log_Complex)326 TEST_F(CWiseUnaryGradTest, Log_Complex) {
327   auto x_fn = [this](const int i) {
328     return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
329   };
330   TestCWiseGrad<complex64, complex64>(LOG, x_fn);
331 }
332 
TEST_F(CWiseUnaryGradTest,Log1p)333 TEST_F(CWiseUnaryGradTest, Log1p) {
334   auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
335   TestCWiseGrad<float, float>(LOG1P, x_fn);
336 }
337 
TEST_F(CWiseUnaryGradTest,Log1p_Complex)338 TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
339   auto x_fn = [this](const int i) {
340     return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
341   };
342   TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
343 }
344 
TEST_F(CWiseUnaryGradTest,Sinh)345 TEST_F(CWiseUnaryGradTest, Sinh) {
346   auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
347   TestCWiseGrad<float, float>(SINH, x_fn);
348 }
349 
TEST_F(CWiseUnaryGradTest,Sinh_Complex)350 TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
351   auto x_fn = [this](const int i) {
352     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
353   };
354   TestCWiseGrad<complex64, complex64>(SINH, x_fn);
355 }
356 
TEST_F(CWiseUnaryGradTest,Cosh)357 TEST_F(CWiseUnaryGradTest, Cosh) {
358   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
359   TestCWiseGrad<float, float>(COSH, x_fn);
360 }
361 
TEST_F(CWiseUnaryGradTest,Cosh_Complex)362 TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
363   auto x_fn = [this](const int i) {
364     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
365   };
366   TestCWiseGrad<complex64, complex64>(COSH, x_fn);
367 }
368 
TEST_F(CWiseUnaryGradTest,Tanh)369 TEST_F(CWiseUnaryGradTest, Tanh) {
370   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
371   TestCWiseGrad<float, float>(TANH, x_fn);
372 }
373 
TEST_F(CWiseUnaryGradTest,Tanh_Complex)374 TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
375   auto x_fn = [this](const int i) {
376     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
377   };
378   TestCWiseGrad<complex64, complex64>(TANH, x_fn);
379 }
380 
TEST_F(CWiseUnaryGradTest,Asinh)381 TEST_F(CWiseUnaryGradTest, Asinh) {
382   auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
383   TestCWiseGrad<float, float>(ASINH, x_fn);
384 }
385 
TEST_F(CWiseUnaryGradTest,Asinh_Complex)386 TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
387   auto x_fn = [this](const int i) {
388     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
389   };
390   TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
391 }
392 
TEST_F(CWiseUnaryGradTest,Acosh)393 TEST_F(CWiseUnaryGradTest, Acosh) {
394   auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
395   TestCWiseGrad<float, float>(ACOSH, x_fn);
396 }
397 
TEST_F(CWiseUnaryGradTest,Acosh_Complex)398 TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
399   auto x_fn = [this](const int i) {
400     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
401   };
402   TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
403 }
404 
TEST_F(CWiseUnaryGradTest,Atanh)405 TEST_F(CWiseUnaryGradTest, Atanh) {
406   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
407   TestCWiseGrad<float, float>(ATANH, x_fn);
408 }
409 
TEST_F(CWiseUnaryGradTest,Atanh_Complex)410 TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
411   auto x_fn = [this](const int i) {
412     return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
413   };
414   TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
415 }
416 
TEST_F(CWiseUnaryGradTest,Sigmoid)417 TEST_F(CWiseUnaryGradTest, Sigmoid) {
418   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
419   TestCWiseGrad<float, float>(SIGMOID, x_fn);
420 }
421 
TEST_F(CWiseUnaryGradTest,Sigmoid_Complex)422 TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
423   auto x_fn = [this](const int i) {
424     return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
425   };
426   TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
427 }
428 
TEST_F(CWiseUnaryGradTest,Sign)429 TEST_F(CWiseUnaryGradTest, Sign) {
430   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
431   TestCWiseGrad<float, float>(SIGN, x_fn);
432 }
433 
TEST_F(CWiseUnaryGradTest,Sin)434 TEST_F(CWiseUnaryGradTest, Sin) {
435   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
436   TestCWiseGrad<float, float>(SIN, x_fn);
437 }
438 
TEST_F(CWiseUnaryGradTest,Sin_Complex)439 TEST_F(CWiseUnaryGradTest, Sin_Complex) {
440   auto x_fn = [this](const int i) {
441     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
442   };
443   TestCWiseGrad<complex64, complex64>(SIN, x_fn);
444 }
445 
TEST_F(CWiseUnaryGradTest,Cos)446 TEST_F(CWiseUnaryGradTest, Cos) {
447   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
448   TestCWiseGrad<float, float>(COS, x_fn);
449 }
450 
TEST_F(CWiseUnaryGradTest,Cos_Complex)451 TEST_F(CWiseUnaryGradTest, Cos_Complex) {
452   auto x_fn = [this](const int i) {
453     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
454   };
455   TestCWiseGrad<complex64, complex64>(COS, x_fn);
456 }
457 
TEST_F(CWiseUnaryGradTest,Asin)458 TEST_F(CWiseUnaryGradTest, Asin) {
459   auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
460   TestCWiseGrad<float, float>(ASIN, x_fn);
461 }
462 
TEST_F(CWiseUnaryGradTest,Asin_Complex)463 TEST_F(CWiseUnaryGradTest, Asin_Complex) {
464   auto x_fn = [this](const int i) {
465     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
466   };
467   // TODO(kbsriram)
468   // Enable test when the asin kernel supports complex numbers
469   if (false) {
470     TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
471   }
472 }
473 
TEST_F(CWiseUnaryGradTest,Acos)474 TEST_F(CWiseUnaryGradTest, Acos) {
475   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
476   TestCWiseGrad<float, float>(ACOS, x_fn);
477 }
478 
TEST_F(CWiseUnaryGradTest,Acos_Complex)479 TEST_F(CWiseUnaryGradTest, Acos_Complex) {
480   auto x_fn = [this](const int i) {
481     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
482   };
483   // TODO(kbsriram)
484   // Add test when the acos kernel supports complex numbers
485   if (false) {
486     TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
487   }
488 }
489 
TEST_F(CWiseUnaryGradTest,Tan)490 TEST_F(CWiseUnaryGradTest, Tan) {
491   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
492   TestCWiseGrad<float, float>(TAN, x_fn);
493 }
494 
TEST_F(CWiseUnaryGradTest,Tan_Complex)495 TEST_F(CWiseUnaryGradTest, Tan_Complex) {
496   auto x_fn = [this](const int i) {
497     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
498   };
499   TestCWiseGrad<complex64, complex64>(TAN, x_fn);
500 }
501 
TEST_F(CWiseUnaryGradTest,Atan)502 TEST_F(CWiseUnaryGradTest, Atan) {
503   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
504   TestCWiseGrad<float, float>(ATAN, x_fn);
505 }
506 
TEST_F(CWiseUnaryGradTest,Atan_Complex)507 TEST_F(CWiseUnaryGradTest, Atan_Complex) {
508   auto x_fn = [this](const int i) {
509     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
510   };
511   // TODO(kbsriram)
512   // Add test when the atan kernel supports complex numbers
513   if (false) {
514     TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
515   }
516 }
517 
TEST_F(CWiseUnaryGradTest,Real)518 TEST_F(CWiseUnaryGradTest, Real) {
519   auto x_fn = [this](const int i) {
520     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
521   };
522   TestCWiseGrad<complex64, float>(REAL, x_fn);
523 }
524 
TEST_F(CWiseUnaryGradTest,Imag)525 TEST_F(CWiseUnaryGradTest, Imag) {
526   auto x_fn = [this](const int i) {
527     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
528   };
529   TestCWiseGrad<complex64, float>(IMAG, x_fn);
530 }
531 
TEST_F(CWiseUnaryGradTest,Conj)532 TEST_F(CWiseUnaryGradTest, Conj) {
533   auto x_fn = [this](const int i) {
534     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
535   };
536   TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
537 }
538 
TEST_F(CWiseUnaryGradTest,Complex)539 TEST_F(CWiseUnaryGradTest, Complex) {
540   auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
541   TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
542 }
543 
TEST_F(CWiseUnaryGradTest,Angle)544 TEST_F(CWiseUnaryGradTest, Angle) {
545   auto x_fn = [this](const int i) {
546     return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
547   };
548   TestCWiseGrad<complex64, float>(ANGLE, x_fn);
549 }
550 
TEST_F(CWiseUnaryGradTest,Lgamma)551 TEST_F(CWiseUnaryGradTest, Lgamma) {
552   auto x_fn = [this](const int i) {
553     return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
554   };
555   TestCWiseGrad<float, float>(LGAMMA, x_fn);
556 }
557 
TEST_F(CWiseUnaryGradTest,Lgamma_Complex)558 TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
559   auto x_fn = [this](const int i) {
560     return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
561   };
562   // TODO(kbsriram)
563   // Add test when the lgamma kernel supports complex numbers
564   if (false) {
565     TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
566   }
567 }
568 
TEST_F(CWiseUnaryGradTest,Erf)569 TEST_F(CWiseUnaryGradTest, Erf) {
570   auto x_fn = [this](const int i) {
571     return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
572   };
573   TestCWiseGrad<float, float>(ERF, x_fn);
574 }
575 
TEST_F(CWiseUnaryGradTest,Erf_Complex)576 TEST_F(CWiseUnaryGradTest, Erf_Complex) {
577   auto x_fn = [this](const int i) {
578     return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
579   };
580   // TODO(kbsriram)
581   // Add test when the erf kernel supports complex numbers
582   if (false) {
583     TestCWiseGrad<complex64, complex64>(ERF, x_fn);
584   }
585 }
586 
TEST_F(CWiseUnaryGradTest,Ndtri)587 TEST_F(CWiseUnaryGradTest, Ndtri) {
588   auto x_fn = [this](const int i) {
589     return RV({0.1, 0.2, 0.3, 0.5, 0.7, 0.9});
590   };
591   TestCWiseGrad<float, float>(NDTRI, x_fn);
592 }
593 
TEST_F(CWiseUnaryGradTest,Erfinv)594 TEST_F(CWiseUnaryGradTest, Erfinv) {
595   auto x_fn = [this](const int i) {
596     return RV({-0.9, -0.3, -0.1, 0.2, 0.6, 0.8});
597   };
598   TestCWiseGrad<float, float>(ERFINV, x_fn);
599 }
600 
601 class MathGradTest : public ::testing::Test {
602  protected:
MathGradTest()603   MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
604 
605   template <typename T>
TestMatMulGrad(const bool is_batch,const bool t_x,const bool t_y)606   void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
607     TF_ASSERT_OK(root_.status());
608     // Generate random (but compatible) shapes for matrix multiplication.
609     std::vector<TensorShape> shapes;
610     RandMatMulShapes(is_batch, t_x, t_y, &shapes);
611     TensorShape x_shape = shapes[0];
612     TensorShape y_shape = shapes[1];
613     TensorShape z_shape = shapes[2];
614     auto x =
615         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
616     auto y =
617         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
618     Output z;
619     if (is_batch) {
620       z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
621     } else {
622       z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
623     }
624 
625     float max_error;
626     TF_ASSERT_OK((ComputeGradientError<T, T, float>(
627         root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
628     EXPECT_LT(max_error, 1e-3);
629   }
630 
RandMatMulShapes(const bool is_batch,const bool tx,const bool ty,std::vector<TensorShape> * shapes)631   void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
632                         std::vector<TensorShape>* shapes) {
633     // Choose a random batch size in [1, 4]
634     const int b = 1 + (random::New64() % 4);
635     // z = MatMul(x, y)
636     const int m = Rand();
637     const int k = Rand();
638     const int n = Rand();
639 
640     TensorShape x_shape;
641     if (is_batch) {
642       // x.shape = [b, m, k]
643       x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
644     } else {
645       // x.shape = [m, k]
646       x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
647     }
648     shapes->push_back(x_shape);
649 
650     TensorShape y_shape;
651     if (is_batch) {
652       // y.shape = [b, k, n]
653       y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
654     } else {
655       // y.shape = [k, n]
656       y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
657     }
658     shapes->push_back(y_shape);
659 
660     TensorShape z_shape;
661     if (is_batch) {
662       // z.shape = [b, m, n]
663       z_shape = TensorShape({b, m, n});
664     } else {
665       // z.shape = [m, n]
666       z_shape = TensorShape({m, n});
667     }
668     shapes->push_back(z_shape);
669   }
670 
Rand()671   int Rand() { return 1 + (random::New64() % 10); }
672 
673   Scope root_;
674 };
675 
TEST_F(MathGradTest,MatMulGrad_NoTranspose)676 TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
677   TestMatMulGrad<float>(false, false, false);
678 }
679 
TEST_F(MathGradTest,MatMulComplexGrad_NoTranspose)680 TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
681   TestMatMulGrad<complex64>(false, false, false);
682 }
683 
TEST_F(MathGradTest,MatMulGrad_TransposeX)684 TEST_F(MathGradTest, MatMulGrad_TransposeX) {
685   TestMatMulGrad<float>(false, true, false);
686 }
687 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX)688 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
689   TestMatMulGrad<complex64>(false, true, false);
690 }
691 
TEST_F(MathGradTest,MatMulGrad_TransposeY)692 TEST_F(MathGradTest, MatMulGrad_TransposeY) {
693   TestMatMulGrad<float>(false, false, true);
694 }
695 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeY)696 TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
697   TestMatMulGrad<complex64>(false, false, true);
698 }
699 
TEST_F(MathGradTest,MatMulGrad_TransposeX_TransposeY)700 TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
701   TestMatMulGrad<float>(false, true, true);
702 }
703 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX_TransposeY)704 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
705   TestMatMulGrad<complex64>(false, true, true);
706 }
707 
TEST_F(MathGradTest,BatchMatMulGrad_NoTranspose)708 TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
709   TestMatMulGrad<float>(true, false, false);
710 }
711 
TEST_F(MathGradTest,BatchMatMulComplexGrad_NoTranspose)712 TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
713   TestMatMulGrad<complex64>(true, false, false);
714 }
715 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX)716 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
717   TestMatMulGrad<float>(true, true, false);
718 }
719 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX)720 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
721   TestMatMulGrad<complex64>(true, true, false);
722 }
723 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeY)724 TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
725   TestMatMulGrad<float>(true, false, true);
726 }
727 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeY)728 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
729   TestMatMulGrad<complex64>(true, false, true);
730 }
731 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX_TransposeY)732 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
733   TestMatMulGrad<float>(true, true, true);
734 }
735 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX_TransposeY)736 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
737   TestMatMulGrad<complex64>(true, true, true);
738 }
739 
740 class NaryGradTest : public ::testing::Test {
741  protected:
NaryGradTest()742   NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
743 
RunTest(const OutputList & xs,const std::vector<TensorShape> & x_shapes,const OutputList & ys,const std::vector<TensorShape> & y_shapes)744   void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
745                const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
746     TF_ASSERT_OK(scope_.status());
747     float max_error;
748     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
749         scope_, xs, x_shapes, ys, y_shapes, &max_error)));
750     EXPECT_LT(max_error, 1e-3);
751   }
752 
RunTest(const Output & x,const Tensor & x_init_value,const Output & y,const TensorShape & y_shape)753   void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
754                const TensorShape& y_shape) {
755     TF_ASSERT_OK(scope_.status());
756     float max_error;
757     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
758         scope_, x, x_init_value, y, y_shape, &max_error)));
759     EXPECT_LT(max_error, 1e-3);
760   }
761 
762   Scope scope_;
763 };
764 
TEST_F(NaryGradTest,Sum)765 TEST_F(NaryGradTest, Sum) {
766   TensorShape x_shape({2, 3, 5, 7});
767   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
768   auto y = Sum(scope_, x, {1, -1});
769   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
770   TensorShape y_shape({2, 5});
771   RunTest({x}, {x_shape}, {y}, {y_shape});
772 }
773 
TEST_F(NaryGradTest,Mean)774 TEST_F(NaryGradTest, Mean) {
775   TensorShape x_shape({2, 3, 5, 7});
776   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
777   auto y = Mean(scope_, x, {1, -1});
778   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
779   TensorShape y_shape({2, 5});
780   RunTest({x}, {x_shape}, {y}, {y_shape});
781 }
782 
TEST_F(NaryGradTest,Min)783 TEST_F(NaryGradTest, Min) {
784   TensorShape x_shape({2, 3});
785   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
786   auto y = Min(scope_, x, {-1});
787   // y's shape is the result of reducing x along axes -1 (= 1)
788   TensorShape y_shape({2});
789   Tensor x_init_value =
790       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
791   RunTest(x, x_init_value, y, y_shape);
792 }
793 
TEST_F(NaryGradTest,Max)794 TEST_F(NaryGradTest, Max) {
795   TensorShape x_shape({2, 3});
796   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
797   auto y = Max(scope_, x, {-1});
798   // y's shape is the result of reducing x along axes -1 (= 1)
799   TensorShape y_shape({2});
800   Tensor x_init_value =
801       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
802   RunTest(x, x_init_value, y, y_shape);
803 }
804 
TEST_F(NaryGradTest,MinMulti)805 TEST_F(NaryGradTest, MinMulti) {
806   // Test gradient when there are multiple minima.
807   // Note that we cannot directly use a test Tensor with multiple
808   // minima, as the numeric estimator will calculate incorrect
809   // gradients when perturbing each entry in the Tensor (which then
810   // changes how many minima exist.)
811   // Instead, we use a single input that broadcast-multiplies a larger
812   // tensor with equal values, and apply reduce_min to the multiplied
813   // result.
814   TensorShape x_shape({1});
815   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
816   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
817   auto y = Min(scope_, all_same, {0});
818   // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
819   TensorShape y_shape({1});
820   RunTest({x}, {x_shape}, {y}, {y_shape});
821 }
822 
TEST_F(NaryGradTest,MaxMulti)823 TEST_F(NaryGradTest, MaxMulti) {
824   TensorShape x_shape({1});
825   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
826   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
827   auto y = Max(scope_, all_same, {0});
828   TensorShape y_shape({1});
829   RunTest({x}, {x_shape}, {y}, {y_shape});
830 }
831 
TEST_F(NaryGradTest,AddN)832 TEST_F(NaryGradTest, AddN) {
833   TensorShape shape({3, 2, 5});
834   std::vector<Output> xs;
835   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
836   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
837   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
838   auto y = AddN(scope_, xs);
839   RunTest(xs, {shape, shape, shape}, {y}, {shape});
840 }
841 
TEST_F(NaryGradTest,Add)842 TEST_F(NaryGradTest, Add) {
843   TensorShape x1_shape({3, 2, 5});
844   TensorShape x2_shape({2, 5});
845   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
846   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
847   auto y = Add(scope_, x1, x2);
848   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
849 }
850 
TEST_F(NaryGradTest,AddV2)851 TEST_F(NaryGradTest, AddV2) {
852   TensorShape x1_shape({3, 2, 5});
853   TensorShape x2_shape({2, 5});
854   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
855   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
856   auto y = AddV2(scope_, x1, x2);
857   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
858 }
859 
TEST_F(NaryGradTest,Sub)860 TEST_F(NaryGradTest, Sub) {
861   TensorShape x1_shape({3, 2, 5});
862   TensorShape x2_shape({2, 5});
863   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
864   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
865   auto y = Sub(scope_, x1, x2);
866   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
867 }
868 
TEST_F(NaryGradTest,Mul)869 TEST_F(NaryGradTest, Mul) {
870   TensorShape x1_shape({3, 2, 5});
871   TensorShape x2_shape({2, 5});
872   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
873   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
874   auto y = Mul(scope_, x1, x2);
875   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
876 }
877 
TEST_F(NaryGradTest,Div)878 TEST_F(NaryGradTest, Div) {
879   TensorShape x_shape({3, 2, 5});
880   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
881   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
882   // division errors in the numeric estimator used by the gradient checker.
883   auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
884   RunTest({x}, {x_shape}, {y}, {x_shape});
885 }
886 
TEST_F(NaryGradTest,RealDiv)887 TEST_F(NaryGradTest, RealDiv) {
888   TensorShape x_shape({3, 2, 5});
889   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
890   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
891   // division errors in the numeric estimator used by the gradient checker.
892   auto y =
893       RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
894   RunTest({x}, {x_shape}, {y}, {x_shape});
895 }
896 
TEST_F(NaryGradTest,DivNoNan)897 TEST_F(NaryGradTest, DivNoNan) {
898   {
899     TensorShape x_shape({3, 2, 5});
900     const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
901     // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
902     // division errors in the numeric estimator used by the gradient checker.
903     const auto y = DivNoNan(
904         scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
905     RunTest({x}, {x_shape}, {y}, {x_shape});
906   }
907   {
908     // Return 0 gradient (rather than NaN) for division by zero.
909     const auto x = Placeholder(scope_, DT_FLOAT);
910     const auto zero = Const<float>(scope_, 0.0);
911     const auto y = DivNoNan(scope_, x, zero);
912 
913     std::vector<Output> grad_outputs;
914     TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
915     ClientSession session(scope_);
916     std::vector<Tensor> grad_result;
917     TF_EXPECT_OK(
918         session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
919     EXPECT_EQ(grad_result.size(), 1);
920     EXPECT_EQ(grad_result[0].NumElements(), 3);
921     EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
922     EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
923     EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
924   }
925 }
926 
TEST_F(NaryGradTest,SquaredDifference)927 TEST_F(NaryGradTest, SquaredDifference) {
928   TensorShape x1_shape({3, 2, 5});
929   TensorShape x2_shape({2, 5});
930   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
931   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
932   auto y = SquaredDifference(scope_, x1, x2);
933   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
934 }
935 
TEST_F(NaryGradTest,Pow)936 TEST_F(NaryGradTest, Pow) {
937   TensorShape shape({3});
938   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
939   // fix exponent to avoid overflow
940   auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
941   RunTest({x}, {shape}, {y}, {shape});
942 }
943 
TEST_F(NaryGradTest,Maximum)944 TEST_F(NaryGradTest, Maximum) {
945   TensorShape shape({3, 2});
946   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
947   auto y = Maximum(scope_, x, Const(scope_, 1.0f));
948   // Select values away from 1.0f to avoid instability when computing
949   // finite differences.
950   Tensor x_init_value =
951       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
952   RunTest(x, x_init_value, y, shape);
953 }
954 
TEST_F(NaryGradTest,Minimum)955 TEST_F(NaryGradTest, Minimum) {
956   TensorShape shape({3, 2});
957   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
958   auto y = Minimum(scope_, x, Const(scope_, 1.0f));
959   // Select values away from 1.0f to avoid instability when computing
960   // finite differences.
961   Tensor x_init_value =
962       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
963   RunTest(x, x_init_value, y, shape);
964 }
965 
TEST_F(NaryGradTest,Prod)966 TEST_F(NaryGradTest, Prod) {
967   TensorShape x_shape({2, 3, 2});
968   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
969   auto y = Prod(scope_, x, {1});
970   // y's shape is the result of reducing x along axes 1
971   TensorShape y_shape({2, 1, 2});
972   RunTest({x}, {x_shape}, {y}, {y_shape});
973 }
974 
TEST_F(NaryGradTest,SegmentSum)975 TEST_F(NaryGradTest, SegmentSum) {
976   TensorShape x_shape({3, 4});
977   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
978   auto y = SegmentSum(scope_, x, {0, 0, 1});
979   // the sum is always on the first dimension
980   TensorShape y_shape({2, 4});
981   RunTest({x}, {x_shape}, {y}, {y_shape});
982 }
983 
984 class CumsumGradTest
985     : public NaryGradTest,
986       public ::testing::WithParamInterface<std::tuple<bool, bool, int>> {};
987 
TEST_P(CumsumGradTest,CumsumGrad)988 TEST_P(CumsumGradTest, CumsumGrad) {
989   int axis = std::get<2>(GetParam());
990 
991   TensorShape shape({2, 3, 2});
992   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
993   Cumsum::Attrs attrs;
994   attrs.exclusive_ = std::get<0>(GetParam());
995   attrs.reverse_ = std::get<1>(GetParam());
996   auto y = Cumsum(scope_, x, axis, attrs);
997   RunTest({x}, {shape}, {y}, {shape});
998 }
999 
1000 INSTANTIATE_TEST_SUITE_P(CumsumGrad, CumsumGradTest,
1001                          ::testing::Combine(::testing::Bool(),
1002                                             ::testing::Bool(),
1003                                             ::testing::Range(0, 2)));
1004 
TEST_F(NaryGradTest,CastGrad)1005 TEST_F(NaryGradTest, CastGrad) {
1006   TensorShape shape({2, 3, 2});
1007   auto x = Placeholder(scope_, DT_DOUBLE, Placeholder::Shape(shape));
1008   auto y = Cast(scope_, x, DT_FLOAT);
1009   TF_ASSERT_OK(scope_.status());
1010   double max_error;
1011   TF_ASSERT_OK((ComputeGradientError<double, float, double>(
1012       scope_, {x}, {shape}, {y}, {shape}, &max_error)));
1013   EXPECT_LT(max_error, 1e-3);
1014 }
1015 
TEST_F(NaryGradTest,Select)1016 TEST_F(NaryGradTest, Select) {
1017   TensorShape shape({1, 3});
1018   auto cond = Const<bool>(scope_, {{false, true, true}});
1019   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1020   auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1021   auto z = Where3(scope_, cond, x, y);
1022   RunTest({x, y}, {shape, shape}, {z}, {shape});
1023 }
1024 
TEST_F(NaryGradTest,SelectV2_Basic)1025 TEST_F(NaryGradTest, SelectV2_Basic) {
1026   TensorShape shape({1, 3});
1027   auto cond = Const<bool>(scope_, {{false, true, true}});
1028   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1029   auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1030   auto z = SelectV2(scope_, cond, x, y);
1031   RunTest({x, y}, {shape, shape}, {z}, {shape});
1032 }
1033 
TEST_F(NaryGradTest,SelectV2_Broadcast)1034 TEST_F(NaryGradTest, SelectV2_Broadcast) {
1035   TensorShape x_shape({2, 3});
1036   TensorShape y_shape({});
1037   auto cond = Const<bool>(scope_, {{false, true, true}, {true, true, false}});
1038   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1039   auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(y_shape));
1040   auto z = SelectV2(scope_, cond, x, y);
1041   RunTest({x, y}, {x_shape, y_shape}, {z}, {x_shape});
1042 }
1043 
TEST_F(NaryGradTest,SelectV2_Broadcast2)1044 TEST_F(NaryGradTest, SelectV2_Broadcast2) {
1045   TensorShape x_shape({2, 3});
1046   auto cond = Const<bool>(scope_, {{false}, {true}});
1047   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1048   auto y = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
1049   auto z = SelectV2(scope_, cond, x, y);
1050   RunTest({x, y}, {x_shape, x_shape}, {z}, {x_shape});
1051 }
1052 
TEST_F(NaryGradTest,Atan2Grad)1053 TEST_F(NaryGradTest, Atan2Grad) {
1054   TensorShape shape({3, 2, 5});
1055   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1056   // Test with x2 = 1 + |x1| to avoid regions where the gradient
1057   // is unstable and might cause problems for the numeric estimator.
1058   auto x2 =
1059       Div(scope_, x1, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x1)));
1060   auto y = Atan2(scope_, x1, x2);
1061   RunTest({x1}, {shape}, {y}, {shape});
1062 }
1063 
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad)1064 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad) {
1065   TensorShape shape({3, 2, 5});
1066   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1067   auto segment_ids = Const(scope_, {0, 0, 1});
1068   auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1069   TensorShape y_shape({2, 2, 5});
1070   RunTest({x}, {shape}, {y}, {y_shape});
1071 }
1072 
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad_Int64Ids)1073 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_Int64Ids) {
1074   TensorShape shape({3, 2, 5});
1075   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1076   auto segment_ids = Const(scope_, {0ll, 0ll, 1ll});
1077   auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2);
1078   TensorShape y_shape({2, 2, 5});
1079   RunTest({x}, {shape}, {y}, {y_shape});
1080 }
1081 
TEST_F(NaryGradTest,UnsortedSegmentMaxGrad_NegativeIds)1082 TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_NegativeIds) {
1083   TensorShape shape({3, 2, 5});
1084   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1085   auto segment_ids = Const(scope_, {0, 0, -1});
1086   auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/1);
1087   TensorShape y_shape({1, 2, 5});
1088   RunTest({x}, {shape}, {y}, {y_shape});
1089 }
1090 
TEST_F(NaryGradTest,UnsortedSegmentMinGrad)1091 TEST_F(NaryGradTest, UnsortedSegmentMinGrad) {
1092   TensorShape shape({3, 2, 5});
1093   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1094   auto segment_ids = Const(scope_, {0, 0, 1});
1095   auto y = UnsortedSegmentMin(scope_, x, segment_ids, /*num_segments=*/2);
1096   TensorShape y_shape({2, 2, 5});
1097   RunTest({x}, {shape}, {y}, {y_shape});
1098 }
1099 
TEST_F(NaryGradTest,UnsortedSegmentSumGrad)1100 TEST_F(NaryGradTest, UnsortedSegmentSumGrad) {
1101   TensorShape shape({3, 2, 5});
1102   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
1103   auto segment_ids = Const(scope_, {0, 0, 1});
1104   auto y = UnsortedSegmentSum(scope_, x, segment_ids, /*num_segments=*/2);
1105   TensorShape y_shape({2, 2, 5});
1106   RunTest({x}, {shape}, {y}, {y_shape});
1107 }
1108 
1109 }  // namespace
1110 }  // namespace tensorflow
1111