1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ 17 18 #include <string> 19 20 #include "tensorflow/c/eager/immediate_execution_context.h" 21 #include "tensorflow/c/eager/immediate_execution_operation.h" 22 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/util/device_name_utils.h" 25 26 namespace tensorflow { 27 28 class TensorHandle; 29 class EagerOperation; 30 class CustomDeviceTensorHandle; 31 32 // Custom devices intercept the execution of operations (the `Execute` method), 33 // typically implemented with one or more of the custom device's own executions. 34 class CustomDevice { 35 public: ~CustomDevice()36 virtual ~CustomDevice() {} 37 virtual const string& name() = 0; 38 virtual Status CopyTensorToDevice( 39 ImmediateExecutionTensorHandle* tensor, 40 ImmediateExecutionTensorHandle** result) = 0; 41 42 virtual Status CopyTensorFromDevice( 43 ImmediateExecutionTensorHandle* tensor, const string& target_device_name, 44 ImmediateExecutionTensorHandle** result) = 0; 45 46 virtual Status Execute(const ImmediateExecutionOperation* op, 47 ImmediateExecutionTensorHandle** retvals, 48 int* num_retvals) = 0; 49 50 // Creates a packed TensorHandle from a group of custom device TensorHandles, 51 // one of which is on this custom device. 52 virtual Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles, 53 ImmediateExecutionTensorHandle** result) = 0; 54 }; 55 56 // Custom devices do many of the same things as physical Devices, but have a 57 // much more restricted interface. We pass around ambiguous pointers since 58 // operations may be placed either on custom or physical devices. 59 using VariantDevice = absl::variant<Device*, CustomDevice*>; 60 61 // Indicates either HostCPU or an unset physical device. We never set a null 62 // CustomDevice*. 63 const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr); 64 65 // A tensor handle produced by a custom device. Generally they can only be 66 // consumed by executing an operation on the same custom device that produced it 67 // originally, or by attempting to copy the handle off the custom device. 68 // 69 // TODO(allenl): Currently custom devices are tied to the eager C API. They 70 // should be renamed op handlers and subclass AbstractTensorHandle instead so 71 // they are eager/graph agnostic. 72 class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { 73 public: CustomDeviceTensorHandle(ImmediateExecutionContext * context,CustomDevice * device,tensorflow::DataType dtype)74 CustomDeviceTensorHandle(ImmediateExecutionContext* context, 75 CustomDevice* device, tensorflow::DataType dtype) 76 : ImmediateExecutionTensorHandle(kCustomDevice), 77 context_(context), 78 device_(device), 79 dtype_(dtype) {} 80 81 // TODO(allenl): Should this be a generic method of 82 // ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer? 83 virtual void* DevicePointer() const = 0; 84 DataType()85 tensorflow::DataType DataType() const override { return dtype_; } 86 Status Shape(PartialTensorShape* shape) const override; 87 Status NumElements(int64* num_elements) const override; 88 DeviceName(Status * status)89 const char* DeviceName(Status* status) const override { 90 return device_->name().c_str(); 91 } BackingDeviceName(Status * status)92 const char* BackingDeviceName(Status* status) const override { 93 return device_->name().c_str(); 94 } device()95 CustomDevice* device() const { return device_; } 96 const char* DeviceType(Status* status) const override; 97 int DeviceId(Status* status) const override; 98 99 AbstractTensorInterface* Resolve(Status* status) override; 100 Copy()101 ImmediateExecutionTensorHandle* Copy() override { 102 Ref(); 103 return this; 104 } Release()105 void Release() override { Unref(); } 106 107 // For LLVM style RTTI. classof(const AbstractTensorHandle * ptr)108 static bool classof(const AbstractTensorHandle* ptr) { 109 return ptr->getKind() == kCustomDevice; 110 } 111 112 protected: 113 const DeviceNameUtils::ParsedName* ParsedName(Status* status) const; 114 115 ImmediateExecutionContext* const context_; 116 CustomDevice* const device_; 117 const tensorflow::DataType dtype_; 118 119 mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_; 120 }; 121 122 } // namespace tensorflow 123 124 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ 125