• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
18 
19 #include <numeric>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 namespace test {
28 
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32   Tensor ret(DataTypeToEnum<T>::value, {});
33   ret.scalar<T>()() = val;
34   return ret;
35 }
36 
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40   Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64_t>(vals.size())});
41   std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42   return ret;
43 }
44 
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48   Tensor ret;
49   CHECK(ret.CopyFrom(AsTensor(vals), shape));
50   return ret;
51 }
52 
53 // Fills in '*tensor' with 'vals'. E.g.,
54 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 //   test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58   auto flat = tensor->flat<T>();
59   CHECK_EQ(flat.size(), vals.size());
60   if (flat.size() > 0) {
61     std::copy_n(vals.data(), vals.size(), flat.data());
62   }
63 }
64 
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68   auto flat = tensor->flat<T>();
69   CHECK_EQ(flat.size(), vals.size());
70   if (flat.size() > 0) {
71     size_t i = 0;
72     for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73       flat(i) = T(*itr);
74     }
75   }
76 }
77 
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 //   test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83   auto flat = tensor->flat<T>();
84   std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86 
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 //   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 //   test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92   auto flat = tensor->flat<T>();
93   for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95 
96 // Expects "x" and "y" are tensors of the same type, same shape, and identical
97 // values (within 4 ULPs for floating point types unless explicitly disabled).
98 enum class Tolerance {
99   kNone,
100   kDefault,
101 };
102 void ExpectEqual(const Tensor& x, const Tensor& y,
103                  Tolerance t = Tolerance ::kDefault);
104 
105 // Expects "x" and "y" are tensors of the same (floating point) type,
106 // same shape and element-wise difference between x and y is no more
107 // than atol + rtol * abs(x). If atol or rtol is negative, the data type's
108 // epsilon * kSlackFactor is used.
109 void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
110                  double rtol = -1.0);
111 
112 // Expects "x" and "y" are tensors of the same type T, same shape, and
113 // equal values. Consider using ExpectEqual above instead.
114 template <typename T>
ExpectTensorEqual(const Tensor & x,const Tensor & y)115 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
116   EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
117   ExpectEqual(x, y);
118 }
119 
120 ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y);
121 ::testing::AssertionResult IsSameShape(const Tensor& x, const Tensor& y);
122 
123 template <typename T>
ExpectTensorEqual(const Tensor & x,const Tensor & y,std::function<bool (const T &,const T &)> is_equal)124 void ExpectTensorEqual(const Tensor& x, const Tensor& y,
125                        std::function<bool(const T&, const T&)> is_equal) {
126   EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
127   ASSERT_TRUE(IsSameType(x, y));
128   ASSERT_TRUE(IsSameShape(x, y));
129 
130   const T* Tx = x.unaligned_flat<T>().data();
131   const T* Ty = y.unaligned_flat<T>().data();
132   auto size = x.NumElements();
133   int max_failures = 10;
134   int num_failures = 0;
135   for (decltype(size) i = 0; i < size; ++i) {
136     EXPECT_TRUE(is_equal(Tx[i], Ty[i])) << "i = " << (++num_failures, i);
137     ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
138   }
139 }
140 
141 // Expects "x" and "y" are tensors of the same type T, same shape, and
142 // approximate equal values. Consider using ExpectClose above instead.
143 template <typename T>
ExpectTensorNear(const Tensor & x,const Tensor & y,double atol)144 void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
145   EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
146   ExpectClose(x, y, atol, /*rtol=*/0.0);
147 }
148 
149 // For tensor_testutil_test only.
150 namespace internal_test {
151 ::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
152                                    double atol = -1.0, double rtol = -1.0);
153 ::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
154                                    double rtol = -1.0);
155 ::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
156                                    double rtol = -1.0);
157 }  // namespace internal_test
158 
159 }  // namespace test
160 }  // namespace tensorflow
161 
162 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
163