1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
17
18 #include <array>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/core/platform/test.h"
25
26 // NOTE(allenl): These tests currently go through TFE_Execute and so are
27 // integration testing rather than purely testing the parallel device. They
28 // correspond fairly well to the implementation, but testing the C++ directly is
29 // another option.
30
31 namespace tensorflow {
32 namespace parallel_device {
33
Create(TFE_Context * context,TF_DataType type,const int64_t * dims,const int num_dims,const char * device,TF_Status * status)34 Variable* Variable::Create(TFE_Context* context, TF_DataType type,
35 const int64_t* dims, const int num_dims,
36 const char* device, TF_Status* status) {
37 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
38 TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
39 if (TF_GetCode(status) != TF_OK) return nullptr;
40 TFE_OpSetAttrType(op.get(), "dtype", type);
41 TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
42 TFE_OpSetAttrString(op.get(), "container", "", 0);
43 // Use the special GUID for no buffer sharing
44 //
45 // TODO(allenl): Should we provide a better API for this? AFAIK this is the
46 // only reasonable way to make variables with no aliasing using the eager C
47 // API.
48 std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
49 TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
50 no_sharing.length());
51 TFE_OpSetDevice(op.get(), device, status);
52 if (TF_GetCode(status) != TF_OK) return nullptr;
53 TFE_TensorHandle* var_handle = nullptr;
54 int num_retvals = 1;
55 TFE_Execute(op.get(), &var_handle, &num_retvals, status);
56 if (TF_GetCode(status) != TF_OK) return nullptr;
57 return new Variable(var_handle, type);
58 }
59
Destroy(TFE_Context * context,TF_Status * status)60 void Variable::Destroy(TFE_Context* context, TF_Status* status) {
61 // Free the backing buffer for the variable.
62 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
63 TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
64 if (TF_GetCode(status) != TF_OK) return;
65 TFE_OpAddInput(op.get(), handle_, status);
66 if (TF_GetCode(status) != TF_OK) return;
67 const char* device = TFE_TensorHandleDeviceName(handle_, status);
68 if (TF_GetCode(status) != TF_OK) return;
69 TFE_OpSetDevice(op.get(), device, status);
70 if (TF_GetCode(status) != TF_OK) return;
71 int num_retvals = 0;
72 TFE_Execute(op.get(), nullptr, &num_retvals, status);
73 if (TF_GetCode(status) != TF_OK) return;
74 // Delete the variable handle itself.
75 TFE_DeleteTensorHandle(handle_);
76 }
77
Read(TFE_Context * context,TF_Status * status)78 TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
79 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
80 TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
81 if (TF_GetCode(status) != TF_OK) return nullptr;
82 TFE_OpAddInput(op.get(), handle_, status);
83 if (TF_GetCode(status) != TF_OK) return nullptr;
84 const char* device = TFE_TensorHandleDeviceName(handle_, status);
85 if (TF_GetCode(status) != TF_OK) return nullptr;
86 TFE_OpSetDevice(op.get(), device, status);
87 if (TF_GetCode(status) != TF_OK) return nullptr;
88 TFE_OpSetAttrType(op.get(), "dtype", type_);
89 int num_retvals = 1;
90 TFE_TensorHandle* var_value = nullptr;
91 TFE_Execute(op.get(), &var_value, &num_retvals, status);
92 if (TF_GetCode(status) != TF_OK) return nullptr;
93 return TensorHandlePtr(var_value);
94 }
95
GeneralAssignment(const char * op_name,TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)96 void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
97 TFE_TensorHandle* value, TF_Status* status) {
98 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
99 TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
100 if (TF_GetCode(status) != TF_OK) return;
101 TFE_OpSetAttrType(op.get(), "dtype", type_);
102 TFE_OpAddInput(op.get(), handle_, status);
103 if (TF_GetCode(status) != TF_OK) return;
104 TFE_OpAddInput(op.get(), value, status);
105 if (TF_GetCode(status) != TF_OK) return;
106 const char* device = TFE_TensorHandleDeviceName(handle_, status);
107 if (TF_GetCode(status) != TF_OK) return;
108 TFE_OpSetDevice(op.get(), device, status);
109
110 int num_retvals = 0;
111 TFE_Execute(op.get(), nullptr, &num_retvals, status);
112 if (TF_GetCode(status) != TF_OK) return;
113 }
114
AssignAdd(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)115 void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
116 TF_Status* status) {
117 GeneralAssignment("AssignAddVariableOp", context, value, status);
118 }
119
Assign(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)120 void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
121 TF_Status* status) {
122 GeneralAssignment("AssignVariableOp", context, value, status);
123 }
124
125 // Passed to `TF_NewTensor` to indicate how an array of floats should be
126 // deleted.
FloatDeallocator(void * data,size_t,void * arg)127 static void FloatDeallocator(void* data, size_t, void* arg) {
128 delete[] static_cast<float*>(data);
129 }
130
131 // Creates a TFE_TensorHandle with value `v`.
FloatTensorHandle(float v,TF_Status * status)132 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
133 const int num_bytes = sizeof(float);
134 float* values = new float[1];
135 values[0] = v;
136 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
137 TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
138 nullptr),
139 TF_DeleteTensor);
140 return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
141 }
142
143 // Creates a rank-one TFE_TensorHandle with value `v`.
VectorFloatTensorHandle(const std::vector<float> & v,TF_Status * status)144 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
145 TF_Status* status) {
146 const int num_bytes = v.size() * sizeof(float);
147 float* values = new float[v.size()];
148 memcpy(values, v.data(), num_bytes);
149 int64_t dims = v.size();
150 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
151 TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
152 &FloatDeallocator, nullptr),
153 TF_DeleteTensor);
154 return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
155 }
156
157 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
158 template <std::size_t num_replicas>
ExtractPerDeviceValues(TFE_Context * context,TFE_TensorHandle * input,std::array<TensorHandlePtr,num_replicas> * components,TF_Status * status)159 void ExtractPerDeviceValues(
160 TFE_Context* context, TFE_TensorHandle* input,
161 std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
162 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
163 TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
164 if (TF_GetCode(status) != TF_OK) return;
165 TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
166 TFE_OpAddInput(op.get(), input, status);
167 if (TF_GetCode(status) != TF_OK) return;
168 const char* device = TFE_TensorHandleDeviceName(input, status);
169 if (TF_GetCode(status) != TF_OK) return;
170 TFE_OpSetDevice(op.get(), device, status);
171 if (TF_GetCode(status) != TF_OK) return;
172
173 TFE_TensorHandle* result_handles[num_replicas];
174 int num_retvals = num_replicas;
175 TFE_Execute(op.get(), result_handles, &num_retvals, status);
176 if (TF_GetCode(status) != TF_OK) return;
177 for (int i = 0; i < num_replicas; ++i) {
178 (*components)[i].reset(result_handles[i]);
179 }
180 }
181
Multiply(TFE_Context * context,TFE_TensorHandle * first,TFE_TensorHandle * second,TF_Status * status)182 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
183 TFE_TensorHandle* second, TF_Status* status) {
184 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
185 TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
186 if (TF_GetCode(status) != TF_OK) return nullptr;
187 TFE_OpAddInput(op.get(), first, status);
188 if (TF_GetCode(status) != TF_OK) return nullptr;
189 TFE_OpAddInput(op.get(), second, status);
190 if (TF_GetCode(status) != TF_OK) return nullptr;
191 const char* first_device = TFE_TensorHandleDeviceName(first, status);
192 if (TF_GetCode(status) != TF_OK) return nullptr;
193 TFE_OpSetDevice(op.get(), first_device, status);
194
195 TFE_TensorHandle* result_handle;
196 int num_retvals = 1;
197 TFE_Execute(op.get(), &result_handle, &num_retvals, status);
198 if (TF_GetCode(status) != TF_OK) return nullptr;
199 return TensorHandlePtr(result_handle);
200 }
201
202 // Create and modify a variable placed on a parallel device which composes
203 // `first_device` and `second_device`.
BasicTestsForTwoDevices(TFE_Context * context,const char * first_device,const char * second_device)204 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
205 const char* second_device) {
206 // Register the custom device
207 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
208 TF_NewStatus(), TF_DeleteStatus);
209 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
210 std::array<const char*, 2> underlying_devices{first_device, second_device};
211 RegisterParallelDevice(context, device_name, underlying_devices,
212 status.get());
213 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
214
215 // Create a variable handle (uninitialized to start) placed on the parallel
216 // device.
217 std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
218 to_delete->Destroy(context, status.get());
219 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
220 delete to_delete;
221 };
222 std::unique_ptr<Variable, decltype(variable_deleter)> variable(
223 Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
224 status.get()),
225 variable_deleter);
226 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227
228 // Assign an initial value to the variable, implicitly mirroring it to each
229 // component device.
230 {
231 TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
232 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
233
234 variable->Assign(context, initial_value.get(), status.get());
235 }
236
237 // Read from the variable and verify that we have a parallel tensor.
238 {
239 TensorHandlePtr read = variable->Read(context, status.get());
240 std::array<TensorHandlePtr, 2> components;
241 ExtractPerDeviceValues(context, read.get(), &components, status.get());
242 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
243
244 ExpectScalarEq<float>(components[0].get(), 20.);
245 ExpectScalarEq<float>(components[1].get(), 20.);
246
247 std::string first_device =
248 TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
249 ASSERT_EQ(underlying_devices[0], first_device);
250 std::string second_device =
251 TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
252 ASSERT_EQ(underlying_devices[1], second_device);
253 }
254
255 // Add a parallel tensor with different values on each device to the variable.
256 {
257 TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
258 TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
259 std::array<TFE_TensorHandle*, 2> components{value_one.get(),
260 value_two.get()};
261 TensorHandlePtr combined_value =
262 CreatePerDeviceValues(context, components, device_name, status.get());
263 variable->AssignAdd(context, combined_value.get(), status.get());
264 }
265
266 // Read the variable and verify that each component has the right modified
267 // value.
268 {
269 TensorHandlePtr read = variable->Read(context, status.get());
270 std::array<TensorHandlePtr, 2> components;
271 ExtractPerDeviceValues(context, read.get(), &components, status.get());
272 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
273
274 ExpectScalarEq<float>(components[0].get(), 23.);
275 ExpectScalarEq<float>(components[1].get(), 18.);
276
277 std::string first_device =
278 TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
279 ASSERT_EQ(underlying_devices[0], first_device);
280 std::string second_device =
281 TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
282 ASSERT_EQ(underlying_devices[1], second_device);
283 }
284 }
285
286 } // namespace parallel_device
287 } // namespace tensorflow
288