• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
18 
19 #include <numeric>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 namespace test {
28 
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32   Tensor ret(DataTypeToEnum<T>::value, {});
33   ret.scalar<T>()() = val;
34   return ret;
35 }
36 
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40   Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
41   std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42   return ret;
43 }
44 
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48   Tensor ret;
49   CHECK(ret.CopyFrom(AsTensor(vals), shape));
50   return ret;
51 }
52 
53 // Fills in '*tensor' with 'vals'. E.g.,
54 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 //   test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58   auto flat = tensor->flat<T>();
59   CHECK_EQ(flat.size(), vals.size());
60   if (flat.size() > 0) {
61     std::copy_n(vals.data(), vals.size(), flat.data());
62   }
63 }
64 
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68   auto flat = tensor->flat<T>();
69   CHECK_EQ(flat.size(), vals.size());
70   if (flat.size() > 0) {
71     size_t i = 0;
72     for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73       flat(i) = T(*itr);
74     }
75   }
76 }
77 
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 //   test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83   auto flat = tensor->flat<T>();
84   std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86 
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 //   test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92   auto flat = tensor->flat<T>();
93   for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95 
96 // Expects "x" and "y" are tensors of the same type, same shape, and
97 // identical values.
98 template <typename T>
99 void ExpectTensorEqual(const Tensor& x, const Tensor& y);
100 
101 // Expects "x" and "y" are tensors of the same type, same shape, and
102 // approximate equal values, each within "abs_err".
103 template <typename T>
104 void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
105 
106 // Expects "x" and "y" are tensors of the same type (float or double),
107 // same shape and element-wise difference between x and y is no more
108 // than atol + rtol * abs(x). If atol or rtol is negative, it is replaced
109 // with a default tolerance value = data type's epsilon * kSlackFactor.
110 void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
111                  double rtol = -1.0);
112 
113 // Implementation details.
114 
115 namespace internal {
116 
117 template <typename T>
118 struct is_floating_point_type {
119   static const bool value = std::is_same<T, Eigen::half>::value ||
120                             std::is_same<T, float>::value ||
121                             std::is_same<T, double>::value ||
122                             std::is_same<T, std::complex<float> >::value ||
123                             std::is_same<T, std::complex<double> >::value;
124 };
125 
126 template <typename T>
ExpectEqual(const T & a,const T & b)127 inline void ExpectEqual(const T& a, const T& b) {
128   EXPECT_EQ(a, b);
129 }
130 
131 template <>
132 inline void ExpectEqual<float>(const float& a, const float& b) {
133   EXPECT_FLOAT_EQ(a, b);
134 }
135 
136 template <>
137 inline void ExpectEqual<double>(const double& a, const double& b) {
138   EXPECT_DOUBLE_EQ(a, b);
139 }
140 
141 template <>
142 inline void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
143   EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
144   EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
145 }
146 
147 template <>
148 inline void ExpectEqual<complex128>(const complex128& a, const complex128& b) {
149   EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b;
150   EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b;
151 }
152 
AssertSameTypeDims(const Tensor & x,const Tensor & y)153 inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
154   ASSERT_EQ(x.dtype(), y.dtype());
155   ASSERT_TRUE(x.IsSameSize(y))
156       << "x.shape [" << x.shape().DebugString() << "] vs "
157       << "y.shape [ " << y.shape().DebugString() << "]";
158 }
159 
160 template <typename T, bool is_fp = is_floating_point_type<T>::value>
161 struct Expector;
162 
163 template <typename T>
164 struct Expector<T, false> {
165   static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
166 
167   static void Equal(const Tensor& x, const Tensor& y) {
168     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
169     AssertSameTypeDims(x, y);
170     const auto size = x.NumElements();
171     const T* a = x.flat<T>().data();
172     const T* b = y.flat<T>().data();
173     for (int i = 0; i < size; ++i) {
174       ExpectEqual(a[i], b[i]);
175     }
176   }
177 };
178 
179 // Partial specialization for float and double.
180 template <typename T>
181 struct Expector<T, true> {
182   static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
183 
184   static void Equal(const Tensor& x, const Tensor& y) {
185     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
186     AssertSameTypeDims(x, y);
187     const auto size = x.NumElements();
188     const T* a = x.flat<T>().data();
189     const T* b = y.flat<T>().data();
190     for (int i = 0; i < size; ++i) {
191       ExpectEqual(a[i], b[i]);
192     }
193   }
194 
195   static bool Near(const T& a, const T& b, const double abs_err) {
196     // Need a == b so that infinities are close to themselves.
197     return (a == b) ||
198            (static_cast<double>(Eigen::numext::abs(a - b)) <= abs_err);
199   }
200 
201   static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
202     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
203     AssertSameTypeDims(x, y);
204     const auto size = x.NumElements();
205     const T* a = x.flat<T>().data();
206     const T* b = y.flat<T>().data();
207     for (int i = 0; i < size; ++i) {
208       EXPECT_TRUE(Near(a[i], b[i], abs_err))
209           << "a = " << a[i] << " b = " << b[i] << " index = " << i;
210     }
211   }
212 };
213 
214 template <typename T>
215 struct Helper {
216   // Assumes atol and rtol are nonnegative.
217   static bool IsClose(const T& x, const T& y, const T& atol, const T& rtol) {
218     // Need x == y so that infinities are close to themselves.
219     return (x == y) ||
220            (Eigen::numext::abs(x - y) <= atol + rtol * Eigen::numext::abs(x));
221   }
222 };
223 
224 template <typename T>
225 struct Helper<std::complex<T>> {
226   static bool IsClose(const std::complex<T>& x, const std::complex<T>& y,
227                       const T& atol, const T& rtol) {
228     return Helper<T>::IsClose(x.real(), y.real(), atol, rtol) &&
229            Helper<T>::IsClose(x.imag(), y.imag(), atol, rtol);
230   }
231 };
232 
233 }  // namespace internal
234 
235 template <typename T>
236 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
237   internal::Expector<T>::Equal(x, y);
238 }
239 
240 template <typename T>
241 void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
242   static_assert(internal::is_floating_point_type<T>::value,
243                 "T is not a floating point types.");
244   ASSERT_GE(abs_err, 0.0) << "abs_error is negative" << abs_err;
245   internal::Expector<T>::Near(x, y, abs_err);
246 }
247 
248 }  // namespace test
249 }  // namespace tensorflow
250 
251 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
252