1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ 17 18 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" 19 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 20 21 namespace tensorflow { 22 23 // Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only 24 // the shape is known. 25 class RemoteTensorHandleData : public TensorHandleData { 26 public: 27 RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape, 28 const string& remote_task, uint64 context_id, 29 EagerContext* ctx); 30 ~RemoteTensorHandleData() override; 31 32 // A remote tensor handle does not have a Tensor object, hence it can only 33 // support the shape requests. 34 Status Tensor(const tensorflow::Tensor** t) const override; 35 Status TensorValue(tensorflow::TensorValue* t) override; 36 Status Shape(TensorShape* shape) const override; 37 Status NumDims(int* num_dims) const override; 38 Status Dim(int dim_index, int64* dim) const override; 39 Status NumElements(int64* num_elements) const override; 40 41 string DebugString() const override; 42 op_id()43 int64 op_id() const { return op_id_; } output_num()44 int32 output_num() const { return output_num_; } 45 46 private: 47 // IDs required when this class is representing a remote tensor handle. 48 const int64 op_id_; 49 const int32 output_num_; 50 const TensorShape shape_; 51 string remote_task_; 52 uint64 context_id_; 53 EagerContext* const ctx_; 54 }; 55 56 // Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the 57 // shape has been computed this is replaced with a remote tensor handle. 58 class UnshapedRemoteTensorHandleData : public TensorHandleData { 59 public: 60 UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num, 61 const string& remote_task, uint64 context_id, 62 EagerContext* ctx); 63 ~UnshapedRemoteTensorHandleData() override; 64 65 // Unshaped remote tensor handles are not ready and hence cannot satisfy any 66 // of these requests. 67 Status Tensor(const tensorflow::Tensor** t) const override; 68 Status TensorValue(tensorflow::TensorValue* t) override; 69 Status Shape(TensorShape* shape) const override; 70 Status NumDims(int* num_dims) const override; 71 Status Dim(int dim_index, int64* dim) const override; 72 Status NumElements(int64* num_elements) const override; 73 74 string DebugString() const override; 75 op_id()76 int64 op_id() const { return op_id_; } output_num()77 int32 output_num() const { return output_num_; } remote_task()78 string remote_task() const { return remote_task_; } context_id()79 uint64 context_id() const { return context_id_; } ctx()80 EagerContext* ctx() const { return ctx_; } 81 82 // When constructed, UnshapedRemoteTensorHandleData owns the remote 83 // TensorHandle and should delete it by issuing an RPC. Once the remote 84 // shape has been learned, the ownership is transferred to 85 // RemoteTensorHandleData. This method must be called to let `this` know 86 // that it no longer owns the remote handle. 87 // TODO(iga): Add a factory method here that will create a new 88 // RemoteTensorHandleData from this and transfer ownership in the process. ReleaseRemoteTensorHandle()89 void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; } 90 91 private: 92 // IDs required when this class is representing a remote tensor handle. 93 const int64 op_id_; 94 const int32 output_num_; 95 bool delete_remote_tensor_; 96 string remote_task_; 97 uint64 context_id_; 98 EagerContext* const ctx_; 99 }; 100 101 } // namespace tensorflow 102 103 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ 104