1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
17
18 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
19 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
20 #include "tensorflow/core/common_runtime/eager/custom_device.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26
27 namespace tensorflow {
28 namespace eager {
29
30 // These ops are not pinnable since they generate data. It can be slower to
31 // generate and then copy the data instead of just generating the data on the
32 // device directly.
IsPinnableOp(StringPiece op_name)33 static bool IsPinnableOp(StringPiece op_name) {
34 static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
35 "RandomUniform",
36 "RandomUniformInt",
37 "RandomStandardNormal",
38 "StatelessRandomUniform",
39 "StatelessRandomUniformInt",
40 "StatelessRandomUniformFullInt",
41 "StatelessRandomNormal",
42 });
43
44 // XRT ops refer to per-device handles that are not safe to move between
45 // devices.
46 return unpinnable_ops->find(string(op_name)) == unpinnable_ops->end() &&
47 !absl::StartsWith(op_name, "XRT");
48 }
49 // Validate if the remote device with the given incarnation is valid in the
50 // remote device manager of the current eager context.
ValidateTensorHandleRemoteDevice(EagerContext * ctx,int64 device_incarnation)51 static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
52 int64 device_incarnation) {
53 if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
54 return Status::OK();
55 }
56 return errors::InvalidArgument(
57 "Resource input tensor contains an invalid device. This might happen "
58 "when the client has connected to a different cluster, or some remote "
59 "workers have been restarted.");
60 }
61
IsColocationExempt(StringPiece op_name)62 bool IsColocationExempt(StringPiece op_name) {
63 const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
64 return exempt_ops.find(string(op_name)) != exempt_ops.end();
65 }
66
IsFunction(StringPiece op_name)67 bool IsFunction(StringPiece op_name) {
68 const OpDef* op_def = nullptr;
69 Status s = OpDefForOp(string(op_name), &op_def);
70 if (!s.ok()) {
71 if (!errors::IsNotFound(s)) {
72 LOG(WARNING) << "Looking up OpDef failed with error: " << s.ToString();
73 }
74 // Cannot find OpDef, it is a function.
75 return true;
76 }
77 return false;
78 }
79
MaybePinSmallOpsToCpu(bool * result,StringPiece op_name,absl::Span<ImmediateExecutionTensorHandle * const> args,StringPiece cpu_device_name)80 Status MaybePinSmallOpsToCpu(
81 bool* result, StringPiece op_name,
82 absl::Span<ImmediateExecutionTensorHandle* const> args,
83 StringPiece cpu_device_name) {
84 if (IsFunction(op_name) || IsColocationExempt(op_name) ||
85 !IsPinnableOp(op_name)) {
86 *result = false;
87 return Status::OK();
88 }
89
90 // Ops without inputs are usually ops that generate a tensor in some way and
91 // usually require being present on whatever device they are scheduled on
92 // - for e.g. VarHandleOp or _Recv).
93 if (args.empty()) {
94 *result = false;
95 return Status::OK();
96 }
97
98 int i = 0;
99 for (auto* arg : args) {
100 Status s;
101 const char* device_name = arg->DeviceName(&s);
102 DataType dtype = arg->DataType();
103 TF_RETURN_IF_ERROR(s);
104
105 DVLOG(2) << "for op " << op_name << " input " << i << " "
106 << DataTypeString(dtype) << " input device = " << device_name;
107
108 // Input is on CPU.
109 if (device_name != cpu_device_name) {
110 *result = false;
111 return Status::OK();
112 }
113
114 if (dtype != DataType::DT_INT32 && dtype != DataType::DT_INT64) {
115 *result = false;
116 return Status::OK();
117 }
118
119 int64 num_elements;
120 TF_RETURN_IF_ERROR(arg->NumElements(&num_elements));
121 if (num_elements > 64) {
122 *result = false;
123 return Status::OK();
124 }
125 i++;
126 }
127
128 // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
129 // an op, but there is a GPU kernel?
130 DVLOG(1) << "Forcing op " << op_name
131 << " to be on the CPU since all input tensors have an "
132 "int32/int64 dtype, and are small (less than 64 elements).";
133 *result = true;
134 return Status::OK();
135 }
136
MaybePinToResourceDevice(Device ** device,const EagerOperation & op)137 Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
138 if (op.colocation_exempt()) {
139 return Status::OK();
140 }
141 EagerContext& ctx = op.EagerContext();
142 const absl::InlinedVector<TensorHandle*, 4>* inputs;
143 TF_RETURN_IF_ERROR(op.TensorHandleInputs(&inputs));
144 Device* op_device = op.Device() == kVariantDeviceNull
145 ? ctx.HostCPU()
146 : absl::get<Device*>(op.Device());
147 for (int i = 0; i < inputs->size(); ++i) {
148 TensorHandle* tensor_handle = (*inputs)[i];
149 if (tensor_handle->dtype == DT_RESOURCE) {
150 if (tensor_handle->resource_remote_device_incarnation() != 0) {
151 TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
152 &ctx, tensor_handle->resource_remote_device_incarnation()));
153 }
154 Device* resource_device = tensor_handle->resource_device();
155 DVLOG(2) << "for op " << op.Name() << " input " << i << " "
156 << DataTypeString(tensor_handle->dtype)
157 << " input device = " << resource_device->name()
158 << ", op device = " << op_device->name();
159 // We check for `op->Device() == nullptr` because it can be later
160 // interpreted as unspecified device and a different device can
161 // be selected based on device priority. If any input to an op
162 // is a resource we must pin it to prevent different device selection.
163 // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
164 if (resource_device != op_device || op.Device() == kVariantDeviceNull) {
165 DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
166 << "device of operation " << op.Name() << " to "
167 << resource_device->name() << " because input #" << i
168 << " is a resource in this device.";
169 *device = resource_device;
170 return Status::OK();
171 // No point in looking at other inputs. If there are other resources,
172 // they must have the same device and we already declared the op to be
173 // ineligible for CPU pinning.
174 }
175 }
176 }
177 return Status::OK();
178 }
179
180 } // namespace eager
181 } // namespace tensorflow
182