1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
18
19 #include <numeric>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace test {
28
29 // Constructs a scalar tensor with 'val'.
30 template <typename T>
AsScalar(const T & val)31 Tensor AsScalar(const T& val) {
32 Tensor ret(DataTypeToEnum<T>::value, {});
33 ret.scalar<T>()() = val;
34 return ret;
35 }
36
37 // Constructs a flat tensor with 'vals'.
38 template <typename T>
AsTensor(gtl::ArraySlice<T> vals)39 Tensor AsTensor(gtl::ArraySlice<T> vals) {
40 Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
41 std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
42 return ret;
43 }
44
45 // Constructs a tensor of "shape" with values "vals".
46 template <typename T>
AsTensor(gtl::ArraySlice<T> vals,const TensorShape & shape)47 Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
48 Tensor ret;
49 CHECK(ret.CopyFrom(AsTensor(vals), shape));
50 return ret;
51 }
52
53 // Fills in '*tensor' with 'vals'. E.g.,
54 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
55 // test::FillValues<float>(&x, {11, 21, 21, 22});
56 template <typename T>
FillValues(Tensor * tensor,gtl::ArraySlice<T> vals)57 void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
58 auto flat = tensor->flat<T>();
59 CHECK_EQ(flat.size(), vals.size());
60 if (flat.size() > 0) {
61 std::copy_n(vals.data(), vals.size(), flat.data());
62 }
63 }
64
65 // Fills in '*tensor' with 'vals', converting the types as needed.
66 template <typename T, typename SrcType>
FillValues(Tensor * tensor,std::initializer_list<SrcType> vals)67 void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
68 auto flat = tensor->flat<T>();
69 CHECK_EQ(flat.size(), vals.size());
70 if (flat.size() > 0) {
71 size_t i = 0;
72 for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
73 flat(i) = T(*itr);
74 }
75 }
76 }
77
78 // Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
79 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
80 // test::FillIota<float>(&x, 1.0);
81 template <typename T>
FillIota(Tensor * tensor,const T & val)82 void FillIota(Tensor* tensor, const T& val) {
83 auto flat = tensor->flat<T>();
84 std::iota(flat.data(), flat.data() + flat.size(), val);
85 }
86
87 // Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
88 // Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
89 // test::FillFn<float>(&x, [](int i)->float { return i*i; });
90 template <typename T>
FillFn(Tensor * tensor,std::function<T (int)> fn)91 void FillFn(Tensor* tensor, std::function<T(int)> fn) {
92 auto flat = tensor->flat<T>();
93 for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
94 }
95
96 // Expects "x" and "y" are tensors of the same type, same shape, and
97 // identical values (within 4 ULPs for floating point types).
98 void ExpectEqual(const Tensor& x, const Tensor& y);
99
100 // Expects "x" and "y" are tensors of the same (floating point) type,
101 // same shape and element-wise difference between x and y is no more
102 // than atol + rtol * abs(x). If atol or rtol is negative, the data type's
103 // epsilon * kSlackFactor is used.
104 void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
105 double rtol = -1.0);
106
107 // Expects "x" and "y" are tensors of the same type T, same shape, and
108 // equal values. Consider using ExpectEqual above instead.
109 template <typename T>
ExpectTensorEqual(const Tensor & x,const Tensor & y)110 void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
111 EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
112 ExpectEqual(x, y);
113 }
114
115 // Expects "x" and "y" are tensors of the same type T, same shape, and
116 // approximate equal values. Consider using ExpectClose above instead.
117 template <typename T>
ExpectTensorNear(const Tensor & x,const Tensor & y,double atol)118 void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
119 EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
120 ExpectClose(x, y, atol, /*rtol=*/0.0);
121 }
122
123 // For tensor_testutil_test only.
124 namespace internal_test {
125 ::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
126 double atol = -1.0, double rtol = -1.0);
127 ::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
128 double rtol = -1.0);
129 ::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
130 double rtol = -1.0);
131 } // namespace internal_test
132
133 } // namespace test
134 } // namespace tensorflow
135
136 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
137