1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <mutex> // NOLINT
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include <cstdlib>
24 #include <cstring>
25
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/graph.pb_text.h"
29 #include "tensorflow/core/framework/kernel_def.pb_text.h"
30 #include "tensorflow/core/framework/kernel_def_util.h"
31 #include "tensorflow/core/framework/log_memory.h"
32 #include "tensorflow/core/framework/memory_types.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op_def_util.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/notification.h"
40 #include "tensorflow/core/lib/core/stringpiece.h"
41 #include "tensorflow/core/lib/gtl/map_util.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/mutex.h"
49 #include "tensorflow/core/platform/platform_strings.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/util/ptr_util.h"
52
53 namespace tensorflow {
54
55 namespace {
56
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)57 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
58 const DataTypeSlice expected_outputs,
59 const DataTypeSlice inputs,
60 const DataTypeSlice outputs) {
61 bool signature_mismatch = false;
62
63 if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
64 for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
65 if (!TypesCompatible(expected_inputs[i], inputs[i])) {
66 signature_mismatch = true;
67 }
68 }
69
70 if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
71 for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
72 if (!TypesCompatible(expected_outputs[i], outputs[i])) {
73 signature_mismatch = true;
74 }
75 }
76
77 if (signature_mismatch) {
78 return errors::InvalidArgument(
79 "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
80 DataTypeSliceString(outputs),
81 " expected: ", DataTypeSliceString(expected_inputs), "->",
82 DataTypeSliceString(expected_outputs));
83 }
84 return Status::OK();
85 }
86
87 } // namespace
88
89 // OpKernel ------------------------------------------------------------------
90
OpKernel(OpKernelConstruction * context)91 OpKernel::OpKernel(OpKernelConstruction* context)
92 : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
93
OpKernel(OpKernelConstruction * context,std::unique_ptr<const NodeDef> node_def)94 OpKernel::OpKernel(OpKernelConstruction* context,
95 std::unique_ptr<const NodeDef> node_def)
96 : def_(std::move(node_def)),
97 input_types_(context->input_types().begin(),
98 context->input_types().end()),
99 input_memory_types_(context->input_memory_types().begin(),
100 context->input_memory_types().end()),
101 output_types_(context->output_types().begin(),
102 context->output_types().end()),
103 output_memory_types_(context->output_memory_types().begin(),
104 context->output_memory_types().end()),
105 graph_def_version_(context->graph_def_version()),
106 is_internal_(str_util::StartsWith(type_string(), "_")),
107 input_name_map_(context->num_inputs()),
108 output_name_map_(context->num_outputs()),
109 cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
110 OP_REQUIRES_OK(context,
111 NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
112 &output_name_map_));
113 OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
114 context->graph_def_version()));
115
116 // Kernels executing on GPU/SYCL tie very few resources on the CPU where the
117 // scheduler runs: we consider them as inexpensive.
118 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
119 context->device_type() != DeviceType(DEVICE_SYCL);
120 }
121
~OpKernel()122 OpKernel::~OpKernel() {}
123
124 const uint64 OpKernel::kInitialCostEstimateCycles;
125 const uint64 OpKernel::kOpIsExpensiveThresholdCycles;
126 const uint64 OpKernel::kCostDecay;
127
name() const128 const string& OpKernel::name() const { return def_->name(); }
type_string() const129 const string& OpKernel::type_string() const { return def_->op(); }
requested_device() const130 const string& OpKernel::requested_device() const { return def_->device(); }
requested_input(int i) const131 const string& OpKernel::requested_input(int i) const { return def_->input(i); }
132
133 // This static function exists only because device_attributes.pb.h is
134 // already included here, and it can't be introduced elsewhere.
DeviceNumaNode(const DeviceBase * device)135 /*static*/ int OpKernel::DeviceNumaNode(const DeviceBase* device) {
136 return device->attributes().locality().numa_node();
137 }
138
InputRange(StringPiece input_name,int * start,int * stop) const139 Status OpKernel::InputRange(StringPiece input_name, int* start,
140 int* stop) const {
141 const auto result = input_name_map_.find(input_name);
142 if (result == input_name_map_.end()) {
143 return errors::InvalidArgument("Unknown input name: ", input_name);
144 } else {
145 *start = result->second.first;
146 *stop = result->second.second;
147 return Status::OK();
148 }
149 }
150
OutputRange(StringPiece output_name,int * start,int * stop) const151 Status OpKernel::OutputRange(StringPiece output_name, int* start,
152 int* stop) const {
153 const auto result = output_name_map_.find(output_name);
154 if (result == output_name_map_.end()) {
155 return errors::InvalidArgument("Unknown output name: ", output_name);
156 } else {
157 *start = result->second.first;
158 *stop = result->second.second;
159 return Status::OK();
160 }
161 }
162
MakeShape(const Tensor & shape,TensorShape * out) const163 Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
164 if (!IsLegacyVector(shape.shape())) {
165 return errors::InvalidArgument(
166 "shape must be a vector of {int32,int64}, got shape ",
167 shape.shape().DebugString());
168 }
169 if (shape.dtype() == DataType::DT_INT32) {
170 auto vec = shape.flat<int32>();
171 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
172 } else if (shape.dtype() == DataType::DT_INT64) {
173 auto vec = shape.flat<int64>();
174 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
175 } else {
176 return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
177 }
178 }
179
Compute(OpKernelContext * context)180 void AsyncOpKernel::Compute(OpKernelContext* context) {
181 Notification n;
182 ComputeAsync(context, [&n]() { n.Notify(); });
183 n.WaitForNotification();
184 }
185
186 // PersistentTensor ----------------------------------------------------------
187
AccessTensor(OpKernelConstruction * context)188 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
189 // the caller has to have a valid context
190 CHECK(context);
191 return &tensor_;
192 }
193
AccessTensor(OpKernelContext * context)194 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
195 context->NotifyUseOfPersistentTensor(tensor_);
196 return &tensor_;
197 }
198
199 // OpKernelConstruction ------------------------------------------------------
200
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef * node_def,const OpDef * op_def,FunctionLibraryRuntime * flib,const DataTypeSlice & input_types,const MemoryTypeSlice & input_memory_types,const DataTypeSlice & output_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)201 OpKernelConstruction::OpKernelConstruction(
202 DeviceType device_type, DeviceBase* device, Allocator* allocator,
203 const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
204 const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
205 const DataTypeSlice& output_types,
206 const MemoryTypeSlice& output_memory_types, int graph_def_version,
207 Status* status)
208 : device_type_(std::move(device_type)),
209 device_(device),
210 allocator_(allocator),
211 def_(node_def),
212 op_def_(op_def),
213 flib_(flib),
214 input_types_(input_types),
215 input_memory_types_(input_memory_types),
216 output_types_(output_types),
217 output_memory_types_(output_memory_types),
218 graph_def_version_(graph_def_version),
219 status_(status) {}
220
HasAttr(StringPiece attr_name) const221 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
222 return HasNodeAttr(def(), attr_name);
223 }
224
SetStatus(const Status & status)225 void OpKernelConstruction::SetStatus(const Status& status) {
226 status_->Update(status);
227 }
228
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)229 Status OpKernelConstruction::MatchSignature(
230 const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
231 return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
232 output_types_);
233 }
234
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)235 Status OpKernelConstruction::allocate_temp(DataType type,
236 const TensorShape& shape,
237 Tensor* out_temp) {
238 AllocationAttributes attr;
239 attr.allocation_will_be_logged = true;
240 Tensor new_temp(allocator_, type, shape, attr);
241
242 if (!new_temp.IsInitialized()) {
243 return errors::ResourceExhausted(
244 "OOM when allocating temporary tensor with shape", shape.DebugString());
245 }
246 if (LogMemory::IsEnabled()) {
247 LogMemory::RecordTensorAllocation(
248 def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
249 }
250 *out_temp = new_temp;
251 return Status::OK();
252 }
253
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)254 Status OpKernelConstruction::allocate_persistent(
255 DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
256 Tensor** out_tensor) {
257 // for now just do the same thing as allocate_temp
258 // TODO(misard) add specific memory tracking for persistent tensors
259 Tensor persistent;
260 Status s = allocate_temp(type, shape, &persistent);
261 if (!s.ok()) {
262 return s;
263 }
264 *out_persistent = PersistentTensor(persistent);
265 Tensor* allocated = out_persistent->AccessTensor(this);
266 if (out_tensor) {
267 *out_tensor = allocated;
268 }
269 return s;
270 }
271
272 // OpKernelContext -----------------------------------------------------------
273
274 const int OpKernelContext::Params::kNeverForward;
275 const int OpKernelContext::Params::kNoReservation;
276
OpKernelContext(Params * params)277 OpKernelContext::OpKernelContext(Params* params)
278 : OpKernelContext(
279 params, static_cast<int>(params->op_kernel->output_types().size())) {}
280
OpKernelContext(Params * params,int num_outputs)281 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
282 : params_(params),
283 outputs_(num_outputs),
284 temp_memory_allocated_(0),
285 persistent_memory_allocated_(0) {
286 params_->ensure_eigen_gpu_device();
287 if (params_->eigen_gpu_device != nullptr) {
288 Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
289 Status s = params_->device->ReinitializeGpuDevice(
290 this, params_->eigen_gpu_device, params_->op_device_context,
291 eigen_gpu_allocator);
292 if (!s.ok()) {
293 SetStatus(s);
294 }
295 }
296 if (params_->record_tensor_accesses) {
297 referenced_tensors_.Init();
298 }
299 }
300
~OpKernelContext()301 OpKernelContext::~OpKernelContext() {
302 for (TensorValue& value : outputs_) {
303 if (!value.is_ref()) {
304 delete value.tensor;
305 }
306 }
307 if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
308 if (params_->track_allocations && !wrapped_allocators_.empty()) {
309 LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
310 << "being consumed by the StepStatsCollector.";
311 for (auto& wrapped_alloator : wrapped_allocators_) {
312 wrapped_alloator.second->GetRecordsAndUnRef();
313 }
314 }
315 }
316
get_allocator(AllocatorAttributes attr)317 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
318 Allocator* allocator = nullptr;
319 if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
320 allocator = params_->device->GetScopedAllocator(attr, step_id());
321 CHECK(allocator);
322 } else {
323 allocator = params_->device->GetAllocator(attr);
324 }
325 if (TF_PREDICT_FALSE(track_allocations())) {
326 mutex_lock lock(mu_);
327 for (const auto& wrapped : wrapped_allocators_) {
328 if (wrapped.first == allocator) {
329 return wrapped.second;
330 }
331 }
332 TrackingAllocator* wrapped_allocator =
333 new TrackingAllocator(allocator, params_->track_allocations);
334 wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
335 return wrapped_allocator;
336 } else {
337 return allocator;
338 }
339 }
340
SetStatus(const Status & status)341 void OpKernelContext::SetStatus(const Status& status) {
342 status_.Update(status);
343 }
344
really_record_tensor_reference(const Tensor & tensor)345 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
346 mutex_lock l(mu_);
347 // Keep a reference to the underlying memory around.
348 referenced_tensors_->Add(tensor);
349 }
350
input(StringPiece name,const Tensor ** tensor)351 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
352 int start, stop;
353 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
354 if (stop != start + 1) {
355 return errors::InvalidArgument("OpKernel used list-valued input name '",
356 name,
357 "' when single-valued input was "
358 "expected");
359 }
360 if (input_is_ref(start)) {
361 return errors::InvalidArgument("OpKernel used ref input name '", name,
362 "' when non-ref input was expected");
363 }
364 *tensor = (*params_->inputs)[start].tensor;
365 record_tensor_reference(**tensor);
366 return Status::OK();
367 }
368
input_dtype(StringPiece name,DataType * dtype) const369 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
370 int start, stop;
371 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
372 if (stop != start + 1) {
373 return errors::InvalidArgument("OpKernel used list-valued input name '",
374 name,
375 "' when single-valued input was "
376 "expected");
377 }
378 const TensorValue& value((*params_->inputs)[start]);
379 if (value.is_ref()) {
380 *dtype = MakeRefType(value->dtype());
381 } else {
382 *dtype = value->dtype();
383 }
384 return Status::OK();
385 }
386
input_ref_mutex(StringPiece name,mutex ** out_mutex)387 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
388 int start, stop;
389 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
390 if (stop != start + 1) {
391 return errors::InvalidArgument("OpKernel used list-valued input name '",
392 name,
393 "' when single-valued input was expected");
394 }
395 *out_mutex = input_ref_mutex(start);
396 return Status::OK();
397 }
398
input(int index)399 const Tensor& OpKernelContext::input(int index) {
400 DCHECK_GE(index, 0);
401 DCHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
402 DCHECK(!input_is_ref(index));
403 const Tensor& tensor = *((*params_->inputs)[index].tensor);
404 record_tensor_reference(tensor);
405 return tensor;
406 }
407
mutable_input(int index,bool lock_held)408 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
409 DCHECK_GE(index, 0);
410 DCHECK_LT(index, num_inputs());
411 DCHECK(input_is_ref(index));
412 // return a copy of the Ref acquired while holding the mutex
413 if (lock_held) {
414 Tensor& tensor = *((*params_->inputs)[index].tensor);
415 record_tensor_reference(tensor);
416 return tensor;
417 } else {
418 tf_shared_lock l(*input_ref_mutex(index));
419 Tensor& tensor = *((*params_->inputs)[index].tensor);
420 record_tensor_reference(tensor);
421 return tensor;
422 }
423 }
424
replace_ref_input(int index,const Tensor & tensor,bool lock_held)425 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
426 bool lock_held) {
427 DCHECK_GE(index, 0);
428 DCHECK_LT(index, num_inputs());
429 DCHECK(input_is_ref(index));
430 // should only modify the tensor while holding the mutex
431 if (lock_held) {
432 *(*params_->inputs)[index].tensor = tensor;
433 } else {
434 mutex_lock l(*input_ref_mutex(index));
435 *(*params_->inputs)[index].tensor = tensor;
436 }
437 record_tensor_reference(tensor);
438 }
439
forward_ref_input_to_ref_output(int input_index,int output_index)440 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
441 int output_index) {
442 DCHECK_GE(input_index, 0);
443 DCHECK_LT(input_index, num_inputs());
444 DCHECK(input_is_ref(input_index));
445 set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
446 (*params_->inputs)[input_index].tensor);
447 }
448
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)449 bool OpKernelContext::forward_input_to_output_with_shape(
450 int input_index, int output_index, const TensorShape& output_shape,
451 Tensor** output) {
452 const auto output_attr = params_->output_attr_array == nullptr
453 ? AllocatorAttributes()
454 : output_alloc_attr(output_index);
455 std::unique_ptr<Tensor> new_tensor = forward_input(
456 input_index, output_index, expected_output_dtype(output_index),
457 output_shape, output_memory_type(output_index), output_attr);
458 if (new_tensor != nullptr) {
459 // Transfer ownership to the output slot in OpKernelContext.
460 outputs_[output_index] = TensorValue(new_tensor.release());
461 *output = outputs_[output_index].tensor;
462 return true;
463 } else {
464 return false;
465 }
466 }
467
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)468 Status OpKernelContext::forward_input_to_output_with_shape(
469 StringPiece input_name, StringPiece output_name,
470 const TensorShape& output_shape, Tensor** output) {
471 int input_index, output_index, stop;
472 TF_RETURN_IF_ERROR(
473 params_->op_kernel->InputRange(input_name, &input_index, &stop));
474 if (stop != input_index + 1) {
475 return errors::InvalidArgument("OpKernel used list-valued input name '",
476 input_name,
477 "' when single-valued input was "
478 "expected");
479 }
480 TF_RETURN_IF_ERROR(
481 params_->op_kernel->OutputRange(output_name, &output_index, &stop));
482 if (stop != output_index + 1) {
483 return errors::InvalidArgument("OpKernel used list-valued output name '",
484 output_name,
485 "' when single-valued output was "
486 "expected");
487 }
488 if (!forward_input_to_output_with_shape(input_index, output_index,
489 output_shape, output)) {
490 return errors::FailedPrecondition("OpKernel could not forward input '",
491 input_name, "' to output '", output_name);
492 }
493 return Status::OK();
494 }
495
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)496 std::unique_ptr<Tensor> OpKernelContext::forward_input(
497 int input_index, int output_index, DataType output_dtype,
498 const TensorShape& output_shape, MemoryType output_memory_type,
499 const AllocatorAttributes& output_attr) {
500 DCHECK_GE(input_index, 0);
501 DCHECK_LT(input_index, num_inputs());
502 const TensorValue& input = (*params_->inputs)[input_index];
503 // Check whether at graph construction time this output was marked
504 // either for no forwarding or with a reservation for this input.
505 // If it's reserved for this input we'll skip the refcount and
506 // AllocatorAttribute checks.
507 // TODO(tucker): Maybe we should skip all of the checks?
508 bool never_forward =
509 (params_->forward_from_array != nullptr && output_index >= 0 &&
510 params_->forward_from_array[output_index] == Params::kNeverForward);
511 if (never_forward) return nullptr;
512 bool forward_expected =
513 (params_->forward_from_array != nullptr && output_index >= 0 &&
514 params_->forward_from_array[output_index] == input_index);
515 if (!forward_expected && params_->forward_from_array != nullptr) {
516 // Check for possibly conflicting forward.
517 for (int i = 0; i < num_outputs(); ++i) {
518 if (params_->forward_from_array[i] == input_index) {
519 // This input is reserved for output i.
520 return nullptr;
521 }
522 }
523 }
524 // Check that input tensor exists and is not a ref.
525 if (input.tensor == nullptr || input.is_ref()) {
526 CHECK(!forward_expected);
527 return nullptr;
528 }
529 // Check that input type matches.
530 if (input_dtype(input_index) != output_dtype) {
531 CHECK(!forward_expected);
532 return nullptr;
533 }
534 // Check that the input and output sizes are compatible.
535 if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
536 CHECK(!forward_expected);
537 return nullptr;
538 }
539 // Check that input and output memory types match, i.e.
540 // that they either both live in host or both live in device memory.
541 if (input_memory_type(input_index) != output_memory_type) {
542 CHECK(!forward_expected);
543 return nullptr;
544 }
545 if (!forward_expected) {
546 if (!input->RefCountIsOne()) {
547 return nullptr;
548 }
549 // Check that output allocator attributes are not more restrictive than
550 // input allocator attributes.
551 const auto input_attr = params_->input_alloc_attrs == nullptr
552 ? AllocatorAttributes()
553 : input_alloc_attr(input_index);
554 if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
555 return nullptr;
556 }
557 }
558
559 auto output_tensor = MakeUnique<Tensor>();
560 CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
561 return output_tensor;
562 }
563
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)564 Status OpKernelContext::forward_input_or_allocate_temp(
565 gtl::ArraySlice<int> candidate_input_indices, DataType type,
566 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
567 Tensor* out_temp) {
568 for (int input_index : candidate_input_indices) {
569 std::unique_ptr<Tensor> new_tensor =
570 forward_input(input_index, Params::kNoReservation /*output_index*/,
571 type, shape, DEVICE_MEMORY, allocator_attr);
572 if (new_tensor != nullptr) {
573 *out_temp = std::move(*new_tensor);
574 return Status::OK();
575 }
576 }
577 return allocate_temp(type, shape, out_temp, allocator_attr);
578 }
579
delete_ref_input(int index,bool lock_held)580 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
581 DCHECK_GE(index, 0);
582 DCHECK_LT(index, num_inputs());
583 DCHECK(input_is_ref(index));
584 // should only modify the tensor while holding the mutex
585 if (lock_held) {
586 delete (*params_->inputs)[index].tensor;
587 } else {
588 mutex_lock l(*input_ref_mutex(index));
589 delete (*params_->inputs)[index].tensor;
590 }
591 }
592
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)593 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
594 bool lock_held) {
595 int start, stop;
596 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
597 if (stop != start + 1) {
598 return errors::InvalidArgument("OpKernel used list-valued input name '",
599 name,
600 "' when single-valued input was expected");
601 }
602 if (!input_is_ref(start)) {
603 return errors::InvalidArgument("OpKernel used non-ref input name '", name,
604 "' when ref input was expected");
605 }
606 // return a copy of the Ref acquired while holding the mutex
607 if (lock_held) {
608 *tensor = *(*params_->inputs)[start].tensor;
609 } else {
610 tf_shared_lock l(*input_ref_mutex(start));
611 *tensor = *(*params_->inputs)[start].tensor;
612 }
613 record_tensor_reference(*tensor);
614 return Status::OK();
615 }
616
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)617 Status OpKernelContext::replace_ref_input(StringPiece name,
618 const Tensor& tensor,
619 bool lock_held) {
620 int start, stop;
621 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
622 if (stop != start + 1) {
623 return errors::InvalidArgument("OpKernel used list-valued input name '",
624 name,
625 "' when single-valued input was expected");
626 }
627 if (!input_is_ref(start)) {
628 return errors::InvalidArgument("OpKernel used immutable input name '", name,
629 "' when ref input was expected");
630 }
631 replace_ref_input(start, tensor, lock_held);
632 return Status::OK();
633 }
634
input_list(StringPiece name,OpInputList * list)635 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
636 int start, stop;
637 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
638 *list = OpInputList(this, start, stop);
639 return Status::OK();
640 }
641
mutable_input_list(StringPiece name,OpMutableInputList * list)642 Status OpKernelContext::mutable_input_list(StringPiece name,
643 OpMutableInputList* list) {
644 int start, stop;
645 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
646 *list = OpMutableInputList(this, start, stop);
647 return Status::OK();
648 }
649
output_list(StringPiece name,OpOutputList * list)650 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
651 int start, stop;
652 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
653 *list = OpOutputList(this, start, stop);
654 return Status::OK();
655 }
656
allocate_output(int index,const TensorShape & shape,Tensor ** output)657 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
658 Tensor** output) {
659 DCHECK_GE(index, 0);
660 DCHECK_LT(index, num_outputs());
661 bool forward_expected =
662 (params_->forward_from_array != nullptr && index >= 0 &&
663 params_->forward_from_array[index] >= 0);
664 if (forward_expected) {
665 return errors::Internal(
666 "Explicit allocate_output call where input forwarding required. Try "
667 "turning off the ScopedAllocator optimizer.");
668 }
669 AllocatorAttributes attr = output_alloc_attr(index);
670 return allocate_output(index, shape, output, attr);
671 }
672
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)673 Status OpKernelContext::allocate_output(StringPiece name,
674 const TensorShape& shape,
675 Tensor** tensor) {
676 int start, stop;
677 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
678 if (stop != start + 1) {
679 return errors::InvalidArgument("OpKernel used list-valued output name '",
680 name,
681 "' when single-valued output was "
682 "expected");
683 }
684 return allocate_output(start, shape, tensor);
685 }
686
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)687 Status OpKernelContext::allocate_output(StringPiece name,
688 const TensorShape& shape,
689 Tensor** tensor,
690 AllocatorAttributes attr) {
691 int start, stop;
692 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
693 if (stop != start + 1) {
694 return errors::InvalidArgument("OpKernel used list-valued output name '",
695 name,
696 "' when single-valued output was "
697 "expected");
698 }
699 return allocate_output(start, shape, tensor, attr);
700 }
701
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)702 Status OpKernelContext::allocate_tensor(
703 DataType type, const TensorShape& shape, Tensor* out_tensor,
704 AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
705 Allocator* a = get_allocator(attr);
706 AllocationAttributes logged_attr(allocation_attr);
707 logged_attr.allocation_will_be_logged = true;
708 Tensor new_tensor(a, type, shape, logged_attr);
709
710 if (!new_tensor.IsInitialized()) {
711 return errors::ResourceExhausted(
712 "OOM when allocating tensor with shape", shape.DebugString(),
713 " and type ", DataTypeString(type), " on ", params_->device->name(),
714 " by allocator ", a->Name());
715 }
716 if (params_->log_memory) {
717 LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
718 params_->step_id, new_tensor);
719 }
720 record_tensor_reference(new_tensor);
721 *out_tensor = std::move(new_tensor);
722 return Status::OK();
723 }
724
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)725 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
726 Tensor** output,
727 AllocatorAttributes attr) {
728 DCHECK_GE(index, 0);
729 DCHECK_LT(index, outputs_.size());
730 const DataType type = params_->op_kernel->output_type(index);
731 DCHECK(!IsRefType(type));
732 DCHECK(mutable_output(index) == nullptr);
733 auto output_tensor = MakeUnique<Tensor>();
734 Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
735 if (s.ok()) {
736 outputs_[index] = TensorValue(output_tensor.release());
737 *output = outputs_[index].tensor;
738 }
739 return s;
740 }
741
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)742 Status OpKernelContext::allocate_temp(
743 DataType type, const TensorShape& shape, Tensor* out_temp,
744 AllocatorAttributes allocator_attr,
745 const AllocationAttributes& allocation_attr) {
746 Status s =
747 allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
748 if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
749 Allocator* a = get_allocator(allocator_attr);
750 if (a->TracksAllocationSizes()) {
751 int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
752 record_temp_memory_allocation(alloc_size, *out_temp);
753 }
754 }
755 return s;
756 }
757
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor,AllocatorAttributes attr)758 Status OpKernelContext::allocate_persistent(DataType type,
759 const TensorShape& shape,
760 PersistentTensor* out_persistent,
761 Tensor** out_tensor,
762 AllocatorAttributes attr) {
763 Tensor persistent;
764 Status s = allocate_tensor(type, shape, &persistent, attr);
765 if (s.ok()) {
766 *out_persistent = PersistentTensor(persistent);
767 if (out_tensor) {
768 *out_tensor = out_persistent->AccessTensor(this);
769 }
770 if (track_allocations()) {
771 Tensor* t = out_persistent->AccessTensor(this);
772 Allocator* a = get_allocator(attr);
773 if (a->TracksAllocationSizes()) {
774 int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
775 int64 alloc_id = a->AllocationId(t->tensor_data().data());
776 record_persistent_memory_allocation(alloc_size, alloc_id);
777 }
778 }
779 }
780 return s;
781 }
782
set_output(StringPiece name,const Tensor & tensor)783 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
784 int start, stop;
785 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
786 if (stop != start + 1) {
787 return errors::InvalidArgument("OpKernel used list-valued output name '",
788 name,
789 "' when single-valued output was "
790 "expected");
791 }
792 set_output(start, tensor);
793 return Status::OK();
794 }
795
set_output(int index,const Tensor & tensor)796 void OpKernelContext::set_output(int index, const Tensor& tensor) {
797 DCHECK_GE(index, 0);
798 DCHECK_LT(index, outputs_.size());
799 DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
800 DCHECK_EQ(mutable_output(index), nullptr);
801 record_tensor_reference(tensor);
802 outputs_[index] = TensorValue(new Tensor(tensor));
803 if (track_allocations() && tensor.TotalBytes() > 0) {
804 mutex_lock l(stats_mu_);
805 if (!temp_tensor_buffer_and_size_) {
806 return;
807 }
808 auto it = std::find_if(temp_tensor_buffer_and_size_->begin(),
809 temp_tensor_buffer_and_size_->end(),
810 [&tensor](const std::pair<const void*, int64>& e) {
811 return e.first == static_cast<const void*>(
812 tensor.tensor_data().data());
813 });
814 if (it != temp_tensor_buffer_and_size_->end()) {
815 temp_memory_allocated_ -= it->second;
816 temp_tensor_buffer_and_size_->erase(it);
817 }
818 }
819 }
820
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)821 void OpKernelContext::set_output_ref(int index, mutex* mu,
822 Tensor* tensor_for_ref) {
823 DCHECK_GE(index, 0);
824 DCHECK_LT(index, outputs_.size());
825 DCHECK(IsRefType(params_->op_kernel->output_type(index)));
826 record_tensor_reference(*tensor_for_ref);
827 outputs_[index] = TensorValue(mu, tensor_for_ref);
828 }
829
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)830 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
831 Tensor* tensor_for_ref) {
832 int start, stop;
833 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
834 if (stop != start + 1) {
835 return errors::InvalidArgument("OpKernel used list-valued output name '",
836 name,
837 "' when single-valued output was "
838 "expected");
839 }
840 set_output_ref(start, mu, tensor_for_ref);
841 return Status::OK();
842 }
843
mutable_output(StringPiece name,Tensor ** tensor)844 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
845 int start, stop;
846 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
847 if (stop != start + 1) {
848 return errors::InvalidArgument("OpKernel used list-valued output name '",
849 name,
850 "' when single-valued output was "
851 "expected");
852 }
853 *tensor = mutable_output(start);
854 return Status::OK();
855 }
856
ValidateInputsAreSameShape(OpKernel * op)857 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
858 const auto& inputs = *params_->inputs;
859 for (size_t i = 1; i < inputs.size(); ++i) {
860 if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
861 SetStatus(errors::InvalidArgument(
862 "Inputs to operation ", op->name(), " of type ", op->type_string(),
863 " must have the same size and shape. Input 0: ",
864 inputs[0]->shape().DebugString(), " != input ", i, ": ",
865 inputs[i]->shape().DebugString()));
866 return false;
867 }
868 }
869 return true;
870 }
871
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)872 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
873 const DataTypeSlice expected_outputs) {
874 DataTypeVector inputs;
875 for (const TensorValue& t : *params_->inputs) {
876 inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
877 }
878 DataTypeVector outputs = params_->op_kernel->output_types();
879 return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
880 outputs);
881 }
882
record_temp_memory_allocation(int64 size,const Tensor & t)883 void OpKernelContext::record_temp_memory_allocation(int64 size,
884 const Tensor& t) {
885 mutex_lock l(stats_mu_);
886 temp_memory_allocated_ += size;
887 if (!temp_tensor_buffer_and_size_) {
888 temp_tensor_buffer_and_size_.reset(
889 new gtl::InlinedVector<std::pair<const void*, int64>, 2>());
890 }
891 temp_tensor_buffer_and_size_->emplace_back(
892 static_cast<const void*>(t.tensor_data().data()), size);
893 }
894
temp_memory_allocated() const895 int64 OpKernelContext::temp_memory_allocated() const {
896 mutex_lock l(stats_mu_);
897 return temp_memory_allocated_;
898 }
899
record_persistent_memory_allocation(int64 size,int64 alloc_id)900 void OpKernelContext::record_persistent_memory_allocation(int64 size,
901 int64 alloc_id) {
902 mutex_lock l(stats_mu_);
903 persistent_memory_allocated_ += size;
904 if (alloc_id >= 0) {
905 if (!persistent_alloc_ids_) {
906 persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>());
907 }
908 persistent_alloc_ids_->push_back(alloc_id);
909 }
910 }
911
persistent_memory_allocated() const912 int64 OpKernelContext::persistent_memory_allocated() const {
913 mutex_lock l(stats_mu_);
914 return persistent_memory_allocated_;
915 }
916
persistent_alloc_ids() const917 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
918 mutex_lock l(stats_mu_);
919 if (persistent_alloc_ids_) {
920 return std::vector<int64>(persistent_alloc_ids_->begin(),
921 persistent_alloc_ids_->end());
922 } else {
923 return std::vector<int64>();
924 }
925 }
926
clear_recorded_memory()927 void OpKernelContext::clear_recorded_memory() {
928 mutex_lock l(stats_mu_);
929 temp_memory_allocated_ = 0;
930 persistent_memory_allocated_ = 0;
931 if (temp_tensor_buffer_and_size_) {
932 temp_tensor_buffer_and_size_->clear();
933 }
934 if (persistent_alloc_ids_) {
935 persistent_alloc_ids_->clear();
936 }
937 }
938
939 // OpKernel registration ------------------------------------------------------
940
941 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration942 KernelRegistration(const KernelDef& d, StringPiece c,
943 std::unique_ptr<kernel_factory::OpKernelFactory> f)
944 : def(d), kernel_class_name(c), factory(std::move(f)) {}
945
946 const KernelDef def;
947 const string kernel_class_name;
948 std::unique_ptr<kernel_factory::OpKernelFactory> factory;
949 };
950
951 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
952 // factory functions for instantiating the OpKernel that matches the
953 // KernelDef.
954 typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
955
956 #if defined(_WIN32)
957 static const char kKernelLibPattern[] = "libtfkernel*.dll";
958 #elif defined(__APPLE__)
959 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
960 #else
961 static const char kKernelLibPattern[] = "libtfkernel*.so";
962 #endif
963
964 #define FEATURE(x) \
965 { x, #x }
966
967 // Returns Status::OK if the dynamic library at the given path is safe to
968 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)969 static Status IsProbablySafeToLoad(const string& path) {
970 // A map of platform string to required CPU feature.
971 using port::CPUFeature;
972 static const auto* feature_map =
973 new std::map<string, std::pair<CPUFeature, string>>{
974 {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
975 };
976
977 std::vector<std::string> platform_strings;
978 int result = GetPlatformStrings(path, &platform_strings);
979 if (result) {
980 return Status(error::Code::UNKNOWN, strerror(result));
981 }
982 if (platform_strings.empty()) {
983 return Status(error::Code::FAILED_PRECONDITION,
984 "Didn't find any platform strings");
985 }
986 std::vector<std::string> missing_features;
987 for (const auto& platform_string : platform_strings) {
988 const auto& entry = feature_map->find(platform_string);
989 if (entry != feature_map->end() &&
990 !port::TestCPUFeature(entry->second.first)) {
991 missing_features.emplace_back(entry->second.second);
992 }
993 }
994 if (!missing_features.empty()) {
995 string errmsg = "Missing CPU features: ";
996 errmsg.append(str_util::Join(missing_features, ", "));
997 return Status(errors::Code::FAILED_PRECONDITION, errmsg);
998 }
999 return Status::OK();
1000 }
1001
LoadDynamicKernelsInternal()1002 void LoadDynamicKernelsInternal() {
1003 Env* env = Env::Default();
1004
1005 // Override to allow loading unsafe packages for development.
1006 // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1007 bool override_abi_check =
1008 strcmp(getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES"), "1") == 0;
1009
1010 string bazel_kernel_dir = io::JoinPath(env->GetRunfilesDir(),
1011 "tensorflow",
1012 "core",
1013 "kernels");
1014 std::vector<string> files;
1015 Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1016 if (s_kernel_dir.ok()) {
1017 string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1018 for (const auto& file : files) {
1019 string fullpath = io::JoinPath(bazel_kernel_dir, file);
1020 if (env->MatchPath(fullpath, dll_spec)) {
1021 Status s = IsProbablySafeToLoad(fullpath);
1022 if (!s.ok() && override_abi_check) {
1023 LOG(WARNING) << "Loading UNSAFE library " << fullpath
1024 << " because ABI check override is set: "
1025 << s.error_message();
1026 }
1027 if (s.ok() || override_abi_check) {
1028 // TODO(gunan): Store the handles to the opened files.
1029 void* unused_filehandle;
1030 TF_CHECK_OK(env->LoadLibrary(fullpath.c_str(), &unused_filehandle));
1031 } else {
1032 LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1033 << s.error_message();
1034 }
1035 }
1036 }
1037 }
1038 }
1039
1040 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1041 void LoadDynamicKernels() {
1042 // TODO(gunan): As more features are available, add intelligent kernel
1043 // selection, and dropping unsuitable kernel logic here.
1044 static std::once_flag dll_loader_flag;
1045 std::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1046 }
1047
GlobalKernelRegistry()1048 void* GlobalKernelRegistry() {
1049 static KernelRegistry* global_kernel_registry = new KernelRegistry;
1050 return global_kernel_registry;
1051 }
1052
GlobalKernelRegistryTyped()1053 static KernelRegistry* GlobalKernelRegistryTyped() {
1054 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1055 LoadDynamicKernels();
1056 #endif // AUTOLOAD_DYNAMIC_KERNELS
1057 return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1058 }
1059
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1060 static string Key(StringPiece op_type, const DeviceType& device_type,
1061 StringPiece label) {
1062 return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1063 label);
1064 }
1065
1066 namespace kernel_factory {
1067
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1068 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1069 StringPiece kernel_class_name,
1070 std::unique_ptr<OpKernelFactory> factory) {
1071 // See comments in register_kernel::Name in header for info on _no_register.
1072 if (kernel_def->op() != "_no_register") {
1073 const string key =
1074 Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1075 kernel_def->label());
1076
1077 // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1078 // here.
1079 // InitInternal gets called by static initializers, so it ends up executing
1080 // before main. This causes LoadKernelLibraries function to get called
1081 // before some file libraries can initialize, which in turn crashes the
1082 // program flakily. Until we get rid of static initializers in kernel
1083 // registration mechanism, we have this workaround here.
1084 reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry())
1085 ->emplace(key, KernelRegistration(*kernel_def, kernel_class_name,
1086 std::move(factory)));
1087 }
1088 delete kernel_def;
1089 }
1090
Create(OpKernelConstruction * context)1091 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1092 OpKernelConstruction* context) {
1093 return (*create_func_)(context);
1094 }
1095
1096 } // namespace kernel_factory
1097
1098 namespace {
1099
1100 static const StringPiece kKernelAttr("_kernel");
1101
1102 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1103 Status FindKernelRegistration(const DeviceType& device_type,
1104 const NodeDef& node_def,
1105 const KernelRegistration** reg,
1106 bool* was_attr_mismatch) {
1107 *reg = nullptr;
1108 *was_attr_mismatch = false;
1109 // Label defaults to empty if not found in NodeDef.
1110 const string& label = GetNodeAttrString(node_def, kKernelAttr);
1111
1112 const string key = Key(node_def.op(), device_type, label);
1113 auto regs = GlobalKernelRegistryTyped()->equal_range(key);
1114 for (auto iter = regs.first; iter != regs.second; ++iter) {
1115 // If there is a kernel registered for the op and device_type,
1116 // check that the attrs match.
1117 bool match;
1118 TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
1119 if (match) {
1120 if (*reg != nullptr) {
1121 return errors::InvalidArgument(
1122 "Multiple OpKernel registrations match NodeDef '",
1123 FormatNodeDefForError(node_def), "': '",
1124 ProtoShortDebugString((*reg)->def), "' and '",
1125 ProtoShortDebugString(iter->second.def), "'");
1126 }
1127 *reg = &iter->second;
1128 } else {
1129 *was_attr_mismatch = true;
1130 }
1131 }
1132 return Status::OK();
1133 }
1134
1135 } // namespace
1136
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1137 bool KernelDefAvailable(const DeviceType& device_type,
1138 const NodeDef& node_def) {
1139 const KernelRegistration* reg = nullptr;
1140 bool was_attr_mismatch;
1141 Status result =
1142 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch);
1143 return result.ok() && reg != nullptr;
1144 }
1145
1146 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1147 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1148 const KernelDef** def, string* kernel_class_name) {
1149 const KernelRegistration* reg = nullptr;
1150 bool was_attr_mismatch;
1151 TF_RETURN_IF_ERROR(
1152 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch));
1153 if (reg == nullptr) {
1154 Status s = errors::NotFound(
1155 "No registered '", node_def.op(), "' OpKernel for ",
1156 DeviceTypeString(device_type), " devices compatible with node ",
1157 FormatNodeDefForError(node_def));
1158 if (was_attr_mismatch) {
1159 errors::AppendToMessage(
1160 &s, " (OpKernel was found, but attributes didn't match) ",
1161 "Requested Attributes: ", SummarizeAttrs(node_def));
1162 }
1163 errors::AppendToMessage(
1164 &s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
1165 return s;
1166 }
1167 if (def != nullptr) *def = ®->def;
1168 if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1169 return Status::OK();
1170 }
1171
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types)1172 Status SupportedDeviceTypesForNode(
1173 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1174 PrioritizedDeviceTypeVector* prioritized_device_types) {
1175 // TODO(zhifengc): Changes the callers (SimplePlacer and
1176 // DynamicPlacer) to consider the possibility that 'def' is call to
1177 // a user-defined function and only calls this
1178 // SupportedDeviceTypesForNode for primitive ops.
1179 const OpRegistrationData* op_reg_data;
1180 const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1181 if (s.ok()) {
1182 for (const DeviceType& device_type : prioritized_types) {
1183 const KernelRegistration* reg = nullptr;
1184 bool was_attr_mismatch;
1185 TF_RETURN_IF_ERROR(
1186 FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
1187 if (reg != nullptr) {
1188 int32 priority = reg->def.priority();
1189 prioritized_device_types->emplace_back(device_type, priority);
1190 }
1191 }
1192 std::sort(prioritized_device_types->begin(),
1193 prioritized_device_types->end(),
1194 [](const std::pair<DeviceType, int32>& a,
1195 const std::pair<DeviceType, int32>& b) {
1196 return a.second > b.second;
1197 });
1198 } else {
1199 // Assumes that all device types support this node.
1200 for (const DeviceType& device_type : prioritized_types) {
1201 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1202 }
1203 }
1204 return Status::OK();
1205 }
1206
LogAllRegisteredKernels()1207 void LogAllRegisteredKernels() {
1208 KernelList kernel_list = GetAllRegisteredKernels();
1209 for (const auto& kernel_def : kernel_list.kernel()) {
1210 LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
1211 }
1212 }
1213
GetAllRegisteredKernels()1214 KernelList GetAllRegisteredKernels() {
1215 return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1216 }
1217
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1218 KernelList GetFilteredRegisteredKernels(
1219 const std::function<bool(const KernelDef&)>& predicate) {
1220 const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1221 KernelList kernel_list;
1222 kernel_list.mutable_kernel()->Reserve(typed_registry->size());
1223 for (const auto& p : *typed_registry) {
1224 const KernelDef& kernel_def = p.second.def;
1225 if (predicate(kernel_def)) {
1226 *kernel_list.add_kernel() = kernel_def;
1227 }
1228 }
1229 return kernel_list;
1230 }
1231
GetRegisteredKernelsForOp(StringPiece op_name)1232 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1233 auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1234 return GetFilteredRegisteredKernels(op_pred);
1235 }
1236
KernelsRegisteredForOp(StringPiece op_name)1237 string KernelsRegisteredForOp(StringPiece op_name) {
1238 KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1239 if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
1240 string ret;
1241 for (const auto& kernel_def : kernel_list.kernel()) {
1242 strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
1243 if (!kernel_def.label().empty()) {
1244 strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1245 }
1246 for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1247 strings::StrAppend(
1248 &ret, "; ", kernel_def.constraint(i).name(), " in ",
1249 SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1250 }
1251 strings::StrAppend(&ret, "\n");
1252 }
1253 return ret;
1254 }
1255
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1256 std::unique_ptr<OpKernel> CreateOpKernel(
1257 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1258 const NodeDef& node_def, int graph_def_version, Status* status) {
1259 OpKernel* kernel = nullptr;
1260 *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
1261 node_def, graph_def_version, &kernel);
1262 return std::unique_ptr<OpKernel>(kernel);
1263 }
1264
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const NodeDef & node_def,int graph_def_version,OpKernel ** kernel)1265 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1266 Allocator* allocator, FunctionLibraryRuntime* flib,
1267 const NodeDef& node_def, int graph_def_version,
1268 OpKernel** kernel) {
1269 VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1270
1271 // Look up the Op registered for this op name.
1272 const OpDef* op_def = nullptr;
1273 Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
1274 if (!s.ok()) return s;
1275
1276 // Validate node_def against OpDef.
1277 s = ValidateNodeDef(node_def, *op_def);
1278 if (!s.ok()) return s;
1279
1280 // Look up kernel registration.
1281 const KernelRegistration* registration;
1282 bool was_attr_mismatch;
1283 s = FindKernelRegistration(device_type, node_def, ®istration,
1284 &was_attr_mismatch);
1285 if (!s.ok()) {
1286 errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1287 return s;
1288 }
1289 if (registration == nullptr) {
1290 s.Update(errors::NotFound("No registered '", node_def.op(),
1291 "' OpKernel for ", DeviceTypeString(device_type),
1292 " devices compatible with node ",
1293 FormatNodeDefForError(node_def)));
1294 if (was_attr_mismatch) {
1295 errors::AppendToMessage(
1296 &s, " (OpKernel was found, but attributes didn't match) ",
1297 "Requested Attributes: ", SummarizeAttrs(node_def));
1298 }
1299 errors::AppendToMessage(
1300 &s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
1301 return s;
1302 }
1303
1304 // Get signature from the OpDef & NodeDef
1305 DataTypeVector inputs;
1306 DataTypeVector outputs;
1307 s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
1308 if (!s.ok()) {
1309 errors::AppendToMessage(&s, " for node: ", FormatNodeDefForError(node_def));
1310 return s;
1311 }
1312
1313 // We are creating a kernel for an op registered in
1314 // OpRegistry::Global(), we consult the kernel registry to decide
1315 // the kernel's input and output memory types.
1316 MemoryTypeVector input_memory_types;
1317 MemoryTypeVector output_memory_types;
1318 TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1319 node_def, &input_memory_types,
1320 &output_memory_types));
1321
1322 // Everything needed for OpKernel construction.
1323 OpKernelConstruction context(
1324 device_type, device, allocator, &node_def, op_def, flib, inputs,
1325 input_memory_types, outputs, output_memory_types, graph_def_version, &s);
1326 *kernel = registration->factory->Create(&context);
1327 if (!s.ok()) {
1328 delete *kernel;
1329 *kernel = nullptr;
1330 }
1331 return s;
1332 }
1333
1334 namespace {
1335
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1336 bool FindArgInOp(StringPiece arg_name,
1337 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1338 for (const auto& arg : args) {
1339 if (arg_name == arg.name()) {
1340 return true;
1341 }
1342 }
1343 return false;
1344 }
1345
1346 } // namespace
1347
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1348 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1349 for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
1350 const KernelDef& kernel_def(key_registration.second.def);
1351 const OpRegistrationData* op_reg_data;
1352 const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1353 if (!status.ok()) {
1354 // TODO(josh11b): Make this a hard error.
1355 LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
1356 << "') for unknown op: " << kernel_def.op();
1357 continue;
1358 }
1359 const OpDef& op_def = op_reg_data->op_def;
1360 for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1361 if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1362 !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1363 return errors::InvalidArgument(
1364 "HostMemory arg '", host_memory_arg,
1365 "' not found in OpDef: ", SummarizeOpDef(op_def));
1366 }
1367 }
1368 }
1369 return Status::OK();
1370 }
1371
1372 template <>
eigen_device() const1373 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1374 return eigen_cpu_device();
1375 }
1376
1377 template <>
eigen_device() const1378 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1379 return eigen_gpu_device();
1380 }
1381
1382 #ifdef TENSORFLOW_USE_SYCL
1383 template <>
eigen_device() const1384 const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
1385 return eigen_sycl_device();
1386 }
1387 #endif
1388
CtxFailure(const Status & s)1389 void OpKernelConstruction::CtxFailure(const Status& s) {
1390 VLOG(1) << s;
1391 SetStatus(s);
1392 }
1393
CtxFailureWithWarning(const Status & s)1394 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1395 LOG(WARNING) << s;
1396 SetStatus(s);
1397 }
1398
CtxFailure(const char * file,int line,const Status & s)1399 void OpKernelConstruction::CtxFailure(const char* file, int line,
1400 const Status& s) {
1401 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1402 << " : " << s;
1403 SetStatus(s);
1404 }
1405
CtxFailureWithWarning(const char * file,int line,const Status & s)1406 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1407 const Status& s) {
1408 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1409 << " : " << s;
1410 SetStatus(s);
1411 }
1412
CtxFailure(const Status & s)1413 void OpKernelContext::CtxFailure(const Status& s) {
1414 VLOG(1) << s;
1415 SetStatus(s);
1416 }
1417
CtxFailureWithWarning(const Status & s)1418 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1419 LOG(WARNING) << s;
1420 SetStatus(s);
1421 }
1422
CtxFailure(const char * file,int line,const Status & s)1423 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1424 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1425 << " : " << s;
1426 SetStatus(s);
1427 }
1428
CtxFailureWithWarning(const char * file,int line,const Status & s)1429 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1430 const Status& s) {
1431 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1432 << " : " << s;
1433 SetStatus(s);
1434 }
1435
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1436 void CheckNotInComputeAsync(OpKernelContext* ctx,
1437 const char* correct_macro_name) {
1438 CHECK_EQ(nullptr, ctx->op_kernel().AsAsync())
1439 << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1440 }
1441
1442 } // namespace tensorflow
1443