1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
18
19 #include <numeric>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace test {
28
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32 Tensor ret(DataTypeToEnum<T>::value, {});
33 ret.scalar<T>()() = val;
34 return ret;
35 }
36
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40 Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
41 std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42 return ret;
43 }
44
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48 Tensor ret;
49 CHECK(ret.CopyFrom(AsTensor(vals), shape));
50 return ret;
51 }
52
53 // Fills in '*tensor' with 'vals'. E.g.,
54 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 // test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58 auto flat = tensor->flat<T>();
59 CHECK_EQ(flat.size(), vals.size());
60 if (flat.size() > 0) {
61 std::copy_n(vals.data(), vals.size(), flat.data());
62 }
63 }
64
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68 auto flat = tensor->flat<T>();
69 CHECK_EQ(flat.size(), vals.size());
70 if (flat.size() > 0) {
71 size_t i = 0;
72 for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73 flat(i) = T(*itr);
74 }
75 }
76 }
77
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 // test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83 auto flat = tensor->flat<T>();
84 std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 // test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92 auto flat = tensor->flat<T>();
93 for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95
96 // Expects "x" and "y" are tensors of the same type, same shape, and
97 // identical values.
98 template <typename T>
99 void ExpectTensorEqual(const Tensor& x, const Tensor& y);
100
101 // Expects "x" and "y" are tensors of the same type, same shape, and
102 // approximate equal values, each within "abs_err".
103 template <typename T>
104 void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
105
106 // Expects "x" and "y" are tensors of the same type (float or double),
107 // same shape and element-wise difference between x and y is no more
108 // than atol + rtol * abs(x).
109 void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
110 double rtol = 1e-6);
111
112 // Implementation details.
113
114 namespace internal {
115
116 template <typename T>
117 struct is_floating_point_type {
118 static const bool value = std::is_same<T, Eigen::half>::value ||
119 std::is_same<T, float>::value ||
120 std::is_same<T, double>::value ||
121 std::is_same<T, std::complex<float> >::value ||
122 std::is_same<T, std::complex<double> >::value;
123 };
124
125 template <typename T>
ExpectEqual(const T & a,const T & b)126 inline void ExpectEqual(const T& a, const T& b) {
127 EXPECT_EQ(a, b);
128 }
129
130 template <>
131 inline void ExpectEqual<float>(const float& a, const float& b) {
132 EXPECT_FLOAT_EQ(a, b);
133 }
134
135 template <>
136 inline void ExpectEqual<double>(const double& a, const double& b) {
137 EXPECT_DOUBLE_EQ(a, b);
138 }
139
140 template <>
141 inline void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
142 EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
143 EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
144 }
145
146 template <>
147 inline void ExpectEqual<complex128>(const complex128& a, const complex128& b) {
148 EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b;
149 EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b;
150 }
151
AssertSameTypeDims(const Tensor & x,const Tensor & y)152 inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
153 ASSERT_EQ(x.dtype(), y.dtype());
154 ASSERT_TRUE(x.IsSameSize(y))
155 << "x.shape [" << x.shape().DebugString() << "] vs "
156 << "y.shape [ " << y.shape().DebugString() << "]";
157 }
158
159 template <typename T, bool is_fp = is_floating_point_type<T>::value>
160 struct Expector;
161
162 template <typename T>
163 struct Expector<T, false> {
164 static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
165
166 static void Equal(const Tensor& x, const Tensor& y) {
167 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
168 AssertSameTypeDims(x, y);
169 const auto size = x.NumElements();
170 const T* a = x.flat<T>().data();
171 const T* b = y.flat<T>().data();
172 for (int i = 0; i < size; ++i) {
173 ExpectEqual(a[i], b[i]);
174 }
175 }
176 };
177
178 // Partial specialization for float and double.
179 template <typename T>
180 struct Expector<T, true> {
181 static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
182
183 static void Equal(const Tensor& x, const Tensor& y) {
184 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
185 AssertSameTypeDims(x, y);
186 const auto size = x.NumElements();
187 const T* a = x.flat<T>().data();
188 const T* b = y.flat<T>().data();
189 for (int i = 0; i < size; ++i) {
190 ExpectEqual(a[i], b[i]);
191 }
192 }
193
194 static void Near(const T& a, const T& b, const double abs_err, int index) {
195 if (a != b) { // Takes care of inf.
196 EXPECT_LE(double(Eigen::numext::abs(a - b)), abs_err)
197 << "a = " << a << " b = " << b << " index = " << index;
198 }
199 }
200
201 static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
202 ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
203 AssertSameTypeDims(x, y);
204 const auto size = x.NumElements();
205 const T* a = x.flat<T>().data();
206 const T* b = y.flat<T>().data();
207 for (int i = 0; i < size; ++i) {
208 Near(a[i], b[i], abs_err, i);
209 }
210 }
211 };
212
213 } // namespace internal
214
215 template <typename T>
216 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
217 internal::Expector<T>::Equal(x, y);
218 }
219
220 template <typename T>
221 void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
222 static_assert(internal::is_floating_point_type<T>::value,
223 "T is not a floating point types.");
224 internal::Expector<T>::Near(x, y, abs_err);
225 }
226
227 } // namespace test
228 } // namespace tensorflow
229
230 #endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
231