• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
18 
19 #include <numeric>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 namespace test {
28 
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32   Tensor ret(DataTypeToEnum<T>::value, {});
33   ret.scalar<T>()() = val;
34   return ret;
35 }
36 
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40   Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
41   std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42   return ret;
43 }
44 
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48   Tensor ret;
49   CHECK(ret.CopyFrom(AsTensor(vals), shape));
50   return ret;
51 }
52 
53 // Fills in '*tensor' with 'vals'. E.g.,
54 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 //   test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58   auto flat = tensor->flat<T>();
59   CHECK_EQ(flat.size(), vals.size());
60   if (flat.size() > 0) {
61     std::copy_n(vals.data(), vals.size(), flat.data());
62   }
63 }
64 
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68   auto flat = tensor->flat<T>();
69   CHECK_EQ(flat.size(), vals.size());
70   if (flat.size() > 0) {
71     size_t i = 0;
72     for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73       flat(i) = T(*itr);
74     }
75   }
76 }
77 
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 //   test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83   auto flat = tensor->flat<T>();
84   std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86 
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 //   test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92   auto flat = tensor->flat<T>();
93   for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95 
96 // Expects "x" and "y" are tensors of the same type, same shape, and
97 // identical values.
98 template <typename T>
99 void ExpectTensorEqual(const Tensor& x, const Tensor& y);
100 
101 // Expects "x" and "y" are tensors of the same type, same shape, and
102 // approximate equal values, each within "abs_err".
103 template <typename T>
104 void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
105 
106 // Expects "x" and "y" are tensors of the same type (float or double),
107 // same shape and element-wise difference between x and y is no more
108 // than atol + rtol * abs(x).
109 void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
110                  double rtol = 1e-6);
111 
112 // Implementation details.
113 
114 namespace internal {
115 
116 template <typename T>
117 struct is_floating_point_type {
118   static const bool value = std::is_same<T, Eigen::half>::value ||
119                             std::is_same<T, float>::value ||
120                             std::is_same<T, double>::value ||
121                             std::is_same<T, std::complex<float> >::value ||
122                             std::is_same<T, std::complex<double> >::value;
123 };
124 
125 template <typename T>
ExpectEqual(const T & a,const T & b)126 inline void ExpectEqual(const T& a, const T& b) {
127   EXPECT_EQ(a, b);
128 }
129 
130 template <>
131 inline void ExpectEqual<float>(const float& a, const float& b) {
132   EXPECT_FLOAT_EQ(a, b);
133 }
134 
135 template <>
136 inline void ExpectEqual<double>(const double& a, const double& b) {
137   EXPECT_DOUBLE_EQ(a, b);
138 }
139 
140 template <>
141 inline void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
142   EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
143   EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
144 }
145 
146 template <>
147 inline void ExpectEqual<complex128>(const complex128& a, const complex128& b) {
148   EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b;
149   EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b;
150 }
151 
AssertSameTypeDims(const Tensor & x,const Tensor & y)152 inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
153   ASSERT_EQ(x.dtype(), y.dtype());
154   ASSERT_TRUE(x.IsSameSize(y))
155       << "x.shape [" << x.shape().DebugString() << "] vs "
156       << "y.shape [ " << y.shape().DebugString() << "]";
157 }
158 
159 template <typename T, bool is_fp = is_floating_point_type<T>::value>
160 struct Expector;
161 
162 template <typename T>
163 struct Expector<T, false> {
164   static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
165 
166   static void Equal(const Tensor& x, const Tensor& y) {
167     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
168     AssertSameTypeDims(x, y);
169     const auto size = x.NumElements();
170     const T* a = x.flat<T>().data();
171     const T* b = y.flat<T>().data();
172     for (int i = 0; i < size; ++i) {
173       ExpectEqual(a[i], b[i]);
174     }
175   }
176 };
177 
178 // Partial specialization for float and double.
179 template <typename T>
180 struct Expector<T, true> {
181   static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
182 
183   static void Equal(const Tensor& x, const Tensor& y) {
184     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
185     AssertSameTypeDims(x, y);
186     const auto size = x.NumElements();
187     const T* a = x.flat<T>().data();
188     const T* b = y.flat<T>().data();
189     for (int i = 0; i < size; ++i) {
190       ExpectEqual(a[i], b[i]);
191     }
192   }
193 
194   static void Near(const T& a, const T& b, const double abs_err, int index) {
195     if (a != b) {  // Takes care of inf.
196       EXPECT_LE(double(Eigen::numext::abs(a - b)), abs_err)
197           << "a = " << a << " b = " << b << " index = " << index;
198     }
199   }
200 
201   static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
202     ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
203     AssertSameTypeDims(x, y);
204     const auto size = x.NumElements();
205     const T* a = x.flat<T>().data();
206     const T* b = y.flat<T>().data();
207     for (int i = 0; i < size; ++i) {
208       Near(a[i], b[i], abs_err, i);
209     }
210   }
211 };
212 
213 }  // namespace internal
214 
215 template <typename T>
216 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
217   internal::Expector<T>::Equal(x, y);
218 }
219 
220 template <typename T>
221 void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
222   static_assert(internal::is_floating_point_type<T>::value,
223                 "T is not a floating point types.");
224   internal::Expector<T>::Near(x, y, abs_err);
225 }
226 
227 }  // namespace test
228 }  // namespace tensorflow
229 
230 #endif  // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
231