• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
17 
18 #include <string>
19 
20 #include "tensorflow/c/eager/immediate_execution_context.h"
21 #include "tensorflow/c/eager/immediate_execution_operation.h"
22 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/util/device_name_utils.h"
25 
26 namespace tensorflow {
27 
28 class TensorHandle;
29 class EagerOperation;
30 class CustomDeviceTensorHandle;
31 
32 // Custom devices intercept the execution of operations (the `Execute` method),
33 // typically implemented with one or more of the custom device's own executions.
34 class CustomDevice {
35  public:
~CustomDevice()36   virtual ~CustomDevice() {}
37   virtual const string& name() = 0;
38   virtual Status CopyTensorToDevice(
39       ImmediateExecutionTensorHandle* tensor,
40       ImmediateExecutionTensorHandle** result) = 0;
41 
42   virtual Status CopyTensorFromDevice(
43       ImmediateExecutionTensorHandle* tensor, const string& target_device_name,
44       ImmediateExecutionTensorHandle** result) = 0;
45 
46   virtual Status Execute(const ImmediateExecutionOperation* op,
47                          ImmediateExecutionTensorHandle** retvals,
48                          int* num_retvals) = 0;
49 
50   // Creates a packed TensorHandle from a group of custom device TensorHandles,
51   // one of which is on this custom device.
52   virtual Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
53                       ImmediateExecutionTensorHandle** result) = 0;
54 };
55 
56 // Custom devices do many of the same things as physical Devices, but have a
57 // much more restricted interface. We pass around ambiguous pointers since
58 // operations may be placed either on custom or physical devices.
59 using VariantDevice = absl::variant<Device*, CustomDevice*>;
60 
61 // Indicates either HostCPU or an unset physical device. We never set a null
62 // CustomDevice*.
63 const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
64 
65 // A tensor handle produced by a custom device. Generally they can only be
66 // consumed by executing an operation on the same custom device that produced it
67 // originally, or by attempting to copy the handle off the custom device.
68 //
69 // TODO(allenl): Currently custom devices are tied to the eager C API. They
70 // should be renamed op handlers and subclass AbstractTensorHandle instead so
71 // they are eager/graph agnostic.
72 class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
73  public:
CustomDeviceTensorHandle(ImmediateExecutionContext * context,CustomDevice * device,tensorflow::DataType dtype)74   CustomDeviceTensorHandle(ImmediateExecutionContext* context,
75                            CustomDevice* device, tensorflow::DataType dtype)
76       : ImmediateExecutionTensorHandle(kCustomDevice),
77         context_(context),
78         device_(device),
79         dtype_(dtype) {}
80 
81   // TODO(allenl): Should this be a generic method of
82   // ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer?
83   virtual void* DevicePointer() const = 0;
84 
DataType()85   tensorflow::DataType DataType() const override { return dtype_; }
86   Status Shape(PartialTensorShape* shape) const override;
87   Status NumElements(int64* num_elements) const override;
88 
DeviceName(Status * status)89   const char* DeviceName(Status* status) const override {
90     return device_->name().c_str();
91   }
BackingDeviceName(Status * status)92   const char* BackingDeviceName(Status* status) const override {
93     return device_->name().c_str();
94   }
device()95   CustomDevice* device() const { return device_; }
96   const char* DeviceType(Status* status) const override;
97   int DeviceId(Status* status) const override;
98 
99   AbstractTensorInterface* Resolve(Status* status) override;
100 
Copy()101   ImmediateExecutionTensorHandle* Copy() override {
102     Ref();
103     return this;
104   }
Release()105   void Release() override { Unref(); }
106 
107   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)108   static bool classof(const AbstractTensorHandle* ptr) {
109     return ptr->getKind() == kCustomDevice;
110   }
111 
112  protected:
113   const DeviceNameUtils::ParsedName* ParsedName(Status* status) const;
114 
115   ImmediateExecutionContext* const context_;
116   CustomDevice* const device_;
117   const tensorflow::DataType dtype_;
118 
119   mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_;
120 };
121 
122 }  // namespace tensorflow
123 
124 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
125