• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
17 
18 #include "tensorflow/core/platform/errors.h"
19 
20 namespace tensorflow {
21 
Clear()22 void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
23 
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)24 Status CustomDeviceOpHandler::RegisterCustomDevice(
25     const string& device_name, std::unique_ptr<CustomDevice> device) {
26   DeviceNameUtils::ParsedName parsed;
27   if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
28       !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
29       !parsed.has_type || !parsed.has_id) {
30     return errors::InvalidArgument(
31         device_name,
32         " could not be parsed as a device name. Use the full "
33         "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
34         "format.");
35   }
36 
37   if (!custom_devices_.emplace(device_name, std::move(device)).second) {
38     return errors::AlreadyExists(device_name,
39                                  " already registered as a custom device.");
40   }
41   return Status::OK();
42 }
43 
FindCustomDeviceFromName(const string & name,CustomDevice ** device) const44 bool CustomDeviceOpHandler::FindCustomDeviceFromName(
45     const string& name, CustomDevice** device) const {
46   auto dev_it = custom_devices_.find(name);
47   if (dev_it == custom_devices_.end()) {
48     return false;
49   }
50   *device = dev_it->second.get();
51   return true;
52 }
53 
Execute(ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)54 Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
55                                       ImmediateExecutionTensorHandle** retvals,
56                                       int* num_retvals) {
57   tensorflow::CustomDevice* custom_device = nullptr;
58 
59   TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
60 
61   if (custom_device != nullptr) {
62     return custom_device->Execute(op, retvals, num_retvals);
63   }
64 
65   // The op will be placed on physical device. However, it contains custom
66   // device tensor handles. The tensor handles will be copy to physical device
67   // first.
68   if (op->HasCustomDeviceInput()) {
69     auto inputs = op->GetInputs();
70     for (int i = 0; i < inputs.size(); ++i) {
71       auto target_device = op->DeviceName();
72       if (target_device.empty()) {
73         target_device = op->GetContext()->HostCPUName();
74       }
75       // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
76       // here.
77       if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
78         tensorflow::CustomDeviceTensorHandle* previous =
79             tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
80                 inputs[i]);
81         tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
82         TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
83             previous, target_device, &new_tesnor));
84         Status s = op->SetInput(i, new_tesnor);
85         new_tesnor->Unref();
86         TF_RETURN_IF_ERROR(s);
87       }
88     }
89   }
90 
91   return op->Execute(
92       absl::MakeSpan(
93           reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
94           *num_retvals),
95       num_retvals);
96 }
97 
MaybePinToCustomDevice(CustomDevice ** device,const ImmediateExecutionOperation & op) const98 Status CustomDeviceOpHandler::MaybePinToCustomDevice(
99     CustomDevice** device, const ImmediateExecutionOperation& op) const {
100   *device = nullptr;
101   if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
102       !op.HasCustomDeviceInput()) {
103     return Status::OK();
104   }
105 
106   // Ops are placed on a custom device if there's no other explicit requested
107   // placement and there is only one custom device in the op
108   // inputs.
109   //
110   // Resource-dtype inputs take precedence over non-resource inputs and explicit
111   // placements; this function pins ops with a resource-dtype custom device
112   // input to that custom device.
113   CustomDevice* first = nullptr;
114   if (!op.GetInputs().empty()) {
115     for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
116       // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
117       // here.
118       if (CustomDeviceTensorHandle::classof(generic_input)) {
119         const CustomDeviceTensorHandle* input =
120             down_cast<const CustomDeviceTensorHandle*>(generic_input);
121         CustomDevice* current = input->device();
122         if (first == nullptr) {
123           first = current;
124         } else if (first != current) {
125           return errors::InvalidArgument(absl::StrCat(
126               "If an operation has one of its inputs in a custom device, then "
127               "all inputs should be on that same custom device or another "
128               "physical device. Operation ",
129               op.Name(),
130               " has one input in custom "
131               "device ",
132               first->name(),
133               " and at least one input in a different custom device ",
134               current->name()));
135         }
136       }
137     }
138     for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
139       if (generic_input->DataType() == DT_RESOURCE) {
140         if (CustomDeviceTensorHandle::classof(generic_input)) {
141           const CustomDeviceTensorHandle* input =
142               down_cast<const CustomDeviceTensorHandle*>(generic_input);
143           // There's only one custom device input, and it's a resource input, so
144           // we'll force-place the op on to that custom device. As with physical
145           // devices, this overrides any explicit placement for the op.
146           *device = input->device();
147           return Status::OK();
148         } else {
149           // Don't set a custom device if there's a physical-device resource
150           // input.
151           return Status::OK();
152         }
153       }
154     }
155   }
156   // Since there are no resource-dtype inputs, we'll respect explicit placements
157   // before considering input-based placement.
158   if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
159     // If there are non-resource inputs on a custom device we will default the
160     // op to that custom device, but not override an explicit op placement.
161     *device = first;
162     return Status::OK();
163   }
164   return Status::OK();
165 }
166 
167 }  // namespace tensorflow
168