1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
23 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/c/tf_status_internal.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30
31 namespace tensorflow {
32 namespace parallel_device {
33 namespace {
34
35 class OpDeleter {
36 public:
operator ()(TFE_Op * to_delete) const37 void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
38 };
39
40 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
41
42 class StatusDeleter {
43 public:
operator ()(TF_Status * to_delete) const44 void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
45 };
46
47 using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
48
49 class ExecutorDeleter {
50 public:
operator ()(TFE_Executor * to_delete) const51 void operator()(TFE_Executor* to_delete) const {
52 TFE_DeleteExecutor(to_delete);
53 }
54 };
55
56 using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
57
58 } // namespace
59
60 // Allows a single op at a time to be launched without blocking.
61 //
62 // DeviceThread itself is thread-safe, in that StartExecute will block if there
63 // is a pending execution. Since StartExecute is equivalent to grabbing a lock,
64 // multiple DeviceThreads should always be accessed in the same order to avoid
65 // deadlocks.
66 class DeviceThread {
67 public:
68 // Starts a background thread waiting for `StartExecute`.
DeviceThread(const std::string & device,const bool is_async)69 explicit DeviceThread(const std::string& device, const bool is_async)
70 : status_(TF_NewStatus()),
71 // If the context's default exector is set to async, re-using that in
72 // each thread would cause collectives to deadlock. For consistency we
73 // create a new sync executor for every thread.
74 //
75 // TODO(allenl): We should have an async API that works with the
76 // parallel device.
77 device_(device),
78 executor_(TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true)),
79 op_(nullptr),
80 thread_(tensorflow::Env::Default()->StartThread(
81 tensorflow::ThreadOptions(), "parallel_device_execute",
82 std::bind(&DeviceThread::Run, this))) {}
83 ~DeviceThread();
84
85 // Requests that the worker thread execute the specified operation. Blocks
86 // until the previously pending operation (a StartExecute without a Join) has
87 // finished, if any.
88 //
89 // `cancellation_manager` must live until after `Join` finishes and pending
90 // `is_async` operations finish. In addition to allowing the caller to cancel
91 // the operation, its `StartCancel` method will be called if op execution
92 // fails on any device in order to cancel the others.
93 void StartExecute(TFE_Context* context, const char* operation_name,
94 std::vector<TFE_TensorHandle*> inputs,
95 const TFE_OpAttrs* attributes, int expected_max_outputs,
96 CancellationManager& cancellation_manager,
97 absl::optional<int64_t> step_id = absl::nullopt);
98 // Block until the previous `StartExecute` operation has executed. Forwards
99 // the status from `TFE_Execute` and returns outputs if the status is OK.
100 std::vector<TensorHandlePtr> Join(TF_Status* status);
101
102 // Block until all Ops finished running on the thread.
103 void AsyncWait(TF_Status* status);
104
105 private:
106 void Run();
107
108 void Execute(TFE_Context* context, const char* operation_name,
109 std::vector<TFE_TensorHandle*> inputs,
110 const TFE_OpAttrs* attributes, int expected_max_outputs,
111 std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
112 TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
113
114 enum class ExecutionState {
115 kReadyToExecute,
116 kHasResult,
117 kIdle,
118 kShuttingDown,
119 };
120
121 tensorflow::mutex execution_mutex_;
122 ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
123 ExecutionState::kIdle;
124 // Tells the worker thread that there is new work.
125 tensorflow::condition_variable start_execute_;
126 // The worker thread notifies that work has finished.
127 tensorflow::condition_variable finished_execute_;
128 // Notifies a StartExecute that the previous Join has finished.
129 tensorflow::condition_variable finished_join_;
130
131 // Temporary state between `StartExecute` and `Join`.
132 //
133 // Inputs; pointers are to objects not owned by the DeviceThread, but which
134 // are expected to live at least until `Join` finishes:
135 TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
136 const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
137 absl::optional<int64_t> step_id_ TF_GUARDED_BY(execution_mutex_) =
138 absl::nullopt;
139 std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
140 const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
141 int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
142 CancellationManager* cancellation_manager_ TF_GUARDED_BY(execution_mutex_);
143 // Outputs:
144 std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
145 // TF_Status is an incomplete type and so can't be stack allocated. To avoid
146 // unnecessary allocations each Execute call, we keep one heap-allocated
147 // version for the thread.
148 StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
149
150 const std::string device_;
151 ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
152 mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
153 std::unique_ptr<Thread> thread_;
154 };
155
~DeviceThread()156 DeviceThread::~DeviceThread() {
157 {
158 tensorflow::mutex_lock l(execution_mutex_);
159 execution_state_ = ExecutionState::kShuttingDown;
160 }
161 start_execute_.notify_one();
162 }
163
AsyncWait(TF_Status * status)164 void DeviceThread::AsyncWait(TF_Status* status) {
165 tensorflow::mutex_lock l(execution_mutex_);
166 TFE_ExecutorWaitForAllPendingNodes(executor_.get(), status);
167 TFE_ExecutorClearError(executor_.get());
168 }
169
Run()170 void DeviceThread::Run() {
171 while (true) {
172 {
173 tensorflow::mutex_lock l(execution_mutex_);
174 while (execution_state_ == ExecutionState::kIdle ||
175 execution_state_ == ExecutionState::kHasResult) {
176 start_execute_.wait(l);
177 }
178 if (execution_state_ == ExecutionState::kShuttingDown) {
179 return;
180 } else if (execution_state_ == ExecutionState::kReadyToExecute) {
181 // op_outputs_ may have been std::moved
182 op_outputs_ = std::vector<TensorHandlePtr>();
183 Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
184 expected_max_outputs_, &op_outputs_, status_.get());
185 execution_state_ = ExecutionState::kHasResult;
186 }
187 }
188 finished_execute_.notify_one();
189 }
190 }
191
StartExecute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager,absl::optional<int64_t> step_id)192 void DeviceThread::StartExecute(TFE_Context* context,
193 const char* operation_name,
194 std::vector<TFE_TensorHandle*> inputs,
195 const TFE_OpAttrs* attributes,
196 int expected_max_outputs,
197 CancellationManager& cancellation_manager,
198 absl::optional<int64_t> step_id) {
199 {
200 tensorflow::mutex_lock l(execution_mutex_);
201 while (execution_state_ != ExecutionState::kIdle) {
202 // If there's already a pending execution, wait until Join finishes before
203 // starting on the next operation.
204 finished_join_.wait(l);
205 }
206 context_ = context;
207 operation_name_ = operation_name;
208 step_id_ = step_id;
209 op_inputs_ = inputs;
210 attributes_ = attributes;
211 expected_max_outputs_ = expected_max_outputs;
212 cancellation_manager_ = &cancellation_manager;
213 execution_state_ = ExecutionState::kReadyToExecute;
214 }
215 start_execute_.notify_one();
216 }
217
Join(TF_Status * status)218 std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
219 std::vector<TensorHandlePtr> result;
220 {
221 tensorflow::mutex_lock l(execution_mutex_);
222 while (execution_state_ != ExecutionState::kHasResult) {
223 finished_execute_.wait(l);
224 }
225 if (TF_GetCode(status_.get()) != TF_OK) {
226 TF_SetStatus(status, TF_GetCode(status_.get()),
227 TF_Message(status_.get()));
228 // Reset the member `status_` so future op executions (after recovery from
229 // the bad `status`) start with an OK status.
230 TF_SetStatus(status_.get(), TF_OK, "");
231 }
232 cancellation_manager_ = nullptr;
233 execution_state_ = ExecutionState::kIdle;
234 result = std::move(op_outputs_);
235 }
236 finished_join_.notify_one();
237 return result;
238 }
239
Execute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,std::vector<TensorHandlePtr> * outputs,TF_Status * status) const240 void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
241 std::vector<TFE_TensorHandle*> inputs,
242 const TFE_OpAttrs* attributes,
243 int expected_max_outputs,
244 std::vector<TensorHandlePtr>* outputs,
245 TF_Status* status) const {
246 if (op_ == nullptr) {
247 TFE_ContextSetExecutorForThread(context, executor_.get());
248 op_.reset(TFE_NewOp(context, operation_name, status));
249 if (TF_GetCode(status) != TF_OK) return;
250 TFE_OpSetDevice(op_.get(), device_.c_str(), status);
251 if (TF_GetCode(status) != TF_OK) return;
252 } else {
253 TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
254 if (TF_GetCode(status) != TF_OK) return;
255 }
256 TFE_OpAddAttrs(op_.get(), attributes);
257 for (int input_index = 0; input_index < inputs.size(); ++input_index) {
258 TFE_OpAddInput(op_.get(), inputs[input_index], status);
259 if (TF_GetCode(status) != TF_OK) return;
260 }
261 std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
262 int real_num_outputs = expected_max_outputs;
263 TFE_OpSetCancellationManager(op_.get(), wrap(cancellation_manager_), status);
264 if (TF_GetCode(status) != TF_OK) return;
265
266 // unwrap op_ and set step_id only if valid step id value was set.
267 // Currently only required for non-TFRT use cases, e.g., EagerOp.
268 if (step_id_.has_value()) {
269 tensorflow::unwrap(op_.get())->SetStepId(step_id_.value());
270 }
271
272 TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
273 if (TF_GetCode(status) != TF_OK) {
274 cancellation_manager_->StartCancel();
275 return;
276 }
277 unwrapped_results.resize(real_num_outputs);
278 outputs->reserve(real_num_outputs);
279 for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
280 outputs->emplace_back(unwrapped_result);
281 }
282 }
283
ParallelDevice(const std::vector<std::string> & devices,const bool is_async)284 ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
285 const bool is_async)
286 : underlying_devices_(devices),
287 default_cancellation_manager_(absl::make_unique<CancellationManager>()) {
288 device_threads_.reserve(devices.size());
289 for (int device_index = 0; device_index < devices.size(); ++device_index) {
290 device_threads_.emplace_back(
291 new DeviceThread(devices[device_index].c_str(), is_async));
292 }
293 }
294
295 // Necessary for a unique_ptr to a forward-declared type.
296 ParallelDevice::~ParallelDevice() = default;
297
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status) const298 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
299 TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
300 std::vector<TensorHandlePtr> components;
301 components.reserve(underlying_devices_.size());
302 for (const std::string& underlying_device_name : underlying_devices_) {
303 TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
304 tensor, context, underlying_device_name.c_str(), status);
305 if (TF_GetCode(status) != TF_OK) return nullptr;
306 components.emplace_back(t);
307 }
308 return ParallelTensor::FromTensorHandles(*this, std::move(components),
309 status);
310 }
311
DeviceIDs(TFE_Context * context,TF_Status * status) const312 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
313 TFE_Context* context, TF_Status* status) const {
314 std::vector<int32_t> ids;
315 ids.reserve(num_underlying_devices());
316 for (int i = 0; i < num_underlying_devices(); ++i) {
317 ids.push_back(i);
318 }
319 return ScalarsFromSequence<int32_t>(ids, context, status);
320 }
321
322 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Execute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status) const323 ParallelDevice::Execute(TFE_Context* context,
324 const std::vector<ParallelTensor*>& inputs,
325 const char* operation_name,
326 const TFE_OpAttrs* attributes, int expected_max_outputs,
327 TF_Status* status) const {
328 std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
329 StartExecute(context, inputs, operation_name, attributes,
330 expected_max_outputs, *default_cancellation_manager_);
331 auto result = Join(expected_output_shapes, status);
332 if (TF_GetCode(status) != TF_OK) {
333 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> await_status(
334 TF_NewStatus(), TF_DeleteStatus);
335 // Wait until all pending nodes have completed since they may have a
336 // reference to default_cancellation_manager_. We ignore the status return
337 // since we already have a bad status to propagate.
338 TFE_ContextAsyncWait(context, await_status.get());
339 // Reset the cancellation manager on a bad status. Otherwise we'll cancel
340 // all future operations.
341 default_cancellation_manager_ = absl::make_unique<CancellationManager>();
342 }
343 return result;
344 }
345
StartExecute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager,absl::optional<int64_t> step_id) const346 void ParallelDevice::StartExecute(TFE_Context* context,
347 const std::vector<ParallelTensor*>& inputs,
348 const char* operation_name,
349 const TFE_OpAttrs* attributes,
350 int expected_max_outputs,
351 CancellationManager& cancellation_manager,
352 absl::optional<int64_t> step_id) const {
353 for (int device_index = 0; device_index < underlying_devices_.size();
354 ++device_index) {
355 DeviceThread* device_thread = device_threads_[device_index].get();
356 std::vector<TFE_TensorHandle*> device_inputs;
357 device_inputs.reserve(inputs.size());
358 for (int input_index = 0; input_index < inputs.size(); ++input_index) {
359 // Parallel tensors are divided between operations by device.
360 device_inputs.push_back(inputs[input_index]->tensor(device_index));
361 }
362 device_thread->StartExecute(
363 context, operation_name, std::move(device_inputs), attributes,
364 expected_max_outputs, cancellation_manager, step_id);
365 }
366 }
367
AsyncWait(TFE_Context * context,TF_Status * status) const368 void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const {
369 StatusPtr first_bad_status(nullptr);
370
371 for (const auto& dt : device_threads_) {
372 StatusPtr async_wait_status(TF_NewStatus());
373 dt->AsyncWait(async_wait_status.get());
374 // Prefer non cancelled errors to uncover real failures.
375 if (TF_GetCode(async_wait_status.get()) != TF_OK &&
376 (first_bad_status == nullptr ||
377 TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
378 first_bad_status.reset(TF_NewStatus());
379 TF_SetStatus(first_bad_status.get(), TF_GetCode(async_wait_status.get()),
380 TF_Message(async_wait_status.get()));
381 }
382 }
383
384 if (first_bad_status != nullptr) {
385 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
386 TF_Message(first_bad_status.get()));
387 }
388 }
389
AsyncWait(TFE_Context * context,TF_Status * status) const390 void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const {
391 StatusPtr first_bad_status(nullptr);
392
393 for (const auto& dt : device_threads_) {
394 StatusPtr async_wait_status(TF_NewStatus());
395 dt->AsyncWait(async_wait_status.get());
396 // Prefer non cancelled errors to uncover real failures.
397 if (TF_GetCode(async_wait_status.get()) != TF_OK &&
398 (first_bad_status == nullptr ||
399 TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
400 first_bad_status.reset(TF_NewStatus());
401 TF_SetStatus(first_bad_status.get(), TF_GetCode(async_wait_status.get()),
402 TF_Message(async_wait_status.get()));
403 }
404 }
405
406 if (first_bad_status != nullptr) {
407 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
408 TF_Message(first_bad_status.get()));
409 }
410 }
411
412 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Join(const std::vector<PartialTensorShape> & expected_output_shapes,TF_Status * status) const413 ParallelDevice::Join(
414 const std::vector<PartialTensorShape>& expected_output_shapes,
415 TF_Status* status) const {
416 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
417 // Compute per-device per-output tensors
418 std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
419 per_device_output_tensors.reserve(underlying_devices_.size());
420 int first_op_output_count = 0;
421 StatusPtr first_bad_status(nullptr);
422 for (int device_index = 0; device_index < underlying_devices_.size();
423 ++device_index) {
424 DeviceThread* device_thread = device_threads_[device_index].get();
425 per_device_output_tensors.push_back(device_thread->Join(status));
426 // We will run every Join even if there are bad statuses in case the user
427 // wants to recover and continue running ops on the parallel device (which
428 // would otherwise deadlock).
429 if (TF_GetCode(status) != TF_OK &&
430 (first_bad_status == nullptr
431 // Prefer propagating non-cancellation related statuses to avoid
432 // shadowing the original failure.
433 || TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
434 first_bad_status.reset(TF_NewStatus());
435 TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
436 TF_Message(status));
437 }
438
439 if (device_index == 0) {
440 first_op_output_count = per_device_output_tensors.rbegin()->size();
441 } else {
442 if (first_bad_status == nullptr &&
443 per_device_output_tensors.rbegin()->size() != first_op_output_count) {
444 first_bad_status.reset(TF_NewStatus());
445 TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
446 "Parallel ops produced different numbers of tensors.");
447 }
448 }
449 }
450 if (first_bad_status != nullptr) {
451 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
452 TF_Message(first_bad_status.get()));
453 return result;
454 }
455 // For each output of the original operation, pack the per-device
456 // TensorHandles we've computed into a single parallel TensorHandle.
457 std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
458 per_device_outputs.reserve(first_op_output_count);
459 for (int i = 0; i < first_op_output_count; ++i) {
460 std::vector<TensorHandlePtr> components;
461 components.reserve(underlying_devices_.size());
462 for (int j = 0; j < underlying_devices_.size(); ++j) {
463 components.push_back(std::move(per_device_output_tensors[j][i]));
464 }
465 if (expected_output_shapes[i].IsFullyDefined()) {
466 per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
467 *this, std::move(components),
468 absl::Span<const int64_t>(expected_output_shapes[i].dim_sizes()),
469 status));
470 } else {
471 per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
472 *this, std::move(components), status));
473 }
474 if (TF_GetCode(status) != TF_OK) return result;
475 }
476 result.emplace(std::move(per_device_outputs));
477 return result;
478 }
479
SummarizeDeviceNames() const480 std::vector<std::string> ParallelDevice::SummarizeDeviceNames() const {
481 std::vector<DeviceNameUtils::ParsedName> parsed_components(
482 underlying_devices_.size());
483 for (int component_index = 0; component_index < underlying_devices_.size();
484 ++component_index) {
485 if (!DeviceNameUtils::ParseFullName(underlying_devices_[component_index],
486 &parsed_components[component_index]) ||
487 !DeviceNameUtils::IsSameAddressSpace(
488 underlying_devices_[component_index], underlying_devices_[0])) {
489 // Device names are from different address spaces, or we can't figure out
490 // whether they are, so we'll fully-qualify everything.
491 return underlying_devices_;
492 }
493 }
494 std::vector<std::string> local_names;
495 local_names.reserve(underlying_devices_.size());
496 for (const DeviceNameUtils::ParsedName& parsed_component :
497 parsed_components) {
498 local_names.push_back(
499 absl::StrCat(parsed_component.type, ":", parsed_component.id));
500 }
501 return local_names;
502 }
503
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,absl::Span<const int64_t> shape,TF_Status * status)504 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
505 const ParallelDevice& parallel_device,
506 std::vector<TensorHandlePtr> components, absl::Span<const int64_t> shape,
507 TF_Status* status) {
508 TFE_TensorHandleGetStatus(components[0].get(), status);
509 if (!status->status.ok()) {
510 return nullptr;
511 }
512
513 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
514 // Verify that the TensorHandle's shape and dtype match all of the component
515 // shapes and dtypes.
516 for (TensorHandlePtr& component : components) {
517 TFE_TensorHandleGetStatus(component.get(), status);
518 if (!status->status.ok()) {
519 return nullptr;
520 }
521 if (TFE_TensorHandleDataType(component.get()) != dtype) {
522 TF_SetStatus(status, TF_INTERNAL,
523 "Components of a ParallelTensor must all have "
524 "the same dtype");
525 return nullptr;
526 }
527 }
528 return std::unique_ptr<ParallelTensor>(
529 new ParallelTensor(parallel_device, std::move(components), shape, dtype));
530 }
531
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,TF_Status * status)532 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
533 const ParallelDevice& parallel_device,
534 std::vector<TensorHandlePtr> components, TF_Status* status) {
535 TFE_TensorHandleGetStatus(components[0].get(), status);
536 if (!status->status.ok()) {
537 return nullptr;
538 }
539
540 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
541 // Verify that the combined TensorHandle's dtype matches all of the component
542 // dtypes.
543 for (TensorHandlePtr& component : components) {
544 TFE_TensorHandleGetStatus(component.get(), status);
545 if (!status->status.ok()) {
546 return nullptr;
547 }
548 if (TFE_TensorHandleDataType(component.get()) != dtype) {
549 TF_SetStatus(status, TF_INTERNAL,
550 "Components of a ParallelTensor must all have "
551 "the same dtype");
552 return nullptr;
553 }
554 }
555 return std::unique_ptr<ParallelTensor>(
556 new ParallelTensor(parallel_device, std::move(components), dtype));
557 }
558
Shape(const std::vector<int64_t> ** shape) const559 Status ParallelTensor::Shape(const std::vector<int64_t>** shape) const {
560 if (!shape_.has_value()) {
561 TF_Status status;
562 PartialTensorShape combined_shape;
563 TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&combined_shape));
564
565 for (const TensorHandlePtr& component : tensors_) {
566 PartialTensorShape component_shape;
567 TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape));
568 if (combined_shape.dims() < 0 ||
569 combined_shape.dims() != component_shape.dims()) {
570 PartialTensorShape first_shape;
571 TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape));
572 return errors::Unimplemented(absl::StrCat(
573 "Computing the shape of a ParallelTensor when the components do "
574 "not all have the same rank is not supported. One tensor had "
575 "shape ",
576 first_shape.DebugString(), " and another had shape ",
577 component_shape.DebugString()));
578 } else {
579 // Generalize differing axis lengths to "variable"/"unknown".
580 for (int axis_index = 0; axis_index < combined_shape.dims();
581 ++axis_index) {
582 int64_t axis_length = combined_shape.dim_size(axis_index);
583 if (axis_length != component_shape.dim_size(axis_index)) {
584 axis_length = -1;
585 }
586 TF_RETURN_IF_ERROR(
587 combined_shape.SetDimWithStatus(axis_index, axis_length));
588 }
589 }
590 }
591 auto dim_sizes = combined_shape.dim_sizes();
592 shape_ = std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
593 }
594 *shape = &*shape_;
595 return OkStatus();
596 }
597
SummarizeValue(std::string & summary)598 Status ParallelTensor::SummarizeValue(std::string& summary) {
599 summary = "{";
600 std::vector<std::string> summarized_devices = device_.SummarizeDeviceNames();
601 for (int component_index = 0; component_index < tensors_.size();
602 ++component_index) {
603 // TODO(allenl): Add a C API for summarizing tensors. Currently custom
604 // devices limiting themselves to a C API (for ABI compatibility) would need
605 // to implement summarization for component tensors themselves.
606 ImmediateExecutionTensorHandle* component =
607 tensorflow::unwrap(tensors_[component_index].get());
608 std::string component_summary;
609 TF_RETURN_IF_ERROR(component->SummarizeValue(component_summary));
610 absl::StrAppend(&summary, component_index == 0 ? "" : ", ", "\"",
611 summarized_devices[component_index],
612 "\": ", component_summary);
613 }
614 summary += "}";
615 return OkStatus();
616 }
617
SummarizeValue(std::string & summary)618 Status ParallelTensor::SummarizeValue(std::string& summary) {
619 summary = "{";
620 std::vector<std::string> summarized_devices = device_.SummarizeDeviceNames();
621 for (int component_index = 0; component_index < tensors_.size();
622 ++component_index) {
623 // TODO(allenl): Add a C API for summarizing tensors. Currently custom
624 // devices limiting themselves to a C API (for ABI compatibility) would need
625 // to implement summarization for component tensors themselves.
626 ImmediateExecutionTensorHandle* component =
627 tensorflow::unwrap(tensors_[component_index].get());
628 std::string component_summary;
629 TF_RETURN_IF_ERROR(component->SummarizeValue(component_summary));
630 absl::StrAppend(&summary, component_index == 0 ? "" : ", ", "\"",
631 summarized_devices[component_index],
632 "\": ", component_summary);
633 }
634 summary += "}";
635 return Status::OK();
636 }
637
638 } // namespace parallel_device
639 } // namespace tensorflow
640