• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/grad_op_registry.h"
18 #include "tensorflow/cc/framework/gradient_checker.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/framework/testutil.h"
21 #include "tensorflow/cc/gradients/grad_testutil.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/random/random.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 using ops::Abs;
31 using ops::Add;
32 using ops::AddN;
33 using ops::BatchMatMul;
34 using ops::Const;
35 using ops::Div;
36 using ops::DivNoNan;
37 using ops::MatMul;
38 using ops::Max;
39 using ops::Maximum;
40 using ops::Mean;
41 using ops::Min;
42 using ops::Minimum;
43 using ops::Mul;
44 using ops::Placeholder;
45 using ops::Pow;
46 using ops::Prod;
47 using ops::RealDiv;
48 using ops::SegmentSum;
49 using ops::SquaredDifference;
50 using ops::Sub;
51 using ops::Sum;
52 
53 // TODO(andydavis) Test gradient function against numeric gradients output.
54 // TODO(andydavis) As more gradients are added move common test functions
55 // to a testutil library.
56 
57 class CWiseUnaryGradTest : public ::testing::Test {
58  protected:
CWiseUnaryGradTest()59   CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
60 
61   enum UnaryOpType {
62     ABS,
63     NEG,
64     INV,
65     SQUARE,
66     SQRT,
67     RSQRT,
68     EXP,
69     EXPM1,
70     LOG,
71     LOG1P,
72     SINH,
73     COSH,
74     TANH,
75     ASINH,
76     ACOSH,
77     ATANH,
78     SIGMOID,
79     SIGN,
80     SIN,
81     COS,
82     ASIN,
83     ACOS,
84     TAN,
85     ATAN,
86     REAL,
87     IMAG,
88     CONJ,
89     COMPLEX,
90     ANGLE,
91     LGAMMA,
92     ERF
93   };
94 
95   template <typename X_T, typename Y_T>
TestCWiseGrad(UnaryOpType op_type,const std::function<X_T (int)> & x_fn)96   void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) {
97     TF_ASSERT_OK(scope_.status());
98     DataType x_type = DataTypeToEnum<X_T>::v();
99     TensorShape shape({2, 3, 2});
100     auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape));
101     Tensor x_data(x_type, shape);
102     auto x_data_flat = x_data.flat<X_T>();
103     for (int i = 0; i < x_data_flat.size(); ++i) {
104       x_data_flat(i) = x_fn(i);
105     }
106 
107     Output y;
108     switch (op_type) {
109       using namespace ops;  // NOLINT(build/namespaces)
110       case ABS:
111         y = Abs(scope_, x);
112         break;
113       case NEG:
114         y = Neg(scope_, x);
115         break;
116       case INV:
117         y = Reciprocal(scope_, x);
118         break;
119       case SQUARE:
120         y = Square(scope_, x);
121         break;
122       case SQRT:
123         y = Sqrt(scope_, x);
124         break;
125       case RSQRT:
126         y = Rsqrt(scope_, x);
127         break;
128       case EXP:
129         y = Exp(scope_, x);
130         break;
131       case EXPM1:
132         y = Expm1(scope_, x);
133         break;
134       case LOG:
135         y = Log(scope_, x);
136         break;
137       case LOG1P:
138         y = Log1p(scope_, x);
139         break;
140       case SINH:
141         y = Sinh(scope_, x);
142         break;
143       case COSH:
144         y = Cosh(scope_, x);
145         break;
146       case TANH:
147         y = Tanh(scope_, x);
148         break;
149       case ASINH:
150         y = Asinh(scope_, x);
151         break;
152       case ACOSH:
153         y = Acosh(scope_, x);
154         break;
155       case ATANH:
156         y = Atanh(scope_, x);
157         break;
158       case SIGMOID:
159         y = Sigmoid(scope_, x);
160         break;
161       case SIGN:
162         y = Sign(scope_, x);
163         break;
164       case SIN:
165         y = Sin(scope_, x);
166         break;
167       case COS:
168         y = Cos(scope_, x);
169         break;
170       case ASIN:
171         y = Asin(scope_, x);
172         break;
173       case ACOS:
174         y = Acos(scope_, x);
175         break;
176       case TAN:
177         y = Tan(scope_, x);
178         break;
179       case ATAN:
180         y = Atan(scope_, x);
181         break;
182       case REAL:
183         y = Real(scope_, x);
184         break;
185       case IMAG:
186         y = Imag(scope_, x);
187         break;
188       case CONJ:
189         y = Conj(scope_, x);
190         break;
191       case COMPLEX:
192         y = Complex(scope_, x, x);
193         break;
194       case ANGLE:
195         y = Angle(scope_, x);
196         break;
197       case LGAMMA:
198         y = Lgamma(scope_, x);
199         break;
200       case ERF:
201         y = Erf(scope_, x);
202         break;
203     }
204 
205     float max_error;
206     TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y,
207                                                         shape, &max_error)));
208     EXPECT_LT(max_error, 1e-3f);
209   }
210 
RV(const std::vector<float> & v)211   float RV(const std::vector<float>& v) {
212     return v[random::New64() % v.size()];
213   }
214 
CRV(const std::vector<complex64> & v)215   complex64 CRV(const std::vector<complex64>& v) {
216     return v[random::New64() % v.size()];
217   }
218 
conjugate(const complex64 & val)219   complex64 conjugate(const complex64& val) {
220     return complex64(val.real(), -val.imag());
221   }
222 
223   Scope scope_;
224 };
225 
TEST_F(CWiseUnaryGradTest,Abs)226 TEST_F(CWiseUnaryGradTest, Abs) {
227   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
228   TestCWiseGrad<float, float>(ABS, x_fn);
229 }
230 
TEST_F(CWiseUnaryGradTest,Neg)231 TEST_F(CWiseUnaryGradTest, Neg) {
232   auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
233   TestCWiseGrad<float, float>(NEG, x_fn);
234 }
235 
TEST_F(CWiseUnaryGradTest,Reciprocal)236 TEST_F(CWiseUnaryGradTest, Reciprocal) {
237   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
238   TestCWiseGrad<float, float>(INV, x_fn);
239 }
240 
TEST_F(CWiseUnaryGradTest,Reciprocal_Complex)241 TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) {
242   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
243   TestCWiseGrad<complex64, complex64>(INV, x_fn);
244 }
245 
TEST_F(CWiseUnaryGradTest,Square)246 TEST_F(CWiseUnaryGradTest, Square) {
247   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
248   TestCWiseGrad<float, float>(SQUARE, x_fn);
249 }
250 
TEST_F(CWiseUnaryGradTest,Square_Complex)251 TEST_F(CWiseUnaryGradTest, Square_Complex) {
252   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
253   TestCWiseGrad<complex64, complex64>(SQUARE, x_fn);
254 }
255 
TEST_F(CWiseUnaryGradTest,Sqrt)256 TEST_F(CWiseUnaryGradTest, Sqrt) {
257   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); };
258   TestCWiseGrad<float, float>(SQRT, x_fn);
259 }
260 
TEST_F(CWiseUnaryGradTest,Sqrt_Complex)261 TEST_F(CWiseUnaryGradTest, Sqrt_Complex) {
262   auto x_fn = [this](const int i) {
263     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
264   };
265   TestCWiseGrad<complex64, complex64>(SQRT, x_fn);
266 }
267 
TEST_F(CWiseUnaryGradTest,Rsqrt)268 TEST_F(CWiseUnaryGradTest, Rsqrt) {
269   auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
270   TestCWiseGrad<float, float>(RSQRT, x_fn);
271 }
272 
TEST_F(CWiseUnaryGradTest,Rsqrt_Complex)273 TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) {
274   auto x_fn = [this](const int i) {
275     return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}});
276   };
277   TestCWiseGrad<complex64, complex64>(RSQRT, x_fn);
278 }
279 
TEST_F(CWiseUnaryGradTest,Exp)280 TEST_F(CWiseUnaryGradTest, Exp) {
281   auto x_fn = [this](const int i) {
282     return RV({0, -1, 1, -1.5f, 1.5f, -2, 2});
283   };
284   TestCWiseGrad<float, float>(EXP, x_fn);
285 }
286 
TEST_F(CWiseUnaryGradTest,Exp_Complex)287 TEST_F(CWiseUnaryGradTest, Exp_Complex) {
288   auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); };
289   TestCWiseGrad<complex64, complex64>(EXP, x_fn);
290 }
291 
TEST_F(CWiseUnaryGradTest,Expm1)292 TEST_F(CWiseUnaryGradTest, Expm1) {
293   auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); };
294   TestCWiseGrad<float, float>(EXPM1, x_fn);
295 }
296 
TEST_F(CWiseUnaryGradTest,Expm1_Complex)297 TEST_F(CWiseUnaryGradTest, Expm1_Complex) {
298   auto x_fn = [this](const int i) {
299     return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}});
300   };
301   TestCWiseGrad<complex64, complex64>(EXPM1, x_fn);
302 }
303 
TEST_F(CWiseUnaryGradTest,Log)304 TEST_F(CWiseUnaryGradTest, Log) {
305   auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); };
306   TestCWiseGrad<float, float>(LOG, x_fn);
307 }
308 
TEST_F(CWiseUnaryGradTest,Log_Complex)309 TEST_F(CWiseUnaryGradTest, Log_Complex) {
310   auto x_fn = [this](const int i) {
311     return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}});
312   };
313   TestCWiseGrad<complex64, complex64>(LOG, x_fn);
314 }
315 
TEST_F(CWiseUnaryGradTest,Log1p)316 TEST_F(CWiseUnaryGradTest, Log1p) {
317   auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); };
318   TestCWiseGrad<float, float>(LOG1P, x_fn);
319 }
320 
TEST_F(CWiseUnaryGradTest,Log1p_Complex)321 TEST_F(CWiseUnaryGradTest, Log1p_Complex) {
322   auto x_fn = [this](const int i) {
323     return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}});
324   };
325   TestCWiseGrad<complex64, complex64>(LOG1P, x_fn);
326 }
327 
TEST_F(CWiseUnaryGradTest,Sinh)328 TEST_F(CWiseUnaryGradTest, Sinh) {
329   auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); };
330   TestCWiseGrad<float, float>(SINH, x_fn);
331 }
332 
TEST_F(CWiseUnaryGradTest,Sinh_Complex)333 TEST_F(CWiseUnaryGradTest, Sinh_Complex) {
334   auto x_fn = [this](const int i) {
335     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
336   };
337   TestCWiseGrad<complex64, complex64>(SINH, x_fn);
338 }
339 
TEST_F(CWiseUnaryGradTest,Cosh)340 TEST_F(CWiseUnaryGradTest, Cosh) {
341   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
342   TestCWiseGrad<float, float>(COSH, x_fn);
343 }
344 
TEST_F(CWiseUnaryGradTest,Cosh_Complex)345 TEST_F(CWiseUnaryGradTest, Cosh_Complex) {
346   auto x_fn = [this](const int i) {
347     return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}});
348   };
349   TestCWiseGrad<complex64, complex64>(COSH, x_fn);
350 }
351 
TEST_F(CWiseUnaryGradTest,Tanh)352 TEST_F(CWiseUnaryGradTest, Tanh) {
353   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
354   TestCWiseGrad<float, float>(TANH, x_fn);
355 }
356 
TEST_F(CWiseUnaryGradTest,Tanh_Complex)357 TEST_F(CWiseUnaryGradTest, Tanh_Complex) {
358   auto x_fn = [this](const int i) {
359     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
360   };
361   TestCWiseGrad<complex64, complex64>(TANH, x_fn);
362 }
363 
TEST_F(CWiseUnaryGradTest,Asinh)364 TEST_F(CWiseUnaryGradTest, Asinh) {
365   auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); };
366   TestCWiseGrad<float, float>(ASINH, x_fn);
367 }
368 
TEST_F(CWiseUnaryGradTest,Asinh_Complex)369 TEST_F(CWiseUnaryGradTest, Asinh_Complex) {
370   auto x_fn = [this](const int i) {
371     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
372   };
373   TestCWiseGrad<complex64, complex64>(ASINH, x_fn);
374 }
375 
TEST_F(CWiseUnaryGradTest,Acosh)376 TEST_F(CWiseUnaryGradTest, Acosh) {
377   auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); };
378   TestCWiseGrad<float, float>(ACOSH, x_fn);
379 }
380 
TEST_F(CWiseUnaryGradTest,Acosh_Complex)381 TEST_F(CWiseUnaryGradTest, Acosh_Complex) {
382   auto x_fn = [this](const int i) {
383     return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}});
384   };
385   TestCWiseGrad<complex64, complex64>(ACOSH, x_fn);
386 }
387 
TEST_F(CWiseUnaryGradTest,Atanh)388 TEST_F(CWiseUnaryGradTest, Atanh) {
389   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); };
390   TestCWiseGrad<float, float>(ATANH, x_fn);
391 }
392 
TEST_F(CWiseUnaryGradTest,Atanh_Complex)393 TEST_F(CWiseUnaryGradTest, Atanh_Complex) {
394   auto x_fn = [this](const int i) {
395     return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}});
396   };
397   TestCWiseGrad<complex64, complex64>(ATANH, x_fn);
398 }
399 
TEST_F(CWiseUnaryGradTest,Sigmoid)400 TEST_F(CWiseUnaryGradTest, Sigmoid) {
401   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
402   TestCWiseGrad<float, float>(SIGMOID, x_fn);
403 }
404 
TEST_F(CWiseUnaryGradTest,Sigmoid_Complex)405 TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) {
406   auto x_fn = [this](const int i) {
407     return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}});
408   };
409   TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn);
410 }
411 
TEST_F(CWiseUnaryGradTest,Sign)412 TEST_F(CWiseUnaryGradTest, Sign) {
413   auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); };
414   TestCWiseGrad<float, float>(SIGN, x_fn);
415 }
416 
TEST_F(CWiseUnaryGradTest,Sin)417 TEST_F(CWiseUnaryGradTest, Sin) {
418   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
419   TestCWiseGrad<float, float>(SIN, x_fn);
420 }
421 
TEST_F(CWiseUnaryGradTest,Sin_Complex)422 TEST_F(CWiseUnaryGradTest, Sin_Complex) {
423   auto x_fn = [this](const int i) {
424     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
425   };
426   TestCWiseGrad<complex64, complex64>(SIN, x_fn);
427 }
428 
TEST_F(CWiseUnaryGradTest,Cos)429 TEST_F(CWiseUnaryGradTest, Cos) {
430   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
431   TestCWiseGrad<float, float>(COS, x_fn);
432 }
433 
TEST_F(CWiseUnaryGradTest,Cos_Complex)434 TEST_F(CWiseUnaryGradTest, Cos_Complex) {
435   auto x_fn = [this](const int i) {
436     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}});
437   };
438   TestCWiseGrad<complex64, complex64>(COS, x_fn);
439 }
440 
TEST_F(CWiseUnaryGradTest,Asin)441 TEST_F(CWiseUnaryGradTest, Asin) {
442   auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); };
443   TestCWiseGrad<float, float>(ASIN, x_fn);
444 }
445 
TEST_F(CWiseUnaryGradTest,Asin_Complex)446 TEST_F(CWiseUnaryGradTest, Asin_Complex) {
447   auto x_fn = [this](const int i) {
448     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
449   };
450   // TODO(kbsriram)
451   // Enable test when the asin kernel supports complex numbers
452   if (false) {
453     TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
454   }
455 }
456 
TEST_F(CWiseUnaryGradTest,Acos)457 TEST_F(CWiseUnaryGradTest, Acos) {
458   auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); };
459   TestCWiseGrad<float, float>(ACOS, x_fn);
460 }
461 
TEST_F(CWiseUnaryGradTest,Acos_Complex)462 TEST_F(CWiseUnaryGradTest, Acos_Complex) {
463   auto x_fn = [this](const int i) {
464     return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}});
465   };
466   // TODO(kbsriram)
467   // Add test when the acos kernel supports complex numbers
468   if (false) {
469     TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
470   }
471 }
472 
TEST_F(CWiseUnaryGradTest,Tan)473 TEST_F(CWiseUnaryGradTest, Tan) {
474   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
475   TestCWiseGrad<float, float>(TAN, x_fn);
476 }
477 
TEST_F(CWiseUnaryGradTest,Tan_Complex)478 TEST_F(CWiseUnaryGradTest, Tan_Complex) {
479   auto x_fn = [this](const int i) {
480     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
481   };
482   TestCWiseGrad<complex64, complex64>(TAN, x_fn);
483 }
484 
TEST_F(CWiseUnaryGradTest,Atan)485 TEST_F(CWiseUnaryGradTest, Atan) {
486   auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
487   TestCWiseGrad<float, float>(ATAN, x_fn);
488 }
489 
TEST_F(CWiseUnaryGradTest,Atan_Complex)490 TEST_F(CWiseUnaryGradTest, Atan_Complex) {
491   auto x_fn = [this](const int i) {
492     return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
493   };
494   // TODO(kbsriram)
495   // Add test when the atan kernel supports complex numbers
496   if (false) {
497     TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
498   }
499 }
500 
TEST_F(CWiseUnaryGradTest,Real)501 TEST_F(CWiseUnaryGradTest, Real) {
502   auto x_fn = [this](const int i) {
503     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
504   };
505   TestCWiseGrad<complex64, float>(REAL, x_fn);
506 }
507 
TEST_F(CWiseUnaryGradTest,Imag)508 TEST_F(CWiseUnaryGradTest, Imag) {
509   auto x_fn = [this](const int i) {
510     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
511   };
512   TestCWiseGrad<complex64, float>(IMAG, x_fn);
513 }
514 
TEST_F(CWiseUnaryGradTest,Conj)515 TEST_F(CWiseUnaryGradTest, Conj) {
516   auto x_fn = [this](const int i) {
517     return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}});
518   };
519   TestCWiseGrad<complex64, complex64>(CONJ, x_fn);
520 }
521 
TEST_F(CWiseUnaryGradTest,Complex)522 TEST_F(CWiseUnaryGradTest, Complex) {
523   auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); };
524   TestCWiseGrad<float, complex64>(COMPLEX, x_fn);
525 }
526 
TEST_F(CWiseUnaryGradTest,Angle)527 TEST_F(CWiseUnaryGradTest, Angle) {
528   auto x_fn = [this](const int i) {
529     return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}});
530   };
531   TestCWiseGrad<complex64, float>(ANGLE, x_fn);
532 }
533 
TEST_F(CWiseUnaryGradTest,Lgamma)534 TEST_F(CWiseUnaryGradTest, Lgamma) {
535   auto x_fn = [this](const int i) {
536     return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
537   };
538   TestCWiseGrad<float, float>(LGAMMA, x_fn);
539 }
540 
TEST_F(CWiseUnaryGradTest,Lgamma_Complex)541 TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
542   auto x_fn = [this](const int i) {
543     return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
544   };
545   // TODO(kbsriram)
546   // Add test when the lgamma kernel supports complex numbers
547   if (false) {
548     TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
549   }
550 }
551 
TEST_F(CWiseUnaryGradTest,Erf)552 TEST_F(CWiseUnaryGradTest, Erf) {
553   auto x_fn = [this](const int i) {
554     return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
555   };
556   TestCWiseGrad<float, float>(ERF, x_fn);
557 }
558 
TEST_F(CWiseUnaryGradTest,Erf_Complex)559 TEST_F(CWiseUnaryGradTest, Erf_Complex) {
560   auto x_fn = [this](const int i) {
561     return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
562   };
563   // TODO(kbsriram)
564   // Add test when the erf kernel supports complex numbers
565   if (false) {
566     TestCWiseGrad<complex64, complex64>(ERF, x_fn);
567   }
568 }
569 
570 class MathGradTest : public ::testing::Test {
571  protected:
MathGradTest()572   MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
573 
574   template <typename T>
TestMatMulGrad(const bool is_batch,const bool t_x,const bool t_y)575   void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) {
576     TF_ASSERT_OK(root_.status());
577     // Generate random (but compatible) shapes for matrix multiplication.
578     std::vector<TensorShape> shapes;
579     RandMatMulShapes(is_batch, t_x, t_y, &shapes);
580     TensorShape x_shape = shapes[0];
581     TensorShape y_shape = shapes[1];
582     TensorShape z_shape = shapes[2];
583     auto x =
584         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(x_shape));
585     auto y =
586         Placeholder(root_, DataTypeToEnum<T>::v(), Placeholder::Shape(y_shape));
587     Output z;
588     if (is_batch) {
589       z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y));
590     } else {
591       z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
592     }
593 
594     float max_error;
595     TF_ASSERT_OK((ComputeGradientError<T, T, float>(
596         root_, {x, y}, {x_shape, y_shape}, {z}, {z_shape}, &max_error)));
597     EXPECT_LT(max_error, 1e-3);
598   }
599 
RandMatMulShapes(const bool is_batch,const bool tx,const bool ty,std::vector<TensorShape> * shapes)600   void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty,
601                         std::vector<TensorShape>* shapes) {
602     // Choose a random batch size in [1, 4]
603     const int b = 1 + (random::New64() % 4);
604     // z = MatMul(x, y)
605     const int m = Rand();
606     const int k = Rand();
607     const int n = Rand();
608 
609     TensorShape x_shape;
610     if (is_batch) {
611       // x.shape = [b, m, k]
612       x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k});
613     } else {
614       // x.shape = [m, k]
615       x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
616     }
617     shapes->push_back(x_shape);
618 
619     TensorShape y_shape;
620     if (is_batch) {
621       // y.shape = [b, k, n]
622       y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n});
623     } else {
624       // y.shape = [k, n]
625       y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
626     }
627     shapes->push_back(y_shape);
628 
629     TensorShape z_shape;
630     if (is_batch) {
631       // z.shape = [b, m, n]
632       z_shape = TensorShape({b, m, n});
633     } else {
634       // z.shape = [m, n]
635       z_shape = TensorShape({m, n});
636     }
637     shapes->push_back(z_shape);
638   }
639 
Rand()640   int Rand() { return 1 + (random::New64() % 10); }
641 
642   Scope root_;
643 };
644 
TEST_F(MathGradTest,MatMulGrad_NoTranspose)645 TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
646   TestMatMulGrad<float>(false, false, false);
647 }
648 
TEST_F(MathGradTest,MatMulComplexGrad_NoTranspose)649 TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) {
650   TestMatMulGrad<complex64>(false, false, false);
651 }
652 
TEST_F(MathGradTest,MatMulGrad_TransposeX)653 TEST_F(MathGradTest, MatMulGrad_TransposeX) {
654   TestMatMulGrad<float>(false, true, false);
655 }
656 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX)657 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) {
658   TestMatMulGrad<complex64>(false, true, false);
659 }
660 
TEST_F(MathGradTest,MatMulGrad_TransposeY)661 TEST_F(MathGradTest, MatMulGrad_TransposeY) {
662   TestMatMulGrad<float>(false, false, true);
663 }
664 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeY)665 TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) {
666   TestMatMulGrad<complex64>(false, false, true);
667 }
668 
TEST_F(MathGradTest,MatMulGrad_TransposeX_TransposeY)669 TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
670   TestMatMulGrad<float>(false, true, true);
671 }
672 
TEST_F(MathGradTest,MatMulComplexGrad_TransposeX_TransposeY)673 TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) {
674   TestMatMulGrad<complex64>(false, true, true);
675 }
676 
TEST_F(MathGradTest,BatchMatMulGrad_NoTranspose)677 TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) {
678   TestMatMulGrad<float>(true, false, false);
679 }
680 
TEST_F(MathGradTest,BatchMatMulComplexGrad_NoTranspose)681 TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) {
682   TestMatMulGrad<complex64>(true, false, false);
683 }
684 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX)685 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) {
686   TestMatMulGrad<float>(true, true, false);
687 }
688 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX)689 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) {
690   TestMatMulGrad<complex64>(true, true, false);
691 }
692 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeY)693 TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) {
694   TestMatMulGrad<float>(true, false, true);
695 }
696 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeY)697 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) {
698   TestMatMulGrad<complex64>(true, false, true);
699 }
700 
TEST_F(MathGradTest,BatchMatMulGrad_TransposeX_TransposeY)701 TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) {
702   TestMatMulGrad<float>(true, true, true);
703 }
704 
TEST_F(MathGradTest,BatchMatMulComplexGrad_TransposeX_TransposeY)705 TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) {
706   TestMatMulGrad<complex64>(true, true, true);
707 }
708 
709 class NaryGradTest : public ::testing::Test {
710  protected:
NaryGradTest()711   NaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
712 
RunTest(const OutputList & xs,const std::vector<TensorShape> & x_shapes,const OutputList & ys,const std::vector<TensorShape> & y_shapes)713   void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
714                const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
715     TF_ASSERT_OK(scope_.status());
716     float max_error;
717     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
718         scope_, xs, x_shapes, ys, y_shapes, &max_error)));
719     EXPECT_LT(max_error, 1e-3);
720   }
721 
RunTest(const Output & x,const Tensor & x_init_value,const Output & y,const TensorShape & y_shape)722   void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
723                const TensorShape& y_shape) {
724     TF_ASSERT_OK(scope_.status());
725     float max_error;
726     TF_ASSERT_OK((ComputeGradientError<float, float, float>(
727         scope_, x, x_init_value, y, y_shape, &max_error)));
728     EXPECT_LT(max_error, 1e-3);
729   }
730 
731   Scope scope_;
732 };
733 
TEST_F(NaryGradTest,Sum)734 TEST_F(NaryGradTest, Sum) {
735   TensorShape x_shape({2, 3, 5, 7});
736   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
737   auto y = Sum(scope_, x, {1, -1});
738   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
739   TensorShape y_shape({2, 5});
740   RunTest({x}, {x_shape}, {y}, {y_shape});
741 }
742 
TEST_F(NaryGradTest,Mean)743 TEST_F(NaryGradTest, Mean) {
744   TensorShape x_shape({2, 3, 5, 7});
745   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
746   auto y = Mean(scope_, x, {1, -1});
747   // y's shape is the result of reducing x along axes 1 and -1 (= 3)
748   TensorShape y_shape({2, 5});
749   RunTest({x}, {x_shape}, {y}, {y_shape});
750 }
751 
TEST_F(NaryGradTest,Min)752 TEST_F(NaryGradTest, Min) {
753   TensorShape x_shape({2, 3});
754   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
755   auto y = Min(scope_, x, {-1});
756   // y's shape is the result of reducing x along axes -1 (= 1)
757   TensorShape y_shape({2});
758   Tensor x_init_value =
759       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
760   RunTest(x, x_init_value, y, y_shape);
761 }
762 
TEST_F(NaryGradTest,Max)763 TEST_F(NaryGradTest, Max) {
764   TensorShape x_shape({2, 3});
765   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
766   auto y = Max(scope_, x, {-1});
767   // y's shape is the result of reducing x along axes -1 (= 1)
768   TensorShape y_shape({2});
769   Tensor x_init_value =
770       test::AsTensor<float>({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape);
771   RunTest(x, x_init_value, y, y_shape);
772 }
773 
TEST_F(NaryGradTest,MinMulti)774 TEST_F(NaryGradTest, MinMulti) {
775   // Test gradient when there are multiple minima.
776   // Note that we cannot directly use a test Tensor with multiple
777   // minima, as the numeric estimator will calculate incorrect
778   // gradients when perturbing each entry in the Tensor (which then
779   // changes how many minima exist.)
780   // Instead, we use a single input that broadcast-multiplies a larger
781   // tensor with equal values, and apply reduce_min to the multiplied
782   // result.
783   TensorShape x_shape({1});
784   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
785   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
786   auto y = Min(scope_, all_same, {0});
787   // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped
788   TensorShape y_shape({1});
789   RunTest({x}, {x_shape}, {y}, {y_shape});
790 }
791 
TEST_F(NaryGradTest,MaxMulti)792 TEST_F(NaryGradTest, MaxMulti) {
793   TensorShape x_shape({1});
794   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
795   auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x);
796   auto y = Max(scope_, all_same, {0});
797   TensorShape y_shape({1});
798   RunTest({x}, {x_shape}, {y}, {y_shape});
799 }
800 
TEST_F(NaryGradTest,AddN)801 TEST_F(NaryGradTest, AddN) {
802   TensorShape shape({3, 2, 5});
803   std::vector<Output> xs;
804   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
805   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
806   xs.push_back(Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)));
807   auto y = AddN(scope_, xs);
808   RunTest(xs, {shape, shape, shape}, {y}, {shape});
809 }
810 
TEST_F(NaryGradTest,Add)811 TEST_F(NaryGradTest, Add) {
812   TensorShape x1_shape({3, 2, 5});
813   TensorShape x2_shape({2, 5});
814   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
815   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
816   auto y = Add(scope_, x1, x2);
817   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
818 }
819 
TEST_F(NaryGradTest,Sub)820 TEST_F(NaryGradTest, Sub) {
821   TensorShape x1_shape({3, 2, 5});
822   TensorShape x2_shape({2, 5});
823   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
824   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
825   auto y = Sub(scope_, x1, x2);
826   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
827 }
828 
TEST_F(NaryGradTest,Mul)829 TEST_F(NaryGradTest, Mul) {
830   TensorShape x1_shape({3, 2, 5});
831   TensorShape x2_shape({2, 5});
832   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
833   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
834   auto y = Mul(scope_, x1, x2);
835   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
836 }
837 
TEST_F(NaryGradTest,Div)838 TEST_F(NaryGradTest, Div) {
839   TensorShape x_shape({3, 2, 5});
840   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
841   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
842   // division errors in the numeric estimator used by the gradient checker.
843   auto y = Div(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
844   RunTest({x}, {x_shape}, {y}, {x_shape});
845 }
846 
TEST_F(NaryGradTest,RealDiv)847 TEST_F(NaryGradTest, RealDiv) {
848   TensorShape x_shape({3, 2, 5});
849   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
850   // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
851   // division errors in the numeric estimator used by the gradient checker.
852   auto y =
853       RealDiv(scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
854   RunTest({x}, {x_shape}, {y}, {x_shape});
855 }
856 
TEST_F(NaryGradTest,DivNoNan)857 TEST_F(NaryGradTest, DivNoNan) {
858   {
859     TensorShape x_shape({3, 2, 5});
860     const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
861     // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
862     // division errors in the numeric estimator used by the gradient checker.
863     const auto y = DivNoNan(
864         scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
865     RunTest({x}, {x_shape}, {y}, {x_shape});
866   }
867   {
868     // Return 0 gradient (rather than NaN) for division by zero.
869     const auto x = Placeholder(scope_, DT_FLOAT);
870     const auto zero = Const<float>(scope_, 0.0);
871     const auto y = DivNoNan(scope_, x, zero);
872 
873     std::vector<Output> grad_outputs;
874     TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
875     ClientSession session(scope_);
876     std::vector<Tensor> grad_result;
877     TF_EXPECT_OK(
878         session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
879     EXPECT_EQ(grad_result.size(), 1);
880     EXPECT_EQ(grad_result[0].NumElements(), 3);
881     EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
882     EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
883     EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
884   }
885 }
886 
TEST_F(NaryGradTest,SquaredDifference)887 TEST_F(NaryGradTest, SquaredDifference) {
888   TensorShape x1_shape({3, 2, 5});
889   TensorShape x2_shape({2, 5});
890   auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x1_shape));
891   auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x2_shape));
892   auto y = SquaredDifference(scope_, x1, x2);
893   RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
894 }
895 
TEST_F(NaryGradTest,Pow)896 TEST_F(NaryGradTest, Pow) {
897   TensorShape shape({3});
898   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
899   // fix exponent to avoid overflow
900   auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
901   RunTest({x}, {shape}, {y}, {shape});
902 }
903 
TEST_F(NaryGradTest,Maximum)904 TEST_F(NaryGradTest, Maximum) {
905   TensorShape shape({3, 2});
906   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
907   auto y = Maximum(scope_, x, Const(scope_, 1.0f));
908   // Select values away from 1.0f to avoid instability when computing
909   // finite differences.
910   Tensor x_init_value =
911       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
912   RunTest(x, x_init_value, y, shape);
913 }
914 
TEST_F(NaryGradTest,Minimum)915 TEST_F(NaryGradTest, Minimum) {
916   TensorShape shape({3, 2});
917   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
918   auto y = Minimum(scope_, x, Const(scope_, 1.0f));
919   // Select values away from 1.0f to avoid instability when computing
920   // finite differences.
921   Tensor x_init_value =
922       test::AsTensor<float>({0.5f, 1.5f, -1.2f, 3.0f, 0.1f, 2.8f}, {3, 2});
923   RunTest(x, x_init_value, y, shape);
924 }
925 
TEST_F(NaryGradTest,Prod)926 TEST_F(NaryGradTest, Prod) {
927   TensorShape x_shape({2, 3, 2});
928   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
929   auto y = Prod(scope_, x, {1});
930   // y's shape is the result of reducing x along axes 1
931   TensorShape y_shape({2, 1, 2});
932   RunTest({x}, {x_shape}, {y}, {y_shape});
933 }
934 
TEST_F(NaryGradTest,SegmentSum)935 TEST_F(NaryGradTest, SegmentSum) {
936   TensorShape x_shape({3, 4});
937   auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
938   auto y = SegmentSum(scope_, x, {0, 0, 1});
939   // the sum is always on the first dimension
940   TensorShape y_shape({2, 4});
941   RunTest({x}, {x_shape}, {y}, {y_shape});
942 }
943 
944 }  // namespace
945 }  // namespace tensorflow
946