1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ 18 19 #include "tensorflow/core/framework/allocator.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/platform/protobuf.h" 23 #include "tensorflow/core/platform/types.h" 24 #include "tensorflow/core/protobuf/worker.pb.h" 25 26 namespace tensorflow { 27 28 class Allocator; 29 class DeviceBase; 30 class TensorProto; 31 32 // TensorResponse can be used as the destination of an RPC that returns 33 // a RecvTensorResponse. It efficiently decodes the incoming data 34 // into Tensor contents as well as associated metadata. 35 class TensorResponse { 36 public: TensorResponse()37 TensorResponse() {} 38 39 // Reset to initial state. 40 void Clear(); 41 42 // Clear just tensor_ and meta_ members without setting allocation 43 // related members. 44 void ClearTensor(); 45 46 // Initialize memory allocation related members. 47 void InitAlloc(DeviceBase* d, const AllocatorAttributes& aa); 48 49 // Source provides a way for a particular RPC implementation to provide 50 // received data to ParseFrom. 51 class Source { 52 public: 53 virtual ~Source(); 54 55 // Return the stream that contains the data to be parsed. 56 // Note that this method might be invoked more than once if 57 // ParseFrom needs to fall back to a more expensive parsing method. 58 // Every call must return a stream pointing at the beginning of 59 // the serialized RecvTensorResponse. 60 // 61 // Note that a subsequent call to contents() invalidates previous 62 // results of contents(). 63 // 64 // Ownership of the returned stream is retained by the Source and 65 // should not be deleted by the caller. 66 virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0; 67 }; 68 69 // Parse the RecvTensorResponse encoded in the data yielded by 70 // source->contents() into *this. 71 Status ParseFrom(Source* source); 72 73 // Initialize tensor from *response. 74 // Leaves *response with unspecified contents. 75 Status InitFrom(RecvTensorResponse* response); 76 77 // Initialize tensor metadata from response and allocate 78 // uninitialized backing storage for actual contents. 79 void InitPartial(const RecvTensorResponse& response, 80 const AllocationAttributes& allocation_attr); 81 82 // Return a reference to the parsed tensor. The tensor will remain 83 // live only until *this is destroyed or modified. tensor()84 const Tensor& tensor() const { return tensor_; } 85 86 // Return a reference to the parsed tensor metadata (no contents). 87 // The result will remain live only until *this is destroyed or 88 // modified. metadata()89 const RecvTensorResponse& metadata() const { return meta_; } 90 91 // Return pointer to the device hosting the tensor. device()92 DeviceBase* device() const { return device_; } 93 94 private: 95 bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input, 96 TensorProto* tensor_meta); 97 bool ParseFast(Source* source); 98 bool ParseSlow(Source* source); 99 100 bool on_host_ = false; 101 DeviceBase* device_ = nullptr; 102 AllocatorAttributes alloc_attrs_; 103 Allocator* allocator_ = nullptr; 104 bool already_used_ = false; 105 Tensor tensor_; 106 RecvTensorResponse meta_; 107 }; 108 109 } // namespace tensorflow 110 111 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ 112