• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
23 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/c/tf_status_internal.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 
31 namespace tensorflow {
32 namespace parallel_device {
33 namespace {
34 
35 class OpDeleter {
36  public:
operator ()(TFE_Op * to_delete) const37   void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
38 };
39 
40 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
41 
42 class StatusDeleter {
43  public:
operator ()(TF_Status * to_delete) const44   void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
45 };
46 
47 using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
48 
49 class ExecutorDeleter {
50  public:
operator ()(TFE_Executor * to_delete) const51   void operator()(TFE_Executor* to_delete) const {
52     TFE_DeleteExecutor(to_delete);
53   }
54 };
55 
56 using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
57 
58 }  // namespace
59 
60 // Allows a single op at a time to be launched without blocking.
61 //
62 // DeviceThread itself is thread-safe, in that StartExecute will block if there
63 // is a pending execution. Since StartExecute is equivalent to grabbing a lock,
64 // multiple DeviceThreads should always be accessed in the same order to avoid
65 // deadlocks.
66 class DeviceThread {
67  public:
68   // Starts a background thread waiting for `StartExecute`.
DeviceThread(const std::string & device,const bool is_async)69   explicit DeviceThread(const std::string& device, const bool is_async)
70       : status_(TF_NewStatus()),
71         // If the context's default exector is set to async, re-using that in
72         // each thread would cause collectives to deadlock. For consistency we
73         // create a new sync executor for every thread.
74         //
75         // TODO(allenl): We should have an async API that works with the
76         // parallel device.
77         device_(device),
78         executor_(TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true)),
79         op_(nullptr),
80         thread_(tensorflow::Env::Default()->StartThread(
81             tensorflow::ThreadOptions(), "parallel_device_execute",
82             std::bind(&DeviceThread::Run, this))) {}
83   ~DeviceThread();
84 
85   // Requests that the worker thread execute the specified operation. Blocks
86   // until the previously pending operation (a StartExecute without a Join) has
87   // finished, if any.
88   //
89   // `cancellation_manager` must live until after `Join` finishes and pending
90   // `is_async` operations finish. In addition to allowing the caller to cancel
91   // the operation, its `StartCancel` method will be called if op execution
92   // fails on any device in order to cancel the others.
93   void StartExecute(TFE_Context* context, const char* operation_name,
94                     std::vector<TFE_TensorHandle*> inputs,
95                     const TFE_OpAttrs* attributes, int expected_max_outputs,
96                     CancellationManager& cancellation_manager,
97                     absl::optional<int64_t> step_id = absl::nullopt);
98   // Block until the previous `StartExecute` operation has executed. Forwards
99   // the status from `TFE_Execute` and returns outputs if the status is OK.
100   std::vector<TensorHandlePtr> Join(TF_Status* status);
101 
102   // Block until all Ops finished running on the thread.
103   void AsyncWait(TF_Status* status);
104 
105  private:
106   void Run();
107 
108   void Execute(TFE_Context* context, const char* operation_name,
109                std::vector<TFE_TensorHandle*> inputs,
110                const TFE_OpAttrs* attributes, int expected_max_outputs,
111                std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
112       TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
113 
114   enum class ExecutionState {
115     kReadyToExecute,
116     kHasResult,
117     kIdle,
118     kShuttingDown,
119   };
120 
121   tensorflow::mutex execution_mutex_;
122   ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
123       ExecutionState::kIdle;
124   // Tells the worker thread that there is new work.
125   tensorflow::condition_variable start_execute_;
126   // The worker thread notifies that work has finished.
127   tensorflow::condition_variable finished_execute_;
128   // Notifies a StartExecute that the previous Join has finished.
129   tensorflow::condition_variable finished_join_;
130 
131   // Temporary state between `StartExecute` and `Join`.
132   //
133   //   Inputs; pointers are to objects not owned by the DeviceThread, but which
134   //   are expected to live at least until `Join` finishes:
135   TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
136   const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
137   absl::optional<int64_t> step_id_ TF_GUARDED_BY(execution_mutex_) =
138       absl::nullopt;
139   std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
140   const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
141   int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
142   CancellationManager* cancellation_manager_ TF_GUARDED_BY(execution_mutex_);
143   //   Outputs:
144   std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
145   // TF_Status is an incomplete type and so can't be stack allocated. To avoid
146   // unnecessary allocations each Execute call, we keep one heap-allocated
147   // version for the thread.
148   StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
149 
150   const std::string device_;
151   ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
152   mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
153   std::unique_ptr<Thread> thread_;
154 };
155 
~DeviceThread()156 DeviceThread::~DeviceThread() {
157   {
158     tensorflow::mutex_lock l(execution_mutex_);
159     execution_state_ = ExecutionState::kShuttingDown;
160   }
161   start_execute_.notify_one();
162 }
163 
AsyncWait(TF_Status * status)164 void DeviceThread::AsyncWait(TF_Status* status) {
165   tensorflow::mutex_lock l(execution_mutex_);
166   TFE_ExecutorWaitForAllPendingNodes(executor_.get(), status);
167   TFE_ExecutorClearError(executor_.get());
168 }
169 
Run()170 void DeviceThread::Run() {
171   while (true) {
172     {
173       tensorflow::mutex_lock l(execution_mutex_);
174       while (execution_state_ == ExecutionState::kIdle ||
175              execution_state_ == ExecutionState::kHasResult) {
176         start_execute_.wait(l);
177       }
178       if (execution_state_ == ExecutionState::kShuttingDown) {
179         return;
180       } else if (execution_state_ == ExecutionState::kReadyToExecute) {
181         // op_outputs_ may have been std::moved
182         op_outputs_ = std::vector<TensorHandlePtr>();
183         Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
184                 expected_max_outputs_, &op_outputs_, status_.get());
185         execution_state_ = ExecutionState::kHasResult;
186       }
187     }
188     finished_execute_.notify_one();
189   }
190 }
191 
StartExecute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager,absl::optional<int64_t> step_id)192 void DeviceThread::StartExecute(TFE_Context* context,
193                                 const char* operation_name,
194                                 std::vector<TFE_TensorHandle*> inputs,
195                                 const TFE_OpAttrs* attributes,
196                                 int expected_max_outputs,
197                                 CancellationManager& cancellation_manager,
198                                 absl::optional<int64_t> step_id) {
199   {
200     tensorflow::mutex_lock l(execution_mutex_);
201     while (execution_state_ != ExecutionState::kIdle) {
202       // If there's already a pending execution, wait until Join finishes before
203       // starting on the next operation.
204       finished_join_.wait(l);
205     }
206     context_ = context;
207     operation_name_ = operation_name;
208     step_id_ = step_id;
209     op_inputs_ = inputs;
210     attributes_ = attributes;
211     expected_max_outputs_ = expected_max_outputs;
212     cancellation_manager_ = &cancellation_manager;
213     execution_state_ = ExecutionState::kReadyToExecute;
214   }
215   start_execute_.notify_one();
216 }
217 
Join(TF_Status * status)218 std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
219   std::vector<TensorHandlePtr> result;
220   {
221     tensorflow::mutex_lock l(execution_mutex_);
222     while (execution_state_ != ExecutionState::kHasResult) {
223       finished_execute_.wait(l);
224     }
225     if (TF_GetCode(status_.get()) != TF_OK) {
226       TF_SetStatus(status, TF_GetCode(status_.get()),
227                    TF_Message(status_.get()));
228       // Reset the member `status_` so future op executions (after recovery from
229       // the bad `status`) start with an OK status.
230       TF_SetStatus(status_.get(), TF_OK, "");
231     }
232     cancellation_manager_ = nullptr;
233     execution_state_ = ExecutionState::kIdle;
234     result = std::move(op_outputs_);
235   }
236   finished_join_.notify_one();
237   return result;
238 }
239 
Execute(TFE_Context * context,const char * operation_name,std::vector<TFE_TensorHandle * > inputs,const TFE_OpAttrs * attributes,int expected_max_outputs,std::vector<TensorHandlePtr> * outputs,TF_Status * status) const240 void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
241                            std::vector<TFE_TensorHandle*> inputs,
242                            const TFE_OpAttrs* attributes,
243                            int expected_max_outputs,
244                            std::vector<TensorHandlePtr>* outputs,
245                            TF_Status* status) const {
246   if (op_ == nullptr) {
247     TFE_ContextSetExecutorForThread(context, executor_.get());
248     op_.reset(TFE_NewOp(context, operation_name, status));
249     if (TF_GetCode(status) != TF_OK) return;
250     TFE_OpSetDevice(op_.get(), device_.c_str(), status);
251     if (TF_GetCode(status) != TF_OK) return;
252   } else {
253     TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
254     if (TF_GetCode(status) != TF_OK) return;
255   }
256   TFE_OpAddAttrs(op_.get(), attributes);
257   for (int input_index = 0; input_index < inputs.size(); ++input_index) {
258     TFE_OpAddInput(op_.get(), inputs[input_index], status);
259     if (TF_GetCode(status) != TF_OK) return;
260   }
261   std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
262   int real_num_outputs = expected_max_outputs;
263   TFE_OpSetCancellationManager(op_.get(), wrap(cancellation_manager_), status);
264   if (TF_GetCode(status) != TF_OK) return;
265 
266   // unwrap op_ and set step_id only if valid step id value was set.
267   // Currently only required for non-TFRT use cases, e.g., EagerOp.
268   if (step_id_.has_value()) {
269     tensorflow::unwrap(op_.get())->SetStepId(step_id_.value());
270   }
271 
272   TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
273   if (TF_GetCode(status) != TF_OK) {
274     cancellation_manager_->StartCancel();
275     return;
276   }
277   unwrapped_results.resize(real_num_outputs);
278   outputs->reserve(real_num_outputs);
279   for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
280     outputs->emplace_back(unwrapped_result);
281   }
282 }
283 
ParallelDevice(const std::vector<std::string> & devices,const bool is_async)284 ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
285                                const bool is_async)
286     : underlying_devices_(devices),
287       default_cancellation_manager_(absl::make_unique<CancellationManager>()) {
288   device_threads_.reserve(devices.size());
289   for (int device_index = 0; device_index < devices.size(); ++device_index) {
290     device_threads_.emplace_back(
291         new DeviceThread(devices[device_index].c_str(), is_async));
292   }
293 }
294 
295 // Necessary for a unique_ptr to a forward-declared type.
296 ParallelDevice::~ParallelDevice() = default;
297 
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status) const298 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
299     TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
300   std::vector<TensorHandlePtr> components;
301   components.reserve(underlying_devices_.size());
302   for (const std::string& underlying_device_name : underlying_devices_) {
303     TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
304         tensor, context, underlying_device_name.c_str(), status);
305     if (TF_GetCode(status) != TF_OK) return nullptr;
306     components.emplace_back(t);
307   }
308   return ParallelTensor::FromTensorHandles(*this, std::move(components),
309                                            status);
310 }
311 
DeviceIDs(TFE_Context * context,TF_Status * status) const312 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
313     TFE_Context* context, TF_Status* status) const {
314   std::vector<int32_t> ids;
315   ids.reserve(num_underlying_devices());
316   for (int i = 0; i < num_underlying_devices(); ++i) {
317     ids.push_back(i);
318   }
319   return ScalarsFromSequence<int32_t>(ids, context, status);
320 }
321 
322 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Execute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status) const323 ParallelDevice::Execute(TFE_Context* context,
324                         const std::vector<ParallelTensor*>& inputs,
325                         const char* operation_name,
326                         const TFE_OpAttrs* attributes, int expected_max_outputs,
327                         TF_Status* status) const {
328   std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
329   StartExecute(context, inputs, operation_name, attributes,
330                expected_max_outputs, *default_cancellation_manager_);
331   auto result = Join(expected_output_shapes, status);
332   if (TF_GetCode(status) != TF_OK) {
333     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> await_status(
334         TF_NewStatus(), TF_DeleteStatus);
335     // Wait until all pending nodes have completed since they may have a
336     // reference to default_cancellation_manager_. We ignore the status return
337     // since we already have a bad status to propagate.
338     TFE_ContextAsyncWait(context, await_status.get());
339     // Reset the cancellation manager on a bad status. Otherwise we'll cancel
340     // all future operations.
341     default_cancellation_manager_ = absl::make_unique<CancellationManager>();
342   }
343   return result;
344 }
345 
StartExecute(TFE_Context * context,const std::vector<ParallelTensor * > & inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,CancellationManager & cancellation_manager,absl::optional<int64_t> step_id) const346 void ParallelDevice::StartExecute(TFE_Context* context,
347                                   const std::vector<ParallelTensor*>& inputs,
348                                   const char* operation_name,
349                                   const TFE_OpAttrs* attributes,
350                                   int expected_max_outputs,
351                                   CancellationManager& cancellation_manager,
352                                   absl::optional<int64_t> step_id) const {
353   for (int device_index = 0; device_index < underlying_devices_.size();
354        ++device_index) {
355     DeviceThread* device_thread = device_threads_[device_index].get();
356     std::vector<TFE_TensorHandle*> device_inputs;
357     device_inputs.reserve(inputs.size());
358     for (int input_index = 0; input_index < inputs.size(); ++input_index) {
359       // Parallel tensors are divided between operations by device.
360       device_inputs.push_back(inputs[input_index]->tensor(device_index));
361     }
362     device_thread->StartExecute(
363         context, operation_name, std::move(device_inputs), attributes,
364         expected_max_outputs, cancellation_manager, step_id);
365   }
366 }
367 
AsyncWait(TFE_Context * context,TF_Status * status) const368 void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const {
369   StatusPtr first_bad_status(nullptr);
370 
371   for (const auto& dt : device_threads_) {
372     StatusPtr async_wait_status(TF_NewStatus());
373     dt->AsyncWait(async_wait_status.get());
374     // Prefer non cancelled errors to uncover real failures.
375     if (TF_GetCode(async_wait_status.get()) != TF_OK &&
376         (first_bad_status == nullptr ||
377          TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
378       first_bad_status.reset(TF_NewStatus());
379       TF_SetStatus(first_bad_status.get(), TF_GetCode(async_wait_status.get()),
380                    TF_Message(async_wait_status.get()));
381     }
382   }
383 
384   if (first_bad_status != nullptr) {
385     TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
386                  TF_Message(first_bad_status.get()));
387   }
388 }
389 
AsyncWait(TFE_Context * context,TF_Status * status) const390 void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const {
391   StatusPtr first_bad_status(nullptr);
392 
393   for (const auto& dt : device_threads_) {
394     StatusPtr async_wait_status(TF_NewStatus());
395     dt->AsyncWait(async_wait_status.get());
396     // Prefer non cancelled errors to uncover real failures.
397     if (TF_GetCode(async_wait_status.get()) != TF_OK &&
398         (first_bad_status == nullptr ||
399          TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
400       first_bad_status.reset(TF_NewStatus());
401       TF_SetStatus(first_bad_status.get(), TF_GetCode(async_wait_status.get()),
402                    TF_Message(async_wait_status.get()));
403     }
404   }
405 
406   if (first_bad_status != nullptr) {
407     TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
408                  TF_Message(first_bad_status.get()));
409   }
410 }
411 
412 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
Join(const std::vector<PartialTensorShape> & expected_output_shapes,TF_Status * status) const413 ParallelDevice::Join(
414     const std::vector<PartialTensorShape>& expected_output_shapes,
415     TF_Status* status) const {
416   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
417   // Compute per-device per-output tensors
418   std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
419   per_device_output_tensors.reserve(underlying_devices_.size());
420   int first_op_output_count = 0;
421   StatusPtr first_bad_status(nullptr);
422   for (int device_index = 0; device_index < underlying_devices_.size();
423        ++device_index) {
424     DeviceThread* device_thread = device_threads_[device_index].get();
425     per_device_output_tensors.push_back(device_thread->Join(status));
426     // We will run every Join even if there are bad statuses in case the user
427     // wants to recover and continue running ops on the parallel device (which
428     // would otherwise deadlock).
429     if (TF_GetCode(status) != TF_OK &&
430         (first_bad_status == nullptr
431          // Prefer propagating non-cancellation related statuses to avoid
432          // shadowing the original failure.
433          || TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
434       first_bad_status.reset(TF_NewStatus());
435       TF_SetStatus(first_bad_status.get(), TF_GetCode(status),
436                    TF_Message(status));
437     }
438 
439     if (device_index == 0) {
440       first_op_output_count = per_device_output_tensors.rbegin()->size();
441     } else {
442       if (first_bad_status == nullptr &&
443           per_device_output_tensors.rbegin()->size() != first_op_output_count) {
444         first_bad_status.reset(TF_NewStatus());
445         TF_SetStatus(first_bad_status.get(), TF_INTERNAL,
446                      "Parallel ops produced different numbers of tensors.");
447       }
448     }
449   }
450   if (first_bad_status != nullptr) {
451     TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
452                  TF_Message(first_bad_status.get()));
453     return result;
454   }
455   // For each output of the original operation, pack the per-device
456   // TensorHandles we've computed into a single parallel TensorHandle.
457   std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
458   per_device_outputs.reserve(first_op_output_count);
459   for (int i = 0; i < first_op_output_count; ++i) {
460     std::vector<TensorHandlePtr> components;
461     components.reserve(underlying_devices_.size());
462     for (int j = 0; j < underlying_devices_.size(); ++j) {
463       components.push_back(std::move(per_device_output_tensors[j][i]));
464     }
465     if (expected_output_shapes[i].IsFullyDefined()) {
466       per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
467           *this, std::move(components),
468           absl::Span<const int64_t>(expected_output_shapes[i].dim_sizes()),
469           status));
470     } else {
471       per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
472           *this, std::move(components), status));
473     }
474     if (TF_GetCode(status) != TF_OK) return result;
475   }
476   result.emplace(std::move(per_device_outputs));
477   return result;
478 }
479 
SummarizeDeviceNames() const480 std::vector<std::string> ParallelDevice::SummarizeDeviceNames() const {
481   std::vector<DeviceNameUtils::ParsedName> parsed_components(
482       underlying_devices_.size());
483   for (int component_index = 0; component_index < underlying_devices_.size();
484        ++component_index) {
485     if (!DeviceNameUtils::ParseFullName(underlying_devices_[component_index],
486                                         &parsed_components[component_index]) ||
487         !DeviceNameUtils::IsSameAddressSpace(
488             underlying_devices_[component_index], underlying_devices_[0])) {
489       // Device names are from different address spaces, or we can't figure out
490       // whether they are, so we'll fully-qualify everything.
491       return underlying_devices_;
492     }
493   }
494   std::vector<std::string> local_names;
495   local_names.reserve(underlying_devices_.size());
496   for (const DeviceNameUtils::ParsedName& parsed_component :
497        parsed_components) {
498     local_names.push_back(
499         absl::StrCat(parsed_component.type, ":", parsed_component.id));
500   }
501   return local_names;
502 }
503 
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,absl::Span<const int64_t> shape,TF_Status * status)504 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
505     const ParallelDevice& parallel_device,
506     std::vector<TensorHandlePtr> components, absl::Span<const int64_t> shape,
507     TF_Status* status) {
508   TFE_TensorHandleGetStatus(components[0].get(), status);
509   if (!status->status.ok()) {
510     return nullptr;
511   }
512 
513   TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
514   // Verify that the TensorHandle's shape and dtype match all of the component
515   // shapes and dtypes.
516   for (TensorHandlePtr& component : components) {
517     TFE_TensorHandleGetStatus(component.get(), status);
518     if (!status->status.ok()) {
519       return nullptr;
520     }
521     if (TFE_TensorHandleDataType(component.get()) != dtype) {
522       TF_SetStatus(status, TF_INTERNAL,
523                    "Components of a ParallelTensor must all have "
524                    "the same dtype");
525       return nullptr;
526     }
527   }
528   return std::unique_ptr<ParallelTensor>(
529       new ParallelTensor(parallel_device, std::move(components), shape, dtype));
530 }
531 
FromTensorHandles(const ParallelDevice & parallel_device,std::vector<TensorHandlePtr> components,TF_Status * status)532 std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
533     const ParallelDevice& parallel_device,
534     std::vector<TensorHandlePtr> components, TF_Status* status) {
535   TFE_TensorHandleGetStatus(components[0].get(), status);
536   if (!status->status.ok()) {
537     return nullptr;
538   }
539 
540   TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
541   // Verify that the combined TensorHandle's dtype matches all of the component
542   // dtypes.
543   for (TensorHandlePtr& component : components) {
544     TFE_TensorHandleGetStatus(component.get(), status);
545     if (!status->status.ok()) {
546       return nullptr;
547     }
548     if (TFE_TensorHandleDataType(component.get()) != dtype) {
549       TF_SetStatus(status, TF_INTERNAL,
550                    "Components of a ParallelTensor must all have "
551                    "the same dtype");
552       return nullptr;
553     }
554   }
555   return std::unique_ptr<ParallelTensor>(
556       new ParallelTensor(parallel_device, std::move(components), dtype));
557 }
558 
Shape(const std::vector<int64_t> ** shape) const559 Status ParallelTensor::Shape(const std::vector<int64_t>** shape) const {
560   if (!shape_.has_value()) {
561     TF_Status status;
562     PartialTensorShape combined_shape;
563     TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&combined_shape));
564 
565     for (const TensorHandlePtr& component : tensors_) {
566       PartialTensorShape component_shape;
567       TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape));
568       if (combined_shape.dims() < 0 ||
569           combined_shape.dims() != component_shape.dims()) {
570         PartialTensorShape first_shape;
571         TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape));
572         return errors::Unimplemented(absl::StrCat(
573             "Computing the shape of a ParallelTensor when the components do "
574             "not all have the same rank is not supported. One tensor had "
575             "shape ",
576             first_shape.DebugString(), " and another had shape ",
577             component_shape.DebugString()));
578       } else {
579         // Generalize differing axis lengths to "variable"/"unknown".
580         for (int axis_index = 0; axis_index < combined_shape.dims();
581              ++axis_index) {
582           int64_t axis_length = combined_shape.dim_size(axis_index);
583           if (axis_length != component_shape.dim_size(axis_index)) {
584             axis_length = -1;
585           }
586           TF_RETURN_IF_ERROR(
587               combined_shape.SetDimWithStatus(axis_index, axis_length));
588         }
589       }
590     }
591     auto dim_sizes = combined_shape.dim_sizes();
592     shape_ = std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
593   }
594   *shape = &*shape_;
595   return OkStatus();
596 }
597 
SummarizeValue(std::string & summary)598 Status ParallelTensor::SummarizeValue(std::string& summary) {
599   summary = "{";
600   std::vector<std::string> summarized_devices = device_.SummarizeDeviceNames();
601   for (int component_index = 0; component_index < tensors_.size();
602        ++component_index) {
603     // TODO(allenl): Add a C API for summarizing tensors. Currently custom
604     // devices limiting themselves to a C API (for ABI compatibility) would need
605     // to implement summarization for component tensors themselves.
606     ImmediateExecutionTensorHandle* component =
607         tensorflow::unwrap(tensors_[component_index].get());
608     std::string component_summary;
609     TF_RETURN_IF_ERROR(component->SummarizeValue(component_summary));
610     absl::StrAppend(&summary, component_index == 0 ? "" : ", ", "\"",
611                     summarized_devices[component_index],
612                     "\": ", component_summary);
613   }
614   summary += "}";
615   return OkStatus();
616 }
617 
SummarizeValue(std::string & summary)618 Status ParallelTensor::SummarizeValue(std::string& summary) {
619   summary = "{";
620   std::vector<std::string> summarized_devices = device_.SummarizeDeviceNames();
621   for (int component_index = 0; component_index < tensors_.size();
622        ++component_index) {
623     // TODO(allenl): Add a C API for summarizing tensors. Currently custom
624     // devices limiting themselves to a C API (for ABI compatibility) would need
625     // to implement summarization for component tensors themselves.
626     ImmediateExecutionTensorHandle* component =
627         tensorflow::unwrap(tensors_[component_index].get());
628     std::string component_summary;
629     TF_RETURN_IF_ERROR(component->SummarizeValue(component_summary));
630     absl::StrAppend(&summary, component_index == 0 ? "" : ", ", "\"",
631                     summarized_devices[component_index],
632                     "\": ", component_summary);
633   }
634   summary += "}";
635   return Status::OK();
636 }
637 
638 }  // namespace parallel_device
639 }  // namespace tensorflow
640