1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/experimental/base/public/tensorhandle.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include <memory>
22
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/cc/experimental/base/public/runtime.h"
25 #include "tensorflow/cc/experimental/base/public/runtime_builder.h"
26 #include "tensorflow/cc/experimental/base/public/tensor.h"
27 #include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace tensorflow {
32 namespace {
33
34 using tensorflow::experimental::cc::Runtime;
35 using tensorflow::experimental::cc::RuntimeBuilder;
36 using tensorflow::experimental::cc::Status;
37 using tensorflow::experimental::cc::Tensor;
38 using tensorflow::experimental::cc::TensorHandle;
39
40 using SimpleTypes = ::testing::Types<
41 tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
42 tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
43 tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
44
45 template <typename T>
46 class ConstructScalarTensorHandleTest : public ::testing::Test {};
47 TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
48
49 // This test constructs a scalar tensor for each of the types in "SimpleTypes",
50 // then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
51 // verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,ValidTensorAttributesAfterConstruction)52 TYPED_TEST(ConstructScalarTensorHandleTest,
53 ValidTensorAttributesAfterConstruction) {
54 Status status;
55 RuntimeBuilder runtime_builder;
56 std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
57 ASSERT_TRUE(status.ok()) << status.message();
58
59 TF_DataType dtype = TypeParam::kDType;
60 typename TypeParam::type value = 42;
61 Tensor original_tensor =
62 Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
63 /*data=*/&value,
64 /*len=*/sizeof(value),
65 /*deleter=*/[](void*, size_t) {}, &status);
66 ASSERT_TRUE(status.ok()) << status.message();
67
68 TensorHandle handle =
69 TensorHandle::FromTensor(original_tensor, *runtime, &status);
70 ASSERT_TRUE(status.ok()) << status.message();
71
72 Tensor tensor = handle.Resolve(&status);
73 ASSERT_TRUE(status.ok()) << status.message();
74
75 EXPECT_EQ(tensor.dims(), 0);
76 EXPECT_EQ(tensor.dtype(), dtype);
77 EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
78 EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
79 EXPECT_EQ(tensor.num_elements(), 1);
80 }
81
82 template <typename T>
83 class Construct1DTensorHandleTest : public ::testing::Test {};
84 TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
85
86 // This test constructs a 1D tensor for each of the types in "SimpleTypes",
87 // and verifies the expected dimensions, dtype, value, number of bytes, and
88 // number of elements.
TYPED_TEST(Construct1DTensorHandleTest,ValidTensorAttributesAfterConstruction)89 TYPED_TEST(Construct1DTensorHandleTest,
90 ValidTensorAttributesAfterConstruction) {
91 Status status;
92 RuntimeBuilder runtime_builder;
93 std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
94 ASSERT_TRUE(status.ok()) << status.message();
95
96 TF_DataType dtype = TypeParam::kDType;
97 // This is our 1D tensor of varying dtype.
98 std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
99 // Shape is Rank 1 vector.
100 std::vector<int64_t> shape;
101 shape.push_back(value.size());
102
103 Tensor original_tensor = Tensor::FromBuffer(
104 /*dtype=*/dtype, /*shape=*/shape,
105 /*data=*/value.data(),
106 /*len=*/value.size() * sizeof(typename TypeParam::type),
107 /*deleter=*/[](void*, size_t) {}, &status);
108 ASSERT_TRUE(status.ok()) << status.message();
109
110 TensorHandle handle =
111 TensorHandle::FromTensor(original_tensor, *runtime, &status);
112 ASSERT_TRUE(status.ok()) << status.message();
113
114 Tensor tensor = handle.Resolve(&status);
115 ASSERT_TRUE(status.ok()) << status.message();
116
117 EXPECT_EQ(tensor.dims(), 1);
118 EXPECT_EQ(tensor.dtype(), dtype);
119 tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
120 reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
121 EXPECT_EQ(tensor_view[0], 42);
122 EXPECT_EQ(tensor_view[1], 100);
123 EXPECT_EQ(tensor_view[2], 0);
124 EXPECT_EQ(tensor_view[3], 1);
125 EXPECT_EQ(tensor_view[4], 4);
126 EXPECT_EQ(tensor_view[5], 29);
127
128 EXPECT_EQ(tensor.num_bytes(),
129 value.size() * sizeof(typename TypeParam::type));
130 EXPECT_EQ(tensor.num_elements(), value.size());
131 }
132
133 template <typename T>
134 class Construct2DTensorHandleTest : public ::testing::Test {};
135 TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
136
137 // This test constructs a 2D tensor for each of the types in "SimpleTypes",
138 // and verifies the expected dimensions, dtype, value, number of bytes, and
139 // number of elements.
TYPED_TEST(Construct2DTensorHandleTest,ValidTensorAttributesAfterConstruction)140 TYPED_TEST(Construct2DTensorHandleTest,
141 ValidTensorAttributesAfterConstruction) {
142 Status status;
143 RuntimeBuilder runtime_builder;
144 std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
145 ASSERT_TRUE(status.ok()) << status.message();
146
147 TF_DataType dtype = TypeParam::kDType;
148 // This is our 1D tensor of varying dtype.
149 std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
150 // Shape is Rank 2 vector with shape 2 x 3.
151 std::vector<int64_t> shape({2, 3});
152
153 Tensor original_tensor = Tensor::FromBuffer(
154 /*dtype=*/dtype, /*shape=*/shape,
155 /*data=*/value.data(),
156 /*len=*/value.size() * sizeof(typename TypeParam::type),
157 /*deleter=*/[](void*, size_t) {}, &status);
158 ASSERT_TRUE(status.ok()) << status.message();
159
160 TensorHandle handle =
161 TensorHandle::FromTensor(original_tensor, *runtime, &status);
162 ASSERT_TRUE(status.ok()) << status.message();
163
164 Tensor tensor = handle.Resolve(&status);
165 ASSERT_TRUE(status.ok()) << status.message();
166
167 EXPECT_EQ(tensor.dims(), 2);
168 EXPECT_EQ(tensor.dtype(), dtype);
169 tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
170 reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
171 EXPECT_EQ(tensor_view[0], 42);
172 EXPECT_EQ(tensor_view[1], 100);
173 EXPECT_EQ(tensor_view[2], 0);
174 EXPECT_EQ(tensor_view[3], 1);
175 EXPECT_EQ(tensor_view[4], 4);
176 EXPECT_EQ(tensor_view[5], 29);
177
178 EXPECT_EQ(tensor.num_bytes(),
179 value.size() * sizeof(typename TypeParam::type));
180 EXPECT_EQ(tensor.num_elements(), value.size());
181 }
182
183 } // namespace
184 } // namespace tensorflow
185