• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <mutex>  // NOLINT
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include <cstdlib>
24 #include <cstring>
25 
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/graph.pb_text.h"
29 #include "tensorflow/core/framework/kernel_def.pb_text.h"
30 #include "tensorflow/core/framework/kernel_def_util.h"
31 #include "tensorflow/core/framework/log_memory.h"
32 #include "tensorflow/core/framework/memory_types.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op_def_util.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/notification.h"
40 #include "tensorflow/core/lib/core/stringpiece.h"
41 #include "tensorflow/core/lib/gtl/map_util.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/mutex.h"
49 #include "tensorflow/core/platform/platform_strings.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/util/ptr_util.h"
52 
53 namespace tensorflow {
54 
55 namespace {
56 
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)57 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
58                             const DataTypeSlice expected_outputs,
59                             const DataTypeSlice inputs,
60                             const DataTypeSlice outputs) {
61   bool signature_mismatch = false;
62 
63   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
64   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
65     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
66       signature_mismatch = true;
67     }
68   }
69 
70   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
71   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
72     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
73       signature_mismatch = true;
74     }
75   }
76 
77   if (signature_mismatch) {
78     return errors::InvalidArgument(
79         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
80         DataTypeSliceString(outputs),
81         " expected: ", DataTypeSliceString(expected_inputs), "->",
82         DataTypeSliceString(expected_outputs));
83   }
84   return Status::OK();
85 }
86 
87 }  // namespace
88 
89 // OpKernel ------------------------------------------------------------------
90 
OpKernel(OpKernelConstruction * context)91 OpKernel::OpKernel(OpKernelConstruction* context)
92     : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
93 
OpKernel(OpKernelConstruction * context,std::unique_ptr<const NodeDef> node_def)94 OpKernel::OpKernel(OpKernelConstruction* context,
95                    std::unique_ptr<const NodeDef> node_def)
96     : def_(std::move(node_def)),
97       input_types_(context->input_types().begin(),
98                    context->input_types().end()),
99       input_memory_types_(context->input_memory_types().begin(),
100                           context->input_memory_types().end()),
101       output_types_(context->output_types().begin(),
102                     context->output_types().end()),
103       output_memory_types_(context->output_memory_types().begin(),
104                            context->output_memory_types().end()),
105       graph_def_version_(context->graph_def_version()),
106       is_internal_(str_util::StartsWith(type_string(), "_")),
107       input_name_map_(context->num_inputs()),
108       output_name_map_(context->num_outputs()),
109       cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
110   OP_REQUIRES_OK(context,
111                  NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
112                                    &output_name_map_));
113   OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
114                                              context->graph_def_version()));
115 
116   // Kernels executing on GPU/SYCL tie very few resources on the CPU where the
117   // scheduler runs: we consider them as inexpensive.
118   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
119                context->device_type() != DeviceType(DEVICE_SYCL);
120 }
121 
~OpKernel()122 OpKernel::~OpKernel() {}
123 
124 const uint64 OpKernel::kInitialCostEstimateCycles;
125 const uint64 OpKernel::kOpIsExpensiveThresholdCycles;
126 const uint64 OpKernel::kCostDecay;
127 
name() const128 const string& OpKernel::name() const { return def_->name(); }
type_string() const129 const string& OpKernel::type_string() const { return def_->op(); }
requested_device() const130 const string& OpKernel::requested_device() const { return def_->device(); }
requested_input(int i) const131 const string& OpKernel::requested_input(int i) const { return def_->input(i); }
132 
133 // This static function exists only because device_attributes.pb.h is
134 // already included here, and it can't be introduced elsewhere.
DeviceNumaNode(const DeviceBase * device)135 /*static*/ int OpKernel::DeviceNumaNode(const DeviceBase* device) {
136   return device->attributes().locality().numa_node();
137 }
138 
InputRange(StringPiece input_name,int * start,int * stop) const139 Status OpKernel::InputRange(StringPiece input_name, int* start,
140                             int* stop) const {
141   const auto result = input_name_map_.find(input_name);
142   if (result == input_name_map_.end()) {
143     return errors::InvalidArgument("Unknown input name: ", input_name);
144   } else {
145     *start = result->second.first;
146     *stop = result->second.second;
147     return Status::OK();
148   }
149 }
150 
OutputRange(StringPiece output_name,int * start,int * stop) const151 Status OpKernel::OutputRange(StringPiece output_name, int* start,
152                              int* stop) const {
153   const auto result = output_name_map_.find(output_name);
154   if (result == output_name_map_.end()) {
155     return errors::InvalidArgument("Unknown output name: ", output_name);
156   } else {
157     *start = result->second.first;
158     *stop = result->second.second;
159     return Status::OK();
160   }
161 }
162 
MakeShape(const Tensor & shape,TensorShape * out) const163 Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
164   if (!IsLegacyVector(shape.shape())) {
165     return errors::InvalidArgument(
166         "shape must be a vector of {int32,int64}, got shape ",
167         shape.shape().DebugString());
168   }
169   if (shape.dtype() == DataType::DT_INT32) {
170     auto vec = shape.flat<int32>();
171     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
172   } else if (shape.dtype() == DataType::DT_INT64) {
173     auto vec = shape.flat<int64>();
174     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
175   } else {
176     return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
177   }
178 }
179 
Compute(OpKernelContext * context)180 void AsyncOpKernel::Compute(OpKernelContext* context) {
181   Notification n;
182   ComputeAsync(context, [&n]() { n.Notify(); });
183   n.WaitForNotification();
184 }
185 
186 // PersistentTensor ----------------------------------------------------------
187 
AccessTensor(OpKernelConstruction * context)188 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
189   // the caller has to have a valid context
190   CHECK(context);
191   return &tensor_;
192 }
193 
AccessTensor(OpKernelContext * context)194 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
195   context->NotifyUseOfPersistentTensor(tensor_);
196   return &tensor_;
197 }
198 
199 // OpKernelConstruction ------------------------------------------------------
200 
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef * node_def,const OpDef * op_def,FunctionLibraryRuntime * flib,const DataTypeSlice & input_types,const MemoryTypeSlice & input_memory_types,const DataTypeSlice & output_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)201 OpKernelConstruction::OpKernelConstruction(
202     DeviceType device_type, DeviceBase* device, Allocator* allocator,
203     const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
204     const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
205     const DataTypeSlice& output_types,
206     const MemoryTypeSlice& output_memory_types, int graph_def_version,
207     Status* status)
208     : device_type_(std::move(device_type)),
209       device_(device),
210       allocator_(allocator),
211       def_(node_def),
212       op_def_(op_def),
213       flib_(flib),
214       input_types_(input_types),
215       input_memory_types_(input_memory_types),
216       output_types_(output_types),
217       output_memory_types_(output_memory_types),
218       graph_def_version_(graph_def_version),
219       status_(status) {}
220 
HasAttr(StringPiece attr_name) const221 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
222   return HasNodeAttr(def(), attr_name);
223 }
224 
SetStatus(const Status & status)225 void OpKernelConstruction::SetStatus(const Status& status) {
226   status_->Update(status);
227 }
228 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)229 Status OpKernelConstruction::MatchSignature(
230     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
231   return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
232                               output_types_);
233 }
234 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)235 Status OpKernelConstruction::allocate_temp(DataType type,
236                                            const TensorShape& shape,
237                                            Tensor* out_temp) {
238   AllocationAttributes attr;
239   attr.allocation_will_be_logged = true;
240   Tensor new_temp(allocator_, type, shape, attr);
241 
242   if (!new_temp.IsInitialized()) {
243     return errors::ResourceExhausted(
244         "OOM when allocating temporary tensor with shape", shape.DebugString());
245   }
246   if (LogMemory::IsEnabled()) {
247     LogMemory::RecordTensorAllocation(
248         def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
249   }
250   *out_temp = new_temp;
251   return Status::OK();
252 }
253 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)254 Status OpKernelConstruction::allocate_persistent(
255     DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
256     Tensor** out_tensor) {
257   // for now just do the same thing as allocate_temp
258   // TODO(misard) add specific memory tracking for persistent tensors
259   Tensor persistent;
260   Status s = allocate_temp(type, shape, &persistent);
261   if (!s.ok()) {
262     return s;
263   }
264   *out_persistent = PersistentTensor(persistent);
265   Tensor* allocated = out_persistent->AccessTensor(this);
266   if (out_tensor) {
267     *out_tensor = allocated;
268   }
269   return s;
270 }
271 
272 // OpKernelContext -----------------------------------------------------------
273 
274 const int OpKernelContext::Params::kNeverForward;
275 const int OpKernelContext::Params::kNoReservation;
276 
OpKernelContext(Params * params)277 OpKernelContext::OpKernelContext(Params* params)
278     : OpKernelContext(
279           params, static_cast<int>(params->op_kernel->output_types().size())) {}
280 
OpKernelContext(Params * params,int num_outputs)281 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
282     : params_(params),
283       outputs_(num_outputs),
284       temp_memory_allocated_(0),
285       persistent_memory_allocated_(0) {
286   params_->ensure_eigen_gpu_device();
287   if (params_->eigen_gpu_device != nullptr) {
288     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
289     Status s = params_->device->ReinitializeGpuDevice(
290         this, params_->eigen_gpu_device, params_->op_device_context,
291         eigen_gpu_allocator);
292     if (!s.ok()) {
293       SetStatus(s);
294     }
295   }
296   if (params_->record_tensor_accesses) {
297     referenced_tensors_.Init();
298   }
299 }
300 
~OpKernelContext()301 OpKernelContext::~OpKernelContext() {
302   for (TensorValue& value : outputs_) {
303     if (!value.is_ref()) {
304       delete value.tensor;
305     }
306   }
307   if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
308   if (params_->track_allocations && !wrapped_allocators_.empty()) {
309     LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
310                  << "being consumed by the StepStatsCollector.";
311     for (auto& wrapped_alloator : wrapped_allocators_) {
312       wrapped_alloator.second->GetRecordsAndUnRef();
313     }
314   }
315 }
316 
get_allocator(AllocatorAttributes attr)317 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
318   Allocator* allocator = nullptr;
319   if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
320     allocator = params_->device->GetScopedAllocator(attr, step_id());
321     CHECK(allocator);
322   } else {
323     allocator = params_->device->GetAllocator(attr);
324   }
325   if (TF_PREDICT_FALSE(track_allocations())) {
326     mutex_lock lock(mu_);
327     for (const auto& wrapped : wrapped_allocators_) {
328       if (wrapped.first == allocator) {
329         return wrapped.second;
330       }
331     }
332     TrackingAllocator* wrapped_allocator =
333         new TrackingAllocator(allocator, params_->track_allocations);
334     wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
335     return wrapped_allocator;
336   } else {
337     return allocator;
338   }
339 }
340 
SetStatus(const Status & status)341 void OpKernelContext::SetStatus(const Status& status) {
342   status_.Update(status);
343 }
344 
really_record_tensor_reference(const Tensor & tensor)345 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
346   mutex_lock l(mu_);
347   // Keep a reference to the underlying memory around.
348   referenced_tensors_->Add(tensor);
349 }
350 
input(StringPiece name,const Tensor ** tensor)351 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
352   int start, stop;
353   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
354   if (stop != start + 1) {
355     return errors::InvalidArgument("OpKernel used list-valued input name '",
356                                    name,
357                                    "' when single-valued input was "
358                                    "expected");
359   }
360   if (input_is_ref(start)) {
361     return errors::InvalidArgument("OpKernel used ref input name '", name,
362                                    "' when non-ref input was expected");
363   }
364   *tensor = (*params_->inputs)[start].tensor;
365   record_tensor_reference(**tensor);
366   return Status::OK();
367 }
368 
input_dtype(StringPiece name,DataType * dtype) const369 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
370   int start, stop;
371   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
372   if (stop != start + 1) {
373     return errors::InvalidArgument("OpKernel used list-valued input name '",
374                                    name,
375                                    "' when single-valued input was "
376                                    "expected");
377   }
378   const TensorValue& value((*params_->inputs)[start]);
379   if (value.is_ref()) {
380     *dtype = MakeRefType(value->dtype());
381   } else {
382     *dtype = value->dtype();
383   }
384   return Status::OK();
385 }
386 
input_ref_mutex(StringPiece name,mutex ** out_mutex)387 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
388   int start, stop;
389   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
390   if (stop != start + 1) {
391     return errors::InvalidArgument("OpKernel used list-valued input name '",
392                                    name,
393                                    "' when single-valued input was expected");
394   }
395   *out_mutex = input_ref_mutex(start);
396   return Status::OK();
397 }
398 
input(int index)399 const Tensor& OpKernelContext::input(int index) {
400   DCHECK_GE(index, 0);
401   DCHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
402   DCHECK(!input_is_ref(index));
403   const Tensor& tensor = *((*params_->inputs)[index].tensor);
404   record_tensor_reference(tensor);
405   return tensor;
406 }
407 
mutable_input(int index,bool lock_held)408 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
409   DCHECK_GE(index, 0);
410   DCHECK_LT(index, num_inputs());
411   DCHECK(input_is_ref(index));
412   // return a copy of the Ref acquired while holding the mutex
413   if (lock_held) {
414     Tensor& tensor = *((*params_->inputs)[index].tensor);
415     record_tensor_reference(tensor);
416     return tensor;
417   } else {
418     tf_shared_lock l(*input_ref_mutex(index));
419     Tensor& tensor = *((*params_->inputs)[index].tensor);
420     record_tensor_reference(tensor);
421     return tensor;
422   }
423 }
424 
replace_ref_input(int index,const Tensor & tensor,bool lock_held)425 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
426                                         bool lock_held) {
427   DCHECK_GE(index, 0);
428   DCHECK_LT(index, num_inputs());
429   DCHECK(input_is_ref(index));
430   // should only modify the tensor while holding the mutex
431   if (lock_held) {
432     *(*params_->inputs)[index].tensor = tensor;
433   } else {
434     mutex_lock l(*input_ref_mutex(index));
435     *(*params_->inputs)[index].tensor = tensor;
436   }
437   record_tensor_reference(tensor);
438 }
439 
forward_ref_input_to_ref_output(int input_index,int output_index)440 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
441                                                       int output_index) {
442   DCHECK_GE(input_index, 0);
443   DCHECK_LT(input_index, num_inputs());
444   DCHECK(input_is_ref(input_index));
445   set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
446                  (*params_->inputs)[input_index].tensor);
447 }
448 
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)449 bool OpKernelContext::forward_input_to_output_with_shape(
450     int input_index, int output_index, const TensorShape& output_shape,
451     Tensor** output) {
452   const auto output_attr = params_->output_attr_array == nullptr
453                                ? AllocatorAttributes()
454                                : output_alloc_attr(output_index);
455   std::unique_ptr<Tensor> new_tensor = forward_input(
456       input_index, output_index, expected_output_dtype(output_index),
457       output_shape, output_memory_type(output_index), output_attr);
458   if (new_tensor != nullptr) {
459     // Transfer ownership to the output slot in OpKernelContext.
460     outputs_[output_index] = TensorValue(new_tensor.release());
461     *output = outputs_[output_index].tensor;
462     return true;
463   } else {
464     return false;
465   }
466 }
467 
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)468 Status OpKernelContext::forward_input_to_output_with_shape(
469     StringPiece input_name, StringPiece output_name,
470     const TensorShape& output_shape, Tensor** output) {
471   int input_index, output_index, stop;
472   TF_RETURN_IF_ERROR(
473       params_->op_kernel->InputRange(input_name, &input_index, &stop));
474   if (stop != input_index + 1) {
475     return errors::InvalidArgument("OpKernel used list-valued input name '",
476                                    input_name,
477                                    "' when single-valued input was "
478                                    "expected");
479   }
480   TF_RETURN_IF_ERROR(
481       params_->op_kernel->OutputRange(output_name, &output_index, &stop));
482   if (stop != output_index + 1) {
483     return errors::InvalidArgument("OpKernel used list-valued output name '",
484                                    output_name,
485                                    "' when single-valued output was "
486                                    "expected");
487   }
488   if (!forward_input_to_output_with_shape(input_index, output_index,
489                                           output_shape, output)) {
490     return errors::FailedPrecondition("OpKernel could not forward input '",
491                                       input_name, "' to output '", output_name);
492   }
493   return Status::OK();
494 }
495 
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)496 std::unique_ptr<Tensor> OpKernelContext::forward_input(
497     int input_index, int output_index, DataType output_dtype,
498     const TensorShape& output_shape, MemoryType output_memory_type,
499     const AllocatorAttributes& output_attr) {
500   DCHECK_GE(input_index, 0);
501   DCHECK_LT(input_index, num_inputs());
502   const TensorValue& input = (*params_->inputs)[input_index];
503   // Check whether at graph construction time this output was marked
504   // either for no forwarding or with a reservation for this input.
505   // If it's reserved for this input we'll skip the refcount and
506   // AllocatorAttribute checks.
507   // TODO(tucker): Maybe we should skip all of the checks?
508   bool never_forward =
509       (params_->forward_from_array != nullptr && output_index >= 0 &&
510        params_->forward_from_array[output_index] == Params::kNeverForward);
511   if (never_forward) return nullptr;
512   bool forward_expected =
513       (params_->forward_from_array != nullptr && output_index >= 0 &&
514        params_->forward_from_array[output_index] == input_index);
515   if (!forward_expected && params_->forward_from_array != nullptr) {
516     // Check for possibly conflicting forward.
517     for (int i = 0; i < num_outputs(); ++i) {
518       if (params_->forward_from_array[i] == input_index) {
519         // This input is reserved for output i.
520         return nullptr;
521       }
522     }
523   }
524   // Check that input tensor exists and is not a ref.
525   if (input.tensor == nullptr || input.is_ref()) {
526     CHECK(!forward_expected);
527     return nullptr;
528   }
529   // Check that input type matches.
530   if (input_dtype(input_index) != output_dtype) {
531     CHECK(!forward_expected);
532     return nullptr;
533   }
534   // Check that the input and output sizes are compatible.
535   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
536     CHECK(!forward_expected);
537     return nullptr;
538   }
539   // Check that input and output memory types match, i.e.
540   // that they either both live in host or both live in device memory.
541   if (input_memory_type(input_index) != output_memory_type) {
542     CHECK(!forward_expected);
543     return nullptr;
544   }
545   if (!forward_expected) {
546     if (!input->RefCountIsOne()) {
547       return nullptr;
548     }
549     // Check that output allocator attributes are not more restrictive than
550     // input allocator attributes.
551     const auto input_attr = params_->input_alloc_attrs == nullptr
552                                 ? AllocatorAttributes()
553                                 : input_alloc_attr(input_index);
554     if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
555       return nullptr;
556     }
557   }
558 
559   auto output_tensor = MakeUnique<Tensor>();
560   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
561   return output_tensor;
562 }
563 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)564 Status OpKernelContext::forward_input_or_allocate_temp(
565     gtl::ArraySlice<int> candidate_input_indices, DataType type,
566     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
567     Tensor* out_temp) {
568   for (int input_index : candidate_input_indices) {
569     std::unique_ptr<Tensor> new_tensor =
570         forward_input(input_index, Params::kNoReservation /*output_index*/,
571                       type, shape, DEVICE_MEMORY, allocator_attr);
572     if (new_tensor != nullptr) {
573       *out_temp = std::move(*new_tensor);
574       return Status::OK();
575     }
576   }
577   return allocate_temp(type, shape, out_temp, allocator_attr);
578 }
579 
delete_ref_input(int index,bool lock_held)580 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
581   DCHECK_GE(index, 0);
582   DCHECK_LT(index, num_inputs());
583   DCHECK(input_is_ref(index));
584   // should only modify the tensor while holding the mutex
585   if (lock_held) {
586     delete (*params_->inputs)[index].tensor;
587   } else {
588     mutex_lock l(*input_ref_mutex(index));
589     delete (*params_->inputs)[index].tensor;
590   }
591 }
592 
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)593 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
594                                       bool lock_held) {
595   int start, stop;
596   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
597   if (stop != start + 1) {
598     return errors::InvalidArgument("OpKernel used list-valued input name '",
599                                    name,
600                                    "' when single-valued input was expected");
601   }
602   if (!input_is_ref(start)) {
603     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
604                                    "' when ref input was expected");
605   }
606   // return a copy of the Ref acquired while holding the mutex
607   if (lock_held) {
608     *tensor = *(*params_->inputs)[start].tensor;
609   } else {
610     tf_shared_lock l(*input_ref_mutex(start));
611     *tensor = *(*params_->inputs)[start].tensor;
612   }
613   record_tensor_reference(*tensor);
614   return Status::OK();
615 }
616 
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)617 Status OpKernelContext::replace_ref_input(StringPiece name,
618                                           const Tensor& tensor,
619                                           bool lock_held) {
620   int start, stop;
621   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
622   if (stop != start + 1) {
623     return errors::InvalidArgument("OpKernel used list-valued input name '",
624                                    name,
625                                    "' when single-valued input was expected");
626   }
627   if (!input_is_ref(start)) {
628     return errors::InvalidArgument("OpKernel used immutable input name '", name,
629                                    "' when ref input was expected");
630   }
631   replace_ref_input(start, tensor, lock_held);
632   return Status::OK();
633 }
634 
input_list(StringPiece name,OpInputList * list)635 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
636   int start, stop;
637   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
638   *list = OpInputList(this, start, stop);
639   return Status::OK();
640 }
641 
mutable_input_list(StringPiece name,OpMutableInputList * list)642 Status OpKernelContext::mutable_input_list(StringPiece name,
643                                            OpMutableInputList* list) {
644   int start, stop;
645   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
646   *list = OpMutableInputList(this, start, stop);
647   return Status::OK();
648 }
649 
output_list(StringPiece name,OpOutputList * list)650 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
651   int start, stop;
652   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
653   *list = OpOutputList(this, start, stop);
654   return Status::OK();
655 }
656 
allocate_output(int index,const TensorShape & shape,Tensor ** output)657 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
658                                         Tensor** output) {
659   DCHECK_GE(index, 0);
660   DCHECK_LT(index, num_outputs());
661   bool forward_expected =
662       (params_->forward_from_array != nullptr && index >= 0 &&
663        params_->forward_from_array[index] >= 0);
664   if (forward_expected) {
665     return errors::Internal(
666         "Explicit allocate_output call where input forwarding required.  Try "
667         "turning off the ScopedAllocator optimizer.");
668   }
669   AllocatorAttributes attr = output_alloc_attr(index);
670   return allocate_output(index, shape, output, attr);
671 }
672 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)673 Status OpKernelContext::allocate_output(StringPiece name,
674                                         const TensorShape& shape,
675                                         Tensor** tensor) {
676   int start, stop;
677   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
678   if (stop != start + 1) {
679     return errors::InvalidArgument("OpKernel used list-valued output name '",
680                                    name,
681                                    "' when single-valued output was "
682                                    "expected");
683   }
684   return allocate_output(start, shape, tensor);
685 }
686 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)687 Status OpKernelContext::allocate_output(StringPiece name,
688                                         const TensorShape& shape,
689                                         Tensor** tensor,
690                                         AllocatorAttributes attr) {
691   int start, stop;
692   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
693   if (stop != start + 1) {
694     return errors::InvalidArgument("OpKernel used list-valued output name '",
695                                    name,
696                                    "' when single-valued output was "
697                                    "expected");
698   }
699   return allocate_output(start, shape, tensor, attr);
700 }
701 
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)702 Status OpKernelContext::allocate_tensor(
703     DataType type, const TensorShape& shape, Tensor* out_tensor,
704     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
705   Allocator* a = get_allocator(attr);
706   AllocationAttributes logged_attr(allocation_attr);
707   logged_attr.allocation_will_be_logged = true;
708   Tensor new_tensor(a, type, shape, logged_attr);
709 
710   if (!new_tensor.IsInitialized()) {
711     return errors::ResourceExhausted(
712         "OOM when allocating tensor with shape", shape.DebugString(),
713         " and type ", DataTypeString(type), " on ", params_->device->name(),
714         " by allocator ", a->Name());
715   }
716   if (params_->log_memory) {
717     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
718                                       params_->step_id, new_tensor);
719   }
720   record_tensor_reference(new_tensor);
721   *out_tensor = std::move(new_tensor);
722   return Status::OK();
723 }
724 
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)725 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
726                                         Tensor** output,
727                                         AllocatorAttributes attr) {
728   DCHECK_GE(index, 0);
729   DCHECK_LT(index, outputs_.size());
730   const DataType type = params_->op_kernel->output_type(index);
731   DCHECK(!IsRefType(type));
732   DCHECK(mutable_output(index) == nullptr);
733   auto output_tensor = MakeUnique<Tensor>();
734   Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
735   if (s.ok()) {
736     outputs_[index] = TensorValue(output_tensor.release());
737     *output = outputs_[index].tensor;
738   }
739   return s;
740 }
741 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)742 Status OpKernelContext::allocate_temp(
743     DataType type, const TensorShape& shape, Tensor* out_temp,
744     AllocatorAttributes allocator_attr,
745     const AllocationAttributes& allocation_attr) {
746   Status s =
747       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
748   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
749     Allocator* a = get_allocator(allocator_attr);
750     if (a->TracksAllocationSizes()) {
751       int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
752       record_temp_memory_allocation(alloc_size, *out_temp);
753     }
754   }
755   return s;
756 }
757 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor,AllocatorAttributes attr)758 Status OpKernelContext::allocate_persistent(DataType type,
759                                             const TensorShape& shape,
760                                             PersistentTensor* out_persistent,
761                                             Tensor** out_tensor,
762                                             AllocatorAttributes attr) {
763   Tensor persistent;
764   Status s = allocate_tensor(type, shape, &persistent, attr);
765   if (s.ok()) {
766     *out_persistent = PersistentTensor(persistent);
767     if (out_tensor) {
768       *out_tensor = out_persistent->AccessTensor(this);
769     }
770     if (track_allocations()) {
771       Tensor* t = out_persistent->AccessTensor(this);
772       Allocator* a = get_allocator(attr);
773       if (a->TracksAllocationSizes()) {
774         int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
775         int64 alloc_id = a->AllocationId(t->tensor_data().data());
776         record_persistent_memory_allocation(alloc_size, alloc_id);
777       }
778     }
779   }
780   return s;
781 }
782 
set_output(StringPiece name,const Tensor & tensor)783 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
784   int start, stop;
785   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
786   if (stop != start + 1) {
787     return errors::InvalidArgument("OpKernel used list-valued output name '",
788                                    name,
789                                    "' when single-valued output was "
790                                    "expected");
791   }
792   set_output(start, tensor);
793   return Status::OK();
794 }
795 
set_output(int index,const Tensor & tensor)796 void OpKernelContext::set_output(int index, const Tensor& tensor) {
797   DCHECK_GE(index, 0);
798   DCHECK_LT(index, outputs_.size());
799   DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
800   DCHECK_EQ(mutable_output(index), nullptr);
801   record_tensor_reference(tensor);
802   outputs_[index] = TensorValue(new Tensor(tensor));
803   if (track_allocations() && tensor.TotalBytes() > 0) {
804     mutex_lock l(stats_mu_);
805     if (!temp_tensor_buffer_and_size_) {
806       return;
807     }
808     auto it = std::find_if(temp_tensor_buffer_and_size_->begin(),
809                            temp_tensor_buffer_and_size_->end(),
810                            [&tensor](const std::pair<const void*, int64>& e) {
811                              return e.first == static_cast<const void*>(
812                                                    tensor.tensor_data().data());
813                            });
814     if (it != temp_tensor_buffer_and_size_->end()) {
815       temp_memory_allocated_ -= it->second;
816       temp_tensor_buffer_and_size_->erase(it);
817     }
818   }
819 }
820 
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)821 void OpKernelContext::set_output_ref(int index, mutex* mu,
822                                      Tensor* tensor_for_ref) {
823   DCHECK_GE(index, 0);
824   DCHECK_LT(index, outputs_.size());
825   DCHECK(IsRefType(params_->op_kernel->output_type(index)));
826   record_tensor_reference(*tensor_for_ref);
827   outputs_[index] = TensorValue(mu, tensor_for_ref);
828 }
829 
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)830 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
831                                        Tensor* tensor_for_ref) {
832   int start, stop;
833   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
834   if (stop != start + 1) {
835     return errors::InvalidArgument("OpKernel used list-valued output name '",
836                                    name,
837                                    "' when single-valued output was "
838                                    "expected");
839   }
840   set_output_ref(start, mu, tensor_for_ref);
841   return Status::OK();
842 }
843 
mutable_output(StringPiece name,Tensor ** tensor)844 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
845   int start, stop;
846   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
847   if (stop != start + 1) {
848     return errors::InvalidArgument("OpKernel used list-valued output name '",
849                                    name,
850                                    "' when single-valued output was "
851                                    "expected");
852   }
853   *tensor = mutable_output(start);
854   return Status::OK();
855 }
856 
ValidateInputsAreSameShape(OpKernel * op)857 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
858   const auto& inputs = *params_->inputs;
859   for (size_t i = 1; i < inputs.size(); ++i) {
860     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
861       SetStatus(errors::InvalidArgument(
862           "Inputs to operation ", op->name(), " of type ", op->type_string(),
863           " must have the same size and shape.  Input 0: ",
864           inputs[0]->shape().DebugString(), " != input ", i, ": ",
865           inputs[i]->shape().DebugString()));
866       return false;
867     }
868   }
869   return true;
870 }
871 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)872 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
873                                        const DataTypeSlice expected_outputs) {
874   DataTypeVector inputs;
875   for (const TensorValue& t : *params_->inputs) {
876     inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
877   }
878   DataTypeVector outputs = params_->op_kernel->output_types();
879   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
880                               outputs);
881 }
882 
record_temp_memory_allocation(int64 size,const Tensor & t)883 void OpKernelContext::record_temp_memory_allocation(int64 size,
884                                                     const Tensor& t) {
885   mutex_lock l(stats_mu_);
886   temp_memory_allocated_ += size;
887   if (!temp_tensor_buffer_and_size_) {
888     temp_tensor_buffer_and_size_.reset(
889         new gtl::InlinedVector<std::pair<const void*, int64>, 2>());
890   }
891   temp_tensor_buffer_and_size_->emplace_back(
892       static_cast<const void*>(t.tensor_data().data()), size);
893 }
894 
temp_memory_allocated() const895 int64 OpKernelContext::temp_memory_allocated() const {
896   mutex_lock l(stats_mu_);
897   return temp_memory_allocated_;
898 }
899 
record_persistent_memory_allocation(int64 size,int64 alloc_id)900 void OpKernelContext::record_persistent_memory_allocation(int64 size,
901                                                           int64 alloc_id) {
902   mutex_lock l(stats_mu_);
903   persistent_memory_allocated_ += size;
904   if (alloc_id >= 0) {
905     if (!persistent_alloc_ids_) {
906       persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>());
907     }
908     persistent_alloc_ids_->push_back(alloc_id);
909   }
910 }
911 
persistent_memory_allocated() const912 int64 OpKernelContext::persistent_memory_allocated() const {
913   mutex_lock l(stats_mu_);
914   return persistent_memory_allocated_;
915 }
916 
persistent_alloc_ids() const917 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
918   mutex_lock l(stats_mu_);
919   if (persistent_alloc_ids_) {
920     return std::vector<int64>(persistent_alloc_ids_->begin(),
921                               persistent_alloc_ids_->end());
922   } else {
923     return std::vector<int64>();
924   }
925 }
926 
clear_recorded_memory()927 void OpKernelContext::clear_recorded_memory() {
928   mutex_lock l(stats_mu_);
929   temp_memory_allocated_ = 0;
930   persistent_memory_allocated_ = 0;
931   if (temp_tensor_buffer_and_size_) {
932     temp_tensor_buffer_and_size_->clear();
933   }
934   if (persistent_alloc_ids_) {
935     persistent_alloc_ids_->clear();
936   }
937 }
938 
939 // OpKernel registration ------------------------------------------------------
940 
941 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration942   KernelRegistration(const KernelDef& d, StringPiece c,
943                      std::unique_ptr<kernel_factory::OpKernelFactory> f)
944       : def(d), kernel_class_name(c), factory(std::move(f)) {}
945 
946   const KernelDef def;
947   const string kernel_class_name;
948   std::unique_ptr<kernel_factory::OpKernelFactory> factory;
949 };
950 
951 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
952 // factory functions for instantiating the OpKernel that matches the
953 // KernelDef.
954 typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
955 
956 #if defined(_WIN32)
957 static const char kKernelLibPattern[] = "libtfkernel*.dll";
958 #elif defined(__APPLE__)
959 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
960 #else
961 static const char kKernelLibPattern[] = "libtfkernel*.so";
962 #endif
963 
964 #define FEATURE(x) \
965   { x, #x }
966 
967 // Returns Status::OK if the dynamic library at the given path is safe to
968 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)969 static Status IsProbablySafeToLoad(const string& path) {
970   // A map of platform string to required CPU feature.
971   using port::CPUFeature;
972   static const auto* feature_map =
973       new std::map<string, std::pair<CPUFeature, string>>{
974           {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
975       };
976 
977   std::vector<std::string> platform_strings;
978   int result = GetPlatformStrings(path, &platform_strings);
979   if (result) {
980     return Status(error::Code::UNKNOWN, strerror(result));
981   }
982   if (platform_strings.empty()) {
983     return Status(error::Code::FAILED_PRECONDITION,
984                   "Didn't find any platform strings");
985   }
986   std::vector<std::string> missing_features;
987   for (const auto& platform_string : platform_strings) {
988     const auto& entry = feature_map->find(platform_string);
989     if (entry != feature_map->end() &&
990         !port::TestCPUFeature(entry->second.first)) {
991       missing_features.emplace_back(entry->second.second);
992     }
993   }
994   if (!missing_features.empty()) {
995     string errmsg = "Missing CPU features: ";
996     errmsg.append(str_util::Join(missing_features, ", "));
997     return Status(errors::Code::FAILED_PRECONDITION, errmsg);
998   }
999   return Status::OK();
1000 }
1001 
LoadDynamicKernelsInternal()1002 void LoadDynamicKernelsInternal() {
1003   Env* env = Env::Default();
1004 
1005   // Override to allow loading unsafe packages for development.
1006   // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1007   bool override_abi_check =
1008       strcmp(getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES"), "1") == 0;
1009 
1010   string bazel_kernel_dir = io::JoinPath(env->GetRunfilesDir(),
1011                                          "tensorflow",
1012                                          "core",
1013                                          "kernels");
1014   std::vector<string> files;
1015   Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1016   if (s_kernel_dir.ok()) {
1017     string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1018     for (const auto& file : files) {
1019       string fullpath = io::JoinPath(bazel_kernel_dir, file);
1020       if (env->MatchPath(fullpath, dll_spec)) {
1021         Status s = IsProbablySafeToLoad(fullpath);
1022         if (!s.ok() && override_abi_check) {
1023           LOG(WARNING) << "Loading UNSAFE library " << fullpath
1024                        << " because ABI check override is set: "
1025                        << s.error_message();
1026         }
1027         if (s.ok() || override_abi_check) {
1028           // TODO(gunan): Store the handles to the opened files.
1029           void* unused_filehandle;
1030           TF_CHECK_OK(env->LoadLibrary(fullpath.c_str(), &unused_filehandle));
1031         } else {
1032           LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1033                        << s.error_message();
1034         }
1035       }
1036     }
1037   }
1038 }
1039 
1040 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1041 void LoadDynamicKernels() {
1042   // TODO(gunan): As more features are available, add intelligent kernel
1043   // selection, and dropping unsuitable kernel logic here.
1044   static std::once_flag dll_loader_flag;
1045   std::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1046 }
1047 
GlobalKernelRegistry()1048 void* GlobalKernelRegistry() {
1049   static KernelRegistry* global_kernel_registry = new KernelRegistry;
1050   return global_kernel_registry;
1051 }
1052 
GlobalKernelRegistryTyped()1053 static KernelRegistry* GlobalKernelRegistryTyped() {
1054 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1055   LoadDynamicKernels();
1056 #endif  // AUTOLOAD_DYNAMIC_KERNELS
1057   return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1058 }
1059 
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1060 static string Key(StringPiece op_type, const DeviceType& device_type,
1061                   StringPiece label) {
1062   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1063                          label);
1064 }
1065 
1066 namespace kernel_factory {
1067 
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1068 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1069                                      StringPiece kernel_class_name,
1070                                      std::unique_ptr<OpKernelFactory> factory) {
1071   // See comments in register_kernel::Name in header for info on _no_register.
1072   if (kernel_def->op() != "_no_register") {
1073     const string key =
1074         Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1075             kernel_def->label());
1076 
1077     // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1078     // here.
1079     // InitInternal gets called by static initializers, so it ends up executing
1080     // before main. This causes LoadKernelLibraries function to get called
1081     // before some file libraries can initialize, which in turn crashes the
1082     // program flakily. Until we get rid of static initializers in kernel
1083     // registration mechanism, we have this workaround here.
1084     reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry())
1085         ->emplace(key, KernelRegistration(*kernel_def, kernel_class_name,
1086                                           std::move(factory)));
1087   }
1088   delete kernel_def;
1089 }
1090 
Create(OpKernelConstruction * context)1091 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1092     OpKernelConstruction* context) {
1093   return (*create_func_)(context);
1094 }
1095 
1096 }  // namespace kernel_factory
1097 
1098 namespace {
1099 
1100 static const StringPiece kKernelAttr("_kernel");
1101 
1102 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1103 Status FindKernelRegistration(const DeviceType& device_type,
1104                               const NodeDef& node_def,
1105                               const KernelRegistration** reg,
1106                               bool* was_attr_mismatch) {
1107   *reg = nullptr;
1108   *was_attr_mismatch = false;
1109   // Label defaults to empty if not found in NodeDef.
1110   const string& label = GetNodeAttrString(node_def, kKernelAttr);
1111 
1112   const string key = Key(node_def.op(), device_type, label);
1113   auto regs = GlobalKernelRegistryTyped()->equal_range(key);
1114   for (auto iter = regs.first; iter != regs.second; ++iter) {
1115     // If there is a kernel registered for the op and device_type,
1116     // check that the attrs match.
1117     bool match;
1118     TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
1119     if (match) {
1120       if (*reg != nullptr) {
1121         return errors::InvalidArgument(
1122             "Multiple OpKernel registrations match NodeDef '",
1123             FormatNodeDefForError(node_def), "': '",
1124             ProtoShortDebugString((*reg)->def), "' and '",
1125             ProtoShortDebugString(iter->second.def), "'");
1126       }
1127       *reg = &iter->second;
1128     } else {
1129       *was_attr_mismatch = true;
1130     }
1131   }
1132   return Status::OK();
1133 }
1134 
1135 }  // namespace
1136 
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1137 bool KernelDefAvailable(const DeviceType& device_type,
1138                         const NodeDef& node_def) {
1139   const KernelRegistration* reg = nullptr;
1140   bool was_attr_mismatch;
1141   Status result =
1142       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch);
1143   return result.ok() && reg != nullptr;
1144 }
1145 
1146 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1147 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1148                      const KernelDef** def, string* kernel_class_name) {
1149   const KernelRegistration* reg = nullptr;
1150   bool was_attr_mismatch;
1151   TF_RETURN_IF_ERROR(
1152       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch));
1153   if (reg == nullptr) {
1154     Status s = errors::NotFound(
1155         "No registered '", node_def.op(), "' OpKernel for ",
1156         DeviceTypeString(device_type), " devices compatible with node ",
1157         FormatNodeDefForError(node_def));
1158     if (was_attr_mismatch) {
1159       errors::AppendToMessage(
1160           &s, " (OpKernel was found, but attributes didn't match) ",
1161           "Requested Attributes: ", SummarizeAttrs(node_def));
1162     }
1163     errors::AppendToMessage(
1164         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1165     return s;
1166   }
1167   if (def != nullptr) *def = &reg->def;
1168   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1169   return Status::OK();
1170 }
1171 
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types)1172 Status SupportedDeviceTypesForNode(
1173     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1174     PrioritizedDeviceTypeVector* prioritized_device_types) {
1175   // TODO(zhifengc): Changes the callers (SimplePlacer and
1176   // DynamicPlacer) to consider the possibility that 'def' is call to
1177   // a user-defined function and only calls this
1178   // SupportedDeviceTypesForNode for primitive ops.
1179   const OpRegistrationData* op_reg_data;
1180   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1181   if (s.ok()) {
1182     for (const DeviceType& device_type : prioritized_types) {
1183       const KernelRegistration* reg = nullptr;
1184       bool was_attr_mismatch;
1185       TF_RETURN_IF_ERROR(
1186           FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
1187       if (reg != nullptr) {
1188         int32 priority = reg->def.priority();
1189         prioritized_device_types->emplace_back(device_type, priority);
1190       }
1191     }
1192     std::sort(prioritized_device_types->begin(),
1193               prioritized_device_types->end(),
1194               [](const std::pair<DeviceType, int32>& a,
1195                  const std::pair<DeviceType, int32>& b) {
1196                 return a.second > b.second;
1197               });
1198   } else {
1199     // Assumes that all device types support this node.
1200     for (const DeviceType& device_type : prioritized_types) {
1201       prioritized_device_types->push_back(std::make_pair(device_type, 0));
1202     }
1203   }
1204   return Status::OK();
1205 }
1206 
LogAllRegisteredKernels()1207 void LogAllRegisteredKernels() {
1208   KernelList kernel_list = GetAllRegisteredKernels();
1209   for (const auto& kernel_def : kernel_list.kernel()) {
1210     LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
1211   }
1212 }
1213 
GetAllRegisteredKernels()1214 KernelList GetAllRegisteredKernels() {
1215   return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1216 }
1217 
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1218 KernelList GetFilteredRegisteredKernels(
1219     const std::function<bool(const KernelDef&)>& predicate) {
1220   const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1221   KernelList kernel_list;
1222   kernel_list.mutable_kernel()->Reserve(typed_registry->size());
1223   for (const auto& p : *typed_registry) {
1224     const KernelDef& kernel_def = p.second.def;
1225     if (predicate(kernel_def)) {
1226       *kernel_list.add_kernel() = kernel_def;
1227     }
1228   }
1229   return kernel_list;
1230 }
1231 
GetRegisteredKernelsForOp(StringPiece op_name)1232 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1233   auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1234   return GetFilteredRegisteredKernels(op_pred);
1235 }
1236 
KernelsRegisteredForOp(StringPiece op_name)1237 string KernelsRegisteredForOp(StringPiece op_name) {
1238   KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1239   if (kernel_list.kernel_size() == 0) return "  <no registered kernels>\n";
1240   string ret;
1241   for (const auto& kernel_def : kernel_list.kernel()) {
1242     strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
1243     if (!kernel_def.label().empty()) {
1244       strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1245     }
1246     for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1247       strings::StrAppend(
1248           &ret, "; ", kernel_def.constraint(i).name(), " in ",
1249           SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1250     }
1251     strings::StrAppend(&ret, "\n");
1252   }
1253   return ret;
1254 }
1255 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1256 std::unique_ptr<OpKernel> CreateOpKernel(
1257     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1258     const NodeDef& node_def, int graph_def_version, Status* status) {
1259   OpKernel* kernel = nullptr;
1260   *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
1261                            node_def, graph_def_version, &kernel);
1262   return std::unique_ptr<OpKernel>(kernel);
1263 }
1264 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const NodeDef & node_def,int graph_def_version,OpKernel ** kernel)1265 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1266                       Allocator* allocator, FunctionLibraryRuntime* flib,
1267                       const NodeDef& node_def, int graph_def_version,
1268                       OpKernel** kernel) {
1269   VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1270 
1271   // Look up the Op registered for this op name.
1272   const OpDef* op_def = nullptr;
1273   Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
1274   if (!s.ok()) return s;
1275 
1276   // Validate node_def against OpDef.
1277   s = ValidateNodeDef(node_def, *op_def);
1278   if (!s.ok()) return s;
1279 
1280   // Look up kernel registration.
1281   const KernelRegistration* registration;
1282   bool was_attr_mismatch;
1283   s = FindKernelRegistration(device_type, node_def, &registration,
1284                              &was_attr_mismatch);
1285   if (!s.ok()) {
1286     errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1287     return s;
1288   }
1289   if (registration == nullptr) {
1290     s.Update(errors::NotFound("No registered '", node_def.op(),
1291                               "' OpKernel for ", DeviceTypeString(device_type),
1292                               " devices compatible with node ",
1293                               FormatNodeDefForError(node_def)));
1294     if (was_attr_mismatch) {
1295       errors::AppendToMessage(
1296           &s, " (OpKernel was found, but attributes didn't match) ",
1297           "Requested Attributes: ", SummarizeAttrs(node_def));
1298     }
1299     errors::AppendToMessage(
1300         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1301     return s;
1302   }
1303 
1304   // Get signature from the OpDef & NodeDef
1305   DataTypeVector inputs;
1306   DataTypeVector outputs;
1307   s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
1308   if (!s.ok()) {
1309     errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def));
1310     return s;
1311   }
1312 
1313   // We are creating a kernel for an op registered in
1314   // OpRegistry::Global(), we consult the kernel registry to decide
1315   // the kernel's input and output memory types.
1316   MemoryTypeVector input_memory_types;
1317   MemoryTypeVector output_memory_types;
1318   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1319                                         node_def, &input_memory_types,
1320                                         &output_memory_types));
1321 
1322   // Everything needed for OpKernel construction.
1323   OpKernelConstruction context(
1324       device_type, device, allocator, &node_def, op_def, flib, inputs,
1325       input_memory_types, outputs, output_memory_types, graph_def_version, &s);
1326   *kernel = registration->factory->Create(&context);
1327   if (!s.ok()) {
1328     delete *kernel;
1329     *kernel = nullptr;
1330   }
1331   return s;
1332 }
1333 
1334 namespace {
1335 
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1336 bool FindArgInOp(StringPiece arg_name,
1337                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1338   for (const auto& arg : args) {
1339     if (arg_name == arg.name()) {
1340       return true;
1341     }
1342   }
1343   return false;
1344 }
1345 
1346 }  // namespace
1347 
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1348 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1349   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
1350     const KernelDef& kernel_def(key_registration.second.def);
1351     const OpRegistrationData* op_reg_data;
1352     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1353     if (!status.ok()) {
1354       // TODO(josh11b): Make this a hard error.
1355       LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
1356                  << "') for unknown op: " << kernel_def.op();
1357       continue;
1358     }
1359     const OpDef& op_def = op_reg_data->op_def;
1360     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1361       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1362           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1363         return errors::InvalidArgument(
1364             "HostMemory arg '", host_memory_arg,
1365             "' not found in OpDef: ", SummarizeOpDef(op_def));
1366       }
1367     }
1368   }
1369   return Status::OK();
1370 }
1371 
1372 template <>
eigen_device() const1373 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1374   return eigen_cpu_device();
1375 }
1376 
1377 template <>
eigen_device() const1378 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1379   return eigen_gpu_device();
1380 }
1381 
1382 #ifdef TENSORFLOW_USE_SYCL
1383 template <>
eigen_device() const1384 const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
1385   return eigen_sycl_device();
1386 }
1387 #endif
1388 
CtxFailure(const Status & s)1389 void OpKernelConstruction::CtxFailure(const Status& s) {
1390   VLOG(1) << s;
1391   SetStatus(s);
1392 }
1393 
CtxFailureWithWarning(const Status & s)1394 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1395   LOG(WARNING) << s;
1396   SetStatus(s);
1397 }
1398 
CtxFailure(const char * file,int line,const Status & s)1399 void OpKernelConstruction::CtxFailure(const char* file, int line,
1400                                       const Status& s) {
1401   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1402           << " : " << s;
1403   SetStatus(s);
1404 }
1405 
CtxFailureWithWarning(const char * file,int line,const Status & s)1406 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1407                                                  const Status& s) {
1408   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1409                << " : " << s;
1410   SetStatus(s);
1411 }
1412 
CtxFailure(const Status & s)1413 void OpKernelContext::CtxFailure(const Status& s) {
1414   VLOG(1) << s;
1415   SetStatus(s);
1416 }
1417 
CtxFailureWithWarning(const Status & s)1418 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1419   LOG(WARNING) << s;
1420   SetStatus(s);
1421 }
1422 
CtxFailure(const char * file,int line,const Status & s)1423 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1424   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1425           << " : " << s;
1426   SetStatus(s);
1427 }
1428 
CtxFailureWithWarning(const char * file,int line,const Status & s)1429 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1430                                             const Status& s) {
1431   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1432                << " : " << s;
1433   SetStatus(s);
1434 }
1435 
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1436 void CheckNotInComputeAsync(OpKernelContext* ctx,
1437                             const char* correct_macro_name) {
1438   CHECK_EQ(nullptr, ctx->op_kernel().AsAsync())
1439       << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1440 }
1441 
1442 }  // namespace tensorflow
1443