1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex>  // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/base/call_once.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/match.h"
29 #include "tensorflow/core/framework/allocation_description.pb.h"
30 #include "tensorflow/core/framework/attr_value.pb.h"
31 #include "tensorflow/core/framework/attr_value_util.h"
32 #include "tensorflow/core/framework/device_attributes.pb.h"
33 #include "tensorflow/core/framework/device_factory.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/kernel_def.pb.h"
36 #include "tensorflow/core/framework/kernel_def_util.h"
37 #include "tensorflow/core/framework/log_memory.h"
38 #include "tensorflow/core/framework/memory_types.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/node_properties.h"
42 #include "tensorflow/core/framework/op_def_util.h"
43 #include "tensorflow/core/framework/tensor_reference.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/notification.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/io/path.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/platform/cpu_info.h"
53 #include "tensorflow/core/platform/env.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/mutex.h"
56 #include "tensorflow/core/platform/platform_strings.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/util/ptr_util.h"
61 
62 namespace tensorflow {
63 
64 const char* kJitKernelLabel = "JITCompiledKernel";
65 const char* kDisableJitKernelsEnvVar = "TF_DISABLE_JIT_KERNELS";
66 
67 namespace {
68 
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)69 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
70                             const DataTypeSlice expected_outputs,
71                             const DataTypeSlice inputs,
72                             const DataTypeSlice outputs) {
73   bool signature_mismatch = false;
74 
75   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
76   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
77     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
78       signature_mismatch = true;
79     }
80   }
81 
82   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
83   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
84     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
85       signature_mismatch = true;
86     }
87   }
88 
89   if (signature_mismatch) {
90     return errors::InvalidArgument(
91         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
92         DataTypeSliceString(outputs),
93         " expected: ", DataTypeSliceString(expected_inputs), "->",
94         DataTypeSliceString(expected_outputs));
95   }
96   return OkStatus();
97 }
98 
GetOpNodeDefsToLogFromEnv()99 const absl::flat_hash_set<std::string>* GetOpNodeDefsToLogFromEnv() {
100   auto* result = new absl::flat_hash_set<std::string>;
101   const char* env = getenv("TF_DEBUG_OPS_TO_LOG_NODEDEFS");
102   if (!env) {
103     return result;
104   }
105 
106   std::vector<absl::string_view> ops = absl::StrSplit(env, ',');
107   LOG(INFO) << "Will log NodeDefs from the following ops: ";
108   for (absl::string_view op : ops) {
109     result->insert(std::string(op));
110     LOG(INFO) << "  |" << op << "|";
111   }
112 
113   return result;
114 }
115 
116 // Returns true if the NodeDef for the OpKernel should be logged. The
117 // envionrmental variable TF_DEBUG_OPS_TO_LOG_NODEDEFS can be set to a
118 // comma-separated list of op types. The NodeDef for each is printed, which is
119 // useful for debugging purposes.
ShouldLogNodeDef(OpKernel * op_kernel)120 bool ShouldLogNodeDef(OpKernel* op_kernel) {
121   static const absl::flat_hash_set<std::string>& ops_to_log_nodedefs =
122       *GetOpNodeDefsToLogFromEnv();
123   return ops_to_log_nodedefs.count(op_kernel->type_string());
124 }
125 
126 }  // namespace
127 
128 // OpKernel ------------------------------------------------------------------
129 
OpKernel(OpKernelConstruction * context)130 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
131 
OpKernel(OpKernelConstruction * context,bool is_deferred)132 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
133     : props_(context->props_),
134       input_memory_types_(context->input_memory_types().begin(),
135                           context->input_memory_types().end()),
136       output_memory_types_(context->output_memory_types().begin(),
137                            context->output_memory_types().end()),
138       input_name_map_(context->num_inputs()),
139       output_name_map_(context->num_outputs()),
140       name_view_(props_->node_def.name()),
141       type_string_view_(props_->node_def.op()),
142       graph_def_version_(context->graph_def_version()),
143       is_deferred_(is_deferred) {
144   OP_REQUIRES_OK(context,
145                  NameRangesForNode(props_->node_def, *props_->op_def,
146                                    &input_name_map_, &output_name_map_));
147   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
148                                              context->graph_def_version()));
149 
150   // Kernels executing on GPU tie very few resources on the CPU where the
151   // scheduler runs: we consider them as inexpensive.
152   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
153                !DeviceFactory::IsPluggableDevice(
154                    DeviceTypeString(context->device_type()));
155 
156   if (ShouldLogNodeDef(this)) {
157     LOG(INFO) << "NodeDef for " << name() << ":\n" << def().ShortDebugString();
158   }
159 }
160 
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)161 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
162                    bool is_deferred)
163     : props_(std::make_shared<const NodeProperties>(
164           context->props_->op_def, std::move(custom_def),
165           context->props_->input_types, context->props_->output_types)),
166       input_memory_types_(context->input_memory_types().begin(),
167                           context->input_memory_types().end()),
168       output_memory_types_(context->output_memory_types().begin(),
169                            context->output_memory_types().end()),
170       input_name_map_(context->num_inputs()),
171       output_name_map_(context->num_outputs()),
172       name_view_(props_->node_def.name()),
173       type_string_view_(props_->node_def.op()),
174       graph_def_version_(context->graph_def_version()),
175       is_deferred_(is_deferred) {
176   OP_REQUIRES_OK(context,
177                  NameRangesForNode(props_->node_def, *props_->op_def,
178                                    &input_name_map_, &output_name_map_));
179   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
180                                              context->graph_def_version()));
181 
182   // Kernels executing on GPU tie very few resources on the CPU where the
183   // scheduler runs: we consider them as inexpensive.
184   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
185                !DeviceFactory::IsPluggableDevice(
186                    DeviceTypeString(context->device_type()));
187 }
188 
~OpKernel()189 OpKernel::~OpKernel() {}
190 
InputRange(StringPiece input_name,int * start,int * stop) const191 Status OpKernel::InputRange(StringPiece input_name, int* start,
192                             int* stop) const {
193   const auto result = input_name_map_.find(input_name);
194   if (result == input_name_map_.end()) {
195     return errors::InvalidArgument("Unknown input name: ", input_name);
196   } else {
197     *start = result->second.first;
198     *stop = result->second.second;
199     return OkStatus();
200   }
201 }
202 
OutputRange(StringPiece output_name,int * start,int * stop) const203 Status OpKernel::OutputRange(StringPiece output_name, int* start,
204                              int* stop) const {
205   const auto result = output_name_map_.find(output_name);
206   if (result == output_name_map_.end()) {
207     return errors::InvalidArgument("Unknown output name: ", output_name);
208   } else {
209     *start = result->second.first;
210     *stop = result->second.second;
211     return OkStatus();
212   }
213 }
214 
ShapeTraceString(const OpKernelContext & ctx) const215 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
216   int num_inputs = ctx.num_inputs();
217   if (num_inputs == 0) return "";
218   std::vector<string> tensor_shapes;
219   tensor_shapes.reserve(num_inputs);
220   for (int i = 0; i < num_inputs; i++) {
221     if (!ctx.has_input(i)) {
222       tensor_shapes.emplace_back();  // Placeholder
223       continue;
224     }
225     DataType input_dtype = ctx.input_dtype(i);
226     if (input_dtype == DataType::DT_RESOURCE ||
227         input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
228       tensor_shapes.emplace_back();  // Placeholder
229       continue;
230     }
231     tensor_shapes.emplace_back(strings::StrCat(
232         DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
233   }
234   return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
235 }
236 
TraceString(const OpKernelContext & ctx,bool verbose) const237 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
238   string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
239   if (verbose) {
240     string shape = ShapeTraceString(ctx);
241     if (!shape.empty()) {
242       trace_string =
243           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
244     }
245   }
246   return trace_string;
247 }
248 
Compute(OpKernelContext * context)249 void AsyncOpKernel::Compute(OpKernelContext* context) {
250   Notification n;
251   ComputeAsync(context, [&n]() { n.Notify(); });
252   n.WaitForNotification();
253 }
254 
255 // OpKernelConstruction ------------------------------------------------------
256 
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)257 OpKernelConstruction::OpKernelConstruction(
258     DeviceType device_type, DeviceBase* device, Allocator* allocator,
259     FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
260     const std::shared_ptr<const NodeProperties>& props,
261     const MemoryTypeSlice& input_memory_types,
262     const MemoryTypeSlice& output_memory_types, int graph_def_version,
263     Status* status)
264     : device_type_(std::move(device_type)),
265       device_(device),
266       allocator_(allocator),
267       flib_(flib),
268       resource_mgr_(resource_mgr),
269       props_(props),
270       input_memory_types_(input_memory_types),
271       output_memory_types_(output_memory_types),
272       graph_def_version_(graph_def_version),
273       status_(status) {}
274 
HasAttr(StringPiece attr_name) const275 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
276   return HasNodeAttr(def(), attr_name);
277 }
278 
SetStatus(const Status & status)279 void OpKernelConstruction::SetStatus(const Status& status) {
280   status_->Update(status);
281 }
282 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)283 Status OpKernelConstruction::MatchSignature(
284     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
285   return MatchSignatureHelper(expected_inputs, expected_outputs,
286                               props_->input_types, props_->output_types);
287 }
288 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)289 Status OpKernelConstruction::allocate_temp(DataType type,
290                                            const TensorShape& shape,
291                                            Tensor* out_temp) {
292   AllocationAttributes attr;
293   attr.allocation_will_be_logged = true;
294   Tensor new_temp(allocator_, type, shape, attr);
295 
296   if (!new_temp.IsInitialized()) {
297     return errors::ResourceExhausted(
298         "OOM when allocating temporary tensor with shape", shape.DebugString());
299   }
300   if (LogMemory::IsEnabled()) {
301     LogMemory::RecordTensorAllocation(
302         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
303   }
304   *out_temp = new_temp;
305   return OkStatus();
306 }
307 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)308 Status OpKernelConstruction::allocate_temp(DataType type,
309                                            const TensorShape& shape,
310                                            Tensor* out_temp,
311                                            AllocatorAttributes allocator_attr) {
312   if (allocator_attr.scope_id != 0) {
313     return errors::InvalidArgument(
314         "ScopedAllocator cannot be used via OpKernelConstruction.");
315   }
316   Allocator* a = device_->GetAllocator(allocator_attr);
317   AllocationAttributes attr;
318   attr.allocation_will_be_logged = true;
319   Tensor new_temp(a, type, shape, attr);
320 
321   if (!new_temp.IsInitialized()) {
322     return errors::ResourceExhausted(
323         "OOM when allocating temporary tensor with shape", shape.DebugString());
324   }
325   if (LogMemory::IsEnabled()) {
326     LogMemory::RecordTensorAllocation(
327         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
328   }
329   *out_temp = new_temp;
330   return OkStatus();
331 }
332 
333 // OpKernelContext -----------------------------------------------------------
334 
335 const int OpKernelContext::Params::kNeverForward;
336 const int OpKernelContext::Params::kNoReservation;
337 
OpKernelContext(Params * params)338 OpKernelContext::OpKernelContext(Params* params)
339     : OpKernelContext(
340           params, static_cast<int>(params->op_kernel->output_types().size())) {}
341 
OpKernelContext(Params * params,int num_outputs)342 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
343     : params_(params), outputs_(num_outputs) {
344   if (params_->track_allocations) {
345     tracking_state_ = absl::make_unique<TrackingState>();
346   }
347 
348   params_->ensure_eigen_gpu_device();
349   if (params_->eigen_gpu_device != nullptr) {
350     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
351     Status s = params_->device->ReinitializeGpuDevice(
352         this, params_->eigen_gpu_device, params_->op_device_context,
353         eigen_gpu_allocator);
354     if (!s.ok()) {
355       SetStatus(s);
356     }
357   }
358 }
359 
~OpKernelContext()360 OpKernelContext::~OpKernelContext() {
361   for (TensorValue& value : outputs_) {
362     if (!value.is_ref()) {
363       delete value.tensor;
364     }
365   }
366   if (params_->track_allocations &&
367       !tracking_state_->wrapped_allocators.empty()) {
368     LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
369                  << "being consumed by the StepStatsCollector.";
370     for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
371       wrapped_allocator.second->GetRecordsAndUnRef();
372     }
373   }
374 }
375 
get_allocator(AllocatorAttributes attr)376 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
377   Allocator* allocator = nullptr;
378   if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
379     allocator = params_->device->GetScopedAllocator(attr, step_id());
380     CHECK(allocator);
381   } else {
382     allocator = params_->device->GetAllocator(attr);
383   }
384   if (TF_PREDICT_FALSE(track_allocations())) {
385     DCHECK(tracking_state_);
386     mutex_lock lock(tracking_state_->mu);
387     for (const auto& wrapped : tracking_state_->wrapped_allocators) {
388       if (wrapped.first == allocator) {
389         return wrapped.second;
390       }
391     }
392     TrackingAllocator* wrapped_allocator =
393         new TrackingAllocator(allocator, params_->track_allocations);
394     tracking_state_->wrapped_allocators.push_back(
395         std::make_pair(allocator, wrapped_allocator));
396     return wrapped_allocator;
397   } else {
398     return allocator;
399   }
400 }
401 
SetStatus(const Status & status)402 void OpKernelContext::SetStatus(const Status& status) {
403   status_.Update(status);
404 }
405 
input(StringPiece name,const Tensor ** tensor)406 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
407   int index;
408   TF_RETURN_IF_ERROR(get_input_index(name, &index));
409   if (input_is_ref(index)) {
410     return errors::InvalidArgument("OpKernel used ref input name '", name,
411                                    "' when non-ref input was expected");
412   }
413   *tensor = params_->inputs[index].tensor;
414   return OkStatus();
415 }
416 
input_dtype(StringPiece name,DataType * dtype) const417 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
418   int index;
419   TF_RETURN_IF_ERROR(get_input_index(name, &index));
420   const TensorValue& value(params_->inputs[index]);
421   *dtype = value.dtype();
422   return OkStatus();
423 }
424 
input_ref_mutex(StringPiece name,mutex ** out_mutex)425 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
426   int index;
427   TF_RETURN_IF_ERROR(get_input_index(name, &index));
428   *out_mutex = input_ref_mutex(index);
429   return OkStatus();
430 }
431 
input(int index) const432 const Tensor& OpKernelContext::input(int index) const {
433   CHECK_GE(index, 0);
434   CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
435   CHECK(!input_is_ref(index));
436   const Tensor& tensor = *params_->inputs[index].tensor;
437   return tensor;
438 }
439 
mutable_input(int index,bool lock_held)440 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
441   CHECK_GE(index, 0);
442   CHECK_LT(index, num_inputs());
443   CHECK(input_is_ref(index));
444   // return a copy of the Ref acquired while holding the mutex
445   if (lock_held) {
446     Tensor& tensor = *params_->inputs[index].tensor;
447     return tensor;
448   } else {
449     tf_shared_lock l(*input_ref_mutex(index));
450     Tensor& tensor = *params_->inputs[index].tensor;
451     return tensor;
452   }
453 }
454 
replace_ref_input(int index,const Tensor & tensor,bool lock_held)455 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
456                                         bool lock_held) {
457   CHECK_GE(index, 0);
458   CHECK_LT(index, num_inputs());
459   CHECK(input_is_ref(index));
460   // should only modify the tensor while holding the mutex
461   if (lock_held) {
462     *params_->inputs[index].tensor = tensor;
463   } else {
464     mutex_lock l(*input_ref_mutex(index));
465     *params_->inputs[index].tensor = tensor;
466   }
467 }
468 
forward_ref_input_to_ref_output(int input_index,int output_index)469 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
470                                                       int output_index) {
471   CHECK_GE(input_index, 0);
472   CHECK_LT(input_index, num_inputs());
473   CHECK(input_is_ref(input_index));
474   set_output_ref(output_index, params_->inputs[input_index].mutex_if_ref,
475                  params_->inputs[input_index].tensor);
476 }
477 
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)478 bool OpKernelContext::forward_input_to_output_with_shape(
479     int input_index, int output_index, const TensorShape& output_shape,
480     Tensor** output) {
481   const auto output_attr = params_->output_attr_array == nullptr
482                                ? AllocatorAttributes()
483                                : output_alloc_attr(output_index);
484   std::unique_ptr<Tensor> new_tensor = forward_input(
485       input_index, output_index, expected_output_dtype(output_index),
486       output_shape, output_memory_type(output_index), output_attr);
487   if (new_tensor != nullptr) {
488     // Transfer ownership to the output slot in OpKernelContext.
489     outputs_[output_index] = TensorValue(new_tensor.release());
490     *output = outputs_[output_index].tensor;
491     return true;
492   } else {
493     return false;
494   }
495 }
496 
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)497 Status OpKernelContext::forward_input_to_output_with_shape(
498     StringPiece input_name, StringPiece output_name,
499     const TensorShape& output_shape, Tensor** output) {
500   int input_index, output_index;
501   TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
502   TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
503   if (!forward_input_to_output_with_shape(input_index, output_index,
504                                           output_shape, output)) {
505     return errors::FailedPrecondition("OpKernel could not forward input '",
506                                       input_name, "' to output '", output_name);
507   }
508   return OkStatus();
509 }
510 
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)511 std::unique_ptr<Tensor> OpKernelContext::forward_input(
512     int input_index, int output_index, DataType output_dtype,
513     const TensorShape& output_shape, MemoryType output_memory_type,
514     const AllocatorAttributes& output_attr) {
515   CHECK_GE(input_index, 0);
516   CHECK_LT(input_index, num_inputs());
517   const TensorValue& input = params_->inputs[input_index];
518   // Check whether at graph construction time this output was marked
519   // either for no forwarding or with a reservation for this input.
520   // If it's reserved for this input we'll skip the refcount and
521   // AllocatorAttribute checks.
522   // TODO(tucker): Maybe we should skip all of the checks?
523   bool never_forward =
524       (params_->forward_from_array != nullptr && output_index >= 0 &&
525        params_->forward_from_array[output_index] == Params::kNeverForward);
526   if (never_forward) return nullptr;
527   bool forward_expected =
528       (params_->forward_from_array != nullptr && output_index >= 0 &&
529        params_->forward_from_array[output_index] == input_index);
530   if (!forward_expected && params_->forward_from_array != nullptr) {
531     // Check for possibly conflicting forward.
532     for (int i = 0; i < num_outputs(); ++i) {
533       if (params_->forward_from_array[i] == input_index) {
534         // This input is reserved for output i.
535         return nullptr;
536       }
537     }
538   }
539   // Check that input tensor exists and is not a ref.
540   if (input.tensor == nullptr || input.is_ref()) {
541     CHECK(!forward_expected);
542     return nullptr;
543   }
544   // Check that input type matches.
545   if (input_dtype(input_index) != output_dtype) {
546     CHECK(!forward_expected);
547     return nullptr;
548   }
549   // Check that the input and output sizes are compatible.
550   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
551     CHECK(!forward_expected);
552     return nullptr;
553   }
554   // Check that input and output memory types match, i.e.
555   // that they either both live in host or both live in device memory.
556   if (input_memory_type(input_index) != output_memory_type) {
557     CHECK(!forward_expected);
558     return nullptr;
559   }
560   if (!forward_expected) {
561     if (!input->RefCountIsOne()) {
562       return nullptr;
563     }
564     // Check that output allocator attributes are not more restrictive than
565     // input allocator attributes.
566     const auto input_attr = params_->input_alloc_attrs.empty()
567                                 ? AllocatorAttributes()
568                                 : input_alloc_attr(input_index);
569     if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
570       return nullptr;
571     }
572   }
573 
574   auto output_tensor = MakeUnique<Tensor>();
575   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
576   return output_tensor;
577 }
578 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)579 Status OpKernelContext::forward_input_or_allocate_temp(
580     gtl::ArraySlice<int> candidate_input_indices, DataType type,
581     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
582     Tensor* out_temp) {
583   for (int input_index : candidate_input_indices) {
584     std::unique_ptr<Tensor> new_tensor =
585         forward_input(input_index, Params::kNoReservation /*output_index*/,
586                       type, shape, DEVICE_MEMORY, allocator_attr);
587     if (new_tensor != nullptr) {
588       *out_temp = std::move(*new_tensor);
589       return OkStatus();
590     }
591   }
592   return allocate_temp(type, shape, out_temp, allocator_attr);
593 }
594 
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)595 Status OpKernelContext::forward_input_or_allocate_output(
596     gtl::ArraySlice<int> candidate_input_indices, int output_index,
597     const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
598   for (int input_index : candidate_input_indices) {
599     if (forward_input_to_output_with_shape(input_index, output_index,
600                                            output_shape, output)) {
601       if (forwarded_input != nullptr) {
602         *forwarded_input = input_index;
603       }
604       return OkStatus();
605     }
606   }
607   if (forwarded_input != nullptr) {
608     *forwarded_input = -1;
609   }
610   return allocate_output(output_index, output_shape, output);
611 }
612 
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)613 Status OpKernelContext::forward_input_or_allocate_output(
614     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
615     const TensorShape& output_shape, Tensor** output) {
616   for (const StringPiece& input_name : candidate_input_names) {
617     if (forward_input_to_output_with_shape(input_name, output_name,
618                                            output_shape, output)
619             .ok()) {
620       return OkStatus();
621     }
622   }
623   return allocate_output(output_name, output_shape, output);
624 }
625 
delete_ref_input(int index,bool lock_held)626 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
627   CHECK_GE(index, 0);
628   CHECK_LT(index, num_inputs());
629   CHECK(input_is_ref(index));
630   // should only modify the tensor while holding the mutex
631   if (lock_held) {
632     delete params_->inputs[index].tensor;
633   } else {
634     mutex_lock l(*input_ref_mutex(index));
635     delete params_->inputs[index].tensor;
636   }
637 }
638 
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)639 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
640                                       bool lock_held) {
641   int index;
642   TF_RETURN_IF_ERROR(get_input_index(name, &index));
643   if (!input_is_ref(index)) {
644     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
645                                    "' when ref input was expected");
646   }
647   // return a copy of the Ref acquired while holding the mutex
648   if (lock_held) {
649     *tensor = *params_->inputs[index].tensor;
650   } else {
651     tf_shared_lock l(*input_ref_mutex(index));
652     *tensor = *params_->inputs[index].tensor;
653   }
654   return OkStatus();
655 }
656 
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)657 Status OpKernelContext::replace_ref_input(StringPiece name,
658                                           const Tensor& tensor,
659                                           bool lock_held) {
660   int index;
661   TF_RETURN_IF_ERROR(get_input_index(name, &index));
662   if (!input_is_ref(index)) {
663     return errors::InvalidArgument("OpKernel used immutable input name '", name,
664                                    "' when ref input was expected");
665   }
666   replace_ref_input(index, tensor, lock_held);
667   return OkStatus();
668 }
669 
input_list(StringPiece name,OpInputList * list)670 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
671   int start, stop;
672   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
673   *list = OpInputList(this, start, stop);
674   return OkStatus();
675 }
676 
mutable_input_list(StringPiece name,OpMutableInputList * list)677 Status OpKernelContext::mutable_input_list(StringPiece name,
678                                            OpMutableInputList* list) {
679   int start, stop;
680   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
681   *list = OpMutableInputList(this, start, stop);
682   return OkStatus();
683 }
684 
output_list(StringPiece name,OpOutputList * list)685 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
686   int start, stop;
687   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
688   *list = OpOutputList(this, start, stop);
689   return OkStatus();
690 }
691 
maybe_initialize_scope_id_set()692 void OpKernelContext::maybe_initialize_scope_id_set() {
693   if (allocated_scope_ids_ == nullptr) {
694     allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
695   }
696 }
697 
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)698 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
699                                         Tensor** tensor) {
700   if (index < 0) {
701     return errors::Internal("allocate_output with bad index=", index,
702                             " kernel=", params_->op_kernel->name());
703   }
704   if (index >= num_outputs()) {
705     return errors::Internal("allocate_output with bad index=", index,
706                             " num_outputs=", num_outputs(),
707                             " kernel=", params_->op_kernel->name());
708   }
709   bool forward_expected =
710       (params_->forward_from_array != nullptr && index >= 0 &&
711        params_->forward_from_array[index] >= 0);
712   if (forward_expected) {
713     return errors::Internal(
714         "Explicit allocate_output call where input forwarding required.  Try "
715         "turning off the ScopedAllocator optimizer.");
716   }
717   AllocatorAttributes attr = output_alloc_attr(index);
718   return allocate_output(index, shape, tensor, attr);
719 }
720 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)721 Status OpKernelContext::allocate_output(StringPiece name,
722                                         const TensorShape& shape,
723                                         Tensor** tensor) {
724   int start, stop;
725   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
726   if (stop != start + 1) {
727     return errors::InvalidArgument("OpKernel used list-valued output name '",
728                                    name,
729                                    "' when single-valued output was "
730                                    "expected");
731   }
732   return allocate_output(start, shape, tensor);
733 }
734 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)735 Status OpKernelContext::allocate_output(StringPiece name,
736                                         const TensorShape& shape,
737                                         Tensor** tensor,
738                                         AllocatorAttributes attr) {
739   int start, stop;
740   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
741   if (stop != start + 1) {
742     return errors::InvalidArgument("OpKernel used list-valued output name '",
743                                    name,
744                                    "' when single-valued output was "
745                                    "expected");
746   }
747   return allocate_output(start, shape, tensor, attr);
748 }
749 
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)750 Status OpKernelContext::allocate_tensor(
751     DataType type, const TensorShape& shape, Tensor* out_tensor,
752     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
753   Allocator* a = get_allocator(attr);
754   Tensor new_tensor(
755       a, type, shape,
756       AllocationAttributes(
757           /*retry_on_failure=*/allocation_attr.retry_on_failure,
758           /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
759 
760   if (!new_tensor.IsInitialized()) {
761     return errors::ResourceExhausted(
762         "OOM when allocating tensor with shape", shape.DebugString(),
763         " and type ", DataTypeString(type), " on ", params_->device->name(),
764         " by allocator ", a->Name());
765   }
766   if (params_->log_memory) {
767     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
768                                       params_->step_id, new_tensor);
769   }
770   *out_tensor = std::move(new_tensor);
771   return OkStatus();
772 }
773 
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)774 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
775                                         Tensor** output,
776                                         AllocatorAttributes attr) {
777   if (index < 0) {
778     return errors::Internal("allocate_output with bad index=", index,
779                             " kernel=", params_->op_kernel->name());
780   }
781   if (index >= num_outputs()) {
782     return errors::Internal("allocate_output with bad index=", index,
783                             " num_outputs=", outputs_.size(),
784                             " kernel=", params_->op_kernel->name());
785   }
786   const DataType type = params_->op_kernel->output_type(index);
787   if (IsRefType(type)) {
788     return errors::Internal("allocate_output with ref type. index=", index,
789                             " type=", type,
790                             " kernel=", params_->op_kernel->name());
791   }
792   if (mutable_output(index) != nullptr) {
793     return errors::Internal("allocate_output on same index multiple times.",
794                             " index = ", index,
795                             " mutable_output(index) = ", mutable_output(index),
796                             " kernel=", params_->op_kernel->name());
797   }
798   if (attr.scope_id > 0) {
799     maybe_initialize_scope_id_set();
800     if (!allocated_scope_ids_->insert(attr.scope_id).second) {
801       return errors::Internal(
802           "OpKernel ", params_->op_kernel->name(),
803           " called allocate_output at index ", index, " with scope_id ",
804           attr.scope_id,
805           " more than once.  Try turning off the ScopedAllocator optimizer.");
806     }
807   }
808   profiler::ScopedMemoryDebugAnnotation op_annotation(
809       op_kernel().name_view().data(), step_id(), "output", type,
810       [&shape]() { return shape.DebugString(); });
811   auto output_tensor = MakeUnique<Tensor>();
812   Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
813   if (s.ok()) {
814     outputs_[index] = TensorValue(output_tensor.release());
815     *output = outputs_[index].tensor;
816   }
817   return s;
818 }
819 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)820 Status OpKernelContext::allocate_temp(
821     DataType type, const TensorShape& shape, Tensor* out_temp,
822     AllocatorAttributes allocator_attr,
823     const AllocationAttributes& allocation_attr) {
824   if (allocator_attr.scope_id > 0) {
825     // We do not allow ScopedAllocator calls from allocate_temp.
826     // Here we clear the scope_id and return a temporary buffer.
827     // This is because it is legal for a kernel to call allocate_temp
828     // and then set_output with the temp tensor.
829     //
830     // We achieve memory correctness by forcing an allocation in set_output and
831     // copying over the tensor from the temp buffer.  Kernels which would like
832     // to avoid this performance penalty should switch to calling
833     // allocate_output.
834     VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
835             << " called allocate_temp with scope_id " << allocator_attr.scope_id
836             << ".  Switch to allocate_output to avoid performance penalty.";
837     allocator_attr.scope_id = -1;
838   }
839   profiler::ScopedMemoryDebugAnnotation op_annotation(
840       op_kernel().name_view().data(), step_id(), "temp", type,
841       [&shape]() { return shape.DebugString(); });
842   Status s =
843       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
844   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
845     Allocator* a = get_allocator(allocator_attr);
846     if (a->TracksAllocationSizes()) {
847       int64_t alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
848       record_temp_memory_allocation(alloc_size, *out_temp);
849     }
850   } else if (record_memory_consumption_) {
851     DCHECK(tracking_state_);
852     mutex_lock l(tracking_state_->stats_mu);
853     tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
854   }
855   return s;
856 }
857 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)858 Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape,
859                                       Tensor* out_temp,
860                                       AllocatorAttributes allocator_attr) {
861   return allocate_temp(type, shape, out_temp, allocator_attr,
862                        AllocationAttributes());
863 }
864 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)865 Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape,
866                                       Tensor* out_temp) {
867   return allocate_temp(type, shape, out_temp, AllocatorAttributes());
868 }
869 
get_input_index(StringPiece name,int * out_index) const870 Status OpKernelContext::get_input_index(StringPiece name,
871                                         int* out_index) const {
872   int start, stop;
873   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
874   if (stop != start + 1) {
875     return errors::InvalidArgument("OpKernel used list-valued input name '",
876                                    name,
877                                    "' when single-valued input was "
878                                    "expected");
879   }
880   *out_index = start;
881   return OkStatus();
882 }
883 
get_output_index(StringPiece name,int * out_index) const884 Status OpKernelContext::get_output_index(StringPiece name,
885                                          int* out_index) const {
886   int start, stop;
887   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
888   if (stop != start + 1) {
889     return errors::InvalidArgument("OpKernel used list-valued output name '",
890                                    name,
891                                    "' when single-valued output was "
892                                    "expected");
893   }
894   *out_index = start;
895   return OkStatus();
896 }
897 
set_output(StringPiece name,const Tensor & tensor)898 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
899   int index;
900   TF_RETURN_IF_ERROR(get_output_index(name, &index));
901   set_output(index, tensor);
902   return OkStatus();
903 }
904 
set_output(StringPiece name,Tensor && tensor)905 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
906   int index;
907   TF_RETURN_IF_ERROR(get_output_index(name, &index));
908   set_output(index, std::move(tensor));
909   return OkStatus();
910 }
911 
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)912 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
913     int index, const Tensor& tensor) {
914   bool allocate_and_copy = false;
915   const bool never_forward =
916       (params_->forward_from_array != nullptr &&
917        params_->forward_from_array[index] == Params::kNeverForward);
918   if (TF_PREDICT_FALSE(never_forward)) {
919     maybe_initialize_scope_id_set();
920     if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
921         allocated_scope_ids_->end()) {
922       allocate_and_copy = true;
923     } else {
924       // The output at `index` must have been previously allocated via a call to
925       // `allocate_output(index, ...)`.  That call would ensure that we return
926       // the correct slice of the ScopedAllocated buffer, so we do not
927       // re-allocate and copy here.
928       LOG(WARNING)
929           << "OpKernel " << params_->op_kernel->name()
930           << " called both allocate_output and set_output with scope_id "
931           << output_alloc_attr(index).scope_id;
932     }
933   }
934 
935   if (TF_PREDICT_FALSE(allocate_and_copy)) {
936     // This output was marked to not be forwarded either during graph
937     // construction or grappler passes.  Force an allocation and copy input to
938     // output.
939     VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
940             << tensor.DebugString() << " never_forward " << never_forward
941             << " params_->forward_from_array[index] "
942             << params_->forward_from_array[index] << " alloc_attr.scope_id "
943             << output_alloc_attr(index).scope_id;
944     profiler::ScopedMemoryDebugAnnotation op_annotation(
945         op_kernel().name_view().data(), step_id(), "output", tensor.dtype(),
946         [&tensor]() { return tensor.shape().DebugString(); });
947     auto new_tensor = MakeUnique<Tensor>();
948     Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
949                                output_alloc_attr(index));
950     TF_CHECK_OK(s);
951     device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
952                                      op_device_context(), [](const Status&) {});
953     outputs_[index] = TensorValue(new_tensor.release());
954   }
955   return allocate_and_copy;
956 }
957 
maybe_track_allocations_for_set_output(const Tensor & tensor)958 void OpKernelContext::maybe_track_allocations_for_set_output(
959     const Tensor& tensor) {
960   if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
961     DCHECK(tracking_state_);
962     mutex_lock l(tracking_state_->stats_mu);
963     const auto it = std::find_if(
964         tracking_state_->temp_tensor_buffer_and_size.begin(),
965         tracking_state_->temp_tensor_buffer_and_size.end(),
966         [&tensor](const std::pair<const void*, int64>& e) {
967           return e.first ==
968                  static_cast<const void*>(tensor.tensor_data().data());
969         });
970     if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
971       tracking_state_->temp_memory_allocated -= it->second;
972       tracking_state_->temp_tensor_buffer_and_size.erase(it);
973     }
974   }
975 }
976 
set_output(int index,const Tensor & tensor)977 void OpKernelContext::set_output(int index, const Tensor& tensor) {
978   CHECK_GE(index, 0);
979   CHECK_LT(index, outputs_.size());
980   const DataType type = params_->op_kernel->output_type(index);
981   CHECK(!IsRefType(type));
982   CHECK_EQ(outputs_[index].tensor, nullptr);
983   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
984     // Input can be forwarded to output; incref on `tensor` and set output at
985     // `index` to this tensor.
986     outputs_[index] = TensorValue(new Tensor(tensor));
987     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
988   }
989 }
990 
set_output(int index,Tensor && tensor)991 void OpKernelContext::set_output(int index, Tensor&& tensor) {
992   CHECK_GE(index, 0);
993   CHECK_LT(index, outputs_.size());
994   const DataType type = params_->op_kernel->output_type(index);
995   CHECK(!IsRefType(type));
996   CHECK_EQ(outputs_[index].tensor, nullptr);
997   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
998     // Input can be forwarded to output; set output at `index` to this tensor.
999     outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
1000     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
1001   }
1002 }
1003 
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)1004 void OpKernelContext::set_output_ref(int index, mutex* mu,
1005                                      Tensor* tensor_for_ref) {
1006   CHECK_GE(index, 0);
1007   CHECK_LT(index, outputs_.size());
1008   CHECK(IsRefType(params_->op_kernel->output_type(index)));
1009   outputs_[index] = TensorValue(mu, tensor_for_ref);
1010 }
1011 
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)1012 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
1013                                        Tensor* tensor_for_ref) {
1014   int index;
1015   TF_RETURN_IF_ERROR(get_output_index(name, &index));
1016   set_output_ref(index, mu, tensor_for_ref);
1017   return OkStatus();
1018 }
1019 
mutable_output(StringPiece name,Tensor ** tensor)1020 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
1021   int index;
1022   TF_RETURN_IF_ERROR(get_output_index(name, &index));
1023   *tensor = mutable_output(index);
1024   return OkStatus();
1025 }
1026 
ValidateInputsAreSameShape(OpKernel * op)1027 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
1028   const auto& inputs = params_->inputs;
1029   for (size_t i = 1; i < inputs.size(); ++i) {
1030     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
1031       SetStatus(errors::InvalidArgument(
1032           "Inputs to operation ", op->name(), " of type ", op->type_string(),
1033           " must have the same size and shape.  Input 0: ",
1034           inputs[0]->shape().DebugString(), " != input ", i, ": ",
1035           inputs[i]->shape().DebugString()));
1036       return false;
1037     }
1038   }
1039   return true;
1040 }
1041 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)1042 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
1043                                        const DataTypeSlice expected_outputs) {
1044   DataTypeVector inputs;
1045   for (const TensorValue& t : params_->inputs) {
1046     inputs.push_back(t.dtype());
1047   }
1048   DataTypeVector outputs = params_->op_kernel->output_types();
1049   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
1050                               outputs);
1051 }
1052 
record_temp_memory_allocation(int64_t size,const Tensor & t)1053 void OpKernelContext::record_temp_memory_allocation(int64_t size,
1054                                                     const Tensor& t) {
1055   if (tracking_state_) {
1056     mutex_lock l(tracking_state_->stats_mu);
1057     tracking_state_->temp_memory_allocated += size;
1058     tracking_state_->temp_tensor_buffer_and_size.emplace_back(
1059         static_cast<const void*>(t.tensor_data().data()), size);
1060   }
1061 }
1062 
temp_memory_allocated() const1063 int64_t OpKernelContext::temp_memory_allocated() const {
1064   if (tracking_state_) {
1065     mutex_lock l(tracking_state_->stats_mu);
1066     return tracking_state_->temp_memory_allocated;
1067   } else {
1068     return 0;
1069   }
1070 }
1071 
record_persistent_memory_allocation(int64_t size,int64_t alloc_id)1072 void OpKernelContext::record_persistent_memory_allocation(int64_t size,
1073                                                           int64_t alloc_id) {
1074   if (tracking_state_) {
1075     mutex_lock l(tracking_state_->stats_mu);
1076     tracking_state_->persistent_memory_allocated += size;
1077     if (alloc_id >= 0) {
1078       tracking_state_->persistent_alloc_ids.push_back(alloc_id);
1079     }
1080   }
1081 }
1082 
persistent_memory_allocated() const1083 int64_t OpKernelContext::persistent_memory_allocated() const {
1084   if (tracking_state_) {
1085     mutex_lock l(tracking_state_->stats_mu);
1086     return tracking_state_->persistent_memory_allocated;
1087   } else {
1088     return 0;
1089   }
1090 }
1091 
persistent_alloc_ids() const1092 std::vector<int64_t> OpKernelContext::persistent_alloc_ids() const {
1093   if (tracking_state_) {
1094     mutex_lock l(tracking_state_->stats_mu);
1095     return std::vector<int64_t>(tracking_state_->persistent_alloc_ids.begin(),
1096                                 tracking_state_->persistent_alloc_ids.end());
1097   } else {
1098     return std::vector<int64_t>();
1099   }
1100 }
1101 
clear_recorded_memory()1102 void OpKernelContext::clear_recorded_memory() {
1103   if (tracking_state_) {
1104     mutex_lock l(tracking_state_->stats_mu);
1105     tracking_state_->temp_memory_allocated = 0;
1106     tracking_state_->persistent_memory_allocated = 0;
1107     tracking_state_->temp_tensor_buffer_and_size.clear();
1108     tracking_state_->persistent_alloc_ids.clear();
1109   }
1110 }
1111 
set_record_memory_consumption(bool v)1112 void OpKernelContext::set_record_memory_consumption(bool v) {
1113   record_memory_consumption_ = v;
1114   if (v && !tracking_state_) {
1115     tracking_state_ = absl::make_unique<TrackingState>();
1116   }
1117 }
1118 
executor_type() const1119 const string& OpKernelContext::executor_type() const {
1120   if (params_->executor_type) {
1121     return *params_->executor_type;
1122   } else {
1123     static const string& kEmptyString = *new string("");
1124     return kEmptyString;
1125   }
1126 }
1127 
1128 // OpKernel registration ------------------------------------------------------
1129 
1130 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1131   KernelRegistration(const KernelDef& d, StringPiece c,
1132                      std::unique_ptr<kernel_factory::OpKernelFactory> f)
1133       : def(d), kernel_class_name(c), factory(std::move(f)) {}
1134 
1135   const KernelDef def;
1136   const string kernel_class_name;
1137   std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1138 };
1139 
1140 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1141 // factory functions for instantiating the OpKernel that matches the
1142 // KernelDef.
1143 struct KernelRegistry {
1144   mutex mu;
1145   std::unordered_multimap<string, KernelRegistration> registry
1146       TF_GUARDED_BY(mu);
1147 };
1148 
1149 #if defined(_WIN32)
1150 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1151 #elif defined(__APPLE__)
1152 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1153 #else
1154 static const char kKernelLibPattern[] = "libtfkernel*.so";
1155 #endif
1156 
1157 #define FEATURE(x) \
1158   { x, #x }
1159 
1160 // Returns Status::OK if the dynamic library at the given path is safe to
1161 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1162 static Status IsProbablySafeToLoad(const string& path) {
1163   // A map of platform string to required CPU feature.
1164   using port::CPUFeature;
1165   static const auto* feature_map =
1166       new std::map<string, std::pair<CPUFeature, string>>{
1167           {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1168       };
1169 
1170   std::vector<std::string> platform_strings;
1171   int result = GetPlatformStrings(path, &platform_strings);
1172   if (result) {
1173     return Status(error::Code::UNKNOWN, strerror(result));
1174   }
1175   if (platform_strings.empty()) {
1176     return Status(error::Code::FAILED_PRECONDITION,
1177                   "Didn't find any platform strings");
1178   }
1179   std::vector<std::string> missing_features;
1180   for (const auto& platform_string : platform_strings) {
1181     const auto& entry = feature_map->find(platform_string);
1182     if (entry != feature_map->end() &&
1183         !port::TestCPUFeature(entry->second.first)) {
1184       missing_features.emplace_back(entry->second.second);
1185     }
1186   }
1187   if (!missing_features.empty()) {
1188     string errmsg = "Missing CPU features: ";
1189     errmsg.append(absl::StrJoin(missing_features, ", "));
1190     return errors::FailedPrecondition(errmsg);
1191   }
1192   return OkStatus();
1193 }
1194 
LoadDynamicKernelsInternal()1195 void LoadDynamicKernelsInternal() {
1196   Env* env = Env::Default();
1197 
1198   // Override to allow loading unsafe packages for development.
1199   // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1200   char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1201   bool override_abi_check = false;
1202   if (_abi_check_env_var != nullptr) {
1203     override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1204   }
1205 
1206   string bazel_kernel_dir =
1207       io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1208   std::vector<string> files;
1209   Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1210   if (s_kernel_dir.ok()) {
1211     string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1212     for (const auto& file : files) {
1213       string fullpath = io::JoinPath(bazel_kernel_dir, file);
1214       if (env->MatchPath(fullpath, dll_spec)) {
1215         Status s = IsProbablySafeToLoad(fullpath);
1216         if (!s.ok() && override_abi_check) {
1217           LOG(WARNING) << "Loading UNSAFE library " << fullpath
1218                        << " because ABI check override is set: "
1219                        << s.error_message();
1220         }
1221         if (s.ok() || override_abi_check) {
1222           // TODO(gunan): Store the handles to the opened files.
1223           void* unused_filehandle;
1224           TF_CHECK_OK(
1225               env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1226         } else {
1227           LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1228                        << s.error_message();
1229         }
1230       }
1231     }
1232   }
1233 }
1234 
1235 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1236 void LoadDynamicKernels() {
1237   // TODO(gunan): As more features are available, add intelligent kernel
1238   // selection, and dropping unsuitable kernel logic here.
1239   static absl::once_flag dll_loader_flag;
1240   absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1241 }
1242 
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1243 static string Key(StringPiece op_type, const DeviceType& device_type,
1244                   StringPiece label) {
1245   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1246                          label);
1247 }
1248 
1249 // Provide a way for users to disable JIT kernels for a transitional period.
1250 // Until this is removed, this function also removes the JIT label that is added
1251 // to JIT kernels during the static registration, to allow them to be found
1252 // during lookup as normal kernels.
SetupOrDisableJit(KernelRegistry * registry)1253 void SetupOrDisableJit(KernelRegistry* registry) {
1254   std::unordered_multimap<string, KernelRegistration> jit_kernels;
1255   bool remove_jit_kernels = absl::StrContains(
1256       absl::NullSafeStringView(getenv(kDisableJitKernelsEnvVar)), "1");
1257 
1258   mutex_lock l(registry->mu);
1259   std::unordered_multimap<string, KernelRegistration>& all_kernels =
1260       registry->registry;
1261   auto it = all_kernels.begin();
1262   while (it != all_kernels.end()) {
1263     if (absl::StrContains(it->second.def.label(), kJitKernelLabel)) {
1264       // Remove all kernels that have the jit label. They will be added back
1265       // without the label if they are not to be disabled.
1266       KernelDef def_without_label = it->second.def;
1267       def_without_label.set_label("");
1268 
1269       if (!remove_jit_kernels) {
1270         jit_kernels.emplace(
1271             Key(def_without_label.op(),
1272                 DeviceType(def_without_label.device_type()),
1273                 def_without_label.label()),
1274             KernelRegistration(def_without_label, it->second.kernel_class_name,
1275                                std::move(it->second.factory)));
1276       }
1277 
1278       it = all_kernels.erase(it);
1279     } else {
1280       it++;
1281     }
1282   }
1283 
1284   // Add back kernels if they are not disabled. This new key-value pair have all
1285   // references to the label removed.
1286   for (auto& jit_kernel : jit_kernels) {
1287     all_kernels.insert(std::move(jit_kernel));
1288   }
1289 }
1290 
1291 namespace register_kernel {
1292 
1293 // Defined out of line to save code space
Name(const char * op)1294 Name::Name(const char* op) : KernelDefBuilder(op) {}
1295 
1296 }  // namespace register_kernel
1297 
GlobalKernelRegistry()1298 void* GlobalKernelRegistry() {
1299   static KernelRegistry* global_kernel_registry = []() {
1300     KernelRegistry* registry = new KernelRegistry;
1301     OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1302     return registry;
1303   }();
1304   return global_kernel_registry;
1305 }
1306 
GlobalKernelRegistryTyped()1307 static KernelRegistry* GlobalKernelRegistryTyped() {
1308 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1309   LoadDynamicKernels();
1310 #endif  // AUTOLOAD_DYNAMIC_KERNELS
1311   auto* registry = reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1312   // Update or disable JIT kernels based on user configuration. This is a
1313   // temporary fallback as part of the initial release of JIT kernels.
1314   static absl::once_flag setup_or_disable_jit;
1315   absl::call_once(setup_or_disable_jit, SetupOrDisableJit, registry);
1316   return registry;
1317 }
1318 
1319 namespace kernel_factory {
1320 
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1321 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1322                                      StringPiece kernel_class_name,
1323                                      std::unique_ptr<OpKernelFactory> factory) {
1324   const string key =
1325       Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1326           kernel_def->label());
1327 
1328   // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1329   // here.
1330   // InitInternal gets called by static initializers, so it ends up executing
1331   // before main. This causes LoadKernelLibraries function to get called
1332   // before some file libraries can initialize, which in turn crashes the
1333   // program flakily. Until we get rid of static initializers in kernel
1334   // registration mechanism, we have this workaround here.
1335   auto global_registry =
1336       reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1337   mutex_lock l(global_registry->mu);
1338   global_registry->registry.emplace(
1339       key,
1340       KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1341   delete kernel_def;
1342 }
1343 
Create(OpKernelConstruction * context)1344 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1345     OpKernelConstruction* context) {
1346   return (*create_func_)(context);
1347 }
1348 
1349 }  // namespace kernel_factory
1350 
1351 namespace {
1352 
1353 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1354 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1355   static const string& kKernelAttr = *new string("_kernel");
1356   static const string& kEmptyString = *new string("");
1357 
1358   // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1359   // to use the `AttrSlice::FindByString()` overload, which does a more
1360   // efficient map lookup (instead of a linear scan) when the attribute name is
1361   // already a `const string&`.
1362   const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1363   if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1364     return kEmptyString;
1365   else
1366     return attr_value->s();
1367 }
1368 
1369 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1370 Status FindKernelRegistration(
1371     const DeviceType& device_type, StringPiece node_name,
1372     bool has_experimental_debug_info,
1373     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1374     StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1375     bool* was_attr_mismatch) {
1376   *reg = nullptr;
1377   *was_attr_mismatch = false;
1378 
1379   const string& label = GetKernelLabelAttr(node_attrs);
1380 
1381   const string key = Key(node_op, device_type, label);
1382   auto typed_registry = GlobalKernelRegistryTyped();
1383   tf_shared_lock lock(typed_registry->mu);
1384   auto regs = typed_registry->registry.equal_range(key);
1385   for (auto iter = regs.first; iter != regs.second; ++iter) {
1386     // If there is a kernel registered for the op and device_type,
1387     // check that the attrs match.
1388     bool match;
1389     TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1390     if (match) {
1391       if (*reg != nullptr) {
1392         if ((*reg)->def.priority() == iter->second.def.priority()) {
1393           return errors::InvalidArgument(
1394               "Multiple OpKernel registrations match NodeDef at the same "
1395               "priority '",
1396               FormatNodeDefForError(node_name, has_experimental_debug_info,
1397                                     experimental_debug_info),
1398               "': '", (*reg)->def.ShortDebugString(), "' and '",
1399               iter->second.def.ShortDebugString(), "'");
1400         } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1401           continue;
1402         }
1403         // iter->second's priority is higher than *reg.
1404       }
1405       *reg = &iter->second;
1406     } else {
1407       *was_attr_mismatch = true;
1408     }
1409   }
1410   // Check if no device specific registrations found. If not, try finding a
1411   // default kernel.
1412   if (*reg == nullptr &&
1413       !IsSymbolicExecutionDevice(device_type.type_string())) {
1414     const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1415     auto regs = typed_registry->registry.equal_range(default_key);
1416     for (auto iter = regs.first; iter != regs.second; ++iter) {
1417       // If there is a kernel registered for the op and device_type,
1418       // check that the attrs match.
1419       bool match;
1420       TF_RETURN_IF_ERROR(
1421           KernelAttrsMatch(iter->second.def, node_attrs, &match));
1422       if (match) {
1423         if (*reg != nullptr) {
1424           return errors::InvalidArgument(
1425               "Multiple Default OpKernel registrations match NodeDef '",
1426               FormatNodeDefForError(node_name, has_experimental_debug_info,
1427                                     experimental_debug_info),
1428               "': '", (*reg)->def.ShortDebugString(), "' and '",
1429               iter->second.def.ShortDebugString(), "'");
1430         }
1431         *reg = &iter->second;
1432       } else {
1433         *was_attr_mismatch = true;
1434       }
1435     }
1436 
1437     if (*reg != nullptr) {
1438       VLOG(1) << "No device-specific kernels found for NodeDef '"
1439               << FormatNodeDefForError(node_name, has_experimental_debug_info,
1440                                        experimental_debug_info)
1441               << "'"
1442               << "Will fall back to a default kernel." << std::endl;
1443     }
1444   }
1445 
1446   return OkStatus();
1447 }
1448 
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1449 Status FindKernelRegistration(const DeviceType& device_type,
1450                               const NodeDef& node_def,
1451                               const KernelRegistration** reg,
1452                               bool* was_attr_mismatch) {
1453   return FindKernelRegistration(
1454       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1455       node_def.experimental_debug_info(), node_def.op(),
1456       AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1457 }
1458 
1459 }  // namespace
1460 
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1461 bool KernelDefAvailable(const DeviceType& device_type,
1462                         const NodeDef& node_def) {
1463   const KernelRegistration* reg = nullptr;
1464   bool was_attr_mismatch;
1465   Status result =
1466       FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch);
1467   return result.ok() && reg != nullptr;
1468 }
1469 
1470 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1471 Status FindKernelDef(
1472     const DeviceType& device_type, StringPiece node_name,
1473     bool has_experimental_debug_info,
1474     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1475     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1476     const KernelDef** def, string* kernel_class_name) {
1477   const KernelRegistration* reg = nullptr;
1478   bool was_attr_mismatch;
1479   TF_RETURN_IF_ERROR(FindKernelRegistration(
1480       device_type, node_name, has_experimental_debug_info,
1481       experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch));
1482   if (reg == nullptr) {
1483     const std::string device_str = DeviceTypeString(device_type);
1484     Status s = errors::NotFound(
1485         "No registered '", node_op, "' OpKernel for ", device_str,
1486         " devices compatible with node ",
1487         FormatNodeDefForError(node_name, has_experimental_debug_info,
1488                               experimental_debug_info));
1489     if (was_attr_mismatch) {
1490       errors::AppendToMessage(
1491           &s, " (OpKernel was found, but attributes didn't match) ",
1492           "Requested Attributes: ",
1493           SummarizeAttrsHelper(node_attrs, node_device));
1494     }
1495 
1496     // Do not print kernel registrations for other devices when using _JIT
1497     // devices for compilation or for MKL ops.
1498     // TODO (intel-tf) : Remove the check for MKL ops when support for
1499     // block format is removed.
1500     if (!absl::StrContains(device_str, "JIT") &&
1501         !absl::StartsWith(node_name, "_Mkl")) {
1502       errors::AppendToMessage(
1503           &s, ".  Registered:", KernelsRegisteredForOp(node_op));
1504     }
1505 
1506     return s;
1507   }
1508   if (def != nullptr) *def = ®->def;
1509   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1510   return OkStatus();
1511 }
1512 
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1513 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1514                      const KernelDef** def, string* kernel_class_name) {
1515   return FindKernelDef(
1516       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1517       node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1518       AttrSlice(&node_def.attr()), def, kernel_class_name);
1519 }
1520 
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1521 Status SupportedDeviceTypesForNode(
1522     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1523     PrioritizedDeviceTypeVector* prioritized_device_types,
1524     const DeviceNameUtils::ParsedName* local_address_spec) {
1525   // TODO(zhifengc): Changes the callers (SimplePlacer and
1526   // DynamicPlacer) to consider the possibility that 'def' is call to
1527   // a user-defined function and only calls this
1528   // SupportedDeviceTypesForNode for primitive ops.
1529   const OpRegistrationData* op_reg_data;
1530   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1531   if (s.ok()) {
1532     bool exists_attr_mismatch = false;
1533     for (const DeviceType& device_type : prioritized_types) {
1534       const KernelRegistration* reg = nullptr;
1535       bool was_attr_mismatch = false;
1536       TF_RETURN_IF_ERROR(
1537           FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
1538       exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1539       if (reg != nullptr) {
1540         int32_t priority = reg->def.priority();
1541         prioritized_device_types->emplace_back(device_type, priority);
1542       }
1543     }
1544     // Add extra supported device types if the following conditions are
1545     // satisfied:
1546     // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1547     // 2) A device is requested for this node which specifies job/replica/task
1548     // 3) A local device is provided which specifies job/replica/task
1549     // 4) The local device does not have the same (job, replica, task) as the
1550     //    requested device
1551     //
1552     // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1553     // whose kernel is known to a remote process but not to the current process.
1554     if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1555         local_address_spec != nullptr) {
1556       DeviceNameUtils::ParsedName requested_device_name;
1557       DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1558       if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1559                                                    requested_device_name)) {
1560         if (requested_device_name.has_type) {
1561           prioritized_device_types->push_back(
1562               std::make_pair(DeviceType(requested_device_name.type), 0));
1563         } else {
1564           for (const DeviceType& device_type : prioritized_types) {
1565             prioritized_device_types->push_back(std::make_pair(device_type, 0));
1566           }
1567         }
1568       }
1569     }
1570 
1571     // If we were unable to find any valid devices let's validate if the node is
1572     // even valid.
1573     if (prioritized_device_types->empty()) {
1574       TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1575     }
1576 
1577     std::stable_sort(prioritized_device_types->begin(),
1578                      prioritized_device_types->end(),
1579                      [](const std::pair<DeviceType, int32>& a,
1580                         const std::pair<DeviceType, int32>& b) {
1581                        return a.second > b.second;
1582                      });
1583   } else {
1584     // Assumes that all device types support this node.
1585     for (const DeviceType& device_type : prioritized_types) {
1586       prioritized_device_types->push_back(std::make_pair(device_type, 0));
1587     }
1588   }
1589   return OkStatus();
1590 }
1591 
LogAllRegisteredKernels()1592 void LogAllRegisteredKernels() {
1593   KernelList kernel_list = GetAllRegisteredKernels();
1594   for (const auto& kernel_def : kernel_list.kernel()) {
1595     LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1596   }
1597 }
1598 
GetAllRegisteredKernels()1599 KernelList GetAllRegisteredKernels() {
1600   return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1601 }
1602 
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1603 KernelList GetFilteredRegisteredKernels(
1604     const std::function<bool(const KernelDef&)>& predicate) {
1605   KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1606   KernelList kernel_list;
1607   tf_shared_lock lock(typed_registry->mu);
1608   kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1609   for (const auto& p : typed_registry->registry) {
1610     const KernelDef& kernel_def = p.second.def;
1611     if (predicate(kernel_def)) {
1612       *kernel_list.add_kernel() = kernel_def;
1613     }
1614   }
1615   return kernel_list;
1616 }
1617 
GetRegisteredKernelsForOp(StringPiece op_name)1618 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1619   auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1620   return GetFilteredRegisteredKernels(op_pred);
1621 }
1622 
KernelsRegisteredForOp(StringPiece op_name)1623 string KernelsRegisteredForOp(StringPiece op_name) {
1624   KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1625   if (kernel_list.kernel_size() == 0) return "  <no registered kernels>\n";
1626   string ret;
1627   for (const auto& kernel_def : kernel_list.kernel()) {
1628     strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
1629     if (!kernel_def.label().empty()) {
1630       strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1631     }
1632     for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1633       strings::StrAppend(
1634           &ret, "; ", kernel_def.constraint(i).name(), " in ",
1635           SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1636     }
1637     strings::StrAppend(&ret, "\n");
1638   }
1639   return ret;
1640 }
1641 
1642 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1643  * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1644 std::unique_ptr<OpKernel> CreateOpKernel(
1645     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1646     const NodeDef& node_def, int graph_def_version, Status* status) {
1647   // Look up the Op registered for this op name.
1648   std::shared_ptr<const NodeProperties> props;
1649   status->Update(NodeProperties::CreateFromNodeDef(
1650       node_def, OpRegistry::Global(), &props));
1651   if (!status->ok()) {
1652     errors::AppendToMessage(status,
1653                             " for node: ", FormatNodeDefForError(node_def));
1654     return nullptr;
1655   }
1656   return CreateOpKernel(device_type, device, allocator, props,
1657                         graph_def_version, status);
1658 }
1659 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1660 std::unique_ptr<OpKernel> CreateOpKernel(
1661     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1662     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1663     Status* status) {
1664   OpKernel* kernel = nullptr;
1665   *status = CreateOpKernel(std::move(device_type), device, allocator,
1666                            /*flib=*/nullptr, props, graph_def_version, &kernel);
1667   return std::unique_ptr<OpKernel>(kernel);
1668 }
1669 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1670 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1671                       Allocator* allocator, FunctionLibraryRuntime* flib,
1672                       const std::shared_ptr<const NodeProperties>& props,
1673                       int graph_def_version, OpKernel** kernel) {
1674   return CreateOpKernel(std::move(device_type), device, allocator, flib,
1675                         /* resource_mgr= */ nullptr, props, graph_def_version,
1676                         kernel);
1677 }
1678 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1679 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1680                       Allocator* allocator, FunctionLibraryRuntime* flib,
1681                       ResourceMgr* resource_mgr,
1682                       const std::shared_ptr<const NodeProperties>& props,
1683                       int graph_def_version, OpKernel** kernel) {
1684   const NodeDef& node_def = props->node_def;
1685   bool was_attr_mismatch;
1686   const KernelRegistration* registration = nullptr;
1687   Status s;
1688   if (props != nullptr) {
1689     VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1690 
1691     // Validate node_def against OpDef.
1692     TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1693 
1694     // Look up kernel registration.
1695     s = FindKernelRegistration(device_type, node_def, ®istration,
1696                                &was_attr_mismatch);
1697     if (!s.ok()) {
1698       errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1699       return s;
1700     }
1701   }
1702   if (registration == nullptr) {
1703     s.Update(errors::NotFound("No registered '", node_def.op(),
1704                               "' OpKernel for '", DeviceTypeString(device_type),
1705                               "' devices compatible with node ",
1706                               FormatNodeDefForError(node_def)));
1707     if (was_attr_mismatch) {
1708       errors::AppendToMessage(
1709           &s, " (OpKernel was found, but attributes didn't match) ",
1710           "Requested Attributes: ", SummarizeAttrs(node_def));
1711     }
1712     errors::AppendToMessage(
1713         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1714     return s;
1715   }
1716 
1717   // We are creating a kernel for an op registered in
1718   // OpRegistry::Global(), we consult the kernel registry to decide
1719   // the kernel's input and output memory types.
1720   MemoryTypeVector input_memory_types;
1721   MemoryTypeVector output_memory_types;
1722   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1723                                         node_def, &input_memory_types,
1724                                         &output_memory_types));
1725 
1726   // Everything needed for OpKernel construction.
1727   OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1728                                resource_mgr, props, input_memory_types,
1729                                output_memory_types, graph_def_version, &s);
1730   *kernel = registration->factory->Create(&context);
1731   if (!s.ok()) {
1732     delete *kernel;
1733     *kernel = nullptr;
1734   }
1735   return s;
1736 }
1737 
1738 namespace {
1739 
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1740 bool FindArgInOp(StringPiece arg_name,
1741                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1742   for (const auto& arg : args) {
1743     if (arg_name == arg.name()) {
1744       return true;
1745     }
1746   }
1747   return false;
1748 }
1749 
1750 }  // namespace
1751 
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1752 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1753   auto typed_registry = GlobalKernelRegistryTyped();
1754   tf_shared_lock lock(typed_registry->mu);
1755   for (const auto& key_registration : typed_registry->registry) {
1756     const KernelDef& kernel_def(key_registration.second.def);
1757     const OpRegistrationData* op_reg_data;
1758     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1759     if (!status.ok()) {
1760       // TODO(josh11b): Make this a hard error.
1761       LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1762                  << "') for unknown op: " << kernel_def.op();
1763       continue;
1764     }
1765     const OpDef& op_def = op_reg_data->op_def;
1766     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1767       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1768           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1769         return errors::InvalidArgument(
1770             "HostMemory arg '", host_memory_arg,
1771             "' not found in OpDef: ", SummarizeOpDef(op_def));
1772       }
1773     }
1774   }
1775   return OkStatus();
1776 }
1777 
1778 template <>
eigen_device() const1779 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1780   return eigen_cpu_device();
1781 }
1782 
1783 template <>
eigen_device() const1784 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1785   return eigen_gpu_device();
1786 }
1787 
CtxFailure(const Status & s)1788 void OpKernelConstruction::CtxFailure(const Status& s) {
1789   VLOG(1) << s;
1790   SetStatus(s);
1791 }
1792 
CtxFailureWithWarning(const Status & s)1793 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1794   LOG(WARNING) << s;
1795   SetStatus(s);
1796 }
1797 
CtxFailure(const char * file,int line,const Status & s)1798 void OpKernelConstruction::CtxFailure(const char* file, int line,
1799                                       const Status& s) {
1800   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1801           << " : " << s;
1802   SetStatus(s);
1803 }
1804 
CtxFailureWithWarning(const char * file,int line,const Status & s)1805 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1806                                                  const Status& s) {
1807   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1808                << " : " << s;
1809   SetStatus(s);
1810 }
1811 
CtxFailure(const Status & s)1812 void OpKernelContext::CtxFailure(const Status& s) {
1813   VLOG(1) << s;
1814   SetStatus(s);
1815 }
1816 
CtxFailureWithWarning(const Status & s)1817 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1818   LOG(WARNING) << s;
1819   SetStatus(s);
1820 }
1821 
CtxFailure(const char * file,int line,const Status & s)1822 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1823   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1824           << " : " << s;
1825   SetStatus(s);
1826 }
1827 
CtxFailureWithWarning(const char * file,int line,const Status & s)1828 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1829                                             const Status& s) {
1830   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1831                << " : " << s;
1832   SetStatus(s);
1833 }
1834 
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1835 void CheckNotInComputeAsync(OpKernelContext* ctx,
1836                             const char* correct_macro_name) {
1837   CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1838       << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1839 }
1840 
1841 }  // namespace tensorflow
1842