• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex>  // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/base/call_once.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/device_factory.h"
33 #include "tensorflow/core/framework/graph.pb.h"
34 #include "tensorflow/core/framework/kernel_def.pb.h"
35 #include "tensorflow/core/framework/kernel_def_util.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/memory_types.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/node_properties.h"
41 #include "tensorflow/core/framework/op_def_util.h"
42 #include "tensorflow/core/framework/tensor_reference.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/notification.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/platform_strings.h"
56 #include "tensorflow/core/platform/types.h"
57 #include "tensorflow/core/profiler/lib/traceme.h"
58 #include "tensorflow/core/util/ptr_util.h"
59 
60 namespace tensorflow {
61 
62 namespace {
63 
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)64 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
65                             const DataTypeSlice expected_outputs,
66                             const DataTypeSlice inputs,
67                             const DataTypeSlice outputs) {
68   bool signature_mismatch = false;
69 
70   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
71   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
72     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
73       signature_mismatch = true;
74     }
75   }
76 
77   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
78   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
79     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
80       signature_mismatch = true;
81     }
82   }
83 
84   if (signature_mismatch) {
85     return errors::InvalidArgument(
86         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
87         DataTypeSliceString(outputs),
88         " expected: ", DataTypeSliceString(expected_inputs), "->",
89         DataTypeSliceString(expected_outputs));
90   }
91   return Status::OK();
92 }
93 
94 }  // namespace
95 
96 // OpKernel ------------------------------------------------------------------
97 
OpKernel(OpKernelConstruction * context)98 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
99 
OpKernel(OpKernelConstruction * context,bool is_deferred)100 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
101     : props_(context->props_),
102       input_memory_types_(context->input_memory_types().begin(),
103                           context->input_memory_types().end()),
104       output_memory_types_(context->output_memory_types().begin(),
105                            context->output_memory_types().end()),
106       input_name_map_(context->num_inputs()),
107       output_name_map_(context->num_outputs()),
108       name_view_(props_->node_def.name()),
109       type_string_view_(props_->node_def.op()),
110       graph_def_version_(context->graph_def_version()),
111       is_deferred_(is_deferred) {
112   OP_REQUIRES_OK(context,
113                  NameRangesForNode(props_->node_def, *props_->op_def,
114                                    &input_name_map_, &output_name_map_));
115   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
116                                              context->graph_def_version()));
117 
118   // Kernels executing on GPU tie very few resources on the CPU where the
119   // scheduler runs: we consider them as inexpensive.
120   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
121                !DeviceFactory::IsPluggableDevice(
122                    DeviceTypeString(context->device_type()));
123 }
124 
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)125 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
126                    bool is_deferred)
127     : props_(std::make_shared<const NodeProperties>(
128           context->props_->op_def, std::move(custom_def),
129           context->props_->input_types, context->props_->output_types)),
130       input_memory_types_(context->input_memory_types().begin(),
131                           context->input_memory_types().end()),
132       output_memory_types_(context->output_memory_types().begin(),
133                            context->output_memory_types().end()),
134       input_name_map_(context->num_inputs()),
135       output_name_map_(context->num_outputs()),
136       name_view_(props_->node_def.name()),
137       type_string_view_(props_->node_def.op()),
138       graph_def_version_(context->graph_def_version()),
139       is_deferred_(is_deferred) {
140   OP_REQUIRES_OK(context,
141                  NameRangesForNode(props_->node_def, *props_->op_def,
142                                    &input_name_map_, &output_name_map_));
143   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
144                                              context->graph_def_version()));
145 
146   // Kernels executing on GPU tie very few resources on the CPU where the
147   // scheduler runs: we consider them as inexpensive.
148   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
149                !DeviceFactory::IsPluggableDevice(
150                    DeviceTypeString(context->device_type()));
151 }
152 
~OpKernel()153 OpKernel::~OpKernel() {}
154 
InputRange(StringPiece input_name,int * start,int * stop) const155 Status OpKernel::InputRange(StringPiece input_name, int* start,
156                             int* stop) const {
157   const auto result = input_name_map_.find(input_name);
158   if (result == input_name_map_.end()) {
159     return errors::InvalidArgument("Unknown input name: ", input_name);
160   } else {
161     *start = result->second.first;
162     *stop = result->second.second;
163     return Status::OK();
164   }
165 }
166 
OutputRange(StringPiece output_name,int * start,int * stop) const167 Status OpKernel::OutputRange(StringPiece output_name, int* start,
168                              int* stop) const {
169   const auto result = output_name_map_.find(output_name);
170   if (result == output_name_map_.end()) {
171     return errors::InvalidArgument("Unknown output name: ", output_name);
172   } else {
173     *start = result->second.first;
174     *stop = result->second.second;
175     return Status::OK();
176   }
177 }
178 
ShapeTraceString(const OpKernelContext & ctx) const179 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
180   int num_inputs = ctx.num_inputs();
181   if (num_inputs == 0) return "";
182   std::vector<string> tensor_shapes;
183   tensor_shapes.reserve(num_inputs);
184   for (int i = 0; i < num_inputs; i++) {
185     if (!ctx.has_input(i)) {
186       tensor_shapes.emplace_back();  // Placeholder
187       continue;
188     }
189     DataType input_dtype = ctx.input_dtype(i);
190     if (input_dtype == DataType::DT_RESOURCE ||
191         input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
192       tensor_shapes.emplace_back();  // Placeholder
193       continue;
194     }
195     tensor_shapes.emplace_back(strings::StrCat(
196         DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
197   }
198   return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
199 }
200 
TraceString(const OpKernelContext & ctx,bool verbose) const201 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
202   string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
203   if (verbose) {
204     string shape = ShapeTraceString(ctx);
205     if (!shape.empty()) {
206       trace_string =
207           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
208     }
209   }
210   return trace_string;
211 }
212 
Compute(OpKernelContext * context)213 void AsyncOpKernel::Compute(OpKernelContext* context) {
214   Notification n;
215   ComputeAsync(context, [&n]() { n.Notify(); });
216   n.WaitForNotification();
217 }
218 
219 // OpKernelConstruction ------------------------------------------------------
220 
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)221 OpKernelConstruction::OpKernelConstruction(
222     DeviceType device_type, DeviceBase* device, Allocator* allocator,
223     FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
224     const std::shared_ptr<const NodeProperties>& props,
225     const MemoryTypeSlice& input_memory_types,
226     const MemoryTypeSlice& output_memory_types, int graph_def_version,
227     Status* status)
228     : device_type_(std::move(device_type)),
229       device_(device),
230       allocator_(allocator),
231       flib_(flib),
232       resource_mgr_(resource_mgr),
233       props_(props),
234       input_memory_types_(input_memory_types),
235       output_memory_types_(output_memory_types),
236       graph_def_version_(graph_def_version),
237       status_(status) {}
238 
HasAttr(StringPiece attr_name) const239 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
240   return HasNodeAttr(def(), attr_name);
241 }
242 
SetStatus(const Status & status)243 void OpKernelConstruction::SetStatus(const Status& status) {
244   status_->Update(status);
245 }
246 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)247 Status OpKernelConstruction::MatchSignature(
248     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
249   return MatchSignatureHelper(expected_inputs, expected_outputs,
250                               props_->input_types, props_->output_types);
251 }
252 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)253 Status OpKernelConstruction::allocate_temp(DataType type,
254                                            const TensorShape& shape,
255                                            Tensor* out_temp) {
256   AllocationAttributes attr;
257   attr.allocation_will_be_logged = true;
258   Tensor new_temp(allocator_, type, shape, attr);
259 
260   if (!new_temp.IsInitialized()) {
261     return errors::ResourceExhausted(
262         "OOM when allocating temporary tensor with shape", shape.DebugString());
263   }
264   if (LogMemory::IsEnabled()) {
265     LogMemory::RecordTensorAllocation(
266         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
267   }
268   *out_temp = new_temp;
269   return Status::OK();
270 }
271 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)272 Status OpKernelConstruction::allocate_temp(DataType type,
273                                            const TensorShape& shape,
274                                            Tensor* out_temp,
275                                            AllocatorAttributes allocator_attr) {
276   if (allocator_attr.scope_id != 0) {
277     return errors::InvalidArgument(
278         "ScopedAllocator cannot be used via OpKernelConstruction.");
279   }
280   Allocator* a = device_->GetAllocator(allocator_attr);
281   AllocationAttributes attr;
282   attr.allocation_will_be_logged = true;
283   Tensor new_temp(a, type, shape, attr);
284 
285   if (!new_temp.IsInitialized()) {
286     return errors::ResourceExhausted(
287         "OOM when allocating temporary tensor with shape", shape.DebugString());
288   }
289   if (LogMemory::IsEnabled()) {
290     LogMemory::RecordTensorAllocation(
291         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
292   }
293   *out_temp = new_temp;
294   return Status::OK();
295 }
296 
297 // OpKernelContext -----------------------------------------------------------
298 
299 const int OpKernelContext::Params::kNeverForward;
300 const int OpKernelContext::Params::kNoReservation;
301 
OpKernelContext(Params * params)302 OpKernelContext::OpKernelContext(Params* params)
303     : OpKernelContext(
304           params, static_cast<int>(params->op_kernel->output_types().size())) {}
305 
OpKernelContext(Params * params,int num_outputs)306 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
307     : params_(params), outputs_(num_outputs) {
308   if (params_->track_allocations) {
309     tracking_state_ = absl::make_unique<TrackingState>();
310   }
311 
312   params_->ensure_eigen_gpu_device();
313   if (params_->eigen_gpu_device != nullptr) {
314     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
315     Status s = params_->device->ReinitializeGpuDevice(
316         this, params_->eigen_gpu_device, params_->op_device_context,
317         eigen_gpu_allocator);
318     if (!s.ok()) {
319       SetStatus(s);
320     }
321   }
322 }
323 
~OpKernelContext()324 OpKernelContext::~OpKernelContext() {
325   for (TensorValue& value : outputs_) {
326     if (!value.is_ref()) {
327       delete value.tensor;
328     }
329   }
330   if (params_->track_allocations &&
331       !tracking_state_->wrapped_allocators.empty()) {
332     LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
333                  << "being consumed by the StepStatsCollector.";
334     for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
335       wrapped_allocator.second->GetRecordsAndUnRef();
336     }
337   }
338 }
339 
get_allocator(AllocatorAttributes attr)340 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
341   Allocator* allocator = nullptr;
342   if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
343     allocator = params_->device->GetScopedAllocator(attr, step_id());
344     CHECK(allocator);
345   } else {
346     allocator = params_->device->GetAllocator(attr);
347   }
348   if (TF_PREDICT_FALSE(track_allocations())) {
349     DCHECK(tracking_state_);
350     mutex_lock lock(tracking_state_->mu);
351     for (const auto& wrapped : tracking_state_->wrapped_allocators) {
352       if (wrapped.first == allocator) {
353         return wrapped.second;
354       }
355     }
356     TrackingAllocator* wrapped_allocator =
357         new TrackingAllocator(allocator, params_->track_allocations);
358     tracking_state_->wrapped_allocators.push_back(
359         std::make_pair(allocator, wrapped_allocator));
360     return wrapped_allocator;
361   } else {
362     return allocator;
363   }
364 }
365 
SetStatus(const Status & status)366 void OpKernelContext::SetStatus(const Status& status) {
367   status_.Update(status);
368 }
369 
input(StringPiece name,const Tensor ** tensor)370 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
371   int index;
372   TF_RETURN_IF_ERROR(get_input_index(name, &index));
373   if (input_is_ref(index)) {
374     return errors::InvalidArgument("OpKernel used ref input name '", name,
375                                    "' when non-ref input was expected");
376   }
377   *tensor = (*params_->inputs)[index].tensor;
378   return Status::OK();
379 }
380 
input_dtype(StringPiece name,DataType * dtype) const381 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
382   int index;
383   TF_RETURN_IF_ERROR(get_input_index(name, &index));
384   const TensorValue& value((*params_->inputs)[index]);
385   *dtype = value.dtype();
386   return Status::OK();
387 }
388 
input_ref_mutex(StringPiece name,mutex ** out_mutex)389 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
390   int index;
391   TF_RETURN_IF_ERROR(get_input_index(name, &index));
392   *out_mutex = input_ref_mutex(index);
393   return Status::OK();
394 }
395 
input(int index) const396 const Tensor& OpKernelContext::input(int index) const {
397   CHECK_GE(index, 0);
398   CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
399   CHECK(!input_is_ref(index));
400   const Tensor& tensor = *((*params_->inputs)[index].tensor);
401   return tensor;
402 }
403 
mutable_input(int index,bool lock_held)404 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
405   CHECK_GE(index, 0);
406   CHECK_LT(index, num_inputs());
407   CHECK(input_is_ref(index));
408   // return a copy of the Ref acquired while holding the mutex
409   if (lock_held) {
410     Tensor& tensor = *((*params_->inputs)[index].tensor);
411     return tensor;
412   } else {
413     tf_shared_lock l(*input_ref_mutex(index));
414     Tensor& tensor = *((*params_->inputs)[index].tensor);
415     return tensor;
416   }
417 }
418 
replace_ref_input(int index,const Tensor & tensor,bool lock_held)419 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
420                                         bool lock_held) {
421   CHECK_GE(index, 0);
422   CHECK_LT(index, num_inputs());
423   CHECK(input_is_ref(index));
424   // should only modify the tensor while holding the mutex
425   if (lock_held) {
426     *(*params_->inputs)[index].tensor = tensor;
427   } else {
428     mutex_lock l(*input_ref_mutex(index));
429     *(*params_->inputs)[index].tensor = tensor;
430   }
431 }
432 
forward_ref_input_to_ref_output(int input_index,int output_index)433 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
434                                                       int output_index) {
435   CHECK_GE(input_index, 0);
436   CHECK_LT(input_index, num_inputs());
437   CHECK(input_is_ref(input_index));
438   set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
439                  (*params_->inputs)[input_index].tensor);
440 }
441 
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)442 bool OpKernelContext::forward_input_to_output_with_shape(
443     int input_index, int output_index, const TensorShape& output_shape,
444     Tensor** output) {
445   const auto output_attr = params_->output_attr_array == nullptr
446                                ? AllocatorAttributes()
447                                : output_alloc_attr(output_index);
448   std::unique_ptr<Tensor> new_tensor = forward_input(
449       input_index, output_index, expected_output_dtype(output_index),
450       output_shape, output_memory_type(output_index), output_attr);
451   if (new_tensor != nullptr) {
452     // Transfer ownership to the output slot in OpKernelContext.
453     outputs_[output_index] = TensorValue(new_tensor.release());
454     *output = outputs_[output_index].tensor;
455     return true;
456   } else {
457     return false;
458   }
459 }
460 
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)461 Status OpKernelContext::forward_input_to_output_with_shape(
462     StringPiece input_name, StringPiece output_name,
463     const TensorShape& output_shape, Tensor** output) {
464   int input_index, output_index;
465   TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
466   TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
467   if (!forward_input_to_output_with_shape(input_index, output_index,
468                                           output_shape, output)) {
469     return errors::FailedPrecondition("OpKernel could not forward input '",
470                                       input_name, "' to output '", output_name);
471   }
472   return Status::OK();
473 }
474 
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)475 std::unique_ptr<Tensor> OpKernelContext::forward_input(
476     int input_index, int output_index, DataType output_dtype,
477     const TensorShape& output_shape, MemoryType output_memory_type,
478     const AllocatorAttributes& output_attr) {
479   CHECK_GE(input_index, 0);
480   CHECK_LT(input_index, num_inputs());
481   const TensorValue& input = (*params_->inputs)[input_index];
482   // Check whether at graph construction time this output was marked
483   // either for no forwarding or with a reservation for this input.
484   // If it's reserved for this input we'll skip the refcount and
485   // AllocatorAttribute checks.
486   // TODO(tucker): Maybe we should skip all of the checks?
487   bool never_forward =
488       (params_->forward_from_array != nullptr && output_index >= 0 &&
489        params_->forward_from_array[output_index] == Params::kNeverForward);
490   if (never_forward) return nullptr;
491   bool forward_expected =
492       (params_->forward_from_array != nullptr && output_index >= 0 &&
493        params_->forward_from_array[output_index] == input_index);
494   if (!forward_expected && params_->forward_from_array != nullptr) {
495     // Check for possibly conflicting forward.
496     for (int i = 0; i < num_outputs(); ++i) {
497       if (params_->forward_from_array[i] == input_index) {
498         // This input is reserved for output i.
499         return nullptr;
500       }
501     }
502   }
503   // Check that input tensor exists and is not a ref.
504   if (input.tensor == nullptr || input.is_ref()) {
505     CHECK(!forward_expected);
506     return nullptr;
507   }
508   // Check that input type matches.
509   if (input_dtype(input_index) != output_dtype) {
510     CHECK(!forward_expected);
511     return nullptr;
512   }
513   // Check that the input and output sizes are compatible.
514   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
515     CHECK(!forward_expected);
516     return nullptr;
517   }
518   // Check that input and output memory types match, i.e.
519   // that they either both live in host or both live in device memory.
520   if (input_memory_type(input_index) != output_memory_type) {
521     CHECK(!forward_expected);
522     return nullptr;
523   }
524   if (!forward_expected) {
525     if (!input->RefCountIsOne()) {
526       return nullptr;
527     }
528     // Check that output allocator attributes are not more restrictive than
529     // input allocator attributes.
530     const auto input_attr = params_->input_alloc_attrs == nullptr
531                                 ? AllocatorAttributes()
532                                 : input_alloc_attr(input_index);
533     if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
534       return nullptr;
535     }
536   }
537 
538   auto output_tensor = MakeUnique<Tensor>();
539   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
540   return output_tensor;
541 }
542 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)543 Status OpKernelContext::forward_input_or_allocate_temp(
544     gtl::ArraySlice<int> candidate_input_indices, DataType type,
545     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
546     Tensor* out_temp) {
547   for (int input_index : candidate_input_indices) {
548     std::unique_ptr<Tensor> new_tensor =
549         forward_input(input_index, Params::kNoReservation /*output_index*/,
550                       type, shape, DEVICE_MEMORY, allocator_attr);
551     if (new_tensor != nullptr) {
552       *out_temp = std::move(*new_tensor);
553       return Status::OK();
554     }
555   }
556   return allocate_temp(type, shape, out_temp, allocator_attr);
557 }
558 
delete_ref_input(int index,bool lock_held)559 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
560   CHECK_GE(index, 0);
561   CHECK_LT(index, num_inputs());
562   CHECK(input_is_ref(index));
563   // should only modify the tensor while holding the mutex
564   if (lock_held) {
565     delete (*params_->inputs)[index].tensor;
566   } else {
567     mutex_lock l(*input_ref_mutex(index));
568     delete (*params_->inputs)[index].tensor;
569   }
570 }
571 
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)572 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
573                                       bool lock_held) {
574   int index;
575   TF_RETURN_IF_ERROR(get_input_index(name, &index));
576   if (!input_is_ref(index)) {
577     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
578                                    "' when ref input was expected");
579   }
580   // return a copy of the Ref acquired while holding the mutex
581   if (lock_held) {
582     *tensor = *(*params_->inputs)[index].tensor;
583   } else {
584     tf_shared_lock l(*input_ref_mutex(index));
585     *tensor = *(*params_->inputs)[index].tensor;
586   }
587   return Status::OK();
588 }
589 
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)590 Status OpKernelContext::replace_ref_input(StringPiece name,
591                                           const Tensor& tensor,
592                                           bool lock_held) {
593   int index;
594   TF_RETURN_IF_ERROR(get_input_index(name, &index));
595   if (!input_is_ref(index)) {
596     return errors::InvalidArgument("OpKernel used immutable input name '", name,
597                                    "' when ref input was expected");
598   }
599   replace_ref_input(index, tensor, lock_held);
600   return Status::OK();
601 }
602 
input_list(StringPiece name,OpInputList * list)603 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
604   int start, stop;
605   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
606   *list = OpInputList(this, start, stop);
607   return Status::OK();
608 }
609 
mutable_input_list(StringPiece name,OpMutableInputList * list)610 Status OpKernelContext::mutable_input_list(StringPiece name,
611                                            OpMutableInputList* list) {
612   int start, stop;
613   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
614   *list = OpMutableInputList(this, start, stop);
615   return Status::OK();
616 }
617 
output_list(StringPiece name,OpOutputList * list)618 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
619   int start, stop;
620   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
621   *list = OpOutputList(this, start, stop);
622   return Status::OK();
623 }
624 
maybe_initialize_scope_id_set()625 void OpKernelContext::maybe_initialize_scope_id_set() {
626   if (allocated_scope_ids_ == nullptr) {
627     allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
628   }
629 }
630 
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)631 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
632                                         Tensor** tensor) {
633   if (index < 0) {
634     return errors::Internal("allocate_output with bad index=", index,
635                             " kernel=", params_->op_kernel->name());
636   }
637   if (index >= num_outputs()) {
638     return errors::Internal("allocate_output with bad index=", index,
639                             " num_outputs=", num_outputs(),
640                             " kernel=", params_->op_kernel->name());
641   }
642   bool forward_expected =
643       (params_->forward_from_array != nullptr && index >= 0 &&
644        params_->forward_from_array[index] >= 0);
645   if (forward_expected) {
646     return errors::Internal(
647         "Explicit allocate_output call where input forwarding required.  Try "
648         "turning off the ScopedAllocator optimizer.");
649   }
650   AllocatorAttributes attr = output_alloc_attr(index);
651   return allocate_output(index, shape, tensor, attr);
652 }
653 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)654 Status OpKernelContext::allocate_output(StringPiece name,
655                                         const TensorShape& shape,
656                                         Tensor** tensor) {
657   int start, stop;
658   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
659   if (stop != start + 1) {
660     return errors::InvalidArgument("OpKernel used list-valued output name '",
661                                    name,
662                                    "' when single-valued output was "
663                                    "expected");
664   }
665   return allocate_output(start, shape, tensor);
666 }
667 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)668 Status OpKernelContext::allocate_output(StringPiece name,
669                                         const TensorShape& shape,
670                                         Tensor** tensor,
671                                         AllocatorAttributes attr) {
672   int start, stop;
673   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
674   if (stop != start + 1) {
675     return errors::InvalidArgument("OpKernel used list-valued output name '",
676                                    name,
677                                    "' when single-valued output was "
678                                    "expected");
679   }
680   return allocate_output(start, shape, tensor, attr);
681 }
682 
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)683 Status OpKernelContext::allocate_tensor(
684     DataType type, const TensorShape& shape, Tensor* out_tensor,
685     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
686   Allocator* a = get_allocator(attr);
687   Tensor new_tensor(
688       a, type, shape,
689       AllocationAttributes(
690           /*retry_on_failure=*/allocation_attr.retry_on_failure,
691           /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
692 
693   if (!new_tensor.IsInitialized()) {
694     return errors::ResourceExhausted(
695         "OOM when allocating tensor with shape", shape.DebugString(),
696         " and type ", DataTypeString(type), " on ", params_->device->name(),
697         " by allocator ", a->Name());
698   }
699   if (params_->log_memory) {
700     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
701                                       params_->step_id, new_tensor);
702   }
703   *out_tensor = std::move(new_tensor);
704   return Status::OK();
705 }
706 
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)707 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
708                                         Tensor** output,
709                                         AllocatorAttributes attr) {
710   if (index < 0) {
711     return errors::Internal("allocate_output with bad index=", index,
712                             " kernel=", params_->op_kernel->name());
713   }
714   if (index >= num_outputs()) {
715     return errors::Internal("allocate_output with bad index=", index,
716                             " num_outputs=", outputs_.size(),
717                             " kernel=", params_->op_kernel->name());
718   }
719   const DataType type = params_->op_kernel->output_type(index);
720   if (IsRefType(type)) {
721     return errors::Internal("allocate_output with ref type. index=", index,
722                             " type=", type,
723                             " kernel=", params_->op_kernel->name());
724   }
725   if (mutable_output(index) != nullptr) {
726     return errors::Internal("allocate_output on same index multiple times.",
727                             " index = ", index,
728                             " mutable_output(index) = ", mutable_output(index),
729                             " kernel=", params_->op_kernel->name());
730   }
731   if (attr.scope_id > 0) {
732     maybe_initialize_scope_id_set();
733     if (!allocated_scope_ids_->insert(attr.scope_id).second) {
734       return errors::Internal(
735           "OpKernel ", params_->op_kernel->name(),
736           " called allocate_output at index ", index, " with scope_id ",
737           attr.scope_id,
738           " more than once.  Try turning off the ScopedAllocator optimizer.");
739     }
740   }
741   ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
742                                             step_id(), "output", type, &shape);
743   auto output_tensor = MakeUnique<Tensor>();
744   Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
745   if (s.ok()) {
746     outputs_[index] = TensorValue(output_tensor.release());
747     *output = outputs_[index].tensor;
748   }
749   return s;
750 }
751 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)752 Status OpKernelContext::allocate_temp(
753     DataType type, const TensorShape& shape, Tensor* out_temp,
754     AllocatorAttributes allocator_attr,
755     const AllocationAttributes& allocation_attr) {
756   if (allocator_attr.scope_id > 0) {
757     // We do not allow ScopedAllocator calls from allocate_temp.
758     // Here we clear the scope_id and return a temporary buffer.
759     // This is because it is legal for a kernel to call allocate_temp
760     // and then set_output with the temp tensor.
761     //
762     // We achieve memory correctness by forcing an allocation in set_output and
763     // copying over the tensor from the temp buffer.  Kernels which would like
764     // to avoid this performance penalty should switch to calling
765     // allocate_output.
766     VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
767             << " called allocate_temp with scope_id " << allocator_attr.scope_id
768             << ".  Switch to allocate_output to avoid performance penalty.";
769     allocator_attr.scope_id = -1;
770   }
771   ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
772                                             step_id(), "temp", type, &shape);
773   Status s =
774       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
775   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
776     Allocator* a = get_allocator(allocator_attr);
777     if (a->TracksAllocationSizes()) {
778       int64_t alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
779       record_temp_memory_allocation(alloc_size, *out_temp);
780     }
781   } else if (record_memory_consumption_) {
782     DCHECK(tracking_state_);
783     mutex_lock l(tracking_state_->stats_mu);
784     tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
785   }
786   return s;
787 }
788 
get_input_index(StringPiece name,int * out_index) const789 Status OpKernelContext::get_input_index(StringPiece name,
790                                         int* out_index) const {
791   int start, stop;
792   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
793   if (stop != start + 1) {
794     return errors::InvalidArgument("OpKernel used list-valued input name '",
795                                    name,
796                                    "' when single-valued input was "
797                                    "expected");
798   }
799   *out_index = start;
800   return Status::OK();
801 }
802 
get_output_index(StringPiece name,int * out_index) const803 Status OpKernelContext::get_output_index(StringPiece name,
804                                          int* out_index) const {
805   int start, stop;
806   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
807   if (stop != start + 1) {
808     return errors::InvalidArgument("OpKernel used list-valued output name '",
809                                    name,
810                                    "' when single-valued output was "
811                                    "expected");
812   }
813   *out_index = start;
814   return Status::OK();
815 }
816 
set_output(StringPiece name,const Tensor & tensor)817 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
818   int index;
819   TF_RETURN_IF_ERROR(get_output_index(name, &index));
820   set_output(index, tensor);
821   return Status::OK();
822 }
823 
set_output(StringPiece name,Tensor && tensor)824 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
825   int index;
826   TF_RETURN_IF_ERROR(get_output_index(name, &index));
827   set_output(index, std::move(tensor));
828   return Status::OK();
829 }
830 
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)831 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
832     int index, const Tensor& tensor) {
833   bool allocate_and_copy = false;
834   const bool never_forward =
835       (params_->forward_from_array != nullptr &&
836        params_->forward_from_array[index] == Params::kNeverForward);
837   if (TF_PREDICT_FALSE(never_forward)) {
838     maybe_initialize_scope_id_set();
839     if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
840         allocated_scope_ids_->end()) {
841       allocate_and_copy = true;
842     } else {
843       // The output at `index` must have been previously allocated via a call to
844       // `allocate_output(index, ...)`.  That call would ensure that we return
845       // the correct slice of the ScopedAllocated buffer, so we do not
846       // re-allocate and copy here.
847       LOG(WARNING)
848           << "OpKernel " << params_->op_kernel->name()
849           << " called both allocate_output and set_output with scope_id "
850           << output_alloc_attr(index).scope_id;
851     }
852   }
853 
854   if (TF_PREDICT_FALSE(allocate_and_copy)) {
855     // This output was marked to not be forwarded either during graph
856     // construction or grappler passes.  Force an allocation and copy input to
857     // output.
858     VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
859             << tensor.DebugString() << " never_forward " << never_forward
860             << " params_->forward_from_array[index] "
861             << params_->forward_from_array[index] << " alloc_attr.scope_id "
862             << output_alloc_attr(index).scope_id;
863     ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
864                                               step_id(), "output",
865                                               tensor.dtype(), &tensor.shape());
866     auto new_tensor = MakeUnique<Tensor>();
867     Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
868                                output_alloc_attr(index));
869     TF_CHECK_OK(s);
870     device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
871                                      op_device_context(), [](const Status&) {});
872     outputs_[index] = TensorValue(new_tensor.release());
873   }
874   return allocate_and_copy;
875 }
876 
maybe_track_allocations_for_set_output(const Tensor & tensor)877 void OpKernelContext::maybe_track_allocations_for_set_output(
878     const Tensor& tensor) {
879   if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
880     DCHECK(tracking_state_);
881     mutex_lock l(tracking_state_->stats_mu);
882     const auto it = std::find_if(
883         tracking_state_->temp_tensor_buffer_and_size.begin(),
884         tracking_state_->temp_tensor_buffer_and_size.end(),
885         [&tensor](const std::pair<const void*, int64>& e) {
886           return e.first ==
887                  static_cast<const void*>(tensor.tensor_data().data());
888         });
889     if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
890       tracking_state_->temp_memory_allocated -= it->second;
891       tracking_state_->temp_tensor_buffer_and_size.erase(it);
892     }
893   }
894 }
895 
set_output(int index,const Tensor & tensor)896 void OpKernelContext::set_output(int index, const Tensor& tensor) {
897   CHECK_GE(index, 0);
898   CHECK_LT(index, outputs_.size());
899   const DataType type = params_->op_kernel->output_type(index);
900   CHECK(!IsRefType(type));
901   CHECK_EQ(outputs_[index].tensor, nullptr);
902   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
903     // Input can be forwarded to output; incref on `tensor` and set output at
904     // `index` to this tensor.
905     outputs_[index] = TensorValue(new Tensor(tensor));
906     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
907   }
908 }
909 
set_output(int index,Tensor && tensor)910 void OpKernelContext::set_output(int index, Tensor&& tensor) {
911   CHECK_GE(index, 0);
912   CHECK_LT(index, outputs_.size());
913   const DataType type = params_->op_kernel->output_type(index);
914   CHECK(!IsRefType(type));
915   CHECK_EQ(outputs_[index].tensor, nullptr);
916   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
917     // Input can be forwarded to output; set output at `index` to this tensor.
918     outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
919     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
920   }
921 }
922 
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)923 void OpKernelContext::set_output_ref(int index, mutex* mu,
924                                      Tensor* tensor_for_ref) {
925   CHECK_GE(index, 0);
926   CHECK_LT(index, outputs_.size());
927   CHECK(IsRefType(params_->op_kernel->output_type(index)));
928   outputs_[index] = TensorValue(mu, tensor_for_ref);
929 }
930 
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)931 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
932                                        Tensor* tensor_for_ref) {
933   int index;
934   TF_RETURN_IF_ERROR(get_output_index(name, &index));
935   set_output_ref(index, mu, tensor_for_ref);
936   return Status::OK();
937 }
938 
mutable_output(StringPiece name,Tensor ** tensor)939 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
940   int index;
941   TF_RETURN_IF_ERROR(get_output_index(name, &index));
942   *tensor = mutable_output(index);
943   return Status::OK();
944 }
945 
ValidateInputsAreSameShape(OpKernel * op)946 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
947   const auto& inputs = *params_->inputs;
948   for (size_t i = 1; i < inputs.size(); ++i) {
949     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
950       SetStatus(errors::InvalidArgument(
951           "Inputs to operation ", op->name(), " of type ", op->type_string(),
952           " must have the same size and shape.  Input 0: ",
953           inputs[0]->shape().DebugString(), " != input ", i, ": ",
954           inputs[i]->shape().DebugString()));
955       return false;
956     }
957   }
958   return true;
959 }
960 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)961 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
962                                        const DataTypeSlice expected_outputs) {
963   DataTypeVector inputs;
964   for (const TensorValue& t : *params_->inputs) {
965     inputs.push_back(t.dtype());
966   }
967   DataTypeVector outputs = params_->op_kernel->output_types();
968   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
969                               outputs);
970 }
971 
record_temp_memory_allocation(int64_t size,const Tensor & t)972 void OpKernelContext::record_temp_memory_allocation(int64_t size,
973                                                     const Tensor& t) {
974   if (tracking_state_) {
975     mutex_lock l(tracking_state_->stats_mu);
976     tracking_state_->temp_memory_allocated += size;
977     tracking_state_->temp_tensor_buffer_and_size.emplace_back(
978         static_cast<const void*>(t.tensor_data().data()), size);
979   }
980 }
981 
temp_memory_allocated() const982 int64 OpKernelContext::temp_memory_allocated() const {
983   if (tracking_state_) {
984     mutex_lock l(tracking_state_->stats_mu);
985     return tracking_state_->temp_memory_allocated;
986   } else {
987     return 0;
988   }
989 }
990 
record_persistent_memory_allocation(int64_t size,int64_t alloc_id)991 void OpKernelContext::record_persistent_memory_allocation(int64_t size,
992                                                           int64_t alloc_id) {
993   if (tracking_state_) {
994     mutex_lock l(tracking_state_->stats_mu);
995     tracking_state_->persistent_memory_allocated += size;
996     if (alloc_id >= 0) {
997       tracking_state_->persistent_alloc_ids.push_back(alloc_id);
998     }
999   }
1000 }
1001 
persistent_memory_allocated() const1002 int64 OpKernelContext::persistent_memory_allocated() const {
1003   if (tracking_state_) {
1004     mutex_lock l(tracking_state_->stats_mu);
1005     return tracking_state_->persistent_memory_allocated;
1006   } else {
1007     return 0;
1008   }
1009 }
1010 
persistent_alloc_ids() const1011 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
1012   if (tracking_state_) {
1013     mutex_lock l(tracking_state_->stats_mu);
1014     return std::vector<int64>(tracking_state_->persistent_alloc_ids.begin(),
1015                               tracking_state_->persistent_alloc_ids.end());
1016   } else {
1017     return std::vector<int64>();
1018   }
1019 }
1020 
clear_recorded_memory()1021 void OpKernelContext::clear_recorded_memory() {
1022   if (tracking_state_) {
1023     mutex_lock l(tracking_state_->stats_mu);
1024     tracking_state_->temp_memory_allocated = 0;
1025     tracking_state_->persistent_memory_allocated = 0;
1026     tracking_state_->temp_tensor_buffer_and_size.clear();
1027     tracking_state_->persistent_alloc_ids.clear();
1028   }
1029 }
1030 
set_record_memory_consumption(bool v)1031 void OpKernelContext::set_record_memory_consumption(bool v) {
1032   record_memory_consumption_ = v;
1033   if (v && !tracking_state_) {
1034     tracking_state_ = absl::make_unique<TrackingState>();
1035   }
1036 }
1037 
executor_type() const1038 const string& OpKernelContext::executor_type() const {
1039   if (params_->executor_type) {
1040     return *params_->executor_type;
1041   } else {
1042     static const string& kEmptyString = *new string("");
1043     return kEmptyString;
1044   }
1045 }
1046 
1047 // OpKernel registration ------------------------------------------------------
1048 
1049 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1050   KernelRegistration(const KernelDef& d, StringPiece c,
1051                      std::unique_ptr<kernel_factory::OpKernelFactory> f)
1052       : def(d), kernel_class_name(c), factory(std::move(f)) {}
1053 
1054   const KernelDef def;
1055   const string kernel_class_name;
1056   std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1057 };
1058 
1059 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1060 // factory functions for instantiating the OpKernel that matches the
1061 // KernelDef.
1062 struct KernelRegistry {
1063   mutex mu;
1064   std::unordered_multimap<string, KernelRegistration> registry
1065       TF_GUARDED_BY(mu);
1066 };
1067 
1068 #if defined(_WIN32)
1069 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1070 #elif defined(__APPLE__)
1071 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1072 #else
1073 static const char kKernelLibPattern[] = "libtfkernel*.so";
1074 #endif
1075 
1076 #define FEATURE(x) \
1077   { x, #x }
1078 
1079 // Returns Status::OK if the dynamic library at the given path is safe to
1080 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1081 static Status IsProbablySafeToLoad(const string& path) {
1082   // A map of platform string to required CPU feature.
1083   using port::CPUFeature;
1084   static const auto* feature_map =
1085       new std::map<string, std::pair<CPUFeature, string>>{
1086           {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1087       };
1088 
1089   std::vector<std::string> platform_strings;
1090   int result = GetPlatformStrings(path, &platform_strings);
1091   if (result) {
1092     return Status(error::Code::UNKNOWN, strerror(result));
1093   }
1094   if (platform_strings.empty()) {
1095     return Status(error::Code::FAILED_PRECONDITION,
1096                   "Didn't find any platform strings");
1097   }
1098   std::vector<std::string> missing_features;
1099   for (const auto& platform_string : platform_strings) {
1100     const auto& entry = feature_map->find(platform_string);
1101     if (entry != feature_map->end() &&
1102         !port::TestCPUFeature(entry->second.first)) {
1103       missing_features.emplace_back(entry->second.second);
1104     }
1105   }
1106   if (!missing_features.empty()) {
1107     string errmsg = "Missing CPU features: ";
1108     errmsg.append(absl::StrJoin(missing_features, ", "));
1109     return Status(errors::Code::FAILED_PRECONDITION, errmsg);
1110   }
1111   return Status::OK();
1112 }
1113 
LoadDynamicKernelsInternal()1114 void LoadDynamicKernelsInternal() {
1115   Env* env = Env::Default();
1116 
1117   // Override to allow loading unsafe packages for development.
1118   // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1119   char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1120   bool override_abi_check = false;
1121   if (_abi_check_env_var != nullptr) {
1122     override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1123   }
1124 
1125   string bazel_kernel_dir =
1126       io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1127   std::vector<string> files;
1128   Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1129   if (s_kernel_dir.ok()) {
1130     string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1131     for (const auto& file : files) {
1132       string fullpath = io::JoinPath(bazel_kernel_dir, file);
1133       if (env->MatchPath(fullpath, dll_spec)) {
1134         Status s = IsProbablySafeToLoad(fullpath);
1135         if (!s.ok() && override_abi_check) {
1136           LOG(WARNING) << "Loading UNSAFE library " << fullpath
1137                        << " because ABI check override is set: "
1138                        << s.error_message();
1139         }
1140         if (s.ok() || override_abi_check) {
1141           // TODO(gunan): Store the handles to the opened files.
1142           void* unused_filehandle;
1143           TF_CHECK_OK(
1144               env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1145         } else {
1146           LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1147                        << s.error_message();
1148         }
1149       }
1150     }
1151   }
1152 }
1153 
1154 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1155 void LoadDynamicKernels() {
1156   // TODO(gunan): As more features are available, add intelligent kernel
1157   // selection, and dropping unsuitable kernel logic here.
1158   static absl::once_flag dll_loader_flag;
1159   absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1160 }
1161 
GlobalKernelRegistry()1162 void* GlobalKernelRegistry() {
1163   static KernelRegistry* global_kernel_registry = []() {
1164     KernelRegistry* registry = new KernelRegistry;
1165     OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1166     return registry;
1167   }();
1168   return global_kernel_registry;
1169 }
1170 
GlobalKernelRegistryTyped()1171 static KernelRegistry* GlobalKernelRegistryTyped() {
1172 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1173   LoadDynamicKernels();
1174 #endif  // AUTOLOAD_DYNAMIC_KERNELS
1175   return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1176 }
1177 
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1178 static string Key(StringPiece op_type, const DeviceType& device_type,
1179                   StringPiece label) {
1180   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1181                          label);
1182 }
1183 
1184 namespace kernel_factory {
1185 
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1186 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1187                                      StringPiece kernel_class_name,
1188                                      std::unique_ptr<OpKernelFactory> factory) {
1189   const string key =
1190       Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1191           kernel_def->label());
1192 
1193   // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1194   // here.
1195   // InitInternal gets called by static initializers, so it ends up executing
1196   // before main. This causes LoadKernelLibraries function to get called
1197   // before some file libraries can initialize, which in turn crashes the
1198   // program flakily. Until we get rid of static initializers in kernel
1199   // registration mechanism, we have this workaround here.
1200   auto global_registry =
1201       reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1202   mutex_lock l(global_registry->mu);
1203   global_registry->registry.emplace(
1204       key,
1205       KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1206   delete kernel_def;
1207 }
1208 
Create(OpKernelConstruction * context)1209 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1210     OpKernelConstruction* context) {
1211   return (*create_func_)(context);
1212 }
1213 
1214 }  // namespace kernel_factory
1215 
1216 namespace {
1217 
1218 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1219 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1220   static const string& kKernelAttr = *new string("_kernel");
1221   static const string& kEmptyString = *new string("");
1222 
1223   // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1224   // to use the `AttrSlice::FindByString()` overload, which does a more
1225   // efficient map lookup (instead of a linear scan) when the attribute name is
1226   // already a `const string&`.
1227   const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1228   if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1229     return kEmptyString;
1230   else
1231     return attr_value->s();
1232 }
1233 
1234 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1235 Status FindKernelRegistration(
1236     const DeviceType& device_type, StringPiece node_name,
1237     bool has_experimental_debug_info,
1238     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1239     StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1240     bool* was_attr_mismatch) {
1241   *reg = nullptr;
1242   *was_attr_mismatch = false;
1243 
1244   const string& label = GetKernelLabelAttr(node_attrs);
1245 
1246   const string key = Key(node_op, device_type, label);
1247   auto typed_registry = GlobalKernelRegistryTyped();
1248   tf_shared_lock lock(typed_registry->mu);
1249   auto regs = typed_registry->registry.equal_range(key);
1250   for (auto iter = regs.first; iter != regs.second; ++iter) {
1251     // If there is a kernel registered for the op and device_type,
1252     // check that the attrs match.
1253     bool match;
1254     TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1255     if (match) {
1256       if (*reg != nullptr) {
1257         if ((*reg)->def.priority() == iter->second.def.priority()) {
1258           return errors::InvalidArgument(
1259               "Multiple OpKernel registrations match NodeDef at the same "
1260               "priority '",
1261               FormatNodeDefForError(node_name, has_experimental_debug_info,
1262                                     experimental_debug_info),
1263               "': '", (*reg)->def.ShortDebugString(), "' and '",
1264               iter->second.def.ShortDebugString(), "'");
1265         } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1266           continue;
1267         }
1268         // iter->second's priority is higher than *reg.
1269       }
1270       *reg = &iter->second;
1271     } else {
1272       *was_attr_mismatch = true;
1273     }
1274   }
1275   // Check if no device specific registrations found. If not, try finding a
1276   // default kernel.
1277   if (*reg == nullptr &&
1278       !IsSymbolicExecutionDevice(device_type.type_string())) {
1279     const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1280     auto regs = typed_registry->registry.equal_range(default_key);
1281     for (auto iter = regs.first; iter != regs.second; ++iter) {
1282       // If there is a kernel registered for the op and device_type,
1283       // check that the attrs match.
1284       bool match;
1285       TF_RETURN_IF_ERROR(
1286           KernelAttrsMatch(iter->second.def, node_attrs, &match));
1287       if (match) {
1288         if (*reg != nullptr) {
1289           return errors::InvalidArgument(
1290               "Multiple Default OpKernel registrations match NodeDef '",
1291               FormatNodeDefForError(node_name, has_experimental_debug_info,
1292                                     experimental_debug_info),
1293               "': '", (*reg)->def.ShortDebugString(), "' and '",
1294               iter->second.def.ShortDebugString(), "'");
1295         }
1296         *reg = &iter->second;
1297       } else {
1298         *was_attr_mismatch = true;
1299       }
1300     }
1301 
1302     if (*reg != nullptr) {
1303       VLOG(1) << "No device-specific kernels found for NodeDef '"
1304               << FormatNodeDefForError(node_name, has_experimental_debug_info,
1305                                        experimental_debug_info)
1306               << "'"
1307               << "Will fall back to a default kernel." << std::endl;
1308     }
1309   }
1310 
1311   return Status::OK();
1312 }
1313 
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1314 Status FindKernelRegistration(const DeviceType& device_type,
1315                               const NodeDef& node_def,
1316                               const KernelRegistration** reg,
1317                               bool* was_attr_mismatch) {
1318   return FindKernelRegistration(
1319       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1320       node_def.experimental_debug_info(), node_def.op(),
1321       AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1322 }
1323 
1324 }  // namespace
1325 
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1326 bool KernelDefAvailable(const DeviceType& device_type,
1327                         const NodeDef& node_def) {
1328   const KernelRegistration* reg = nullptr;
1329   bool was_attr_mismatch;
1330   Status result =
1331       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch);
1332   return result.ok() && reg != nullptr;
1333 }
1334 
1335 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1336 Status FindKernelDef(
1337     const DeviceType& device_type, StringPiece node_name,
1338     bool has_experimental_debug_info,
1339     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1340     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1341     const KernelDef** def, string* kernel_class_name) {
1342   const KernelRegistration* reg = nullptr;
1343   bool was_attr_mismatch;
1344   TF_RETURN_IF_ERROR(FindKernelRegistration(
1345       device_type, node_name, has_experimental_debug_info,
1346       experimental_debug_info, node_op, node_attrs, &reg, &was_attr_mismatch));
1347   if (reg == nullptr) {
1348     const std::string device_str = DeviceTypeString(device_type);
1349     Status s = errors::NotFound(
1350         "No registered '", node_op, "' OpKernel for ", device_str,
1351         " devices compatible with node ",
1352         FormatNodeDefForError(node_name, has_experimental_debug_info,
1353                               experimental_debug_info));
1354     if (was_attr_mismatch) {
1355       errors::AppendToMessage(
1356           &s, " (OpKernel was found, but attributes didn't match) ",
1357           "Requested Attributes: ",
1358           SummarizeAttrsHelper(node_attrs, node_device));
1359     }
1360 
1361     // Do not print kernel registrations for other devices when using _JIT
1362     // devices for compilation.
1363     if (!absl::StrContains(device_str, "JIT")) {
1364       errors::AppendToMessage(
1365           &s, ".  Registered:", KernelsRegisteredForOp(node_op));
1366     }
1367 
1368     return s;
1369   }
1370   if (def != nullptr) *def = &reg->def;
1371   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1372   return Status::OK();
1373 }
1374 
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1375 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1376                      const KernelDef** def, string* kernel_class_name) {
1377   return FindKernelDef(
1378       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1379       node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1380       AttrSlice(&node_def.attr()), def, kernel_class_name);
1381 }
1382 
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1383 Status SupportedDeviceTypesForNode(
1384     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1385     PrioritizedDeviceTypeVector* prioritized_device_types,
1386     const DeviceNameUtils::ParsedName* local_address_spec) {
1387   // TODO(zhifengc): Changes the callers (SimplePlacer and
1388   // DynamicPlacer) to consider the possibility that 'def' is call to
1389   // a user-defined function and only calls this
1390   // SupportedDeviceTypesForNode for primitive ops.
1391   const OpRegistrationData* op_reg_data;
1392   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1393   if (s.ok()) {
1394     bool exists_attr_mismatch = false;
1395     for (const DeviceType& device_type : prioritized_types) {
1396       const KernelRegistration* reg = nullptr;
1397       bool was_attr_mismatch = false;
1398       TF_RETURN_IF_ERROR(
1399           FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
1400       exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1401       if (reg != nullptr) {
1402         int32_t priority = reg->def.priority();
1403         prioritized_device_types->emplace_back(device_type, priority);
1404       }
1405     }
1406     // Add extra supported device types if the following conditions are
1407     // satisfied:
1408     // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1409     // 2) A device is requested for this node which specifies job/replica/task
1410     // 3) A local device is provided which specifies job/replica/task
1411     // 4) The local device does not have the same (job, replica, task) as the
1412     //    requested device
1413     //
1414     // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1415     // whose kernel is known to a remote process but not to the current process.
1416     if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1417         local_address_spec != nullptr) {
1418       DeviceNameUtils::ParsedName requested_device_name;
1419       DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1420       if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1421                                                    requested_device_name)) {
1422         if (requested_device_name.has_type) {
1423           prioritized_device_types->push_back(
1424               std::make_pair(DeviceType(requested_device_name.type), 0));
1425         } else {
1426           for (const DeviceType& device_type : prioritized_types) {
1427             prioritized_device_types->push_back(std::make_pair(device_type, 0));
1428           }
1429         }
1430       }
1431     }
1432 
1433     // If we were unable to find any valid devices let's validate if the node is
1434     // even valid.
1435     if (prioritized_device_types->empty()) {
1436       TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1437     }
1438 
1439     std::stable_sort(prioritized_device_types->begin(),
1440                      prioritized_device_types->end(),
1441                      [](const std::pair<DeviceType, int32>& a,
1442                         const std::pair<DeviceType, int32>& b) {
1443                        return a.second > b.second;
1444                      });
1445   } else {
1446     // Assumes that all device types support this node.
1447     for (const DeviceType& device_type : prioritized_types) {
1448       prioritized_device_types->push_back(std::make_pair(device_type, 0));
1449     }
1450   }
1451   return Status::OK();
1452 }
1453 
LogAllRegisteredKernels()1454 void LogAllRegisteredKernels() {
1455   KernelList kernel_list = GetAllRegisteredKernels();
1456   for (const auto& kernel_def : kernel_list.kernel()) {
1457     LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1458   }
1459 }
1460 
GetAllRegisteredKernels()1461 KernelList GetAllRegisteredKernels() {
1462   return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1463 }
1464 
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1465 KernelList GetFilteredRegisteredKernels(
1466     const std::function<bool(const KernelDef&)>& predicate) {
1467   KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1468   KernelList kernel_list;
1469   tf_shared_lock lock(typed_registry->mu);
1470   kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1471   for (const auto& p : typed_registry->registry) {
1472     const KernelDef& kernel_def = p.second.def;
1473     if (predicate(kernel_def)) {
1474       *kernel_list.add_kernel() = kernel_def;
1475     }
1476   }
1477   return kernel_list;
1478 }
1479 
GetRegisteredKernelsForOp(StringPiece op_name)1480 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1481   auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1482   return GetFilteredRegisteredKernels(op_pred);
1483 }
1484 
KernelsRegisteredForOp(StringPiece op_name)1485 string KernelsRegisteredForOp(StringPiece op_name) {
1486   KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1487   if (kernel_list.kernel_size() == 0) return "  <no registered kernels>\n";
1488   string ret;
1489   for (const auto& kernel_def : kernel_list.kernel()) {
1490     strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
1491     if (!kernel_def.label().empty()) {
1492       strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1493     }
1494     for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1495       strings::StrAppend(
1496           &ret, "; ", kernel_def.constraint(i).name(), " in ",
1497           SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1498     }
1499     strings::StrAppend(&ret, "\n");
1500   }
1501   return ret;
1502 }
1503 
1504 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1505  * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1506 std::unique_ptr<OpKernel> CreateOpKernel(
1507     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1508     const NodeDef& node_def, int graph_def_version, Status* status) {
1509   // Look up the Op registered for this op name.
1510   std::shared_ptr<const NodeProperties> props;
1511   status->Update(NodeProperties::CreateFromNodeDef(
1512       node_def, OpRegistry::Global(), &props));
1513   if (!status->ok()) {
1514     errors::AppendToMessage(status,
1515                             " for node: ", FormatNodeDefForError(node_def));
1516     return nullptr;
1517   }
1518   return CreateOpKernel(device_type, device, allocator, props,
1519                         graph_def_version, status);
1520 }
1521 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1522 std::unique_ptr<OpKernel> CreateOpKernel(
1523     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1524     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1525     Status* status) {
1526   OpKernel* kernel = nullptr;
1527   *status = CreateOpKernel(std::move(device_type), device, allocator,
1528                            /*flib=*/nullptr, props, graph_def_version, &kernel);
1529   return std::unique_ptr<OpKernel>(kernel);
1530 }
1531 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1532 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1533                       Allocator* allocator, FunctionLibraryRuntime* flib,
1534                       const std::shared_ptr<const NodeProperties>& props,
1535                       int graph_def_version, OpKernel** kernel) {
1536   return CreateOpKernel(std::move(device_type), device, allocator, flib,
1537                         /* resource_mgr= */ nullptr, props, graph_def_version,
1538                         kernel);
1539 }
1540 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1541 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1542                       Allocator* allocator, FunctionLibraryRuntime* flib,
1543                       ResourceMgr* resource_mgr,
1544                       const std::shared_ptr<const NodeProperties>& props,
1545                       int graph_def_version, OpKernel** kernel) {
1546   const NodeDef& node_def = props->node_def;
1547   bool was_attr_mismatch;
1548   const KernelRegistration* registration = nullptr;
1549   Status s;
1550   if (props != nullptr) {
1551     VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1552 
1553     // Validate node_def against OpDef.
1554     TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1555 
1556     // Look up kernel registration.
1557     s = FindKernelRegistration(device_type, node_def, &registration,
1558                                &was_attr_mismatch);
1559     if (!s.ok()) {
1560       errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1561       return s;
1562     }
1563   }
1564   if (registration == nullptr) {
1565     s.Update(errors::NotFound("No registered '", node_def.op(),
1566                               "' OpKernel for '", DeviceTypeString(device_type),
1567                               "' devices compatible with node ",
1568                               FormatNodeDefForError(node_def)));
1569     if (was_attr_mismatch) {
1570       errors::AppendToMessage(
1571           &s, " (OpKernel was found, but attributes didn't match) ",
1572           "Requested Attributes: ", SummarizeAttrs(node_def));
1573     }
1574     errors::AppendToMessage(
1575         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1576     return s;
1577   }
1578 
1579   // We are creating a kernel for an op registered in
1580   // OpRegistry::Global(), we consult the kernel registry to decide
1581   // the kernel's input and output memory types.
1582   MemoryTypeVector input_memory_types;
1583   MemoryTypeVector output_memory_types;
1584   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1585                                         node_def, &input_memory_types,
1586                                         &output_memory_types));
1587 
1588   // Everything needed for OpKernel construction.
1589   OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1590                                resource_mgr, props, input_memory_types,
1591                                output_memory_types, graph_def_version, &s);
1592   *kernel = registration->factory->Create(&context);
1593   if (!s.ok()) {
1594     delete *kernel;
1595     *kernel = nullptr;
1596   }
1597   return s;
1598 }
1599 
1600 namespace {
1601 
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1602 bool FindArgInOp(StringPiece arg_name,
1603                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1604   for (const auto& arg : args) {
1605     if (arg_name == arg.name()) {
1606       return true;
1607     }
1608   }
1609   return false;
1610 }
1611 
1612 }  // namespace
1613 
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1614 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1615   auto typed_registry = GlobalKernelRegistryTyped();
1616   tf_shared_lock lock(typed_registry->mu);
1617   for (const auto& key_registration : typed_registry->registry) {
1618     const KernelDef& kernel_def(key_registration.second.def);
1619     const OpRegistrationData* op_reg_data;
1620     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1621     if (!status.ok()) {
1622       // TODO(josh11b): Make this a hard error.
1623       LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1624                  << "') for unknown op: " << kernel_def.op();
1625       continue;
1626     }
1627     const OpDef& op_def = op_reg_data->op_def;
1628     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1629       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1630           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1631         return errors::InvalidArgument(
1632             "HostMemory arg '", host_memory_arg,
1633             "' not found in OpDef: ", SummarizeOpDef(op_def));
1634       }
1635     }
1636   }
1637   return Status::OK();
1638 }
1639 
1640 template <>
eigen_device() const1641 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1642   return eigen_cpu_device();
1643 }
1644 
1645 template <>
eigen_device() const1646 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1647   return eigen_gpu_device();
1648 }
1649 
CtxFailure(const Status & s)1650 void OpKernelConstruction::CtxFailure(const Status& s) {
1651   VLOG(1) << s;
1652   SetStatus(s);
1653 }
1654 
CtxFailureWithWarning(const Status & s)1655 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1656   LOG(WARNING) << s;
1657   SetStatus(s);
1658 }
1659 
CtxFailure(const char * file,int line,const Status & s)1660 void OpKernelConstruction::CtxFailure(const char* file, int line,
1661                                       const Status& s) {
1662   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1663           << " : " << s;
1664   SetStatus(s);
1665 }
1666 
CtxFailureWithWarning(const char * file,int line,const Status & s)1667 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1668                                                  const Status& s) {
1669   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1670                << " : " << s;
1671   SetStatus(s);
1672 }
1673 
CtxFailure(const Status & s)1674 void OpKernelContext::CtxFailure(const Status& s) {
1675   VLOG(1) << s;
1676   SetStatus(s);
1677 }
1678 
CtxFailureWithWarning(const Status & s)1679 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1680   LOG(WARNING) << s;
1681   SetStatus(s);
1682 }
1683 
CtxFailure(const char * file,int line,const Status & s)1684 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1685   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1686           << " : " << s;
1687   SetStatus(s);
1688 }
1689 
CtxFailureWithWarning(const char * file,int line,const Status & s)1690 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1691                                             const Status& s) {
1692   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1693                << " : " << s;
1694   SetStatus(s);
1695 }
1696 
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1697 void CheckNotInComputeAsync(OpKernelContext* ctx,
1698                             const char* correct_macro_name) {
1699   CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1700       << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1701 }
1702 
1703 }  // namespace tensorflow
1704