1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
18
19 #include <numeric>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace test {
28
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32 Tensor ret(DataTypeToEnum<T>::value, {});
33 ret.scalar<T>()() = val;
34 return ret;
35 }
36
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40 Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64_t>(vals.size())});
41 std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42 return ret;
43 }
44
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48 Tensor ret;
49 CHECK(ret.CopyFrom(AsTensor(vals), shape));
50 return ret;
51 }
52
53 // Fills in '*tensor' with 'vals'. E.g.,
54 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 // test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58 auto flat = tensor->flat<T>();
59 CHECK_EQ(flat.size(), vals.size());
60 if (flat.size() > 0) {
61 std::copy_n(vals.data(), vals.size(), flat.data());
62 }
63 }
64
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68 auto flat = tensor->flat<T>();
69 CHECK_EQ(flat.size(), vals.size());
70 if (flat.size() > 0) {
71 size_t i = 0;
72 for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73 flat(i) = T(*itr);
74 }
75 }
76 }
77
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 // test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83 auto flat = tensor->flat<T>();
84 std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 // test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92 auto flat = tensor->flat<T>();
93 for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95
96 // Expects "x" and "y" are tensors of the same type, same shape, and identical
97 // values (within 4 ULPs for floating point types unless explicitly disabled).
98 enum class Tolerance {
99 kNone,
100 kDefault,
101 };
102 void ExpectEqual(const Tensor& x, const Tensor& y,
103 Tolerance t = Tolerance ::kDefault);
104
105 // Expects "x" and "y" are tensors of the same (floating point) type,
106 // same shape and element-wise difference between x and y is no more
107 // than atol + rtol * abs(x). If atol or rtol is negative, the data type's
108 // epsilon * kSlackFactor is used.
109 void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
110 double rtol = -1.0);
111
112 // Expects "x" and "y" are tensors of the same type T, same shape, and
113 // equal values. Consider using ExpectEqual above instead.
114 template <typename T>
ExpectTensorEqual(const Tensor & x,const Tensor & y)115 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
116 EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
117 ExpectEqual(x, y);
118 }
119
120 ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y);
121 ::testing::AssertionResult IsSameShape(const Tensor& x, const Tensor& y);
122
123 template <typename T>
ExpectTensorEqual(const Tensor & x,const Tensor & y,std::function<bool (const T &,const T &)> is_equal)124 void ExpectTensorEqual(const Tensor& x, const Tensor& y,
125 std::function<bool(const T&, const T&)> is_equal) {
126 EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
127 ASSERT_TRUE(IsSameType(x, y));
128 ASSERT_TRUE(IsSameShape(x, y));
129
130 const T* Tx = x.unaligned_flat<T>().data();
131 const T* Ty = y.unaligned_flat<T>().data();
132 auto size = x.NumElements();
133 int max_failures = 10;
134 int num_failures = 0;
135 for (decltype(size) i = 0; i < size; ++i) {
136 EXPECT_TRUE(is_equal(Tx[i], Ty[i])) << "i = " << (++num_failures, i);
137 ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
138 }
139 }
140
141 // Expects "x" and "y" are tensors of the same type T, same shape, and
142 // approximate equal values. Consider using ExpectClose above instead.
143 template <typename T>
ExpectTensorNear(const Tensor & x,const Tensor & y,double atol)144 void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
145 EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
146 ExpectClose(x, y, atol, /*rtol=*/0.0);
147 }
148
149 // For tensor_testutil_test only.
150 namespace internal_test {
151 ::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
152 double atol = -1.0, double rtol = -1.0);
153 ::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
154 double rtol = -1.0);
155 ::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
156 double rtol = -1.0);
157 } // namespace internal_test
158
159 } // namespace test
160 } // namespace tensorflow
161
162 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
163