• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex>  // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/base/call_once.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/kernel_def.pb.h"
34 #include "tensorflow/core/framework/kernel_def_util.h"
35 #include "tensorflow/core/framework/log_memory.h"
36 #include "tensorflow/core/framework/memory_types.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/node_properties.h"
40 #include "tensorflow/core/framework/op_def_util.h"
41 #include "tensorflow/core/framework/tensor_reference.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/core/stringpiece.h"
46 #include "tensorflow/core/lib/gtl/map_util.h"
47 #include "tensorflow/core/lib/io/path.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/cpu_info.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/platform_strings.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/profiler/lib/traceme.h"
57 #include "tensorflow/core/util/ptr_util.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)63 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
64                             const DataTypeSlice expected_outputs,
65                             const DataTypeSlice inputs,
66                             const DataTypeSlice outputs) {
67   bool signature_mismatch = false;
68 
69   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
70   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
71     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
72       signature_mismatch = true;
73     }
74   }
75 
76   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
77   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
78     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
79       signature_mismatch = true;
80     }
81   }
82 
83   if (signature_mismatch) {
84     return errors::InvalidArgument(
85         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
86         DataTypeSliceString(outputs),
87         " expected: ", DataTypeSliceString(expected_inputs), "->",
88         DataTypeSliceString(expected_outputs));
89   }
90   return Status::OK();
91 }
92 
93 }  // namespace
94 
95 // OpKernel ------------------------------------------------------------------
96 
OpKernel(OpKernelConstruction * context)97 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
98 
OpKernel(OpKernelConstruction * context,bool is_deferred)99 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
100     : props_(context->props_),
101       input_memory_types_(context->input_memory_types().begin(),
102                           context->input_memory_types().end()),
103       output_memory_types_(context->output_memory_types().begin(),
104                            context->output_memory_types().end()),
105       input_name_map_(context->num_inputs()),
106       output_name_map_(context->num_outputs()),
107       name_view_(props_->node_def.name()),
108       type_string_view_(props_->node_def.op()),
109       graph_def_version_(context->graph_def_version()),
110       is_deferred_(is_deferred) {
111   OP_REQUIRES_OK(context,
112                  NameRangesForNode(props_->node_def, *props_->op_def,
113                                    &input_name_map_, &output_name_map_));
114   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
115                                              context->graph_def_version()));
116 
117   // Kernels executing on GPU tie very few resources on the CPU where the
118   // scheduler runs: we consider them as inexpensive.
119   expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
120 }
121 
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)122 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
123                    bool is_deferred)
124     : props_(std::make_shared<const NodeProperties>(
125           context->props_->op_def, std::move(custom_def),
126           context->props_->input_types, context->props_->output_types)),
127       input_memory_types_(context->input_memory_types().begin(),
128                           context->input_memory_types().end()),
129       output_memory_types_(context->output_memory_types().begin(),
130                            context->output_memory_types().end()),
131       input_name_map_(context->num_inputs()),
132       output_name_map_(context->num_outputs()),
133       name_view_(props_->node_def.name()),
134       type_string_view_(props_->node_def.op()),
135       graph_def_version_(context->graph_def_version()),
136       is_deferred_(is_deferred) {
137   OP_REQUIRES_OK(context,
138                  NameRangesForNode(props_->node_def, *props_->op_def,
139                                    &input_name_map_, &output_name_map_));
140   OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
141                                              context->graph_def_version()));
142 
143   // Kernels executing on GPU tie very few resources on the CPU where the
144   // scheduler runs: we consider them as inexpensive.
145   expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
146 }
147 
~OpKernel()148 OpKernel::~OpKernel() {}
149 
InputRange(StringPiece input_name,int * start,int * stop) const150 Status OpKernel::InputRange(StringPiece input_name, int* start,
151                             int* stop) const {
152   const auto result = input_name_map_.find(input_name);
153   if (result == input_name_map_.end()) {
154     return errors::InvalidArgument("Unknown input name: ", input_name);
155   } else {
156     *start = result->second.first;
157     *stop = result->second.second;
158     return Status::OK();
159   }
160 }
161 
OutputRange(StringPiece output_name,int * start,int * stop) const162 Status OpKernel::OutputRange(StringPiece output_name, int* start,
163                              int* stop) const {
164   const auto result = output_name_map_.find(output_name);
165   if (result == output_name_map_.end()) {
166     return errors::InvalidArgument("Unknown output name: ", output_name);
167   } else {
168     *start = result->second.first;
169     *stop = result->second.second;
170     return Status::OK();
171   }
172 }
173 
ShapeTraceString(const OpKernelContext & ctx) const174 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
175   int num_inputs = ctx.num_inputs();
176   if (num_inputs == 0) return "";
177   std::vector<string> tensor_shapes;
178   tensor_shapes.reserve(num_inputs);
179   for (int i = 0; i < num_inputs; i++) {
180     if (!ctx.has_input(i)) {
181       tensor_shapes.emplace_back();  // Placeholder
182       continue;
183     }
184     DataType input_dtype = ctx.input_dtype(i);
185     if (input_dtype == DataType::DT_RESOURCE ||
186         input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
187       tensor_shapes.emplace_back();  // Placeholder
188       continue;
189     }
190     tensor_shapes.emplace_back(strings::StrCat(
191         DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
192   }
193   return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
194 }
195 
TraceString(const OpKernelContext & ctx,bool verbose) const196 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
197   string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
198   if (verbose) {
199     string shape = ShapeTraceString(ctx);
200     if (!shape.empty()) {
201       trace_string =
202           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
203     }
204   }
205   return trace_string;
206 }
207 
Compute(OpKernelContext * context)208 void AsyncOpKernel::Compute(OpKernelContext* context) {
209   Notification n;
210   ComputeAsync(context, [&n]() { n.Notify(); });
211   n.WaitForNotification();
212 }
213 
214 // PersistentTensor ----------------------------------------------------------
215 
AccessTensor(OpKernelConstruction * context)216 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
217   // the caller has to have a valid context
218   CHECK(context);
219   return &tensor_;
220 }
221 
AccessTensor(OpKernelContext * context)222 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
223   return &tensor_;
224 }
225 
226 // OpKernelConstruction ------------------------------------------------------
227 
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)228 OpKernelConstruction::OpKernelConstruction(
229     DeviceType device_type, DeviceBase* device, Allocator* allocator,
230     FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
231     const std::shared_ptr<const NodeProperties>& props,
232     const MemoryTypeSlice& input_memory_types,
233     const MemoryTypeSlice& output_memory_types, int graph_def_version,
234     Status* status)
235     : device_type_(std::move(device_type)),
236       device_(device),
237       allocator_(allocator),
238       flib_(flib),
239       resource_mgr_(resource_mgr),
240       props_(props),
241       input_memory_types_(input_memory_types),
242       output_memory_types_(output_memory_types),
243       graph_def_version_(graph_def_version),
244       status_(status) {}
245 
HasAttr(StringPiece attr_name) const246 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
247   return HasNodeAttr(def(), attr_name);
248 }
249 
SetStatus(const Status & status)250 void OpKernelConstruction::SetStatus(const Status& status) {
251   status_->Update(status);
252 }
253 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)254 Status OpKernelConstruction::MatchSignature(
255     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
256   return MatchSignatureHelper(expected_inputs, expected_outputs,
257                               props_->input_types, props_->output_types);
258 }
259 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)260 Status OpKernelConstruction::allocate_temp(DataType type,
261                                            const TensorShape& shape,
262                                            Tensor* out_temp) {
263   AllocationAttributes attr;
264   attr.allocation_will_be_logged = true;
265   Tensor new_temp(allocator_, type, shape, attr);
266 
267   if (!new_temp.IsInitialized()) {
268     return errors::ResourceExhausted(
269         "OOM when allocating temporary tensor with shape", shape.DebugString());
270   }
271   if (LogMemory::IsEnabled()) {
272     LogMemory::RecordTensorAllocation(
273         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
274   }
275   *out_temp = new_temp;
276   return Status::OK();
277 }
278 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)279 Status OpKernelConstruction::allocate_temp(DataType type,
280                                            const TensorShape& shape,
281                                            Tensor* out_temp,
282                                            AllocatorAttributes allocator_attr) {
283   if (allocator_attr.scope_id != 0) {
284     return errors::InvalidArgument(
285         "ScopedAllocator cannot be used via OpKernelConstruction.");
286   }
287   Allocator* a = device_->GetAllocator(allocator_attr);
288   AllocationAttributes attr;
289   attr.allocation_will_be_logged = true;
290   Tensor new_temp(a, type, shape, attr);
291 
292   if (!new_temp.IsInitialized()) {
293     return errors::ResourceExhausted(
294         "OOM when allocating temporary tensor with shape", shape.DebugString());
295   }
296   if (LogMemory::IsEnabled()) {
297     LogMemory::RecordTensorAllocation(
298         def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
299   }
300   *out_temp = new_temp;
301   return Status::OK();
302 }
303 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)304 Status OpKernelConstruction::allocate_persistent(
305     DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
306     Tensor** out_tensor) {
307   // for now just do the same thing as allocate_temp
308   // TODO(misard) add specific memory tracking for persistent tensors
309   Tensor persistent;
310   TF_RETURN_IF_ERROR(allocate_temp(type, shape, &persistent));
311 
312   *out_persistent = PersistentTensor(persistent);
313   Tensor* allocated = out_persistent->AccessTensor(this);
314   if (out_tensor) {
315     *out_tensor = allocated;
316   }
317   return Status::OK();
318 }
319 
320 // OpKernelContext -----------------------------------------------------------
321 
322 const int OpKernelContext::Params::kNeverForward;
323 const int OpKernelContext::Params::kNoReservation;
324 
OpKernelContext(Params * params)325 OpKernelContext::OpKernelContext(Params* params)
326     : OpKernelContext(
327           params, static_cast<int>(params->op_kernel->output_types().size())) {}
328 
OpKernelContext(Params * params,int num_outputs)329 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
330     : params_(params), outputs_(num_outputs) {
331   if (params_->track_allocations) {
332     tracking_state_ = absl::make_unique<TrackingState>();
333   }
334 
335   params_->ensure_eigen_gpu_device();
336   if (params_->eigen_gpu_device != nullptr) {
337     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
338     Status s = params_->device->ReinitializeGpuDevice(
339         this, params_->eigen_gpu_device, params_->op_device_context,
340         eigen_gpu_allocator);
341     if (!s.ok()) {
342       SetStatus(s);
343     }
344   }
345 }
346 
~OpKernelContext()347 OpKernelContext::~OpKernelContext() {
348   for (TensorValue& value : outputs_) {
349     if (!value.is_ref()) {
350       delete value.tensor;
351     }
352   }
353   if (params_->track_allocations &&
354       !tracking_state_->wrapped_allocators.empty()) {
355     LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
356                  << "being consumed by the StepStatsCollector.";
357     for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
358       wrapped_allocator.second->GetRecordsAndUnRef();
359     }
360   }
361 }
362 
get_allocator(AllocatorAttributes attr)363 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
364   Allocator* allocator = nullptr;
365   if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
366     allocator = params_->device->GetScopedAllocator(attr, step_id());
367     CHECK(allocator);
368   } else {
369     allocator = params_->device->GetAllocator(attr);
370   }
371   if (TF_PREDICT_FALSE(track_allocations())) {
372     DCHECK(tracking_state_);
373     mutex_lock lock(tracking_state_->mu);
374     for (const auto& wrapped : tracking_state_->wrapped_allocators) {
375       if (wrapped.first == allocator) {
376         return wrapped.second;
377       }
378     }
379     TrackingAllocator* wrapped_allocator =
380         new TrackingAllocator(allocator, params_->track_allocations);
381     tracking_state_->wrapped_allocators.push_back(
382         std::make_pair(allocator, wrapped_allocator));
383     return wrapped_allocator;
384   } else {
385     return allocator;
386   }
387 }
388 
SetStatus(const Status & status)389 void OpKernelContext::SetStatus(const Status& status) {
390   status_.Update(status);
391 }
392 
input(StringPiece name,const Tensor ** tensor)393 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
394   int index;
395   TF_RETURN_IF_ERROR(get_input_index(name, &index));
396   if (input_is_ref(index)) {
397     return errors::InvalidArgument("OpKernel used ref input name '", name,
398                                    "' when non-ref input was expected");
399   }
400   *tensor = (*params_->inputs)[index].tensor;
401   return Status::OK();
402 }
403 
input_dtype(StringPiece name,DataType * dtype) const404 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
405   int index;
406   TF_RETURN_IF_ERROR(get_input_index(name, &index));
407   const TensorValue& value((*params_->inputs)[index]);
408   *dtype = value.dtype();
409   return Status::OK();
410 }
411 
input_ref_mutex(StringPiece name,mutex ** out_mutex)412 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
413   int index;
414   TF_RETURN_IF_ERROR(get_input_index(name, &index));
415   *out_mutex = input_ref_mutex(index);
416   return Status::OK();
417 }
418 
input(int index) const419 const Tensor& OpKernelContext::input(int index) const {
420   CHECK_GE(index, 0);
421   CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
422   CHECK(!input_is_ref(index));
423   const Tensor& tensor = *((*params_->inputs)[index].tensor);
424   return tensor;
425 }
426 
mutable_input(int index,bool lock_held)427 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
428   CHECK_GE(index, 0);
429   CHECK_LT(index, num_inputs());
430   CHECK(input_is_ref(index));
431   // return a copy of the Ref acquired while holding the mutex
432   if (lock_held) {
433     Tensor& tensor = *((*params_->inputs)[index].tensor);
434     return tensor;
435   } else {
436     tf_shared_lock l(*input_ref_mutex(index));
437     Tensor& tensor = *((*params_->inputs)[index].tensor);
438     return tensor;
439   }
440 }
441 
replace_ref_input(int index,const Tensor & tensor,bool lock_held)442 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
443                                         bool lock_held) {
444   CHECK_GE(index, 0);
445   CHECK_LT(index, num_inputs());
446   CHECK(input_is_ref(index));
447   // should only modify the tensor while holding the mutex
448   if (lock_held) {
449     *(*params_->inputs)[index].tensor = tensor;
450   } else {
451     mutex_lock l(*input_ref_mutex(index));
452     *(*params_->inputs)[index].tensor = tensor;
453   }
454 }
455 
forward_ref_input_to_ref_output(int input_index,int output_index)456 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
457                                                       int output_index) {
458   CHECK_GE(input_index, 0);
459   CHECK_LT(input_index, num_inputs());
460   CHECK(input_is_ref(input_index));
461   set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
462                  (*params_->inputs)[input_index].tensor);
463 }
464 
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)465 bool OpKernelContext::forward_input_to_output_with_shape(
466     int input_index, int output_index, const TensorShape& output_shape,
467     Tensor** output) {
468   const auto output_attr = params_->output_attr_array == nullptr
469                                ? AllocatorAttributes()
470                                : output_alloc_attr(output_index);
471   std::unique_ptr<Tensor> new_tensor = forward_input(
472       input_index, output_index, expected_output_dtype(output_index),
473       output_shape, output_memory_type(output_index), output_attr);
474   if (new_tensor != nullptr) {
475     // Transfer ownership to the output slot in OpKernelContext.
476     outputs_[output_index] = TensorValue(new_tensor.release());
477     *output = outputs_[output_index].tensor;
478     return true;
479   } else {
480     return false;
481   }
482 }
483 
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)484 Status OpKernelContext::forward_input_to_output_with_shape(
485     StringPiece input_name, StringPiece output_name,
486     const TensorShape& output_shape, Tensor** output) {
487   int input_index, output_index;
488   TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
489   TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
490   if (!forward_input_to_output_with_shape(input_index, output_index,
491                                           output_shape, output)) {
492     return errors::FailedPrecondition("OpKernel could not forward input '",
493                                       input_name, "' to output '", output_name);
494   }
495   return Status::OK();
496 }
497 
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)498 std::unique_ptr<Tensor> OpKernelContext::forward_input(
499     int input_index, int output_index, DataType output_dtype,
500     const TensorShape& output_shape, MemoryType output_memory_type,
501     const AllocatorAttributes& output_attr) {
502   CHECK_GE(input_index, 0);
503   CHECK_LT(input_index, num_inputs());
504   const TensorValue& input = (*params_->inputs)[input_index];
505   // Check whether at graph construction time this output was marked
506   // either for no forwarding or with a reservation for this input.
507   // If it's reserved for this input we'll skip the refcount and
508   // AllocatorAttribute checks.
509   // TODO(tucker): Maybe we should skip all of the checks?
510   bool never_forward =
511       (params_->forward_from_array != nullptr && output_index >= 0 &&
512        params_->forward_from_array[output_index] == Params::kNeverForward);
513   if (never_forward) return nullptr;
514   bool forward_expected =
515       (params_->forward_from_array != nullptr && output_index >= 0 &&
516        params_->forward_from_array[output_index] == input_index);
517   if (!forward_expected && params_->forward_from_array != nullptr) {
518     // Check for possibly conflicting forward.
519     for (int i = 0; i < num_outputs(); ++i) {
520       if (params_->forward_from_array[i] == input_index) {
521         // This input is reserved for output i.
522         return nullptr;
523       }
524     }
525   }
526   // Check that input tensor exists and is not a ref.
527   if (input.tensor == nullptr || input.is_ref()) {
528     CHECK(!forward_expected);
529     return nullptr;
530   }
531   // Check that input type matches.
532   if (input_dtype(input_index) != output_dtype) {
533     CHECK(!forward_expected);
534     return nullptr;
535   }
536   // Check that the input and output sizes are compatible.
537   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
538     CHECK(!forward_expected);
539     return nullptr;
540   }
541   // Check that input and output memory types match, i.e.
542   // that they either both live in host or both live in device memory.
543   if (input_memory_type(input_index) != output_memory_type) {
544     CHECK(!forward_expected);
545     return nullptr;
546   }
547   if (!forward_expected) {
548     if (!input->RefCountIsOne()) {
549       return nullptr;
550     }
551     // Check that output allocator attributes are not more restrictive than
552     // input allocator attributes.
553     const auto input_attr = params_->input_alloc_attrs == nullptr
554                                 ? AllocatorAttributes()
555                                 : input_alloc_attr(input_index);
556     if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
557       return nullptr;
558     }
559   }
560 
561   auto output_tensor = MakeUnique<Tensor>();
562   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
563   return output_tensor;
564 }
565 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)566 Status OpKernelContext::forward_input_or_allocate_temp(
567     gtl::ArraySlice<int> candidate_input_indices, DataType type,
568     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
569     Tensor* out_temp) {
570   for (int input_index : candidate_input_indices) {
571     std::unique_ptr<Tensor> new_tensor =
572         forward_input(input_index, Params::kNoReservation /*output_index*/,
573                       type, shape, DEVICE_MEMORY, allocator_attr);
574     if (new_tensor != nullptr) {
575       *out_temp = std::move(*new_tensor);
576       return Status::OK();
577     }
578   }
579   return allocate_temp(type, shape, out_temp, allocator_attr);
580 }
581 
delete_ref_input(int index,bool lock_held)582 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
583   CHECK_GE(index, 0);
584   CHECK_LT(index, num_inputs());
585   CHECK(input_is_ref(index));
586   // should only modify the tensor while holding the mutex
587   if (lock_held) {
588     delete (*params_->inputs)[index].tensor;
589   } else {
590     mutex_lock l(*input_ref_mutex(index));
591     delete (*params_->inputs)[index].tensor;
592   }
593 }
594 
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)595 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
596                                       bool lock_held) {
597   int index;
598   TF_RETURN_IF_ERROR(get_input_index(name, &index));
599   if (!input_is_ref(index)) {
600     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
601                                    "' when ref input was expected");
602   }
603   // return a copy of the Ref acquired while holding the mutex
604   if (lock_held) {
605     *tensor = *(*params_->inputs)[index].tensor;
606   } else {
607     tf_shared_lock l(*input_ref_mutex(index));
608     *tensor = *(*params_->inputs)[index].tensor;
609   }
610   return Status::OK();
611 }
612 
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)613 Status OpKernelContext::replace_ref_input(StringPiece name,
614                                           const Tensor& tensor,
615                                           bool lock_held) {
616   int index;
617   TF_RETURN_IF_ERROR(get_input_index(name, &index));
618   if (!input_is_ref(index)) {
619     return errors::InvalidArgument("OpKernel used immutable input name '", name,
620                                    "' when ref input was expected");
621   }
622   replace_ref_input(index, tensor, lock_held);
623   return Status::OK();
624 }
625 
input_list(StringPiece name,OpInputList * list)626 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
627   int start, stop;
628   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
629   *list = OpInputList(this, start, stop);
630   return Status::OK();
631 }
632 
mutable_input_list(StringPiece name,OpMutableInputList * list)633 Status OpKernelContext::mutable_input_list(StringPiece name,
634                                            OpMutableInputList* list) {
635   int start, stop;
636   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
637   *list = OpMutableInputList(this, start, stop);
638   return Status::OK();
639 }
640 
output_list(StringPiece name,OpOutputList * list)641 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
642   int start, stop;
643   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
644   *list = OpOutputList(this, start, stop);
645   return Status::OK();
646 }
647 
maybe_initialize_scope_id_set()648 void OpKernelContext::maybe_initialize_scope_id_set() {
649   if (allocated_scope_ids_ == nullptr) {
650     allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
651   }
652 }
653 
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)654 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
655                                         Tensor** tensor) {
656   if (index < 0) {
657     return errors::Internal("allocate_output with bad index=", index,
658                             " kernel=", params_->op_kernel->name());
659   }
660   if (index >= num_outputs()) {
661     return errors::Internal("allocate_output with bad index=", index,
662                             " num_outputs=", num_outputs(),
663                             " kernel=", params_->op_kernel->name());
664   }
665   bool forward_expected =
666       (params_->forward_from_array != nullptr && index >= 0 &&
667        params_->forward_from_array[index] >= 0);
668   if (forward_expected) {
669     return errors::Internal(
670         "Explicit allocate_output call where input forwarding required.  Try "
671         "turning off the ScopedAllocator optimizer.");
672   }
673   AllocatorAttributes attr = output_alloc_attr(index);
674   return allocate_output(index, shape, tensor, attr);
675 }
676 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)677 Status OpKernelContext::allocate_output(StringPiece name,
678                                         const TensorShape& shape,
679                                         Tensor** tensor) {
680   int start, stop;
681   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
682   if (stop != start + 1) {
683     return errors::InvalidArgument("OpKernel used list-valued output name '",
684                                    name,
685                                    "' when single-valued output was "
686                                    "expected");
687   }
688   return allocate_output(start, shape, tensor);
689 }
690 
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)691 Status OpKernelContext::allocate_output(StringPiece name,
692                                         const TensorShape& shape,
693                                         Tensor** tensor,
694                                         AllocatorAttributes attr) {
695   int start, stop;
696   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
697   if (stop != start + 1) {
698     return errors::InvalidArgument("OpKernel used list-valued output name '",
699                                    name,
700                                    "' when single-valued output was "
701                                    "expected");
702   }
703   return allocate_output(start, shape, tensor, attr);
704 }
705 
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)706 Status OpKernelContext::allocate_tensor(
707     DataType type, const TensorShape& shape, Tensor* out_tensor,
708     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
709   Allocator* a = get_allocator(attr);
710   Tensor new_tensor(
711       a, type, shape,
712       AllocationAttributes(
713           /*retry_on_failure=*/allocation_attr.retry_on_failure,
714           /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
715 
716   if (!new_tensor.IsInitialized()) {
717     return errors::ResourceExhausted(
718         "OOM when allocating tensor with shape", shape.DebugString(),
719         " and type ", DataTypeString(type), " on ", params_->device->name(),
720         " by allocator ", a->Name());
721   }
722   if (params_->log_memory) {
723     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
724                                       params_->step_id, new_tensor);
725   }
726   *out_tensor = std::move(new_tensor);
727   return Status::OK();
728 }
729 
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)730 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
731                                         Tensor** output,
732                                         AllocatorAttributes attr) {
733   if (index < 0) {
734     return errors::Internal("allocate_output with bad index=", index,
735                             " kernel=", params_->op_kernel->name());
736   }
737   if (index >= num_outputs()) {
738     return errors::Internal("allocate_output with bad index=", index,
739                             " num_outputs=", outputs_.size(),
740                             " kernel=", params_->op_kernel->name());
741   }
742   const DataType type = params_->op_kernel->output_type(index);
743   if (IsRefType(type)) {
744     return errors::Internal("allocate_output with ref type. index=", index,
745                             " type=", type,
746                             " kernel=", params_->op_kernel->name());
747   }
748   if (mutable_output(index) != nullptr) {
749     return errors::Internal("allocate_output on same index multiple times.",
750                             " index = ", index,
751                             " mutable_output(index) = ", mutable_output(index),
752                             " kernel=", params_->op_kernel->name());
753   }
754   if (attr.scope_id > 0) {
755     maybe_initialize_scope_id_set();
756     if (!allocated_scope_ids_->insert(attr.scope_id).second) {
757       return errors::Internal(
758           "OpKernel ", params_->op_kernel->name(),
759           " called allocate_output at index ", index, " with scope_id ",
760           attr.scope_id,
761           " more than once.  Try turning off the ScopedAllocator optimizer.");
762     }
763   }
764   ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
765                                             step_id(), "output", type, &shape);
766   auto output_tensor = MakeUnique<Tensor>();
767   Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
768   if (s.ok()) {
769     outputs_[index] = TensorValue(output_tensor.release());
770     *output = outputs_[index].tensor;
771   }
772   return s;
773 }
774 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)775 Status OpKernelContext::allocate_temp(
776     DataType type, const TensorShape& shape, Tensor* out_temp,
777     AllocatorAttributes allocator_attr,
778     const AllocationAttributes& allocation_attr) {
779   if (allocator_attr.scope_id > 0) {
780     // We do not allow ScopedAllocator calls from allocate_temp.  Unlike
781     // allocate_persistent where we return an error if a kernel provides a
782     // meaningful scope_id, here we clear the scope_id and return a temporary
783     // buffer.  This is because it is legal for a kernel to call allocate_temp
784     // and then set_output with the temp tensor.
785     //
786     // We achieve memory correctness by forcing an allocation in set_output and
787     // copying over the tensor from the temp buffer.  Kernels which would like
788     // to avoid this performance penalty should switch to calling
789     // allocate_output.
790     VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
791             << " called allocate_temp with scope_id " << allocator_attr.scope_id
792             << ".  Switch to allocate_output to avoid performance penalty.";
793     allocator_attr.scope_id = -1;
794   }
795   ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
796                                             step_id(), "temp", type, &shape);
797   Status s =
798       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
799   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
800     Allocator* a = get_allocator(allocator_attr);
801     if (a->TracksAllocationSizes()) {
802       int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
803       record_temp_memory_allocation(alloc_size, *out_temp);
804     }
805   } else if (record_memory_consumption_) {
806     DCHECK(tracking_state_);
807     mutex_lock l(tracking_state_->stats_mu);
808     tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
809   }
810   return s;
811 }
812 
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor,AllocatorAttributes attr)813 Status OpKernelContext::allocate_persistent(DataType type,
814                                             const TensorShape& shape,
815                                             PersistentTensor* out_persistent,
816                                             Tensor** out_tensor,
817                                             AllocatorAttributes attr) {
818   if (attr.scope_id > 0) {
819     // ScopedAllocator cannot be used for persistent tensors, because these
820     // tensors may persist across kernel invocations/steps, whereas the backing
821     // tensor for the scoped allocator will be reallocated every step.
822     return errors::Internal(
823         "Unexpected call to allocate_persistent with scope_id ", attr.scope_id);
824   }
825   ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
826                                             step_id(), "persist", type, &shape);
827   Tensor persistent;
828   Status s = allocate_tensor(type, shape, &persistent, attr);
829   if (s.ok()) {
830     *out_persistent = PersistentTensor(persistent);
831     Tensor* t = out_persistent->AccessTensor(this);
832 
833     if (out_tensor) {
834       *out_tensor = t;
835     }
836 
837     if (track_allocations()) {
838       Allocator* a = get_allocator(attr);
839       if (a->TracksAllocationSizes()) {
840         // Zero-byte Tensors don't use allocators: check and skip tracking.
841         AllocationDescription alloc_desc;
842         TensorReference tensor_ref(*t);
843         tensor_ref.FillDescription(&alloc_desc);
844         tensor_ref.Unref();
845 
846         if (alloc_desc.allocated_bytes()) {  // Non-zero sized tensor.
847           int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
848           int64 alloc_id = a->AllocationId(t->tensor_data().data());
849           record_persistent_memory_allocation(alloc_size, alloc_id);
850         }
851       }
852     } else if (record_memory_consumption_) {
853       record_persistent_memory_allocation(t->TotalBytes());
854     }
855   }
856   return s;
857 }
858 
get_input_index(StringPiece name,int * out_index) const859 Status OpKernelContext::get_input_index(StringPiece name,
860                                         int* out_index) const {
861   int start, stop;
862   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
863   if (stop != start + 1) {
864     return errors::InvalidArgument("OpKernel used list-valued input name '",
865                                    name,
866                                    "' when single-valued input was "
867                                    "expected");
868   }
869   *out_index = start;
870   return Status::OK();
871 }
872 
get_output_index(StringPiece name,int * out_index) const873 Status OpKernelContext::get_output_index(StringPiece name,
874                                          int* out_index) const {
875   int start, stop;
876   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
877   if (stop != start + 1) {
878     return errors::InvalidArgument("OpKernel used list-valued output name '",
879                                    name,
880                                    "' when single-valued output was "
881                                    "expected");
882   }
883   *out_index = start;
884   return Status::OK();
885 }
886 
set_output(StringPiece name,const Tensor & tensor)887 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
888   int index;
889   TF_RETURN_IF_ERROR(get_output_index(name, &index));
890   set_output(index, tensor);
891   return Status::OK();
892 }
893 
set_output(StringPiece name,Tensor && tensor)894 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
895   int index;
896   TF_RETURN_IF_ERROR(get_output_index(name, &index));
897   set_output(index, std::move(tensor));
898   return Status::OK();
899 }
900 
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)901 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
902     int index, const Tensor& tensor) {
903   bool allocate_and_copy = false;
904   const bool never_forward =
905       (params_->forward_from_array != nullptr &&
906        params_->forward_from_array[index] == Params::kNeverForward);
907   if (TF_PREDICT_FALSE(never_forward)) {
908     maybe_initialize_scope_id_set();
909     if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
910         allocated_scope_ids_->end()) {
911       allocate_and_copy = true;
912     } else {
913       // The output at `index` must have been previously allocated via a call to
914       // `allocate_output(index, ...)`.  That call would ensure that we return
915       // the correct slice of the ScopedAllocated buffer, so we do not
916       // re-allocate and copy here.
917       LOG(WARNING)
918           << "OpKernel " << params_->op_kernel->name()
919           << " called both allocate_output and set_output with scope_id "
920           << output_alloc_attr(index).scope_id;
921     }
922   }
923 
924   if (TF_PREDICT_FALSE(allocate_and_copy)) {
925     // This output was marked to not be forwarded either during graph
926     // construction or grappler passes.  Force an allocation and copy input to
927     // output.
928     VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
929             << tensor.DebugString() << " never_forward " << never_forward
930             << " params_->forward_from_array[index] "
931             << params_->forward_from_array[index] << " alloc_attr.scope_id "
932             << output_alloc_attr(index).scope_id;
933     ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
934                                               step_id(), "output",
935                                               tensor.dtype(), &tensor.shape());
936     auto new_tensor = MakeUnique<Tensor>();
937     Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
938                                output_alloc_attr(index));
939     TF_CHECK_OK(s);
940     device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
941                                      op_device_context(), [](const Status&) {});
942     outputs_[index] = TensorValue(new_tensor.release());
943   }
944   return allocate_and_copy;
945 }
946 
maybe_track_allocations_for_set_output(const Tensor & tensor)947 void OpKernelContext::maybe_track_allocations_for_set_output(
948     const Tensor& tensor) {
949   if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
950     DCHECK(tracking_state_);
951     mutex_lock l(tracking_state_->stats_mu);
952     const auto it = std::find_if(
953         tracking_state_->temp_tensor_buffer_and_size.begin(),
954         tracking_state_->temp_tensor_buffer_and_size.end(),
955         [&tensor](const std::pair<const void*, int64>& e) {
956           return e.first ==
957                  static_cast<const void*>(tensor.tensor_data().data());
958         });
959     if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
960       tracking_state_->temp_memory_allocated -= it->second;
961       tracking_state_->temp_tensor_buffer_and_size.erase(it);
962     }
963   }
964 }
965 
set_output(int index,const Tensor & tensor)966 void OpKernelContext::set_output(int index, const Tensor& tensor) {
967   CHECK_GE(index, 0);
968   CHECK_LT(index, outputs_.size());
969   const DataType type = params_->op_kernel->output_type(index);
970   CHECK(!IsRefType(type));
971   CHECK_EQ(outputs_[index].tensor, nullptr);
972   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
973     // Input can be forwarded to output; incref on `tensor` and set output at
974     // `index` to this tensor.
975     outputs_[index] = TensorValue(new Tensor(tensor));
976     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
977   }
978 }
979 
set_output(int index,Tensor && tensor)980 void OpKernelContext::set_output(int index, Tensor&& tensor) {
981   CHECK_GE(index, 0);
982   CHECK_LT(index, outputs_.size());
983   const DataType type = params_->op_kernel->output_type(index);
984   CHECK(!IsRefType(type));
985   CHECK_EQ(outputs_[index].tensor, nullptr);
986   if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
987     // Input can be forwarded to output; set output at `index` to this tensor.
988     outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
989     maybe_track_allocations_for_set_output(*outputs_[index].tensor);
990   }
991 }
992 
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)993 void OpKernelContext::set_output_ref(int index, mutex* mu,
994                                      Tensor* tensor_for_ref) {
995   CHECK_GE(index, 0);
996   CHECK_LT(index, outputs_.size());
997   CHECK(IsRefType(params_->op_kernel->output_type(index)));
998   outputs_[index] = TensorValue(mu, tensor_for_ref);
999 }
1000 
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)1001 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
1002                                        Tensor* tensor_for_ref) {
1003   int index;
1004   TF_RETURN_IF_ERROR(get_output_index(name, &index));
1005   set_output_ref(index, mu, tensor_for_ref);
1006   return Status::OK();
1007 }
1008 
mutable_output(StringPiece name,Tensor ** tensor)1009 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
1010   int index;
1011   TF_RETURN_IF_ERROR(get_output_index(name, &index));
1012   *tensor = mutable_output(index);
1013   return Status::OK();
1014 }
1015 
ValidateInputsAreSameShape(OpKernel * op)1016 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
1017   const auto& inputs = *params_->inputs;
1018   for (size_t i = 1; i < inputs.size(); ++i) {
1019     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
1020       SetStatus(errors::InvalidArgument(
1021           "Inputs to operation ", op->name(), " of type ", op->type_string(),
1022           " must have the same size and shape.  Input 0: ",
1023           inputs[0]->shape().DebugString(), " != input ", i, ": ",
1024           inputs[i]->shape().DebugString()));
1025       return false;
1026     }
1027   }
1028   return true;
1029 }
1030 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)1031 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
1032                                        const DataTypeSlice expected_outputs) {
1033   DataTypeVector inputs;
1034   for (const TensorValue& t : *params_->inputs) {
1035     inputs.push_back(t.dtype());
1036   }
1037   DataTypeVector outputs = params_->op_kernel->output_types();
1038   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
1039                               outputs);
1040 }
1041 
record_temp_memory_allocation(int64 size,const Tensor & t)1042 void OpKernelContext::record_temp_memory_allocation(int64 size,
1043                                                     const Tensor& t) {
1044   if (tracking_state_) {
1045     mutex_lock l(tracking_state_->stats_mu);
1046     tracking_state_->temp_memory_allocated += size;
1047     tracking_state_->temp_tensor_buffer_and_size.emplace_back(
1048         static_cast<const void*>(t.tensor_data().data()), size);
1049   }
1050 }
1051 
temp_memory_allocated() const1052 int64 OpKernelContext::temp_memory_allocated() const {
1053   if (tracking_state_) {
1054     mutex_lock l(tracking_state_->stats_mu);
1055     return tracking_state_->temp_memory_allocated;
1056   } else {
1057     return 0;
1058   }
1059 }
1060 
record_persistent_memory_allocation(int64 size,int64 alloc_id)1061 void OpKernelContext::record_persistent_memory_allocation(int64 size,
1062                                                           int64 alloc_id) {
1063   if (tracking_state_) {
1064     mutex_lock l(tracking_state_->stats_mu);
1065     tracking_state_->persistent_memory_allocated += size;
1066     if (alloc_id >= 0) {
1067       tracking_state_->persistent_alloc_ids.push_back(alloc_id);
1068     }
1069   }
1070 }
1071 
persistent_memory_allocated() const1072 int64 OpKernelContext::persistent_memory_allocated() const {
1073   if (tracking_state_) {
1074     mutex_lock l(tracking_state_->stats_mu);
1075     return tracking_state_->persistent_memory_allocated;
1076   } else {
1077     return 0;
1078   }
1079 }
1080 
persistent_alloc_ids() const1081 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
1082   if (tracking_state_) {
1083     mutex_lock l(tracking_state_->stats_mu);
1084     return std::vector<int64>(tracking_state_->persistent_alloc_ids.begin(),
1085                               tracking_state_->persistent_alloc_ids.end());
1086   } else {
1087     return std::vector<int64>();
1088   }
1089 }
1090 
clear_recorded_memory()1091 void OpKernelContext::clear_recorded_memory() {
1092   if (tracking_state_) {
1093     mutex_lock l(tracking_state_->stats_mu);
1094     tracking_state_->temp_memory_allocated = 0;
1095     tracking_state_->persistent_memory_allocated = 0;
1096     tracking_state_->temp_tensor_buffer_and_size.clear();
1097     tracking_state_->persistent_alloc_ids.clear();
1098   }
1099 }
1100 
set_record_memory_consumption(bool v)1101 void OpKernelContext::set_record_memory_consumption(bool v) {
1102   record_memory_consumption_ = v;
1103   if (v && !tracking_state_) {
1104     tracking_state_ = absl::make_unique<TrackingState>();
1105   }
1106 }
1107 
executor_type() const1108 const string& OpKernelContext::executor_type() const {
1109   if (params_->executor_type) {
1110     return *params_->executor_type;
1111   } else {
1112     static const string& kEmptyString = *new string("");
1113     return kEmptyString;
1114   }
1115 }
1116 
1117 // OpKernel registration ------------------------------------------------------
1118 
1119 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1120   KernelRegistration(const KernelDef& d, StringPiece c,
1121                      std::unique_ptr<kernel_factory::OpKernelFactory> f)
1122       : def(d), kernel_class_name(c), factory(std::move(f)) {}
1123 
1124   const KernelDef def;
1125   const string kernel_class_name;
1126   std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1127 };
1128 
1129 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1130 // factory functions for instantiating the OpKernel that matches the
1131 // KernelDef.
1132 struct KernelRegistry {
1133   mutex mu;
1134   std::unordered_multimap<string, KernelRegistration> registry
1135       TF_GUARDED_BY(mu);
1136 };
1137 
1138 #if defined(_WIN32)
1139 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1140 #elif defined(__APPLE__)
1141 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1142 #else
1143 static const char kKernelLibPattern[] = "libtfkernel*.so";
1144 #endif
1145 
1146 #define FEATURE(x) \
1147   { x, #x }
1148 
1149 // Returns Status::OK if the dynamic library at the given path is safe to
1150 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1151 static Status IsProbablySafeToLoad(const string& path) {
1152   // A map of platform string to required CPU feature.
1153   using port::CPUFeature;
1154   static const auto* feature_map =
1155       new std::map<string, std::pair<CPUFeature, string>>{
1156           {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1157       };
1158 
1159   std::vector<std::string> platform_strings;
1160   int result = GetPlatformStrings(path, &platform_strings);
1161   if (result) {
1162     return Status(error::Code::UNKNOWN, strerror(result));
1163   }
1164   if (platform_strings.empty()) {
1165     return Status(error::Code::FAILED_PRECONDITION,
1166                   "Didn't find any platform strings");
1167   }
1168   std::vector<std::string> missing_features;
1169   for (const auto& platform_string : platform_strings) {
1170     const auto& entry = feature_map->find(platform_string);
1171     if (entry != feature_map->end() &&
1172         !port::TestCPUFeature(entry->second.first)) {
1173       missing_features.emplace_back(entry->second.second);
1174     }
1175   }
1176   if (!missing_features.empty()) {
1177     string errmsg = "Missing CPU features: ";
1178     errmsg.append(absl::StrJoin(missing_features, ", "));
1179     return Status(errors::Code::FAILED_PRECONDITION, errmsg);
1180   }
1181   return Status::OK();
1182 }
1183 
LoadDynamicKernelsInternal()1184 void LoadDynamicKernelsInternal() {
1185   Env* env = Env::Default();
1186 
1187   // Override to allow loading unsafe packages for development.
1188   // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1189   char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1190   bool override_abi_check = false;
1191   if (_abi_check_env_var != nullptr) {
1192     override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1193   }
1194 
1195   string bazel_kernel_dir =
1196       io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1197   std::vector<string> files;
1198   Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1199   if (s_kernel_dir.ok()) {
1200     string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1201     for (const auto& file : files) {
1202       string fullpath = io::JoinPath(bazel_kernel_dir, file);
1203       if (env->MatchPath(fullpath, dll_spec)) {
1204         Status s = IsProbablySafeToLoad(fullpath);
1205         if (!s.ok() && override_abi_check) {
1206           LOG(WARNING) << "Loading UNSAFE library " << fullpath
1207                        << " because ABI check override is set: "
1208                        << s.error_message();
1209         }
1210         if (s.ok() || override_abi_check) {
1211           // TODO(gunan): Store the handles to the opened files.
1212           void* unused_filehandle;
1213           TF_CHECK_OK(
1214               env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1215         } else {
1216           LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1217                        << s.error_message();
1218         }
1219       }
1220     }
1221   }
1222 }
1223 
1224 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1225 void LoadDynamicKernels() {
1226   // TODO(gunan): As more features are available, add intelligent kernel
1227   // selection, and dropping unsuitable kernel logic here.
1228   static absl::once_flag dll_loader_flag;
1229   absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1230 }
1231 
GlobalKernelRegistry()1232 void* GlobalKernelRegistry() {
1233   static KernelRegistry* global_kernel_registry = []() {
1234     KernelRegistry* registry = new KernelRegistry;
1235     OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1236     return registry;
1237   }();
1238   return global_kernel_registry;
1239 }
1240 
GlobalKernelRegistryTyped()1241 static KernelRegistry* GlobalKernelRegistryTyped() {
1242 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1243   LoadDynamicKernels();
1244 #endif  // AUTOLOAD_DYNAMIC_KERNELS
1245   return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1246 }
1247 
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1248 static string Key(StringPiece op_type, const DeviceType& device_type,
1249                   StringPiece label) {
1250   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1251                          label);
1252 }
1253 
1254 namespace kernel_factory {
1255 
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1256 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1257                                      StringPiece kernel_class_name,
1258                                      std::unique_ptr<OpKernelFactory> factory) {
1259   const string key =
1260       Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1261           kernel_def->label());
1262 
1263   // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1264   // here.
1265   // InitInternal gets called by static initializers, so it ends up executing
1266   // before main. This causes LoadKernelLibraries function to get called
1267   // before some file libraries can initialize, which in turn crashes the
1268   // program flakily. Until we get rid of static initializers in kernel
1269   // registration mechanism, we have this workaround here.
1270   auto global_registry =
1271       reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1272   mutex_lock l(global_registry->mu);
1273   global_registry->registry.emplace(
1274       key,
1275       KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1276   delete kernel_def;
1277 }
1278 
Create(OpKernelConstruction * context)1279 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1280     OpKernelConstruction* context) {
1281   return (*create_func_)(context);
1282 }
1283 
1284 }  // namespace kernel_factory
1285 
1286 namespace {
1287 
1288 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1289 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1290   static const string& kKernelAttr = *new string("_kernel");
1291   static const string& kEmptyString = *new string("");
1292 
1293   // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1294   // to use the `AttrSlice::FindByString()` overload, which does a more
1295   // efficient map lookup (instead of a linear scan) when the attribute name is
1296   // already a `const string&`.
1297   const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1298   if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1299     return kEmptyString;
1300   else
1301     return attr_value->s();
1302 }
1303 
1304 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1305 Status FindKernelRegistration(
1306     const DeviceType& device_type, StringPiece node_name,
1307     bool has_experimental_debug_info,
1308     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1309     StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1310     bool* was_attr_mismatch) {
1311   *reg = nullptr;
1312   *was_attr_mismatch = false;
1313 
1314   const string& label = GetKernelLabelAttr(node_attrs);
1315 
1316   const string key = Key(node_op, device_type, label);
1317   auto typed_registry = GlobalKernelRegistryTyped();
1318   tf_shared_lock lock(typed_registry->mu);
1319   auto regs = typed_registry->registry.equal_range(key);
1320   for (auto iter = regs.first; iter != regs.second; ++iter) {
1321     // If there is a kernel registered for the op and device_type,
1322     // check that the attrs match.
1323     bool match;
1324     TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1325     if (match) {
1326       if (*reg != nullptr) {
1327         if ((*reg)->def.priority() == iter->second.def.priority()) {
1328           return errors::InvalidArgument(
1329               "Multiple OpKernel registrations match NodeDef at the same "
1330               "priority '",
1331               FormatNodeDefForError(node_name, has_experimental_debug_info,
1332                                     experimental_debug_info),
1333               "': '", (*reg)->def.ShortDebugString(), "' and '",
1334               iter->second.def.ShortDebugString(), "'");
1335         } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1336           continue;
1337         }
1338         // iter->second's priority is higher than *reg.
1339       }
1340       *reg = &iter->second;
1341     } else {
1342       *was_attr_mismatch = true;
1343     }
1344   }
1345   // Check if no device specific registrations found. If not, try finding a
1346   // default kernel.
1347   if (*reg == nullptr &&
1348       !IsSymbolicExecutionDevice(device_type.type_string())) {
1349     const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1350     auto regs = typed_registry->registry.equal_range(default_key);
1351     for (auto iter = regs.first; iter != regs.second; ++iter) {
1352       // If there is a kernel registered for the op and device_type,
1353       // check that the attrs match.
1354       bool match;
1355       TF_RETURN_IF_ERROR(
1356           KernelAttrsMatch(iter->second.def, node_attrs, &match));
1357       if (match) {
1358         if (*reg != nullptr) {
1359           return errors::InvalidArgument(
1360               "Multiple Default OpKernel registrations match NodeDef '",
1361               FormatNodeDefForError(node_name, has_experimental_debug_info,
1362                                     experimental_debug_info),
1363               "': '", (*reg)->def.ShortDebugString(), "' and '",
1364               iter->second.def.ShortDebugString(), "'");
1365         }
1366         *reg = &iter->second;
1367       } else {
1368         *was_attr_mismatch = true;
1369       }
1370     }
1371 
1372     if (*reg != nullptr) {
1373       VLOG(1) << "No device-specific kernels found for NodeDef '"
1374               << FormatNodeDefForError(node_name, has_experimental_debug_info,
1375                                        experimental_debug_info)
1376               << "'"
1377               << "Will fall back to a default kernel." << std::endl;
1378     }
1379   }
1380 
1381   return Status::OK();
1382 }
1383 
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1384 Status FindKernelRegistration(const DeviceType& device_type,
1385                               const NodeDef& node_def,
1386                               const KernelRegistration** reg,
1387                               bool* was_attr_mismatch) {
1388   return FindKernelRegistration(
1389       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1390       node_def.experimental_debug_info(), node_def.op(),
1391       AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1392 }
1393 
1394 }  // namespace
1395 
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1396 bool KernelDefAvailable(const DeviceType& device_type,
1397                         const NodeDef& node_def) {
1398   const KernelRegistration* reg = nullptr;
1399   bool was_attr_mismatch;
1400   Status result =
1401       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch);
1402   return result.ok() && reg != nullptr;
1403 }
1404 
1405 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1406 Status FindKernelDef(
1407     const DeviceType& device_type, StringPiece node_name,
1408     bool has_experimental_debug_info,
1409     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1410     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1411     const KernelDef** def, string* kernel_class_name) {
1412   const KernelRegistration* reg = nullptr;
1413   bool was_attr_mismatch;
1414   TF_RETURN_IF_ERROR(FindKernelRegistration(
1415       device_type, node_name, has_experimental_debug_info,
1416       experimental_debug_info, node_op, node_attrs, &reg, &was_attr_mismatch));
1417   if (reg == nullptr) {
1418     const std::string device_str = DeviceTypeString(device_type);
1419     Status s = errors::NotFound(
1420         "No registered '", node_op, "' OpKernel for ", device_str,
1421         " devices compatible with node ",
1422         FormatNodeDefForError(node_name, has_experimental_debug_info,
1423                               experimental_debug_info));
1424     if (was_attr_mismatch) {
1425       errors::AppendToMessage(
1426           &s, " (OpKernel was found, but attributes didn't match) ",
1427           "Requested Attributes: ",
1428           SummarizeAttrsHelper(node_attrs, node_device));
1429     }
1430 
1431     // Do not print kernel registrations for other devices when using _JIT
1432     // devices for compilation.
1433     if (!absl::StrContains(device_str, "JIT")) {
1434       errors::AppendToMessage(
1435           &s, ".  Registered:", KernelsRegisteredForOp(node_op));
1436     }
1437 
1438     return s;
1439   }
1440   if (def != nullptr) *def = &reg->def;
1441   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1442   return Status::OK();
1443 }
1444 
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1445 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1446                      const KernelDef** def, string* kernel_class_name) {
1447   return FindKernelDef(
1448       device_type, node_def.name(), node_def.has_experimental_debug_info(),
1449       node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1450       AttrSlice(&node_def.attr()), def, kernel_class_name);
1451 }
1452 
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1453 Status SupportedDeviceTypesForNode(
1454     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1455     PrioritizedDeviceTypeVector* prioritized_device_types,
1456     const DeviceNameUtils::ParsedName* local_address_spec) {
1457   // TODO(zhifengc): Changes the callers (SimplePlacer and
1458   // DynamicPlacer) to consider the possibility that 'def' is call to
1459   // a user-defined function and only calls this
1460   // SupportedDeviceTypesForNode for primitive ops.
1461   const OpRegistrationData* op_reg_data;
1462   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1463   if (s.ok()) {
1464     bool exists_attr_mismatch = false;
1465     for (const DeviceType& device_type : prioritized_types) {
1466       const KernelRegistration* reg = nullptr;
1467       bool was_attr_mismatch = false;
1468       TF_RETURN_IF_ERROR(
1469           FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
1470       exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1471       if (reg != nullptr) {
1472         int32 priority = reg->def.priority();
1473         prioritized_device_types->emplace_back(device_type, priority);
1474       }
1475     }
1476     // Add extra supported device types if the following conditions are
1477     // satisfied:
1478     // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1479     // 2) A device is requested for this node which specifies job/replica/task
1480     // 3) A local device is provided which specifies job/replica/task
1481     // 4) The local device does not have the same (job, replica, task) as the
1482     //    requested device
1483     //
1484     // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1485     // whose kernel is known to a remote process but not to the current process.
1486     if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1487         local_address_spec != nullptr) {
1488       DeviceNameUtils::ParsedName requested_device_name;
1489       DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1490       if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1491                                                    requested_device_name)) {
1492         if (requested_device_name.has_type) {
1493           prioritized_device_types->push_back(
1494               std::make_pair(DeviceType(requested_device_name.type), 0));
1495         } else {
1496           for (const DeviceType& device_type : prioritized_types) {
1497             prioritized_device_types->push_back(std::make_pair(device_type, 0));
1498           }
1499         }
1500       }
1501     }
1502 
1503     // If we were unable to find any valid devices let's validate if the node is
1504     // even valid.
1505     if (prioritized_device_types->empty()) {
1506       TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1507     }
1508 
1509     std::sort(prioritized_device_types->begin(),
1510               prioritized_device_types->end(),
1511               [](const std::pair<DeviceType, int32>& a,
1512                  const std::pair<DeviceType, int32>& b) {
1513                 return a.second > b.second;
1514               });
1515   } else {
1516     // Assumes that all device types support this node.
1517     for (const DeviceType& device_type : prioritized_types) {
1518       prioritized_device_types->push_back(std::make_pair(device_type, 0));
1519     }
1520   }
1521   return Status::OK();
1522 }
1523 
LogAllRegisteredKernels()1524 void LogAllRegisteredKernels() {
1525   KernelList kernel_list = GetAllRegisteredKernels();
1526   for (const auto& kernel_def : kernel_list.kernel()) {
1527     LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1528   }
1529 }
1530 
GetAllRegisteredKernels()1531 KernelList GetAllRegisteredKernels() {
1532   return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1533 }
1534 
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1535 KernelList GetFilteredRegisteredKernels(
1536     const std::function<bool(const KernelDef&)>& predicate) {
1537   KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1538   KernelList kernel_list;
1539   tf_shared_lock lock(typed_registry->mu);
1540   kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1541   for (const auto& p : typed_registry->registry) {
1542     const KernelDef& kernel_def = p.second.def;
1543     if (predicate(kernel_def)) {
1544       *kernel_list.add_kernel() = kernel_def;
1545     }
1546   }
1547   return kernel_list;
1548 }
1549 
GetRegisteredKernelsForOp(StringPiece op_name)1550 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1551   auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1552   return GetFilteredRegisteredKernels(op_pred);
1553 }
1554 
KernelsRegisteredForOp(StringPiece op_name)1555 string KernelsRegisteredForOp(StringPiece op_name) {
1556   KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1557   if (kernel_list.kernel_size() == 0) return "  <no registered kernels>\n";
1558   string ret;
1559   for (const auto& kernel_def : kernel_list.kernel()) {
1560     strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
1561     if (!kernel_def.label().empty()) {
1562       strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1563     }
1564     for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1565       strings::StrAppend(
1566           &ret, "; ", kernel_def.constraint(i).name(), " in ",
1567           SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1568     }
1569     strings::StrAppend(&ret, "\n");
1570   }
1571   return ret;
1572 }
1573 
1574 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1575  * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1576 std::unique_ptr<OpKernel> CreateOpKernel(
1577     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1578     const NodeDef& node_def, int graph_def_version, Status* status) {
1579   // Look up the Op registered for this op name.
1580   std::shared_ptr<const NodeProperties> props;
1581   status->Update(NodeProperties::CreateFromNodeDef(
1582       node_def, OpRegistry::Global(), &props));
1583   if (!status->ok()) {
1584     errors::AppendToMessage(status,
1585                             " for node: ", FormatNodeDefForError(node_def));
1586     return nullptr;
1587   }
1588   return CreateOpKernel(device_type, device, allocator, props,
1589                         graph_def_version, status);
1590 }
1591 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1592 std::unique_ptr<OpKernel> CreateOpKernel(
1593     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1594     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1595     Status* status) {
1596   OpKernel* kernel = nullptr;
1597   *status = CreateOpKernel(std::move(device_type), device, allocator,
1598                            /*flib=*/nullptr, props, graph_def_version, &kernel);
1599   return std::unique_ptr<OpKernel>(kernel);
1600 }
1601 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1602 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1603                       Allocator* allocator, FunctionLibraryRuntime* flib,
1604                       const std::shared_ptr<const NodeProperties>& props,
1605                       int graph_def_version, OpKernel** kernel) {
1606   return CreateOpKernel(std::move(device_type), device, allocator, flib,
1607                         /* resource_mgr= */ nullptr, props, graph_def_version,
1608                         kernel);
1609 }
1610 
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1611 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1612                       Allocator* allocator, FunctionLibraryRuntime* flib,
1613                       ResourceMgr* resource_mgr,
1614                       const std::shared_ptr<const NodeProperties>& props,
1615                       int graph_def_version, OpKernel** kernel) {
1616   const NodeDef& node_def = props->node_def;
1617   bool was_attr_mismatch;
1618   const KernelRegistration* registration = nullptr;
1619   Status s;
1620   if (props != nullptr) {
1621     VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1622 
1623     // Validate node_def against OpDef.
1624     TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1625 
1626     // Look up kernel registration.
1627     s = FindKernelRegistration(device_type, node_def, &registration,
1628                                &was_attr_mismatch);
1629     if (!s.ok()) {
1630       errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1631       return s;
1632     }
1633   }
1634   if (registration == nullptr) {
1635     s.Update(errors::NotFound("No registered '", node_def.op(),
1636                               "' OpKernel for '", DeviceTypeString(device_type),
1637                               "' devices compatible with node ",
1638                               FormatNodeDefForError(node_def)));
1639     if (was_attr_mismatch) {
1640       errors::AppendToMessage(
1641           &s, " (OpKernel was found, but attributes didn't match) ",
1642           "Requested Attributes: ", SummarizeAttrs(node_def));
1643     }
1644     errors::AppendToMessage(
1645         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
1646     return s;
1647   }
1648 
1649   // We are creating a kernel for an op registered in
1650   // OpRegistry::Global(), we consult the kernel registry to decide
1651   // the kernel's input and output memory types.
1652   MemoryTypeVector input_memory_types;
1653   MemoryTypeVector output_memory_types;
1654   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1655                                         node_def, &input_memory_types,
1656                                         &output_memory_types));
1657 
1658   // Everything needed for OpKernel construction.
1659   OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1660                                resource_mgr, props, input_memory_types,
1661                                output_memory_types, graph_def_version, &s);
1662   *kernel = registration->factory->Create(&context);
1663   if (!s.ok()) {
1664     delete *kernel;
1665     *kernel = nullptr;
1666   }
1667   return s;
1668 }
1669 
1670 namespace {
1671 
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1672 bool FindArgInOp(StringPiece arg_name,
1673                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1674   for (const auto& arg : args) {
1675     if (arg_name == arg.name()) {
1676       return true;
1677     }
1678   }
1679   return false;
1680 }
1681 
1682 }  // namespace
1683 
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1684 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1685   auto typed_registry = GlobalKernelRegistryTyped();
1686   tf_shared_lock lock(typed_registry->mu);
1687   for (const auto& key_registration : typed_registry->registry) {
1688     const KernelDef& kernel_def(key_registration.second.def);
1689     const OpRegistrationData* op_reg_data;
1690     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1691     if (!status.ok()) {
1692       // TODO(josh11b): Make this a hard error.
1693       LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1694                  << "') for unknown op: " << kernel_def.op();
1695       continue;
1696     }
1697     const OpDef& op_def = op_reg_data->op_def;
1698     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1699       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1700           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1701         return errors::InvalidArgument(
1702             "HostMemory arg '", host_memory_arg,
1703             "' not found in OpDef: ", SummarizeOpDef(op_def));
1704       }
1705     }
1706   }
1707   return Status::OK();
1708 }
1709 
1710 template <>
eigen_device() const1711 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1712   return eigen_cpu_device();
1713 }
1714 
1715 template <>
eigen_device() const1716 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1717   return eigen_gpu_device();
1718 }
1719 
1720 
CtxFailure(const Status & s)1721 void OpKernelConstruction::CtxFailure(const Status& s) {
1722   VLOG(1) << s;
1723   SetStatus(s);
1724 }
1725 
CtxFailureWithWarning(const Status & s)1726 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1727   LOG(WARNING) << s;
1728   SetStatus(s);
1729 }
1730 
CtxFailure(const char * file,int line,const Status & s)1731 void OpKernelConstruction::CtxFailure(const char* file, int line,
1732                                       const Status& s) {
1733   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1734           << " : " << s;
1735   SetStatus(s);
1736 }
1737 
CtxFailureWithWarning(const char * file,int line,const Status & s)1738 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1739                                                  const Status& s) {
1740   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1741                << " : " << s;
1742   SetStatus(s);
1743 }
1744 
CtxFailure(const Status & s)1745 void OpKernelContext::CtxFailure(const Status& s) {
1746   VLOG(1) << s;
1747   SetStatus(s);
1748 }
1749 
CtxFailureWithWarning(const Status & s)1750 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1751   LOG(WARNING) << s;
1752   SetStatus(s);
1753 }
1754 
CtxFailure(const char * file,int line,const Status & s)1755 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1756   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1757           << " : " << s;
1758   SetStatus(s);
1759 }
1760 
CtxFailureWithWarning(const char * file,int line,const Status & s)1761 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1762                                             const Status& s) {
1763   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1764                << " : " << s;
1765   SetStatus(s);
1766 }
1767 
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1768 void CheckNotInComputeAsync(OpKernelContext* ctx,
1769                             const char* correct_macro_name) {
1770   CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1771       << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1772 }
1773 
1774 }  // namespace tensorflow
1775