• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_test_util.h"
17 
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21 
22 using tensorflow::string;
23 
TestScalarTensorHandle(float value)24 TFE_TensorHandle* TestScalarTensorHandle(float value) {
25   float data[] = {value};
26   TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
27   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
28   TF_Status* status = TF_NewStatus();
29   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
30   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
31   TF_DeleteTensor(t);
32   TF_DeleteStatus(status);
33   return th;
34 }
35 
TestScalarTensorHandle(int value)36 TFE_TensorHandle* TestScalarTensorHandle(int value) {
37   int data[] = {value};
38   TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
39   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
40   TF_Status* status = TF_NewStatus();
41   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
42   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
43   TF_DeleteTensor(t);
44   TF_DeleteStatus(status);
45   return th;
46 }
47 
TestScalarTensorHandle(bool value)48 TFE_TensorHandle* TestScalarTensorHandle(bool value) {
49   bool data[] = {value};
50   TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
51   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
52   TF_Status* status = TF_NewStatus();
53   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
54   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
55   TF_DeleteTensor(t);
56   TF_DeleteStatus(status);
57   return th;
58 }
59 
DoubleTestMatrixTensorHandle()60 TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
61   int64_t dims[] = {2, 2};
62   double data[] = {1.0, 2.0, 3.0, 4.0};
63   TF_Tensor* t = TF_AllocateTensor(
64       TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
65   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
66   TF_Status* status = TF_NewStatus();
67   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
68   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
69   TF_DeleteTensor(t);
70   TF_DeleteStatus(status);
71   return th;
72 }
73 
TestMatrixTensorHandle()74 TFE_TensorHandle* TestMatrixTensorHandle() {
75   int64_t dims[] = {2, 2};
76   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
77   TF_Tensor* t = TF_AllocateTensor(
78       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
79   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
80   TF_Status* status = TF_NewStatus();
81   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
82   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83   TF_DeleteTensor(t);
84   TF_DeleteStatus(status);
85   return th;
86 }
87 
TestMatrixTensorHandle100x100()88 TFE_TensorHandle* TestMatrixTensorHandle100x100() {
89   constexpr int64_t dims[] = {100, 100};
90   constexpr int num_elements = dims[0] * dims[1];
91   float data[num_elements];
92   for (int i = 0; i < num_elements; ++i) {
93     data[i] = 1.0f;
94   }
95   TF_Tensor* t = TF_AllocateTensor(
96       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
97   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
98   TF_Status* status = TF_NewStatus();
99   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
100   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
101   TF_DeleteTensor(t);
102   TF_DeleteStatus(status);
103   return th;
104 }
105 
DoubleTestMatrixTensorHandle3X2()106 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
107   int64_t dims[] = {3, 2};
108   double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
109   TF_Tensor* t = TF_AllocateTensor(
110       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
111   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
112   TF_Status* status = TF_NewStatus();
113   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
114   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
115   TF_DeleteTensor(t);
116   TF_DeleteStatus(status);
117   return th;
118 }
119 
TestMatrixTensorHandle3X2()120 TFE_TensorHandle* TestMatrixTensorHandle3X2() {
121   int64_t dims[] = {3, 2};
122   float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
123   TF_Tensor* t = TF_AllocateTensor(
124       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
125   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
126   TF_Status* status = TF_NewStatus();
127   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
128   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
129   TF_DeleteTensor(t);
130   TF_DeleteStatus(status);
131   return th;
132 }
133 
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)134 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
135   TF_Status* status = TF_NewStatus();
136 
137   TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
138   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
139   TFE_OpAddInput(op, a, status);
140   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141   TFE_OpAddInput(op, b, status);
142   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
143   TF_DeleteStatus(status);
144   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
145 
146   return op;
147 }
148 
IdentityOp(TFE_Context * ctx,TFE_TensorHandle * a)149 TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
150   TF_Status* status = TF_NewStatus();
151 
152   TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
153   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
154   TFE_OpAddInput(op, a, status);
155   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
156   TF_DeleteStatus(status);
157   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
158 
159   return op;
160 }
161 
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)162 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
163   TF_Status* status = TF_NewStatus();
164 
165   TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
166   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
167   TFE_OpAddInput(op, a, status);
168   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
169   TF_DeleteStatus(status);
170   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
171 
172   return op;
173 }
174 
TestAxisTensorHandle()175 TFE_TensorHandle* TestAxisTensorHandle() {
176   int64_t dims[] = {1};
177   int data[] = {1};
178   TF_Tensor* t = TF_AllocateTensor(
179       TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
180   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
181   TF_Status* status = TF_NewStatus();
182   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
183   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
184   TF_DeleteTensor(t);
185   TF_DeleteStatus(status);
186   return th;
187 }
188 
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)189 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
190               TFE_TensorHandle* axis) {
191   TF_Status* status = TF_NewStatus();
192 
193   TFE_Op* op = TFE_NewOp(ctx, "Min", status);
194   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
195   TFE_OpAddInput(op, input, status);
196   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
197   TFE_OpAddInput(op, axis, status);
198   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
199   TFE_OpSetAttrBool(op, "keep_dims", 1);
200   TFE_OpSetAttrType(op, "Tidx", TF_INT32);
201   TF_DeleteStatus(status);
202   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
203 
204   return op;
205 }
206 
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)207 bool GetDeviceName(TFE_Context* ctx, string* device_name,
208                    const char* device_type) {
209   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
210       TF_NewStatus(), TF_DeleteStatus);
211   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
212   CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
213 
214   const int num_devices = TF_DeviceListCount(devices);
215   for (int i = 0; i < num_devices; ++i) {
216     const string dev_type(TF_DeviceListType(devices, i, status.get()));
217     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
218     const string dev_name(TF_DeviceListName(devices, i, status.get()));
219     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
220     if (dev_type == device_type) {
221       *device_name = dev_name;
222       LOG(INFO) << "Found " << device_type << " device " << *device_name;
223       TF_DeleteDeviceList(devices);
224       return true;
225     }
226   }
227   TF_DeleteDeviceList(devices);
228   return false;
229 }
230