• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/experimental/base/public/tensorhandle.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <memory>
22 
23 #include "tensorflow/c/tf_datatype.h"
24 #include "tensorflow/cc/experimental/base/public/runtime.h"
25 #include "tensorflow/cc/experimental/base/public/runtime_builder.h"
26 #include "tensorflow/cc/experimental/base/public/tensor.h"
27 #include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 using tensorflow::experimental::cc::Runtime;
35 using tensorflow::experimental::cc::RuntimeBuilder;
36 using tensorflow::experimental::cc::Status;
37 using tensorflow::experimental::cc::Tensor;
38 using tensorflow::experimental::cc::TensorHandle;
39 
40 using SimpleTypes = ::testing::Types<
41     tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
42     tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
43     tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
44 
45 template <typename T>
46 class ConstructScalarTensorHandleTest : public ::testing::Test {};
47 TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
48 
49 // This test constructs a scalar tensor for each of the types in "SimpleTypes",
50 // then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
51 // verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,ValidTensorAttributesAfterConstruction)52 TYPED_TEST(ConstructScalarTensorHandleTest,
53            ValidTensorAttributesAfterConstruction) {
54   Status status;
55   RuntimeBuilder runtime_builder;
56   std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
57   ASSERT_TRUE(status.ok()) << status.message();
58 
59   TF_DataType dtype = TypeParam::kDType;
60   typename TypeParam::type value = 42;
61   Tensor original_tensor =
62       Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
63                          /*data=*/&value,
64                          /*len=*/sizeof(value),
65                          /*deleter=*/[](void*, size_t) {}, &status);
66   ASSERT_TRUE(status.ok()) << status.message();
67 
68   TensorHandle handle =
69       TensorHandle::FromTensor(original_tensor, *runtime, &status);
70   ASSERT_TRUE(status.ok()) << status.message();
71 
72   Tensor tensor = handle.Resolve(&status);
73   ASSERT_TRUE(status.ok()) << status.message();
74 
75   EXPECT_EQ(tensor.dims(), 0);
76   EXPECT_EQ(tensor.dtype(), dtype);
77   EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
78   EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
79   EXPECT_EQ(tensor.num_elements(), 1);
80 }
81 
82 template <typename T>
83 class Construct1DTensorHandleTest : public ::testing::Test {};
84 TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
85 
86 // This test constructs a 1D tensor for each of the types in "SimpleTypes",
87 // and verifies the expected dimensions, dtype, value, number of bytes, and
88 // number of elements.
TYPED_TEST(Construct1DTensorHandleTest,ValidTensorAttributesAfterConstruction)89 TYPED_TEST(Construct1DTensorHandleTest,
90            ValidTensorAttributesAfterConstruction) {
91   Status status;
92   RuntimeBuilder runtime_builder;
93   std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
94   ASSERT_TRUE(status.ok()) << status.message();
95 
96   TF_DataType dtype = TypeParam::kDType;
97   // This is our 1D tensor of varying dtype.
98   std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
99   // Shape is Rank 1 vector.
100   std::vector<int64_t> shape;
101   shape.push_back(value.size());
102 
103   Tensor original_tensor = Tensor::FromBuffer(
104       /*dtype=*/dtype, /*shape=*/shape,
105       /*data=*/value.data(),
106       /*len=*/value.size() * sizeof(typename TypeParam::type),
107       /*deleter=*/[](void*, size_t) {}, &status);
108   ASSERT_TRUE(status.ok()) << status.message();
109 
110   TensorHandle handle =
111       TensorHandle::FromTensor(original_tensor, *runtime, &status);
112   ASSERT_TRUE(status.ok()) << status.message();
113 
114   Tensor tensor = handle.Resolve(&status);
115   ASSERT_TRUE(status.ok()) << status.message();
116 
117   EXPECT_EQ(tensor.dims(), 1);
118   EXPECT_EQ(tensor.dtype(), dtype);
119   tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
120       reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
121   EXPECT_EQ(tensor_view[0], 42);
122   EXPECT_EQ(tensor_view[1], 100);
123   EXPECT_EQ(tensor_view[2], 0);
124   EXPECT_EQ(tensor_view[3], 1);
125   EXPECT_EQ(tensor_view[4], 4);
126   EXPECT_EQ(tensor_view[5], 29);
127 
128   EXPECT_EQ(tensor.num_bytes(),
129             value.size() * sizeof(typename TypeParam::type));
130   EXPECT_EQ(tensor.num_elements(), value.size());
131 }
132 
133 template <typename T>
134 class Construct2DTensorHandleTest : public ::testing::Test {};
135 TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
136 
137 // This test constructs a 2D tensor for each of the types in "SimpleTypes",
138 // and verifies the expected dimensions, dtype, value, number of bytes, and
139 // number of elements.
TYPED_TEST(Construct2DTensorHandleTest,ValidTensorAttributesAfterConstruction)140 TYPED_TEST(Construct2DTensorHandleTest,
141            ValidTensorAttributesAfterConstruction) {
142   Status status;
143   RuntimeBuilder runtime_builder;
144   std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
145   ASSERT_TRUE(status.ok()) << status.message();
146 
147   TF_DataType dtype = TypeParam::kDType;
148   // This is our 1D tensor of varying dtype.
149   std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
150   // Shape is Rank 2 vector with shape 2 x 3.
151   std::vector<int64_t> shape({2, 3});
152 
153   Tensor original_tensor = Tensor::FromBuffer(
154       /*dtype=*/dtype, /*shape=*/shape,
155       /*data=*/value.data(),
156       /*len=*/value.size() * sizeof(typename TypeParam::type),
157       /*deleter=*/[](void*, size_t) {}, &status);
158   ASSERT_TRUE(status.ok()) << status.message();
159 
160   TensorHandle handle =
161       TensorHandle::FromTensor(original_tensor, *runtime, &status);
162   ASSERT_TRUE(status.ok()) << status.message();
163 
164   Tensor tensor = handle.Resolve(&status);
165   ASSERT_TRUE(status.ok()) << status.message();
166 
167   EXPECT_EQ(tensor.dims(), 2);
168   EXPECT_EQ(tensor.dtype(), dtype);
169   tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
170       reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
171   EXPECT_EQ(tensor_view[0], 42);
172   EXPECT_EQ(tensor_view[1], 100);
173   EXPECT_EQ(tensor_view[2], 0);
174   EXPECT_EQ(tensor_view[3], 1);
175   EXPECT_EQ(tensor_view[4], 4);
176   EXPECT_EQ(tensor_view[5], 29);
177 
178   EXPECT_EQ(tensor.num_bytes(),
179             value.size() * sizeof(typename TypeParam::type));
180   EXPECT_EQ(tensor.num_elements(), value.size());
181 }
182 
183 }  // namespace
184 }  // namespace tensorflow
185