1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17
18 #include <memory>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/eager/c_api.h"
25 #include "tensorflow/c/eager/c_api_experimental.h"
26 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/c/tf_status_helper.h"
29
30 namespace tensorflow {
31 namespace parallel_device {
32 namespace {
33
34 class OpDeleter {
35 public:
operator ()(TFE_Op * to_delete) const36 void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
37 };
38
39 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
40
41 using MaybeParallelTensorOwned =
42 absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
43
44 using MaybeParallelTensorUnowned =
45 absl::variant<ParallelTensor*, TFE_TensorHandle*>;
46
47 // A ParallelDevice on its own is not registered with a TFE_Context, and so has
48 // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
49 // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
50 // placed on the parallel device.
51 class NamedParallelDevice {
52 public:
NamedParallelDevice(const std::string & name,std::unique_ptr<ParallelDevice> parallel_device)53 NamedParallelDevice(const std::string& name,
54 std::unique_ptr<ParallelDevice> parallel_device)
55 : device_name_(name), parallel_device_(std::move(parallel_device)) {}
name() const56 const std::string& name() const { return device_name_; }
device() const57 const ParallelDevice& device() const { return *parallel_device_; }
58
59 private:
60 std::string device_name_;
61 std::unique_ptr<ParallelDevice> parallel_device_;
62 };
63
ExecuteWithSpecialOps(const ParallelDevice & parallel_device,const std::string & parallel_device_name,TFE_Context * context,std::vector<MaybeParallelTensorUnowned> inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status)64 absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
65 const ParallelDevice& parallel_device,
66 const std::string& parallel_device_name, TFE_Context* context,
67 std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
68 const TFE_OpAttrs* attributes, int expected_max_outputs,
69 TF_Status* status) {
70 absl::optional<std::vector<MaybeParallelTensorOwned>> result;
71 // TODO(allenl): We should remove "TPU" from these op names at the very least,
72 // or consider other ways of packing/unpacking parallel tensors.
73 if (operation_name == std::string("TPUReplicatedInput")) {
74 // Special-cased operation for packing per-device tensors into one parallel
75 // tensor.
76 if (inputs.size() != parallel_device.num_underlying_devices()) {
77 std::string message(absl::StrCat(
78 "The parallel device ", parallel_device_name, " expected ",
79 parallel_device.num_underlying_devices(),
80 " inputs to TPUReplicatedInput, but got ", inputs.size()));
81 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
82 return result;
83 }
84 std::vector<TensorHandlePtr> components;
85 components.reserve(inputs.size());
86 for (int i = 0; i < inputs.size(); ++i) {
87 if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
88 std::string message(absl::StrCat(
89 "Expected all inputs to TPUReplicatedInput to be non-parallel "
90 "TensorHandles. The input ",
91 i,
92 " was a parallel tensor (already "
93 "placed on the parallel device)."));
94 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
95 return result;
96 }
97 components.emplace_back(TFE_TensorHandleCopySharingTensor(
98 absl::get<TFE_TensorHandle*>(inputs[i]), status));
99 }
100 std::vector<MaybeParallelTensorOwned> result_content;
101 result_content.reserve(1);
102 result_content.push_back(ParallelTensor::FromTensorHandles(
103 parallel_device, std::move(components), status));
104 if (TF_GetCode(status) != TF_OK) return result;
105 result.emplace(std::move(result_content));
106 return result;
107 } else if (operation_name == std::string("TPUReplicatedOutput")) {
108 // Special-cased operation for un-packing one parallel tensor into
109 // per-device tensors.
110 OpPtr op(TFE_NewOp(context, operation_name, status));
111 TFE_OpAddAttrs(op.get(), attributes);
112 int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
113 if (TF_GetCode(status) != TF_OK) return result;
114 if (expected_outputs != parallel_device.num_underlying_devices()) {
115 std::string message(absl::StrCat(
116 "The parallel device ", parallel_device_name, " expected ",
117 parallel_device.num_underlying_devices(),
118 " outputs for TPUReplicatedOutput, but got ", expected_outputs));
119 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
120 return result;
121 }
122 if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
123 TF_SetStatus(status, TF_INVALID_ARGUMENT,
124 "Expected the input to "
125 "TPUReplicatedOutput to be a parallel tensor (placed on the "
126 "parallel device).");
127 return result;
128 }
129 ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
130 std::vector<MaybeParallelTensorOwned> outputs;
131 outputs.reserve(t->num_tensors());
132 for (int i = 0; i < t->num_tensors(); ++i) {
133 TensorHandlePtr this_output(
134 TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
135 outputs.emplace_back(std::move(this_output));
136 if (TF_GetCode(status) != TF_OK) return result;
137 }
138 result.emplace(std::move(outputs));
139 return result;
140 }
141 std::vector<ParallelTensor*> parallel_inputs;
142 std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
143 parallel_inputs.reserve(inputs.size());
144 implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
145 for (const auto& input : inputs) {
146 if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
147 // Non-parallel tensors are implicitly broadcast, i.e. set as the input
148 // to each parallel operation.
149 //
150 // TODO(allenl): There may be smarter ways to do this copy in some
151 // cases, i.e. with a collective broadcast. We'll need to be careful
152 // about things that are taken as inputs on the host or on their
153 // existing device (for multi-device functions).
154 std::unique_ptr<ParallelTensor> parallel_tensor(
155 parallel_device.CopyToParallelDevice(
156 context, absl::get<TFE_TensorHandle*>(input), status));
157 if (TF_GetCode(status) != TF_OK) return result;
158 parallel_inputs.push_back(parallel_tensor.get());
159 implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
160 } else {
161 parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
162 }
163 }
164 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
165 maybe_parallel_results(
166 parallel_device.Execute(context, parallel_inputs, operation_name,
167 attributes, expected_max_outputs, status));
168 if (!maybe_parallel_results.has_value()) return result;
169 std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
170 std::move(maybe_parallel_results.value()));
171 std::vector<MaybeParallelTensorOwned> result_content;
172 result_content.reserve(parallel_results.size());
173 for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
174 result_content.push_back(
175 MaybeParallelTensorOwned(std::move(parallel_result)));
176 }
177 result.emplace(std::move(result_content));
178 return result;
179 }
180
181 // Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how
182 // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
183 // reference counts drop to zero.
ParallelTensorDeallocator(void * data,void * arg)184 void ParallelTensorDeallocator(void* data, void* arg) {
185 delete reinterpret_cast<ParallelTensor*>(data);
186 }
187
188 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the
189 // number of dimensions of a parallel tensor.
ParallelTensorNumDims(void * data,void * arg,TF_Status * status)190 int ParallelTensorNumDims(void* data, void* arg, TF_Status* status) {
191 const std::vector<int64_t>* shape;
192 Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
193 if (!s.ok()) {
194 Set_TF_Status_from_Status(status, s);
195 return -1;
196 }
197 return shape->size();
198 }
199
200 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a
201 // dimension of a parallel tensor.
ParallelTensorDim(void * data,int dim_index,void * arg,TF_Status * status)202 int64_t ParallelTensorDim(void* data, int dim_index, void* arg,
203 TF_Status* status) {
204 const std::vector<int64_t>* shape;
205 Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
206 if (!s.ok()) {
207 Set_TF_Status_from_Status(status, s);
208 return -1;
209 }
210 return (*shape)[dim_index];
211 }
212
ParallelTensorToTensorHandle(const std::string & parallel_device_name,TFE_Context * context,std::unique_ptr<ParallelTensor> t,TF_Status * status)213 TensorHandlePtr ParallelTensorToTensorHandle(
214 const std::string& parallel_device_name, TFE_Context* context,
215 std::unique_ptr<ParallelTensor> t, TF_Status* status) {
216 // The resulting TensorHandle owns an opaque pointer to "device memory", which
217 // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
218 // deleted, it will call ParallelTensorDeallocator to free the struct.
219 ParallelTensor* t_released = t.release();
220 return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle(
221 context, parallel_device_name.c_str(), t_released->dtype(), t_released,
222 &ParallelTensorNumDims, &ParallelTensorDim, &ParallelTensorDeallocator,
223 nullptr, status));
224 }
225
226 // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
227 // registration.
228 //
229 // Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
230 // a ParallelTensor with one copy of `tensor` for each device in the
231 // ParallelDevice.
232 //
233 // Since this function is used to satisfy the TFE_CustomDevice C API,
234 // device_info is passed in using a C-style generic. It must always be a
235 // ParallelDevice.
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)236 TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
237 TFE_TensorHandle* tensor,
238 TF_Status* status, void* device_info) {
239 NamedParallelDevice* named_device =
240 reinterpret_cast<NamedParallelDevice*>(device_info);
241 const ParallelDevice& dev = named_device->device();
242 std::unique_ptr<ParallelTensor> parallel_tensor(
243 dev.CopyToParallelDevice(context, tensor, status));
244 if (TF_GetCode(status) != TF_OK) return nullptr;
245 return ParallelTensorToTensorHandle(named_device->name(), context,
246 std::move(parallel_tensor), status)
247 .release();
248 }
249
250 // For TFE_CustomDevice::copy_tensor_from_device in the parallel device
251 // registration.
252 //
253 // Currently this is an error, and un-packing ParallelTensors must be performed
254 // explicitly by running a TPUReplicatedOutput operation on the parallel device.
255 //
256 // TODO(allenl): There are some use-cases that are only supported by copying to
257 // host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
258 // need to return something here or address these use-cases one by one.
CopyTensorFromParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)259 TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
260 TFE_TensorHandle* tensor,
261 const char* target_device_name,
262 TF_Status* status,
263 void* device_info) {
264 TF_SetStatus(status, TF_UNIMPLEMENTED,
265 "Trying to copy a tensor out of a parallel device. Since there "
266 "are multiple components to parallel tensors, they must be "
267 "unpacked explicitly.");
268 return nullptr;
269 }
270
271 // For TFE_CustomDevice::execute in the parallel device registration.
272 //
273 // Since this function is used to satisfy the TFE_CustomDevice C API,
274 // device_info is passed in using a C-style generic. It must always be a
275 // ParallelDevice.
ParallelDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)276 void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
277 TFE_TensorHandle** outputs, TF_Status* status,
278 void* device_info) {
279 const char* requested_placement = TFE_OpGetDevice(original_op, status);
280 if (*requested_placement == '\0') {
281 TF_SetStatus(
282 status, TF_INTERNAL,
283 "Ops must be placed on the parallel device explicitly, or their inputs "
284 "first un-packed. Got an un-placed op with an input placed on the "
285 "parallel device.");
286 return;
287 }
288 TFE_Context* context = TFE_OpGetContext(original_op, status);
289 if (TF_GetCode(status) != TF_OK) return;
290 const char* operation_name = TFE_OpGetName(original_op, status);
291 if (TF_GetCode(status) != TF_OK) return;
292 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
293
294 NamedParallelDevice* named_device =
295 reinterpret_cast<NamedParallelDevice*>(device_info);
296 std::vector<MaybeParallelTensorUnowned> typed_inputs;
297 int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
298 if (TF_GetCode(status) != TF_OK) return;
299 typed_inputs.reserve(num_inputs);
300 for (int i = 0; i < num_inputs; ++i) {
301 TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
302 if (TF_GetCode(status) != TF_OK) return;
303 const char* tensor_handle_device =
304 TFE_TensorHandleDeviceName(input, status);
305 if (TF_GetCode(status) != TF_OK) return;
306 if (named_device->name() == tensor_handle_device) {
307 // We assume that any tensors already placed on this device are
308 // ParallelTensors.
309 typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
310 TFE_TensorHandleDevicePointer(input, status)));
311 if (TF_GetCode(status) != TF_OK) return;
312 } else {
313 typed_inputs.emplace_back(input);
314 }
315 }
316
317 absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
318 ExecuteWithSpecialOps(named_device->device(), named_device->name(),
319 context, std::move(typed_inputs), operation_name,
320 attributes, *num_outputs, status));
321 if (TF_GetCode(status) != TF_OK) return;
322 if (!maybe_typed_outputs.has_value()) {
323 TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
324 return;
325 }
326
327 std::vector<MaybeParallelTensorOwned> typed_outputs(
328 std::move(maybe_typed_outputs.value()));
329
330 if (typed_outputs.size() > *num_outputs) {
331 TF_SetStatus(status, TF_INTERNAL,
332 "The allocated output buffer was too small.");
333 return;
334 }
335
336 for (int i = 0; i < typed_outputs.size(); ++i) {
337 MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
338 if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
339 outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
340 } else {
341 outputs[i] = ParallelTensorToTensorHandle(
342 named_device->name(), context,
343 std::move(absl::get<std::unique_ptr<ParallelTensor>>(
344 typed_output)),
345 status)
346 .release();
347 if (TF_GetCode(status) != TF_OK) return;
348 }
349 }
350 *num_outputs = typed_outputs.size();
351 }
352
353 // For TFE_CustomDevice::delete_device in the parallel device registration.
354 //
355 // Since this function is used to satisfy the TFE_CustomDevice C API,
356 // device_info is passed in using a C-style generic. It must always be a
357 // ParallelDevice.
DeleteParallelDevice(void * device_info)358 void DeleteParallelDevice(void* device_info) {
359 delete reinterpret_cast<NamedParallelDevice*>(device_info);
360 }
361
362 } // namespace
363
AllocateParallelDevice(const char * device_name,const char * const * underlying_devices,int num_underlying_devices,TFE_CustomDevice * device,void ** device_info)364 void AllocateParallelDevice(const char* device_name,
365 const char* const* underlying_devices,
366 int num_underlying_devices,
367 TFE_CustomDevice* device, void** device_info) {
368 device->copy_tensor_to_device = &CopyToParallelDevice;
369 device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
370 device->delete_device = &DeleteParallelDevice;
371 device->execute = &ParallelDeviceExecute;
372 std::vector<std::string> underlying_devices_vector;
373 underlying_devices_vector.reserve(num_underlying_devices);
374 for (int device_index = 0; device_index < num_underlying_devices;
375 ++device_index) {
376 underlying_devices_vector.push_back(underlying_devices[device_index]);
377 }
378 std::unique_ptr<ParallelDevice> parallel_device(
379 new ParallelDevice(underlying_devices_vector));
380 *device_info =
381 new NamedParallelDevice{device_name, std::move(parallel_device)};
382 }
383 } // namespace parallel_device
384 } // namespace tensorflow
385