1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_test_util.h"
17
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21
22 using tensorflow::string;
23
TestScalarTensorHandle(float value)24 TFE_TensorHandle* TestScalarTensorHandle(float value) {
25 float data[] = {value};
26 TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
27 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
28 TF_Status* status = TF_NewStatus();
29 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
30 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
31 TF_DeleteTensor(t);
32 TF_DeleteStatus(status);
33 return th;
34 }
35
TestScalarTensorHandle(int value)36 TFE_TensorHandle* TestScalarTensorHandle(int value) {
37 int data[] = {value};
38 TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
39 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
40 TF_Status* status = TF_NewStatus();
41 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
42 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
43 TF_DeleteTensor(t);
44 TF_DeleteStatus(status);
45 return th;
46 }
47
TestScalarTensorHandle(bool value)48 TFE_TensorHandle* TestScalarTensorHandle(bool value) {
49 bool data[] = {value};
50 TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
51 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
52 TF_Status* status = TF_NewStatus();
53 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
54 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
55 TF_DeleteTensor(t);
56 TF_DeleteStatus(status);
57 return th;
58 }
59
DoubleTestMatrixTensorHandle()60 TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
61 int64_t dims[] = {2, 2};
62 double data[] = {1.0, 2.0, 3.0, 4.0};
63 TF_Tensor* t = TF_AllocateTensor(
64 TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
65 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
66 TF_Status* status = TF_NewStatus();
67 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
68 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
69 TF_DeleteTensor(t);
70 TF_DeleteStatus(status);
71 return th;
72 }
73
TestMatrixTensorHandle()74 TFE_TensorHandle* TestMatrixTensorHandle() {
75 int64_t dims[] = {2, 2};
76 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
77 TF_Tensor* t = TF_AllocateTensor(
78 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
79 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
80 TF_Status* status = TF_NewStatus();
81 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
82 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83 TF_DeleteTensor(t);
84 TF_DeleteStatus(status);
85 return th;
86 }
87
TestMatrixTensorHandle100x100()88 TFE_TensorHandle* TestMatrixTensorHandle100x100() {
89 constexpr int64_t dims[] = {100, 100};
90 constexpr int num_elements = dims[0] * dims[1];
91 float data[num_elements];
92 for (int i = 0; i < num_elements; ++i) {
93 data[i] = 1.0f;
94 }
95 TF_Tensor* t = TF_AllocateTensor(
96 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
97 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
98 TF_Status* status = TF_NewStatus();
99 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
100 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
101 TF_DeleteTensor(t);
102 TF_DeleteStatus(status);
103 return th;
104 }
105
DoubleTestMatrixTensorHandle3X2()106 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
107 int64_t dims[] = {3, 2};
108 double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
109 TF_Tensor* t = TF_AllocateTensor(
110 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
111 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
112 TF_Status* status = TF_NewStatus();
113 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
114 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
115 TF_DeleteTensor(t);
116 TF_DeleteStatus(status);
117 return th;
118 }
119
TestMatrixTensorHandle3X2()120 TFE_TensorHandle* TestMatrixTensorHandle3X2() {
121 int64_t dims[] = {3, 2};
122 float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
123 TF_Tensor* t = TF_AllocateTensor(
124 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
125 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
126 TF_Status* status = TF_NewStatus();
127 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
128 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
129 TF_DeleteTensor(t);
130 TF_DeleteStatus(status);
131 return th;
132 }
133
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)134 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
135 TF_Status* status = TF_NewStatus();
136
137 TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
138 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
139 TFE_OpAddInput(op, a, status);
140 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141 TFE_OpAddInput(op, b, status);
142 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
143 TF_DeleteStatus(status);
144 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
145
146 return op;
147 }
148
IdentityOp(TFE_Context * ctx,TFE_TensorHandle * a)149 TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
150 TF_Status* status = TF_NewStatus();
151
152 TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
153 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
154 TFE_OpAddInput(op, a, status);
155 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
156 TF_DeleteStatus(status);
157 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
158
159 return op;
160 }
161
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)162 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
163 TF_Status* status = TF_NewStatus();
164
165 TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
166 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
167 TFE_OpAddInput(op, a, status);
168 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
169 TF_DeleteStatus(status);
170 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
171
172 return op;
173 }
174
TestAxisTensorHandle()175 TFE_TensorHandle* TestAxisTensorHandle() {
176 int64_t dims[] = {1};
177 int data[] = {1};
178 TF_Tensor* t = TF_AllocateTensor(
179 TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
180 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
181 TF_Status* status = TF_NewStatus();
182 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
183 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
184 TF_DeleteTensor(t);
185 TF_DeleteStatus(status);
186 return th;
187 }
188
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)189 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
190 TFE_TensorHandle* axis) {
191 TF_Status* status = TF_NewStatus();
192
193 TFE_Op* op = TFE_NewOp(ctx, "Min", status);
194 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
195 TFE_OpAddInput(op, input, status);
196 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
197 TFE_OpAddInput(op, axis, status);
198 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
199 TFE_OpSetAttrBool(op, "keep_dims", 1);
200 TFE_OpSetAttrType(op, "Tidx", TF_INT32);
201 TF_DeleteStatus(status);
202 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
203
204 return op;
205 }
206
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)207 bool GetDeviceName(TFE_Context* ctx, string* device_name,
208 const char* device_type) {
209 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
210 TF_NewStatus(), TF_DeleteStatus);
211 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
212 CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
213
214 const int num_devices = TF_DeviceListCount(devices);
215 for (int i = 0; i < num_devices; ++i) {
216 const string dev_type(TF_DeviceListType(devices, i, status.get()));
217 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
218 const string dev_name(TF_DeviceListName(devices, i, status.get()));
219 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
220 if (dev_type == device_type) {
221 *device_name = dev_name;
222 LOG(INFO) << "Found " << device_type << " device " << *device_name;
223 TF_DeleteDeviceList(devices);
224 return true;
225 }
226 }
227 TF_DeleteDeviceList(devices);
228 return false;
229 }
230