1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
17
18 #include "tensorflow/core/platform/errors.h"
19
20 namespace tensorflow {
21
Clear()22 void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
23
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)24 Status CustomDeviceOpHandler::RegisterCustomDevice(
25 const string& device_name, std::unique_ptr<CustomDevice> device) {
26 DeviceNameUtils::ParsedName parsed;
27 if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
28 !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
29 !parsed.has_type || !parsed.has_id) {
30 return errors::InvalidArgument(
31 device_name,
32 " could not be parsed as a device name. Use the full "
33 "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
34 "format.");
35 }
36
37 if (!custom_devices_.emplace(device_name, std::move(device)).second) {
38 return errors::AlreadyExists(device_name,
39 " already registered as a custom device.");
40 }
41 return Status::OK();
42 }
43
FindCustomDeviceFromName(const string & name,CustomDevice ** device) const44 bool CustomDeviceOpHandler::FindCustomDeviceFromName(
45 const string& name, CustomDevice** device) const {
46 auto dev_it = custom_devices_.find(name);
47 if (dev_it == custom_devices_.end()) {
48 return false;
49 }
50 *device = dev_it->second.get();
51 return true;
52 }
53
Execute(ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)54 Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
55 ImmediateExecutionTensorHandle** retvals,
56 int* num_retvals) {
57 tensorflow::CustomDevice* custom_device = nullptr;
58
59 TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
60
61 if (custom_device != nullptr) {
62 return custom_device->Execute(op, retvals, num_retvals);
63 }
64
65 // The op will be placed on physical device. However, it contains custom
66 // device tensor handles. The tensor handles will be copy to physical device
67 // first.
68 if (op->HasCustomDeviceInput()) {
69 auto inputs = op->GetInputs();
70 for (int i = 0; i < inputs.size(); ++i) {
71 auto target_device = op->DeviceName();
72 if (target_device.empty()) {
73 target_device = op->GetContext()->HostCPUName();
74 }
75 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
76 // here.
77 if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
78 tensorflow::CustomDeviceTensorHandle* previous =
79 tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
80 inputs[i]);
81 tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
82 TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
83 previous, target_device, &new_tesnor));
84 Status s = op->SetInput(i, new_tesnor);
85 new_tesnor->Unref();
86 TF_RETURN_IF_ERROR(s);
87 }
88 }
89 }
90
91 return op->Execute(
92 absl::MakeSpan(
93 reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
94 *num_retvals),
95 num_retvals);
96 }
97
MaybePinToCustomDevice(CustomDevice ** device,const ImmediateExecutionOperation & op) const98 Status CustomDeviceOpHandler::MaybePinToCustomDevice(
99 CustomDevice** device, const ImmediateExecutionOperation& op) const {
100 *device = nullptr;
101 if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
102 !op.HasCustomDeviceInput()) {
103 return Status::OK();
104 }
105
106 // Ops are placed on a custom device if there's no other explicit requested
107 // placement and there is only one custom device in the op
108 // inputs.
109 //
110 // Resource-dtype inputs take precedence over non-resource inputs and explicit
111 // placements; this function pins ops with a resource-dtype custom device
112 // input to that custom device.
113 CustomDevice* first = nullptr;
114 if (!op.GetInputs().empty()) {
115 for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
116 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
117 // here.
118 if (CustomDeviceTensorHandle::classof(generic_input)) {
119 const CustomDeviceTensorHandle* input =
120 down_cast<const CustomDeviceTensorHandle*>(generic_input);
121 CustomDevice* current = input->device();
122 if (first == nullptr) {
123 first = current;
124 } else if (first != current) {
125 return errors::InvalidArgument(absl::StrCat(
126 "If an operation has one of its inputs in a custom device, then "
127 "all inputs should be on that same custom device or another "
128 "physical device. Operation ",
129 op.Name(),
130 " has one input in custom "
131 "device ",
132 first->name(),
133 " and at least one input in a different custom device ",
134 current->name()));
135 }
136 }
137 }
138 for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
139 if (generic_input->DataType() == DT_RESOURCE) {
140 if (CustomDeviceTensorHandle::classof(generic_input)) {
141 const CustomDeviceTensorHandle* input =
142 down_cast<const CustomDeviceTensorHandle*>(generic_input);
143 // There's only one custom device input, and it's a resource input, so
144 // we'll force-place the op on to that custom device. As with physical
145 // devices, this overrides any explicit placement for the op.
146 *device = input->device();
147 return Status::OK();
148 } else {
149 // Don't set a custom device if there's a physical-device resource
150 // input.
151 return Status::OK();
152 }
153 }
154 }
155 }
156 // Since there are no resource-dtype inputs, we'll respect explicit placements
157 // before considering input-based placement.
158 if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
159 // If there are non-resource inputs on a custom device we will default the
160 // op to that custom device, but not override an explicit op placement.
161 *device = first;
162 return Status::OK();
163 }
164 return Status::OK();
165 }
166
167 } // namespace tensorflow
168