• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
17 
18 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
19 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
20 #include "tensorflow/core/common_runtime/eager/custom_device.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 
30 // These ops are not pinnable since they generate data. It can be slower to
31 // generate and then copy the data instead of just generating the data on the
32 // device directly.
IsPinnableOp(StringPiece op_name)33 static bool IsPinnableOp(StringPiece op_name) {
34   static const gtl::FlatSet<string>* unpinnable_ops = new gtl::FlatSet<string>({
35       "RandomUniform",
36       "RandomUniformInt",
37       "RandomStandardNormal",
38       "StatelessRandomUniform",
39       "StatelessRandomUniformInt",
40       "StatelessRandomUniformFullInt",
41       "StatelessRandomNormal",
42   });
43 
44   // XRT ops refer to per-device handles that are not safe to move between
45   // devices.
46   return unpinnable_ops->find(string(op_name)) == unpinnable_ops->end() &&
47          !absl::StartsWith(op_name, "XRT");
48 }
49 // Validate if the remote device with the given incarnation is valid in the
50 // remote device manager of the current eager context.
ValidateTensorHandleRemoteDevice(EagerContext * ctx,int64 device_incarnation)51 static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
52                                                int64 device_incarnation) {
53   if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
54     return Status::OK();
55   }
56   return errors::InvalidArgument(
57       "Resource input tensor contains an invalid device. This might happen "
58       "when the client has connected to a different cluster, or some remote "
59       "workers have been restarted.");
60 }
61 
IsColocationExempt(StringPiece op_name)62 bool IsColocationExempt(StringPiece op_name) {
63   const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get();
64   return exempt_ops.find(string(op_name)) != exempt_ops.end();
65 }
66 
IsFunction(StringPiece op_name)67 bool IsFunction(StringPiece op_name) {
68   const OpDef* op_def = nullptr;
69   Status s = OpDefForOp(string(op_name), &op_def);
70   if (!s.ok()) {
71     if (!errors::IsNotFound(s)) {
72       LOG(WARNING) << "Looking up OpDef failed with error: " << s.ToString();
73     }
74     // Cannot find OpDef, it is a function.
75     return true;
76   }
77   return false;
78 }
79 
MaybePinSmallOpsToCpu(bool * result,StringPiece op_name,absl::Span<ImmediateExecutionTensorHandle * const> args,StringPiece cpu_device_name)80 Status MaybePinSmallOpsToCpu(
81     bool* result, StringPiece op_name,
82     absl::Span<ImmediateExecutionTensorHandle* const> args,
83     StringPiece cpu_device_name) {
84   if (IsFunction(op_name) || IsColocationExempt(op_name) ||
85       !IsPinnableOp(op_name)) {
86     *result = false;
87     return Status::OK();
88   }
89 
90   // Ops without inputs are usually ops that generate a tensor in some way and
91   // usually require being present on whatever device they are scheduled on
92   // - for e.g. VarHandleOp or _Recv).
93   if (args.empty()) {
94     *result = false;
95     return Status::OK();
96   }
97 
98   int i = 0;
99   for (auto* arg : args) {
100     Status s;
101     const char* device_name = arg->DeviceName(&s);
102     DataType dtype = arg->DataType();
103     TF_RETURN_IF_ERROR(s);
104 
105     DVLOG(2) << "for op " << op_name << " input " << i << " "
106              << DataTypeString(dtype) << " input device = " << device_name;
107 
108     // Input is on CPU.
109     if (device_name != cpu_device_name) {
110       *result = false;
111       return Status::OK();
112     }
113 
114     if (dtype != DataType::DT_INT32 && dtype != DataType::DT_INT64) {
115       *result = false;
116       return Status::OK();
117     }
118 
119     int64 num_elements;
120     TF_RETURN_IF_ERROR(arg->NumElements(&num_elements));
121     if (num_elements > 64) {
122       *result = false;
123       return Status::OK();
124     }
125     i++;
126   }
127 
128   // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
129   // an op, but there is a GPU kernel?
130   DVLOG(1) << "Forcing op " << op_name
131            << " to be on the CPU since all input tensors have an "
132               "int32/int64 dtype, and are small (less than 64 elements).";
133   *result = true;
134   return Status::OK();
135 }
136 
MaybePinToResourceDevice(Device ** device,const EagerOperation & op)137 Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
138   if (op.colocation_exempt()) {
139     return Status::OK();
140   }
141   EagerContext& ctx = op.EagerContext();
142   const absl::InlinedVector<TensorHandle*, 4>* inputs;
143   TF_RETURN_IF_ERROR(op.TensorHandleInputs(&inputs));
144   Device* op_device = op.Device() == kVariantDeviceNull
145                           ? ctx.HostCPU()
146                           : absl::get<Device*>(op.Device());
147   for (int i = 0; i < inputs->size(); ++i) {
148     TensorHandle* tensor_handle = (*inputs)[i];
149     if (tensor_handle->dtype == DT_RESOURCE) {
150       if (tensor_handle->resource_remote_device_incarnation() != 0) {
151         TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
152             &ctx, tensor_handle->resource_remote_device_incarnation()));
153       }
154       Device* resource_device = tensor_handle->resource_device();
155       DVLOG(2) << "for op " << op.Name() << " input " << i << " "
156                << DataTypeString(tensor_handle->dtype)
157                << " input device = " << resource_device->name()
158                << ", op device = " << op_device->name();
159       // We check for `op->Device() == nullptr` because it can be later
160       // interpreted as unspecified device and a different device can
161       // be selected based on device priority. If any input to an op
162       // is a resource we must pin it to prevent different device selection.
163       // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
164       if (resource_device != op_device || op.Device() == kVariantDeviceNull) {
165         DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
166                  << "device of operation " << op.Name() << " to "
167                  << resource_device->name() << " because input #" << i
168                  << " is a resource in this device.";
169         *device = resource_device;
170         return Status::OK();
171         // No point in looking at other inputs. If there are other resources,
172         // they must have the same device and we already declared the op to be
173         // ineligible for CPU pinning.
174       }
175     }
176   }
177   return Status::OK();
178 }
179 
180 }  // namespace eager
181 }  // namespace tensorflow
182