• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/framework/graph.pb_text.h"
25 #include "tensorflow/core/framework/kernel_def.pb_text.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_def_util.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/notification.h"
35 #include "tensorflow/core/lib/core/stringpiece.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/io/path.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace tensorflow {
45 
46 namespace {
47 
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)48 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
49                             const DataTypeSlice expected_outputs,
50                             const DataTypeSlice inputs,
51                             const DataTypeSlice outputs) {
52   bool signature_mismatch = false;
53 
54   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
55   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
56     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
57       signature_mismatch = true;
58     }
59   }
60 
61   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
62   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
63     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
64       signature_mismatch = true;
65     }
66   }
67 
68   if (signature_mismatch) {
69     return errors::InvalidArgument(
70         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
71         DataTypeSliceString(outputs),
72         " expected: ", DataTypeSliceString(expected_inputs), "->",
73         DataTypeSliceString(expected_outputs));
74   }
75   return Status::OK();
76 }
77 
78 }  // namespace
79 
80 // OpKernel ------------------------------------------------------------------
81 
82 // TODO(mrry): Convert to std::make_unique when available.
OpKernel(OpKernelConstruction * context)83 OpKernel::OpKernel(OpKernelConstruction* context)
84     : OpKernel(context,
85                std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
86 
OpKernel(OpKernelConstruction * context,std::unique_ptr<const NodeDef> node_def)87 OpKernel::OpKernel(OpKernelConstruction* context,
88                    std::unique_ptr<const NodeDef> node_def)
89     : def_(std::move(node_def)),
90       input_types_(context->input_types().begin(),
91                    context->input_types().end()),
92       input_memory_types_(context->input_memory_types().begin(),
93                           context->input_memory_types().end()),
94       output_types_(context->output_types().begin(),
95                     context->output_types().end()),
96       output_memory_types_(context->output_memory_types().begin(),
97                            context->output_memory_types().end()),
98       graph_def_version_(context->graph_def_version()),
99       is_internal_(StringPiece(type_string()).starts_with("_")),
100       input_name_map_(context->num_inputs()),
101       output_name_map_(context->num_outputs()) {
102   OP_REQUIRES_OK(context,
103                  NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
104                                    &output_name_map_));
105   OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
106                                              context->graph_def_version()));
107 
108   // Kernels executing on GPU/SYCL tie very few resources on the CPU where the
109   // scheduler runs: we consider them as inexpensive.
110   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
111                context->device_type() != DeviceType(DEVICE_SYCL);
112 }
113 
~OpKernel()114 OpKernel::~OpKernel() {}
115 
name() const116 const string& OpKernel::name() const { return def_->name(); }
type_string() const117 const string& OpKernel::type_string() const { return def_->op(); }
requested_device() const118 const string& OpKernel::requested_device() const { return def_->device(); }
requested_input(int i) const119 const string& OpKernel::requested_input(int i) const { return def_->input(i); }
120 
InputRange(StringPiece input_name,int * start,int * stop) const121 Status OpKernel::InputRange(StringPiece input_name, int* start,
122                             int* stop) const {
123   const auto result = input_name_map_.find(input_name);
124   if (result == input_name_map_.end()) {
125     return errors::InvalidArgument("Unknown input name: ", input_name);
126   } else {
127     *start = result->second.first;
128     *stop = result->second.second;
129     return Status::OK();
130   }
131 }
132 
OutputRange(StringPiece output_name,int * start,int * stop) const133 Status OpKernel::OutputRange(StringPiece output_name, int* start,
134                              int* stop) const {
135   const auto result = output_name_map_.find(output_name);
136   if (result == output_name_map_.end()) {
137     return errors::InvalidArgument("Unknown output name: ", output_name);
138   } else {
139     *start = result->second.first;
140     *stop = result->second.second;
141     return Status::OK();
142   }
143 }
144 
MakeShape(const Tensor & shape,TensorShape * out) const145 Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
146   if (!IsLegacyVector(shape.shape())) {
147     return errors::InvalidArgument(
148         "shape must be a vector of {int32,int64}, got shape ",
149         shape.shape().DebugString());
150   }
151   if (shape.dtype() == DataType::DT_INT32) {
152     auto vec = shape.flat<int32>();
153     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
154   } else if (shape.dtype() == DataType::DT_INT64) {
155     auto vec = shape.flat<int64>();
156     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
157   } else {
158     return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
159   }
160 }
161 
Compute(OpKernelContext * context)162 void AsyncOpKernel::Compute(OpKernelContext* context) {
163   Notification n;
164   ComputeAsync(context, [&n]() { n.Notify(); });
165   n.WaitForNotification();
166 }
167 
168 // PersistentTensor ----------------------------------------------------------
169 
AccessTensor(OpKernelConstruction * context)170 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
171   // the caller has to have a valid context
172   CHECK(context);
173   return &tensor_;
174 }
175 
AccessTensor(OpKernelContext * context)176 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
177   context->NotifyUseOfPersistentTensor(tensor_);
178   return &tensor_;
179 }
180 
181 // OpKernelConstruction ------------------------------------------------------
182 
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef * node_def,const OpDef * op_def,FunctionLibraryRuntime * flib,const DataTypeSlice & input_types,const MemoryTypeSlice & input_memory_types,const DataTypeSlice & output_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)183 OpKernelConstruction::OpKernelConstruction(
184     DeviceType device_type, DeviceBase* device, Allocator* allocator,
185     const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
186     const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
187     const DataTypeSlice& output_types,
188     const MemoryTypeSlice& output_memory_types, int graph_def_version,
189     Status* status)
190     : device_type_(std::move(device_type)),
191       device_(device),
192       allocator_(allocator),
193       def_(node_def),
194       op_def_(op_def),
195       flib_(flib),
196       input_types_(input_types),
197       input_memory_types_(input_memory_types),
198       output_types_(output_types),
199       output_memory_types_(output_memory_types),
200       graph_def_version_(graph_def_version),
201       status_(status) {}
202 
HasAttr(StringPiece attr_name) const203 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
204   return HasNodeAttr(def(), attr_name);
205 }
206 
SetStatus(const Status & status)207 void OpKernelConstruction::SetStatus(const Status& status) {
208   status_->Update(status);
209 }
210 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)211 Status OpKernelConstruction::MatchSignature(
212     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
213   return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
214                               output_types_);
215 }
216 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)217 Status OpKernelConstruction::allocate_temp(DataType type,
218                                            const TensorShape& shape,
219                                            Tensor* out_temp) {
220   AllocationAttributes attr;
221   attr.allocation_will_be_logged = true;
222   Tensor new_temp(allocator_, type, shape, attr);
223 
224   if (!new_temp.IsInitialized()) {
225     return errors::ResourceExhausted(
226         "OOM when allocating temporary tensor with shape", shape.DebugString());
227   }
228   if (LogMemory::IsEnabled()) {
229     LogMemory::RecordTensorAllocation(
230         def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
231   }
232   *out_temp = new_temp;
233   return Status::OK();
234 }
235 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)236 Status OpKernelConstruction::allocate_persistent(
237     DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
238     Tensor** out_tensor) {
239   // for now just do the same thing as allocate_temp
240   // TODO(misard) add specific memory tracking for persistent tensors
241   Tensor persistent;
242   Status s = allocate_temp(type, shape, &persistent);
243   if (!s.ok()) {
244     return s;
245   }
246   *out_persistent = PersistentTensor(persistent);
247   Tensor* allocated = out_persistent->AccessTensor(this);
248   if (out_tensor) {
249     *out_tensor = allocated;
250   }
251   return s;
252 }
253 
254 // OpKernelContext -----------------------------------------------------------
255 
OpKernelContext(Params * params)256 OpKernelContext::OpKernelContext(Params* params)
257     : OpKernelContext(
258           params, static_cast<int>(params->op_kernel->output_types().size())) {}
259 
OpKernelContext(Params * params,int num_outputs)260 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
261     : params_(params),
262       outputs_(num_outputs),
263       temp_memory_allocated_(0),
264       persistent_memory_allocated_(0) {
265   Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
266   params_->ensure_eigen_gpu_device();
267   params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
268                                          params_->op_device_context,
269                                          eigen_gpu_allocator);
270   if (params_->record_tensor_accesses) {
271     referenced_tensors_.Init();
272   }
273 }
274 
~OpKernelContext()275 OpKernelContext::~OpKernelContext() {
276   for (TensorValue& value : outputs_) {
277     if (!value.is_ref()) {
278       delete value.tensor;
279     }
280   }
281   if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
282 }
283 
get_allocator(AllocatorAttributes attr)284 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
285   Allocator* allocator =
286       params_->device->GetStepAllocator(attr, resource_manager());
287   if (track_allocations()) {
288     mutex_lock lock(mu_);
289     for (const auto& wrapped : wrapped_allocators_) {
290       if (wrapped.first == allocator) {
291         return wrapped.second;
292       }
293     }
294     TrackingAllocator* wrapped_allocator =
295         new TrackingAllocator(allocator, params_->track_allocations);
296     wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
297     return wrapped_allocator;
298   } else {
299     return allocator;
300   }
301 }
302 
SetStatus(const Status & status)303 void OpKernelContext::SetStatus(const Status& status) {
304   status_.Update(status);
305 }
306 
really_record_tensor_reference(const Tensor & tensor)307 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
308   mutex_lock l(mu_);
309   // Keep a reference to the underlying memory around.
310   referenced_tensors_->Add(tensor);
311 }
312 
input(StringPiece name,const Tensor ** tensor)313 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
314   int start, stop;
315   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
316   if (stop != start + 1) {
317     return errors::InvalidArgument("OpKernel used list-valued input name '",
318                                    name,
319                                    "' when single-valued input was "
320                                    "expected");
321   }
322   if (input_is_ref(start)) {
323     return errors::InvalidArgument("OpKernel used ref input name '", name,
324                                    "' when non-ref input was expected");
325   }
326   *tensor = (*params_->inputs)[start].tensor;
327   record_tensor_reference(**tensor);
328   return Status::OK();
329 }
330 
input_dtype(StringPiece name,DataType * dtype) const331 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
332   int start, stop;
333   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
334   if (stop != start + 1) {
335     return errors::InvalidArgument("OpKernel used list-valued input name '",
336                                    name,
337                                    "' when single-valued input was "
338                                    "expected");
339   }
340   const TensorValue& value((*params_->inputs)[start]);
341   if (value.is_ref()) {
342     *dtype = MakeRefType(value->dtype());
343   } else {
344     *dtype = value->dtype();
345   }
346   return Status::OK();
347 }
348 
input_ref_mutex(StringPiece name,mutex ** out_mutex)349 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
350   int start, stop;
351   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
352   if (stop != start + 1) {
353     return errors::InvalidArgument("OpKernel used list-valued input name '",
354                                    name,
355                                    "' when single-valued input was expected");
356   }
357   *out_mutex = input_ref_mutex(start);
358   return Status::OK();
359 }
360 
input(int index)361 const Tensor& OpKernelContext::input(int index) {
362   DCHECK_GE(index, 0);
363   DCHECK_LT(index, num_inputs());
364   DCHECK(!input_is_ref(index));
365   const Tensor& tensor = *((*params_->inputs)[index].tensor);
366   record_tensor_reference(tensor);
367   return tensor;
368 }
369 
mutable_input(int index,bool lock_held)370 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
371   DCHECK_GE(index, 0);
372   DCHECK_LT(index, num_inputs());
373   DCHECK(input_is_ref(index));
374   // return a copy of the Ref acquired while holding the mutex
375   if (lock_held) {
376     Tensor& tensor = *((*params_->inputs)[index].tensor);
377     record_tensor_reference(tensor);
378     return tensor;
379   } else {
380     mutex_lock l(*input_ref_mutex(index));
381     Tensor& tensor = *((*params_->inputs)[index].tensor);
382     record_tensor_reference(tensor);
383     return tensor;
384   }
385 }
386 
replace_ref_input(int index,const Tensor & tensor,bool lock_held)387 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
388                                         bool lock_held) {
389   DCHECK_GE(index, 0);
390   DCHECK_LT(index, num_inputs());
391   DCHECK(input_is_ref(index));
392   // should only modify the tensor while holding the mutex
393   if (lock_held) {
394     *(*params_->inputs)[index].tensor = tensor;
395   } else {
396     mutex_lock l(*input_ref_mutex(index));
397     *(*params_->inputs)[index].tensor = tensor;
398   }
399   record_tensor_reference(tensor);
400 }
401 
forward_ref_input_to_ref_output(int input_index,int output_index)402 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
403                                                       int output_index) {
404   DCHECK_GE(input_index, 0);
405   DCHECK_LT(input_index, num_inputs());
406   DCHECK(input_is_ref(input_index));
407   set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
408                  (*params_->inputs)[input_index].tensor);
409 }
410 
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)411 bool OpKernelContext::forward_input_to_output_with_shape(
412     int input_index, int output_index, const TensorShape& output_shape,
413     Tensor** output) {
414   const auto output_attr = params_->output_attr_array == nullptr
415                                ? AllocatorAttributes()
416                                : output_alloc_attr(output_index);
417   std::unique_ptr<Tensor> new_tensor = forward_input(
418       input_index, expected_output_dtype(output_index), output_shape,
419       output_memory_type(output_index), output_attr);
420   if (new_tensor != nullptr) {
421     // Transfer ownership to the output slot in OpKernelContext.
422     outputs_[output_index] = TensorValue(new_tensor.release());
423     *output = outputs_[output_index].tensor;
424     return true;
425   } else {
426     return false;
427   }
428 }
429 
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)430 Status OpKernelContext::forward_input_to_output_with_shape(
431     StringPiece input_name, StringPiece output_name,
432     const TensorShape& output_shape, Tensor** output) {
433   int input_index, output_index, stop;
434   TF_RETURN_IF_ERROR(
435       params_->op_kernel->InputRange(input_name, &input_index, &stop));
436   if (stop != input_index + 1) {
437     return errors::InvalidArgument("OpKernel used list-valued input name '",
438                                    input_name,
439                                    "' when single-valued input was "
440                                    "expected");
441   }
442   TF_RETURN_IF_ERROR(
443       params_->op_kernel->OutputRange(output_name, &output_index, &stop));
444   if (stop != output_index + 1) {
445     return errors::InvalidArgument("OpKernel used list-valued output name '",
446                                    output_name,
447                                    "' when single-valued output was "
448                                    "expected");
449   }
450   if (!forward_input_to_output_with_shape(input_index, output_index,
451                                           output_shape, output)) {
452     return errors::FailedPrecondition("OpKernel could not forward input '",
453                                       input_name, "' to output '", output_name);
454   }
455   return Status::OK();
456 }
457 
forward_input(int input_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)458 std::unique_ptr<Tensor> OpKernelContext::forward_input(
459     int input_index, DataType output_dtype, const TensorShape& output_shape,
460     MemoryType output_memory_type, const AllocatorAttributes& output_attr) {
461   DCHECK_GE(input_index, 0);
462   DCHECK_LT(input_index, num_inputs());
463   const TensorValue& input = (*params_->inputs)[input_index];
464   // Check that input tensor exists, is not a ref, and has no other consumers.
465   if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) {
466     return nullptr;
467   }
468   // Check that input type matches.
469   if (input_dtype(input_index) != output_dtype) {
470     return nullptr;
471   }
472   // Check that the input and output sizes are compatible.
473   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
474     return nullptr;
475   }
476   // Check that input and output memory types match, i.e.
477   // that they either both live in host or both live in device memory.
478   if (input_memory_type(input_index) != output_memory_type) {
479     return nullptr;
480   }
481   // Check that output allocator attributes are not more restrictive than
482   // input allocator attributes.
483   const auto input_attr = params_->input_alloc_attrs == nullptr
484                               ? AllocatorAttributes()
485                               : input_alloc_attr(input_index);
486   if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
487     return nullptr;
488   }
489   // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
490   // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
491   // general cleanup of ownership in this code.
492   std::unique_ptr<Tensor> output_tensor(new Tensor());
493   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
494   return output_tensor;
495 }
496 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)497 Status OpKernelContext::forward_input_or_allocate_temp(
498     gtl::ArraySlice<int> candidate_input_indices, DataType type,
499     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
500     Tensor* out_temp) {
501   for (int input_index : candidate_input_indices) {
502     std::unique_ptr<Tensor> new_tensor =
503         forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr);
504     if (new_tensor != nullptr) {
505       *out_temp = std::move(*new_tensor);
506       return Status::OK();
507     }
508   }
509   return allocate_temp(type, shape, out_temp, allocator_attr);
510 }
511 
delete_ref_input(int index,bool lock_held)512 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
513   DCHECK_GE(index, 0);
514   DCHECK_LT(index, num_inputs());
515   DCHECK(input_is_ref(index));
516   // should only modify the tensor while holding the mutex
517   if (lock_held) {
518     delete (*params_->inputs)[index].tensor;
519   } else {
520     mutex_lock l(*input_ref_mutex(index));
521     delete (*params_->inputs)[index].tensor;
522   }
523 }
524 
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)525 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
526                                       bool lock_held) {
527   int start, stop;
528   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
529   if (stop != start + 1) {
530     return errors::InvalidArgument("OpKernel used list-valued input name '",
531                                    name,
532                                    "' when single-valued input was expected");
533   }
534   if (!input_is_ref(start)) {
535     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
536                                    "' when ref input was expected");
537   }
538   // return a copy of the Ref acquired while holding the mutex
539   if (lock_held) {
540     *tensor = *(*params_->inputs)[start].tensor;
541   } else {
542     mutex_lock l(*input_ref_mutex(start));
543     *tensor = *(*params_->inputs)[start].tensor;
544   }
545   record_tensor_reference(*tensor);
546   return Status::OK();
547 }
548 
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)549 Status OpKernelContext::replace_ref_input(StringPiece name,
550                                           const Tensor& tensor,
551                                           bool lock_held) {
552   int start, stop;
553   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
554   if (stop != start + 1) {
555     return errors::InvalidArgument("OpKernel used list-valued input name '",
556                                    name,
557                                    "' when single-valued input was expected");
558   }
559   if (!input_is_ref(start)) {
560     return errors::InvalidArgument("OpKernel used immutable input name '", name,
561                                    "' when ref input was expected");
562   }
563   replace_ref_input(start, tensor, lock_held);
564   return Status::OK();
565 }
566 
input_list(StringPiece name,OpInputList * list)567 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
568   int start, stop;
569   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
570   *list = OpInputList(this, start, stop);
571   return Status::OK();
572 }
573 
mutable_input_list(StringPiece name,OpMutableInputList * list)574 Status OpKernelContext::mutable_input_list(StringPiece name,
575                                            OpMutableInputList* list) {
576   int start, stop;
577   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
578   *list = OpMutableInputList(this, start, stop);
579   return Status::OK();
580 }
581 
output_list(StringPiece name,OpOutputList * list)582 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
583   int start, stop;
584   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
585   *list = OpOutputList(this, start, stop);
586   return Status::OK();
587 }
588 
allocate_output(int index,const TensorShape & shape,Tensor ** output)589 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
590                                         Tensor** output) {
591   DCHECK_GE(index, 0);
592   DCHECK_LT(index, num_outputs());
593   AllocatorAttributes attr = output_alloc_attr(index);
594   return allocate_output(index, shape, output, attr);
595 }
596 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)597 Status OpKernelContext::allocate_output(StringPiece name,
598                                         const TensorShape& shape,
599                                         Tensor** tensor) {
600   int start, stop;
601   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
602   if (stop != start + 1) {
603     return errors::InvalidArgument("OpKernel used list-valued output name '",
604                                    name,
605                                    "' when single-valued output was "
606                                    "expected");
607   }
608   return allocate_output(start, shape, tensor);
609 }
610 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)611 Status OpKernelContext::allocate_output(StringPiece name,
612                                         const TensorShape& shape,
613                                         Tensor** tensor,
614                                         AllocatorAttributes attr) {
615   int start, stop;
616   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
617   if (stop != start + 1) {
618     return errors::InvalidArgument("OpKernel used list-valued output name '",
619                                    name,
620                                    "' when single-valued output was "
621                                    "expected");
622   }
623   return allocate_output(start, shape, tensor, attr);
624 }
625 
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)626 Status OpKernelContext::allocate_tensor(
627     DataType type, const TensorShape& shape, Tensor* out_tensor,
628     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
629   Allocator* a = get_allocator(attr);
630   AllocationAttributes logged_attr(allocation_attr);
631   logged_attr.allocation_will_be_logged = true;
632   Tensor new_tensor(a, type, shape, logged_attr);
633 
634   if (!new_tensor.IsInitialized()) {
635     return errors::ResourceExhausted(
636         "OOM when allocating tensor with shape", shape.DebugString(),
637         " and type ", DataTypeString(type), " on ", params_->device->name(),
638         " by allocator ", a->Name());
639   }
640   if (params_->log_memory) {
641     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
642                                       params_->step_id, new_tensor);
643   }
644   record_tensor_reference(new_tensor);
645   *out_tensor = std::move(new_tensor);
646   return Status::OK();
647 }
648 
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)649 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
650                                         Tensor** output,
651                                         AllocatorAttributes attr) {
652   DCHECK_GE(index, 0);
653   DCHECK_LT(index, outputs_.size());
654   const DataType type = params_->op_kernel->output_type(index);
655   DCHECK(!IsRefType(type));
656   DCHECK(mutable_output(index) == nullptr);
657   Tensor* output_tensor = new Tensor();
658   Status s = allocate_tensor(type, shape, output_tensor, attr);
659   if (s.ok()) {
660     outputs_[index] = TensorValue(output_tensor);
661     *output = outputs_[index].tensor;
662   }
663   return s;
664 }
665 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)666 Status OpKernelContext::allocate_temp(
667     DataType type, const TensorShape& shape, Tensor* out_temp,
668     AllocatorAttributes allocator_attr,
669     const AllocationAttributes& allocation_attr) {
670   Status s =
671       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
672   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
673     Allocator* a = get_allocator(allocator_attr);
674     if (a->TracksAllocationSizes()) {
675       int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
676       record_temp_memory_allocation(alloc_size, *out_temp);
677     }
678   }
679   return s;
680 }
681 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor,AllocatorAttributes attr)682 Status OpKernelContext::allocate_persistent(DataType type,
683                                             const TensorShape& shape,
684                                             PersistentTensor* out_persistent,
685                                             Tensor** out_tensor,
686                                             AllocatorAttributes attr) {
687   Tensor persistent;
688   Status s = allocate_tensor(type, shape, &persistent, attr);
689   if (s.ok()) {
690     *out_persistent = PersistentTensor(persistent);
691     if (out_tensor) {
692       *out_tensor = out_persistent->AccessTensor(this);
693     }
694     if (track_allocations()) {
695       Tensor* t = out_persistent->AccessTensor(this);
696       Allocator* a = get_allocator(attr);
697       if (a->TracksAllocationSizes()) {
698         int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
699         int64 alloc_id = a->AllocationId(t->tensor_data().data());
700         record_persistent_memory_allocation(alloc_size, alloc_id);
701       }
702     }
703   }
704   return s;
705 }
706 
set_output(StringPiece name,const Tensor & tensor)707 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
708   int start, stop;
709   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
710   if (stop != start + 1) {
711     return errors::InvalidArgument("OpKernel used list-valued output name '",
712                                    name,
713                                    "' when single-valued output was "
714                                    "expected");
715   }
716   set_output(start, tensor);
717   return Status::OK();
718 }
719 
set_output(int index,const Tensor & tensor)720 void OpKernelContext::set_output(int index, const Tensor& tensor) {
721   DCHECK_GE(index, 0);
722   DCHECK_LT(index, outputs_.size());
723   DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
724   DCHECK_EQ(mutable_output(index), nullptr);
725   record_tensor_reference(tensor);
726   outputs_[index] = TensorValue(new Tensor(tensor));
727   if (track_allocations() && tensor.TotalBytes() > 0) {
728     mutex_lock l(stats_mu_);
729     if (!temp_tensor_buffer_and_size_) {
730       return;
731     }
732     auto it = std::find_if(temp_tensor_buffer_and_size_->begin(),
733                            temp_tensor_buffer_and_size_->end(),
734                            [&tensor](const std::pair<const void*, int64>& e) {
735                              return e.first == static_cast<const void*>(
736                                                    tensor.tensor_data().data());
737                            });
738     if (it != temp_tensor_buffer_and_size_->end()) {
739       temp_memory_allocated_ -= it->second;
740       temp_tensor_buffer_and_size_->erase(it);
741     }
742   }
743 }
744 
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)745 void OpKernelContext::set_output_ref(int index, mutex* mu,
746                                      Tensor* tensor_for_ref) {
747   DCHECK_GE(index, 0);
748   DCHECK_LT(index, outputs_.size());
749   DCHECK(IsRefType(params_->op_kernel->output_type(index)));
750   record_tensor_reference(*tensor_for_ref);
751   outputs_[index] = TensorValue(mu, tensor_for_ref);
752 }
753 
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)754 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
755                                        Tensor* tensor_for_ref) {
756   int start, stop;
757   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
758   if (stop != start + 1) {
759     return errors::InvalidArgument("OpKernel used list-valued output name '",
760                                    name,
761                                    "' when single-valued output was "
762                                    "expected");
763   }
764   set_output_ref(start, mu, tensor_for_ref);
765   return Status::OK();
766 }
767 
mutable_output(StringPiece name,Tensor ** tensor)768 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
769   int start, stop;
770   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
771   if (stop != start + 1) {
772     return errors::InvalidArgument("OpKernel used list-valued output name '",
773                                    name,
774                                    "' when single-valued output was "
775                                    "expected");
776   }
777   *tensor = mutable_output(start);
778   return Status::OK();
779 }
780 
release_output(StringPiece name,TensorValue * value)781 Status OpKernelContext::release_output(StringPiece name, TensorValue* value) {
782   int start, stop;
783   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
784   if (stop != start + 1) {
785     return errors::InvalidArgument("OpKernel used list-valued output name '",
786                                    name,
787                                    "' when single-valued output was "
788                                    "expected");
789   }
790   *value = release_output(start);
791   return Status::OK();
792 }
793 
ValidateInputsAreSameShape(OpKernel * op)794 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
795   const auto& inputs = *params_->inputs;
796   for (size_t i = 1; i < inputs.size(); ++i) {
797     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
798       SetStatus(errors::InvalidArgument(
799           "Inputs to operation ", op->name(), " of type ", op->type_string(),
800           " must have the same size and shape.  Input 0: ",
801           inputs[0]->shape().DebugString(), " != input ", i, ": ",
802           inputs[i]->shape().DebugString()));
803       return false;
804     }
805   }
806   return true;
807 }
808 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)809 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
810                                        const DataTypeSlice expected_outputs) {
811   DataTypeVector inputs;
812   for (const TensorValue& t : *params_->inputs) {
813     inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
814   }
815   DataTypeVector outputs = params_->op_kernel->output_types();
816   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
817                               outputs);
818 }
819 
record_temp_memory_allocation(int64 size,const Tensor & t)820 void OpKernelContext::record_temp_memory_allocation(int64 size,
821                                                     const Tensor& t) {
822   mutex_lock l(stats_mu_);
823   temp_memory_allocated_ += size;
824   if (!temp_tensor_buffer_and_size_) {
825     temp_tensor_buffer_and_size_.reset(
826         new gtl::InlinedVector<std::pair<const void*, int64>, 2>());
827   }
828   temp_tensor_buffer_and_size_->emplace_back(
829       static_cast<const void*>(t.tensor_data().data()), size);
830 }
831 
temp_memory_allocated() const832 int64 OpKernelContext::temp_memory_allocated() const {
833   mutex_lock l(stats_mu_);
834   return temp_memory_allocated_;
835 }
836 
record_persistent_memory_allocation(int64 size,int64 alloc_id)837 void OpKernelContext::record_persistent_memory_allocation(int64 size,
838                                                           int64 alloc_id) {
839   mutex_lock l(stats_mu_);
840   persistent_memory_allocated_ += size;
841   if (alloc_id >= 0) {
842     if (!persistent_alloc_ids_) {
843       persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>());
844     }
845     persistent_alloc_ids_->push_back(alloc_id);
846   }
847 }
848 
persistent_memory_allocated() const849 int64 OpKernelContext::persistent_memory_allocated() const {
850   mutex_lock l(stats_mu_);
851   return persistent_memory_allocated_;
852 }
853 
persistent_alloc_ids() const854 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
855   mutex_lock l(stats_mu_);
856   if (persistent_alloc_ids_) {
857     return std::vector<int64>(persistent_alloc_ids_->begin(),
858                               persistent_alloc_ids_->end());
859   } else {
860     return std::vector<int64>();
861   }
862 }
863 
clear_recorded_memory()864 void OpKernelContext::clear_recorded_memory() {
865   mutex_lock l(stats_mu_);
866   temp_memory_allocated_ = 0;
867   persistent_memory_allocated_ = 0;
868   if (temp_tensor_buffer_and_size_) {
869     temp_tensor_buffer_and_size_->clear();
870   }
871   if (persistent_alloc_ids_) {
872     persistent_alloc_ids_->clear();
873   }
874 }
875 
876 // OpKernel registration ------------------------------------------------------
877 
878 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration879   KernelRegistration(const KernelDef& d, StringPiece c,
880                      kernel_factory::OpKernelRegistrar::Factory f)
881       : def(d), kernel_class_name(c.ToString()), factory(f) {}
882   const KernelDef def;
883   const string kernel_class_name;
884   const kernel_factory::OpKernelRegistrar::Factory factory;
885 };
886 
887 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
888 // factory functions for instantiating the OpKernel that matches the
889 // KernelDef.
890 typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
891 
GlobalKernelRegistry()892 void* GlobalKernelRegistry() {
893   static KernelRegistry* global_kernel_registry = new KernelRegistry;
894   return global_kernel_registry;
895 }
896 
GlobalKernelRegistryTyped()897 static KernelRegistry* GlobalKernelRegistryTyped() {
898   return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
899 }
900 
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)901 static string Key(StringPiece op_type, const DeviceType& device_type,
902                   StringPiece label) {
903   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
904                          label);
905 }
906 
907 namespace kernel_factory {
908 
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,Factory factory)909 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
910                                      StringPiece kernel_class_name,
911                                      Factory factory) {
912   // See comments in register_kernel::Name in header for info on _no_register.
913   if (kernel_def->op() != "_no_register") {
914     const string key =
915         Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
916             kernel_def->label());
917     GlobalKernelRegistryTyped()->insert(std::make_pair(
918         key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
919   }
920   delete kernel_def;
921 }
922 
923 }  // namespace kernel_factory
924 
925 namespace {
926 
927 // Helper for AttrsMatch().
InTypeList(DataType dt,const AttrValue & type_list)928 bool InTypeList(DataType dt, const AttrValue& type_list) {
929   for (int in_list : type_list.list().type()) {
930     if (dt == in_list) return true;
931   }
932   return false;
933 }
934 
935 // Returns whether the attrs satisfy the constraints in the kernel_def.  Returns
936 // an error if attrs in kernel_def are not found, or have a mismatching type.
AttrsMatch(AttrSlice attrs,const KernelDef & kernel_def,bool * match)937 Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
938   *match = false;
939   for (const auto& constraint : kernel_def.constraint()) {
940     if (constraint.allowed_values().list().type_size() == 0) {
941       return errors::Unimplemented(
942           "KernelDef '", ProtoShortDebugString(kernel_def),
943           " has constraint on attr '", constraint.name(),
944           "' with unsupported type: ",
945           SummarizeAttrValue(constraint.allowed_values()));
946     }
947 
948     const AttrValue* found = attrs.Find(constraint.name());
949     if (found) {
950       if (found->type() != DT_INVALID) {
951         if (!InTypeList(found->type(), constraint.allowed_values())) {
952           return Status::OK();
953         }
954       } else {
955         if (!AttrValueHasType(*found, "list(type)").ok()) {
956           return errors::InvalidArgument(
957               "KernelDef '", ProtoShortDebugString(kernel_def),
958               "' has constraint on attr '", constraint.name(),
959               "' that has value '", SummarizeAttrValue(*found),
960               "' that does not have type 'type' or 'list(type)' in NodeDef "
961               "'",
962               attrs.SummarizeNode(), "'");
963         }
964 
965         for (int t : found->list().type()) {
966           if (!InTypeList(static_cast<DataType>(t),
967                           constraint.allowed_values())) {
968             return Status::OK();
969           }
970         }
971       }
972     } else {
973       return errors::InvalidArgument(
974           "OpKernel '", kernel_def.op(), "' has constraint on attr '",
975           constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
976           "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
977     }
978   }
979   *match = true;
980   return Status::OK();
981 }
982 
983 static const StringPiece kKernelAttr("_kernel");
984 
985 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)986 Status FindKernelRegistration(const DeviceType& device_type,
987                               const NodeDef& node_def,
988                               const KernelRegistration** reg,
989                               bool* was_attr_mismatch) {
990   *reg = nullptr;
991   *was_attr_mismatch = false;
992   // Label defaults to empty if not found in NodeDef.
993   const string& label = GetNodeAttrString(node_def, kKernelAttr);
994 
995   const string key = Key(node_def.op(), device_type, label);
996   auto regs = GlobalKernelRegistryTyped()->equal_range(key);
997   for (auto iter = regs.first; iter != regs.second; ++iter) {
998     // If there is a kernel registered for the op and device_type,
999     // check that the attrs match.
1000     bool match;
1001     TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
1002     if (match) {
1003       if (*reg != nullptr) {
1004         return errors::InvalidArgument(
1005             "Multiple OpKernel registrations match NodeDef '",
1006             SummarizeNodeDef(node_def), "': '",
1007             ProtoShortDebugString((*reg)->def), "' and '",
1008             ProtoShortDebugString(iter->second.def), "'");
1009       }
1010       *reg = &iter->second;
1011     } else {
1012       *was_attr_mismatch = true;
1013     }
1014   }
1015   return Status::OK();
1016 }
1017 
1018 }  // namespace
1019 
1020 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1021 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1022                      const KernelDef** def, string* kernel_class_name) {
1023   const KernelRegistration* reg = nullptr;
1024   bool was_attr_mismatch;
1025   TF_RETURN_IF_ERROR(
1026       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch));
1027   if (reg == nullptr) {
1028     Status s = errors::NotFound(
1029         "No registered '", node_def.op(), "' OpKernel for ",
1030         DeviceTypeString(device_type), " devices compatible with node ",
1031         SummarizeNodeDef(node_def));
1032     if (was_attr_mismatch) {
1033       errors::AppendToMessage(
1034           &s, " (OpKernel was found, but attributes didn't match)");
1035     }
1036     errors::AppendToMessage(
1037         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1038     return s;
1039   }
1040   if (def != nullptr) *def = &reg->def;
1041   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1042   return Status::OK();
1043 }
1044 
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,DeviceTypeVector * device_types)1045 Status SupportedDeviceTypesForNode(
1046     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1047     DeviceTypeVector* device_types) {
1048   // TODO(zhifengc): Changes the callers (SimplePlacer and
1049   // DynamicPlacer) to consider the possibility that 'def' is call to
1050   // a user-defined function and only calls this
1051   // SupportedDeviceTypesForNode for primitive ops.
1052   const OpRegistrationData* op_reg_data;
1053   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1054   if (s.ok()) {
1055     for (const DeviceType& device_type : prioritized_types) {
1056       const KernelRegistration* reg = nullptr;
1057       bool was_attr_mismatch;
1058       TF_RETURN_IF_ERROR(
1059           FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
1060       if (reg != nullptr) device_types->push_back(device_type);
1061     }
1062   } else {
1063     // Assumes that all device types support this node.
1064     for (const DeviceType& device_type : prioritized_types) {
1065       device_types->push_back(device_type);
1066     }
1067   }
1068   return Status::OK();
1069 }
1070 
LogAllRegisteredKernels()1071 void LogAllRegisteredKernels() {
1072   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
1073     const KernelDef& kernel_def(key_registration.second.def);
1074     LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
1075   }
1076 }
1077 
KernelsRegisteredForOp(StringPiece op_name)1078 string KernelsRegisteredForOp(StringPiece op_name) {
1079   string ret;
1080   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
1081     const KernelDef& kernel_def(key_registration.second.def);
1082     if (kernel_def.op() == op_name) {
1083       strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
1084       if (!kernel_def.label().empty()) {
1085         strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1086       }
1087       for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1088         strings::StrAppend(
1089             &ret, "; ", kernel_def.constraint(i).name(), " in ",
1090             SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1091       }
1092       strings::StrAppend(&ret, "\n");
1093     }
1094   }
1095   if (ret.empty()) return "  <no registered kernels>\n";
1096   return ret;
1097 }
1098 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1099 std::unique_ptr<OpKernel> CreateOpKernel(
1100     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1101     const NodeDef& node_def, int graph_def_version, Status* status) {
1102   OpKernel* kernel = nullptr;
1103   *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
1104                            node_def, graph_def_version, &kernel);
1105   return std::unique_ptr<OpKernel>(kernel);
1106 }
1107 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const NodeDef & node_def,int graph_def_version,OpKernel ** kernel)1108 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1109                       Allocator* allocator, FunctionLibraryRuntime* flib,
1110                       const NodeDef& node_def, int graph_def_version,
1111                       OpKernel** kernel) {
1112   VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1113 
1114   // Look up the Op registered for this op name.
1115   const OpDef* op_def = nullptr;
1116   Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
1117   if (!s.ok()) return s;
1118 
1119   // Validate node_def against OpDef.
1120   s = ValidateNodeDef(node_def, *op_def);
1121   if (!s.ok()) return s;
1122 
1123   // Look up kernel registration.
1124   const KernelRegistration* registration;
1125   bool was_attr_mismatch;
1126   s = FindKernelRegistration(device_type, node_def, &registration,
1127                              &was_attr_mismatch);
1128   if (!s.ok()) {
1129     errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1130     return s;
1131   }
1132   if (registration == nullptr) {
1133     s.Update(errors::NotFound("No registered '", node_def.op(),
1134                               "' OpKernel for ", DeviceTypeString(device_type),
1135                               " devices compatible with node ",
1136                               SummarizeNodeDef(node_def)));
1137     if (was_attr_mismatch) {
1138       errors::AppendToMessage(
1139           &s, " (OpKernel was found, but attributes didn't match)");
1140     }
1141     errors::AppendToMessage(
1142         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1143     return s;
1144   }
1145 
1146   // Get signature from the OpDef & NodeDef
1147   DataTypeVector inputs;
1148   DataTypeVector outputs;
1149   s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
1150   if (!s.ok()) {
1151     errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def));
1152     return s;
1153   }
1154 
1155   // We are creating a kernel for an op registered in
1156   // OpRegistry::Global(), we consult the kernel registry to decide
1157   // the kernel's input and output memory types.
1158   MemoryTypeVector input_memory_types;
1159   MemoryTypeVector output_memory_types;
1160   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1161                                         node_def, &input_memory_types,
1162                                         &output_memory_types));
1163 
1164   // Everything needed for OpKernel construction.
1165   OpKernelConstruction context(
1166       device_type, device, allocator, &node_def, op_def, flib, inputs,
1167       input_memory_types, outputs, output_memory_types, graph_def_version, &s);
1168   *kernel = (*registration->factory)(&context);
1169   if (!s.ok()) {
1170     delete *kernel;
1171     *kernel = nullptr;
1172   }
1173   return s;
1174 }
1175 
1176 namespace {
1177 
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1178 bool FindArgInOp(StringPiece arg_name,
1179                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1180   for (const auto& arg : args) {
1181     if (arg_name == arg.name()) {
1182       return true;
1183     }
1184   }
1185   return false;
1186 }
1187 
1188 }  // namespace
1189 
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1190 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1191   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
1192     const KernelDef& kernel_def(key_registration.second.def);
1193     const OpRegistrationData* op_reg_data;
1194     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1195     if (!status.ok()) {
1196       // TODO(josh11b): Make this a hard error.
1197       LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
1198                  << "') for unknown op: " << kernel_def.op();
1199       continue;
1200     }
1201     const OpDef& op_def = op_reg_data->op_def;
1202     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1203       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1204           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1205         return errors::InvalidArgument(
1206             "HostMemory arg '", host_memory_arg,
1207             "' not found in OpDef: ", SummarizeOpDef(op_def));
1208       }
1209     }
1210   }
1211   return Status::OK();
1212 }
1213 
1214 template <>
eigen_device() const1215 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1216   return eigen_cpu_device();
1217 }
1218 
1219 template <>
eigen_device() const1220 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1221   return eigen_gpu_device();
1222 }
1223 
1224 #ifdef TENSORFLOW_USE_SYCL
1225 template <>
eigen_device() const1226 const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
1227   return eigen_sycl_device();
1228 }
1229 #endif
1230 
CtxFailure(const Status & s)1231 void OpKernelConstruction::CtxFailure(const Status& s) {
1232   VLOG(1) << s;
1233   SetStatus(s);
1234 }
1235 
CtxFailureWithWarning(const Status & s)1236 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1237   LOG(WARNING) << s;
1238   SetStatus(s);
1239 }
1240 
CtxFailure(const char * file,int line,const Status & s)1241 void OpKernelConstruction::CtxFailure(const char* file, int line,
1242                                       const Status& s) {
1243   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1244           << " : " << s;
1245   SetStatus(s);
1246 }
1247 
CtxFailureWithWarning(const char * file,int line,const Status & s)1248 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1249                                                  const Status& s) {
1250   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1251                << " : " << s;
1252   SetStatus(s);
1253 }
1254 
CtxFailure(const Status & s)1255 void OpKernelContext::CtxFailure(const Status& s) {
1256   VLOG(1) << s;
1257   SetStatus(s);
1258 }
1259 
CtxFailureWithWarning(const Status & s)1260 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1261   LOG(WARNING) << s;
1262   SetStatus(s);
1263 }
1264 
CtxFailure(const char * file,int line,const Status & s)1265 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1266   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1267           << " : " << s;
1268   SetStatus(s);
1269 }
1270 
CtxFailureWithWarning(const char * file,int line,const Status & s)1271 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1272                                             const Status& s) {
1273   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1274                << " : " << s;
1275   SetStatus(s);
1276 }
1277 
1278 }  // namespace tensorflow
1279