1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
18
19 #include <numeric>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace test {
28
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32 Tensor ret(DataTypeToEnum<T>::value, {});
33 ret.scalar<T>()() = val;
34 return ret;
35 }
36
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40 Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
41 std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42 return ret;
43 }
44
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48 Tensor ret;
49 CHECK(ret.CopyFrom(AsTensor(vals), shape));
50 return ret;
51 }
52
53 // Fills in '*tensor' with 'vals'. E.g.,
54 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 // test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58 auto flat = tensor->flat<T>();
59 CHECK_EQ(flat.size(), vals.size());
60 if (flat.size() > 0) {
61 std::copy_n(vals.data(), vals.size(), flat.data());
62 }
63 }
64
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68 auto flat = tensor->flat<T>();
69 CHECK_EQ(flat.size(), vals.size());
70 if (flat.size() > 0) {
71 size_t i = 0;
72 for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73 flat(i) = T(*itr);
74 }
75 }
76 }
77
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 // test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83 auto flat = tensor->flat<T>();
84 std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 // test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92 auto flat = tensor->flat<T>();
93 for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95
96 // Expects "x" and "y" are tensors of the same type, same shape, and
97 // identical values.
98 template <typename T>
99 void ExpectTensorEqual(const Tensor& x, const Tensor& y);
100
101 // Expects "x" and "y" are tensors of the same type, same shape, and
102 // approximate equal values, each within "abs_err".
103 template <typename T>
104 void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
105
106 // Expects "x" and "y" are tensors of the same type (float or double),
107 // same shape and element-wise difference between x and y is no more
108 // than atol + rtol * abs(x). If atol or rtol is negative, it is replaced
109 // with a default tolerance value = data type's epsilon * kSlackFactor.
110 void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
111 double rtol = -1.0);
112
113 // Implementation details.
114
115 namespace internal {
116
117 template <typename T>
118 struct is_floating_point_type {
119 static const bool value = std::is_same<T, Eigen::half>::value ||
120 std::is_same<T, float>::value ||
121 std::is_same<T, double>::value ||
122 std::is_same<T, std::complex<float> >::value ||
123 std::is_same<T, std::complex<double> >::value;
124 };
125
126 template <typename T>
ExpectEqual(const T & a,const T & b)127 inline void ExpectEqual(const T& a, const T& b) {
128 EXPECT_EQ(a, b);
129 }
130
131 template <>
132 inline void ExpectEqual<float>(const float& a, const float& b) {
133 EXPECT_FLOAT_EQ(a, b);
134 }
135
136 template <>
137 inline void ExpectEqual<double>(const double& a, const double& b) {
138 EXPECT_DOUBLE_EQ(a, b);
139 }
140
141 template <>
142 inline void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
143 EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
144 EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
145 }
146
147 template <>
148 inline void ExpectEqual<complex128>(const complex128& a, const complex128& b) {
149 EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b;
150 EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b;
151 }
152
AssertSameTypeDims(const Tensor & x,const Tensor & y)153 inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
154 ASSERT_EQ(x.dtype(), y.dtype());
155 ASSERT_TRUE(x.IsSameSize(y))
156 << "x.shape [" << x.shape().DebugString() << "] vs "
157 << "y.shape [ " << y.shape().DebugString() << "]";
158 }
159
160 template <typename T, bool is_fp = is_floating_point_type<T>::value>
161 struct Expector;
162
163 template <typename T>
164 struct Expector<T, false> {
165 static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
166
167 static void Equal(const Tensor& x, const Tensor& y) {
168 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
169 AssertSameTypeDims(x, y);
170 const auto size = x.NumElements();
171 const T* a = x.flat<T>().data();
172 const T* b = y.flat<T>().data();
173 for (int i = 0; i < size; ++i) {
174 ExpectEqual(a[i], b[i]);
175 }
176 }
177 };
178
179 // Partial specialization for float and double.
180 template <typename T>
181 struct Expector<T, true> {
182 static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
183
184 static void Equal(const Tensor& x, const Tensor& y) {
185 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
186 AssertSameTypeDims(x, y);
187 const auto size = x.NumElements();
188 const T* a = x.flat<T>().data();
189 const T* b = y.flat<T>().data();
190 for (int i = 0; i < size; ++i) {
191 ExpectEqual(a[i], b[i]);
192 }
193 }
194
195 static bool Near(const T& a, const T& b, const double abs_err) {
196 // Need a == b so that infinities are close to themselves.
197 return (a == b) ||
198 (static_cast<double>(Eigen::numext::abs(a - b)) <= abs_err);
199 }
200
201 static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
202 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
203 AssertSameTypeDims(x, y);
204 const auto size = x.NumElements();
205 const T* a = x.flat<T>().data();
206 const T* b = y.flat<T>().data();
207 for (int i = 0; i < size; ++i) {
208 EXPECT_TRUE(Near(a[i], b[i], abs_err))
209 << "a = " << a[i] << " b = " << b[i] << " index = " << i;
210 }
211 }
212 };
213
214 template <typename T>
215 struct Helper {
216 // Assumes atol and rtol are nonnegative.
217 static bool IsClose(const T& x, const T& y, const T& atol, const T& rtol) {
218 // Need x == y so that infinities are close to themselves.
219 return (x == y) ||
220 (Eigen::numext::abs(x - y) <= atol + rtol * Eigen::numext::abs(x));
221 }
222 };
223
224 template <typename T>
225 struct Helper<std::complex<T>> {
226 static bool IsClose(const std::complex<T>& x, const std::complex<T>& y,
227 const T& atol, const T& rtol) {
228 return Helper<T>::IsClose(x.real(), y.real(), atol, rtol) &&
229 Helper<T>::IsClose(x.imag(), y.imag(), atol, rtol);
230 }
231 };
232
233 } // namespace internal
234
235 template <typename T>
236 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
237 internal::Expector<T>::Equal(x, y);
238 }
239
240 template <typename T>
241 void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
242 static_assert(internal::is_floating_point_type<T>::value,
243 "T is not a floating point types.");
244 ASSERT_GE(abs_err, 0.0) << "abs_error is negative" << abs_err;
245 internal::Expector<T>::Near(x, y, abs_err);
246 }
247
248 } // namespace test
249 } // namespace tensorflow
250
251 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
252