1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex> // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/base/call_once.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/device_factory.h"
33 #include "tensorflow/core/framework/graph.pb.h"
34 #include "tensorflow/core/framework/kernel_def.pb.h"
35 #include "tensorflow/core/framework/kernel_def_util.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/memory_types.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/node_properties.h"
41 #include "tensorflow/core/framework/op_def_util.h"
42 #include "tensorflow/core/framework/tensor_reference.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/notification.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/platform_strings.h"
56 #include "tensorflow/core/platform/types.h"
57 #include "tensorflow/core/profiler/lib/traceme.h"
58 #include "tensorflow/core/util/ptr_util.h"
59
60 namespace tensorflow {
61
62 namespace {
63
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)64 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
65 const DataTypeSlice expected_outputs,
66 const DataTypeSlice inputs,
67 const DataTypeSlice outputs) {
68 bool signature_mismatch = false;
69
70 if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
71 for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
72 if (!TypesCompatible(expected_inputs[i], inputs[i])) {
73 signature_mismatch = true;
74 }
75 }
76
77 if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
78 for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
79 if (!TypesCompatible(expected_outputs[i], outputs[i])) {
80 signature_mismatch = true;
81 }
82 }
83
84 if (signature_mismatch) {
85 return errors::InvalidArgument(
86 "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
87 DataTypeSliceString(outputs),
88 " expected: ", DataTypeSliceString(expected_inputs), "->",
89 DataTypeSliceString(expected_outputs));
90 }
91 return Status::OK();
92 }
93
94 } // namespace
95
96 // OpKernel ------------------------------------------------------------------
97
OpKernel(OpKernelConstruction * context)98 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
99
OpKernel(OpKernelConstruction * context,bool is_deferred)100 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
101 : props_(context->props_),
102 input_memory_types_(context->input_memory_types().begin(),
103 context->input_memory_types().end()),
104 output_memory_types_(context->output_memory_types().begin(),
105 context->output_memory_types().end()),
106 input_name_map_(context->num_inputs()),
107 output_name_map_(context->num_outputs()),
108 name_view_(props_->node_def.name()),
109 type_string_view_(props_->node_def.op()),
110 graph_def_version_(context->graph_def_version()),
111 is_deferred_(is_deferred) {
112 OP_REQUIRES_OK(context,
113 NameRangesForNode(props_->node_def, *props_->op_def,
114 &input_name_map_, &output_name_map_));
115 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
116 context->graph_def_version()));
117
118 // Kernels executing on GPU tie very few resources on the CPU where the
119 // scheduler runs: we consider them as inexpensive.
120 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
121 !DeviceFactory::IsPluggableDevice(
122 DeviceTypeString(context->device_type()));
123 }
124
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)125 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
126 bool is_deferred)
127 : props_(std::make_shared<const NodeProperties>(
128 context->props_->op_def, std::move(custom_def),
129 context->props_->input_types, context->props_->output_types)),
130 input_memory_types_(context->input_memory_types().begin(),
131 context->input_memory_types().end()),
132 output_memory_types_(context->output_memory_types().begin(),
133 context->output_memory_types().end()),
134 input_name_map_(context->num_inputs()),
135 output_name_map_(context->num_outputs()),
136 name_view_(props_->node_def.name()),
137 type_string_view_(props_->node_def.op()),
138 graph_def_version_(context->graph_def_version()),
139 is_deferred_(is_deferred) {
140 OP_REQUIRES_OK(context,
141 NameRangesForNode(props_->node_def, *props_->op_def,
142 &input_name_map_, &output_name_map_));
143 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
144 context->graph_def_version()));
145
146 // Kernels executing on GPU tie very few resources on the CPU where the
147 // scheduler runs: we consider them as inexpensive.
148 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
149 !DeviceFactory::IsPluggableDevice(
150 DeviceTypeString(context->device_type()));
151 }
152
~OpKernel()153 OpKernel::~OpKernel() {}
154
InputRange(StringPiece input_name,int * start,int * stop) const155 Status OpKernel::InputRange(StringPiece input_name, int* start,
156 int* stop) const {
157 const auto result = input_name_map_.find(input_name);
158 if (result == input_name_map_.end()) {
159 return errors::InvalidArgument("Unknown input name: ", input_name);
160 } else {
161 *start = result->second.first;
162 *stop = result->second.second;
163 return Status::OK();
164 }
165 }
166
OutputRange(StringPiece output_name,int * start,int * stop) const167 Status OpKernel::OutputRange(StringPiece output_name, int* start,
168 int* stop) const {
169 const auto result = output_name_map_.find(output_name);
170 if (result == output_name_map_.end()) {
171 return errors::InvalidArgument("Unknown output name: ", output_name);
172 } else {
173 *start = result->second.first;
174 *stop = result->second.second;
175 return Status::OK();
176 }
177 }
178
ShapeTraceString(const OpKernelContext & ctx) const179 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
180 int num_inputs = ctx.num_inputs();
181 if (num_inputs == 0) return "";
182 std::vector<string> tensor_shapes;
183 tensor_shapes.reserve(num_inputs);
184 for (int i = 0; i < num_inputs; i++) {
185 if (!ctx.has_input(i)) {
186 tensor_shapes.emplace_back(); // Placeholder
187 continue;
188 }
189 DataType input_dtype = ctx.input_dtype(i);
190 if (input_dtype == DataType::DT_RESOURCE ||
191 input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
192 tensor_shapes.emplace_back(); // Placeholder
193 continue;
194 }
195 tensor_shapes.emplace_back(strings::StrCat(
196 DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
197 }
198 return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
199 }
200
TraceString(const OpKernelContext & ctx,bool verbose) const201 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
202 string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
203 if (verbose) {
204 string shape = ShapeTraceString(ctx);
205 if (!shape.empty()) {
206 trace_string =
207 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
208 }
209 }
210 return trace_string;
211 }
212
Compute(OpKernelContext * context)213 void AsyncOpKernel::Compute(OpKernelContext* context) {
214 Notification n;
215 ComputeAsync(context, [&n]() { n.Notify(); });
216 n.WaitForNotification();
217 }
218
219 // OpKernelConstruction ------------------------------------------------------
220
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)221 OpKernelConstruction::OpKernelConstruction(
222 DeviceType device_type, DeviceBase* device, Allocator* allocator,
223 FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
224 const std::shared_ptr<const NodeProperties>& props,
225 const MemoryTypeSlice& input_memory_types,
226 const MemoryTypeSlice& output_memory_types, int graph_def_version,
227 Status* status)
228 : device_type_(std::move(device_type)),
229 device_(device),
230 allocator_(allocator),
231 flib_(flib),
232 resource_mgr_(resource_mgr),
233 props_(props),
234 input_memory_types_(input_memory_types),
235 output_memory_types_(output_memory_types),
236 graph_def_version_(graph_def_version),
237 status_(status) {}
238
HasAttr(StringPiece attr_name) const239 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
240 return HasNodeAttr(def(), attr_name);
241 }
242
SetStatus(const Status & status)243 void OpKernelConstruction::SetStatus(const Status& status) {
244 status_->Update(status);
245 }
246
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)247 Status OpKernelConstruction::MatchSignature(
248 const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
249 return MatchSignatureHelper(expected_inputs, expected_outputs,
250 props_->input_types, props_->output_types);
251 }
252
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)253 Status OpKernelConstruction::allocate_temp(DataType type,
254 const TensorShape& shape,
255 Tensor* out_temp) {
256 AllocationAttributes attr;
257 attr.allocation_will_be_logged = true;
258 Tensor new_temp(allocator_, type, shape, attr);
259
260 if (!new_temp.IsInitialized()) {
261 return errors::ResourceExhausted(
262 "OOM when allocating temporary tensor with shape", shape.DebugString());
263 }
264 if (LogMemory::IsEnabled()) {
265 LogMemory::RecordTensorAllocation(
266 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
267 }
268 *out_temp = new_temp;
269 return Status::OK();
270 }
271
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)272 Status OpKernelConstruction::allocate_temp(DataType type,
273 const TensorShape& shape,
274 Tensor* out_temp,
275 AllocatorAttributes allocator_attr) {
276 if (allocator_attr.scope_id != 0) {
277 return errors::InvalidArgument(
278 "ScopedAllocator cannot be used via OpKernelConstruction.");
279 }
280 Allocator* a = device_->GetAllocator(allocator_attr);
281 AllocationAttributes attr;
282 attr.allocation_will_be_logged = true;
283 Tensor new_temp(a, type, shape, attr);
284
285 if (!new_temp.IsInitialized()) {
286 return errors::ResourceExhausted(
287 "OOM when allocating temporary tensor with shape", shape.DebugString());
288 }
289 if (LogMemory::IsEnabled()) {
290 LogMemory::RecordTensorAllocation(
291 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
292 }
293 *out_temp = new_temp;
294 return Status::OK();
295 }
296
297 // OpKernelContext -----------------------------------------------------------
298
299 const int OpKernelContext::Params::kNeverForward;
300 const int OpKernelContext::Params::kNoReservation;
301
OpKernelContext(Params * params)302 OpKernelContext::OpKernelContext(Params* params)
303 : OpKernelContext(
304 params, static_cast<int>(params->op_kernel->output_types().size())) {}
305
OpKernelContext(Params * params,int num_outputs)306 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
307 : params_(params), outputs_(num_outputs) {
308 if (params_->track_allocations) {
309 tracking_state_ = absl::make_unique<TrackingState>();
310 }
311
312 params_->ensure_eigen_gpu_device();
313 if (params_->eigen_gpu_device != nullptr) {
314 Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
315 Status s = params_->device->ReinitializeGpuDevice(
316 this, params_->eigen_gpu_device, params_->op_device_context,
317 eigen_gpu_allocator);
318 if (!s.ok()) {
319 SetStatus(s);
320 }
321 }
322 }
323
~OpKernelContext()324 OpKernelContext::~OpKernelContext() {
325 for (TensorValue& value : outputs_) {
326 if (!value.is_ref()) {
327 delete value.tensor;
328 }
329 }
330 if (params_->track_allocations &&
331 !tracking_state_->wrapped_allocators.empty()) {
332 LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
333 << "being consumed by the StepStatsCollector.";
334 for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
335 wrapped_allocator.second->GetRecordsAndUnRef();
336 }
337 }
338 }
339
get_allocator(AllocatorAttributes attr)340 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
341 Allocator* allocator = nullptr;
342 if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
343 allocator = params_->device->GetScopedAllocator(attr, step_id());
344 CHECK(allocator);
345 } else {
346 allocator = params_->device->GetAllocator(attr);
347 }
348 if (TF_PREDICT_FALSE(track_allocations())) {
349 DCHECK(tracking_state_);
350 mutex_lock lock(tracking_state_->mu);
351 for (const auto& wrapped : tracking_state_->wrapped_allocators) {
352 if (wrapped.first == allocator) {
353 return wrapped.second;
354 }
355 }
356 TrackingAllocator* wrapped_allocator =
357 new TrackingAllocator(allocator, params_->track_allocations);
358 tracking_state_->wrapped_allocators.push_back(
359 std::make_pair(allocator, wrapped_allocator));
360 return wrapped_allocator;
361 } else {
362 return allocator;
363 }
364 }
365
SetStatus(const Status & status)366 void OpKernelContext::SetStatus(const Status& status) {
367 status_.Update(status);
368 }
369
input(StringPiece name,const Tensor ** tensor)370 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
371 int index;
372 TF_RETURN_IF_ERROR(get_input_index(name, &index));
373 if (input_is_ref(index)) {
374 return errors::InvalidArgument("OpKernel used ref input name '", name,
375 "' when non-ref input was expected");
376 }
377 *tensor = (*params_->inputs)[index].tensor;
378 return Status::OK();
379 }
380
input_dtype(StringPiece name,DataType * dtype) const381 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
382 int index;
383 TF_RETURN_IF_ERROR(get_input_index(name, &index));
384 const TensorValue& value((*params_->inputs)[index]);
385 *dtype = value.dtype();
386 return Status::OK();
387 }
388
input_ref_mutex(StringPiece name,mutex ** out_mutex)389 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
390 int index;
391 TF_RETURN_IF_ERROR(get_input_index(name, &index));
392 *out_mutex = input_ref_mutex(index);
393 return Status::OK();
394 }
395
input(int index) const396 const Tensor& OpKernelContext::input(int index) const {
397 CHECK_GE(index, 0);
398 CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
399 CHECK(!input_is_ref(index));
400 const Tensor& tensor = *((*params_->inputs)[index].tensor);
401 return tensor;
402 }
403
mutable_input(int index,bool lock_held)404 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
405 CHECK_GE(index, 0);
406 CHECK_LT(index, num_inputs());
407 CHECK(input_is_ref(index));
408 // return a copy of the Ref acquired while holding the mutex
409 if (lock_held) {
410 Tensor& tensor = *((*params_->inputs)[index].tensor);
411 return tensor;
412 } else {
413 tf_shared_lock l(*input_ref_mutex(index));
414 Tensor& tensor = *((*params_->inputs)[index].tensor);
415 return tensor;
416 }
417 }
418
replace_ref_input(int index,const Tensor & tensor,bool lock_held)419 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
420 bool lock_held) {
421 CHECK_GE(index, 0);
422 CHECK_LT(index, num_inputs());
423 CHECK(input_is_ref(index));
424 // should only modify the tensor while holding the mutex
425 if (lock_held) {
426 *(*params_->inputs)[index].tensor = tensor;
427 } else {
428 mutex_lock l(*input_ref_mutex(index));
429 *(*params_->inputs)[index].tensor = tensor;
430 }
431 }
432
forward_ref_input_to_ref_output(int input_index,int output_index)433 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
434 int output_index) {
435 CHECK_GE(input_index, 0);
436 CHECK_LT(input_index, num_inputs());
437 CHECK(input_is_ref(input_index));
438 set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
439 (*params_->inputs)[input_index].tensor);
440 }
441
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)442 bool OpKernelContext::forward_input_to_output_with_shape(
443 int input_index, int output_index, const TensorShape& output_shape,
444 Tensor** output) {
445 const auto output_attr = params_->output_attr_array == nullptr
446 ? AllocatorAttributes()
447 : output_alloc_attr(output_index);
448 std::unique_ptr<Tensor> new_tensor = forward_input(
449 input_index, output_index, expected_output_dtype(output_index),
450 output_shape, output_memory_type(output_index), output_attr);
451 if (new_tensor != nullptr) {
452 // Transfer ownership to the output slot in OpKernelContext.
453 outputs_[output_index] = TensorValue(new_tensor.release());
454 *output = outputs_[output_index].tensor;
455 return true;
456 } else {
457 return false;
458 }
459 }
460
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)461 Status OpKernelContext::forward_input_to_output_with_shape(
462 StringPiece input_name, StringPiece output_name,
463 const TensorShape& output_shape, Tensor** output) {
464 int input_index, output_index;
465 TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
466 TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
467 if (!forward_input_to_output_with_shape(input_index, output_index,
468 output_shape, output)) {
469 return errors::FailedPrecondition("OpKernel could not forward input '",
470 input_name, "' to output '", output_name);
471 }
472 return Status::OK();
473 }
474
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)475 std::unique_ptr<Tensor> OpKernelContext::forward_input(
476 int input_index, int output_index, DataType output_dtype,
477 const TensorShape& output_shape, MemoryType output_memory_type,
478 const AllocatorAttributes& output_attr) {
479 CHECK_GE(input_index, 0);
480 CHECK_LT(input_index, num_inputs());
481 const TensorValue& input = (*params_->inputs)[input_index];
482 // Check whether at graph construction time this output was marked
483 // either for no forwarding or with a reservation for this input.
484 // If it's reserved for this input we'll skip the refcount and
485 // AllocatorAttribute checks.
486 // TODO(tucker): Maybe we should skip all of the checks?
487 bool never_forward =
488 (params_->forward_from_array != nullptr && output_index >= 0 &&
489 params_->forward_from_array[output_index] == Params::kNeverForward);
490 if (never_forward) return nullptr;
491 bool forward_expected =
492 (params_->forward_from_array != nullptr && output_index >= 0 &&
493 params_->forward_from_array[output_index] == input_index);
494 if (!forward_expected && params_->forward_from_array != nullptr) {
495 // Check for possibly conflicting forward.
496 for (int i = 0; i < num_outputs(); ++i) {
497 if (params_->forward_from_array[i] == input_index) {
498 // This input is reserved for output i.
499 return nullptr;
500 }
501 }
502 }
503 // Check that input tensor exists and is not a ref.
504 if (input.tensor == nullptr || input.is_ref()) {
505 CHECK(!forward_expected);
506 return nullptr;
507 }
508 // Check that input type matches.
509 if (input_dtype(input_index) != output_dtype) {
510 CHECK(!forward_expected);
511 return nullptr;
512 }
513 // Check that the input and output sizes are compatible.
514 if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
515 CHECK(!forward_expected);
516 return nullptr;
517 }
518 // Check that input and output memory types match, i.e.
519 // that they either both live in host or both live in device memory.
520 if (input_memory_type(input_index) != output_memory_type) {
521 CHECK(!forward_expected);
522 return nullptr;
523 }
524 if (!forward_expected) {
525 if (!input->RefCountIsOne()) {
526 return nullptr;
527 }
528 // Check that output allocator attributes are not more restrictive than
529 // input allocator attributes.
530 const auto input_attr = params_->input_alloc_attrs == nullptr
531 ? AllocatorAttributes()
532 : input_alloc_attr(input_index);
533 if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
534 return nullptr;
535 }
536 }
537
538 auto output_tensor = MakeUnique<Tensor>();
539 CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
540 return output_tensor;
541 }
542
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)543 Status OpKernelContext::forward_input_or_allocate_temp(
544 gtl::ArraySlice<int> candidate_input_indices, DataType type,
545 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
546 Tensor* out_temp) {
547 for (int input_index : candidate_input_indices) {
548 std::unique_ptr<Tensor> new_tensor =
549 forward_input(input_index, Params::kNoReservation /*output_index*/,
550 type, shape, DEVICE_MEMORY, allocator_attr);
551 if (new_tensor != nullptr) {
552 *out_temp = std::move(*new_tensor);
553 return Status::OK();
554 }
555 }
556 return allocate_temp(type, shape, out_temp, allocator_attr);
557 }
558
delete_ref_input(int index,bool lock_held)559 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
560 CHECK_GE(index, 0);
561 CHECK_LT(index, num_inputs());
562 CHECK(input_is_ref(index));
563 // should only modify the tensor while holding the mutex
564 if (lock_held) {
565 delete (*params_->inputs)[index].tensor;
566 } else {
567 mutex_lock l(*input_ref_mutex(index));
568 delete (*params_->inputs)[index].tensor;
569 }
570 }
571
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)572 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
573 bool lock_held) {
574 int index;
575 TF_RETURN_IF_ERROR(get_input_index(name, &index));
576 if (!input_is_ref(index)) {
577 return errors::InvalidArgument("OpKernel used non-ref input name '", name,
578 "' when ref input was expected");
579 }
580 // return a copy of the Ref acquired while holding the mutex
581 if (lock_held) {
582 *tensor = *(*params_->inputs)[index].tensor;
583 } else {
584 tf_shared_lock l(*input_ref_mutex(index));
585 *tensor = *(*params_->inputs)[index].tensor;
586 }
587 return Status::OK();
588 }
589
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)590 Status OpKernelContext::replace_ref_input(StringPiece name,
591 const Tensor& tensor,
592 bool lock_held) {
593 int index;
594 TF_RETURN_IF_ERROR(get_input_index(name, &index));
595 if (!input_is_ref(index)) {
596 return errors::InvalidArgument("OpKernel used immutable input name '", name,
597 "' when ref input was expected");
598 }
599 replace_ref_input(index, tensor, lock_held);
600 return Status::OK();
601 }
602
input_list(StringPiece name,OpInputList * list)603 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
604 int start, stop;
605 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
606 *list = OpInputList(this, start, stop);
607 return Status::OK();
608 }
609
mutable_input_list(StringPiece name,OpMutableInputList * list)610 Status OpKernelContext::mutable_input_list(StringPiece name,
611 OpMutableInputList* list) {
612 int start, stop;
613 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
614 *list = OpMutableInputList(this, start, stop);
615 return Status::OK();
616 }
617
output_list(StringPiece name,OpOutputList * list)618 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
619 int start, stop;
620 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
621 *list = OpOutputList(this, start, stop);
622 return Status::OK();
623 }
624
maybe_initialize_scope_id_set()625 void OpKernelContext::maybe_initialize_scope_id_set() {
626 if (allocated_scope_ids_ == nullptr) {
627 allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
628 }
629 }
630
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)631 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
632 Tensor** tensor) {
633 if (index < 0) {
634 return errors::Internal("allocate_output with bad index=", index,
635 " kernel=", params_->op_kernel->name());
636 }
637 if (index >= num_outputs()) {
638 return errors::Internal("allocate_output with bad index=", index,
639 " num_outputs=", num_outputs(),
640 " kernel=", params_->op_kernel->name());
641 }
642 bool forward_expected =
643 (params_->forward_from_array != nullptr && index >= 0 &&
644 params_->forward_from_array[index] >= 0);
645 if (forward_expected) {
646 return errors::Internal(
647 "Explicit allocate_output call where input forwarding required. Try "
648 "turning off the ScopedAllocator optimizer.");
649 }
650 AllocatorAttributes attr = output_alloc_attr(index);
651 return allocate_output(index, shape, tensor, attr);
652 }
653
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)654 Status OpKernelContext::allocate_output(StringPiece name,
655 const TensorShape& shape,
656 Tensor** tensor) {
657 int start, stop;
658 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
659 if (stop != start + 1) {
660 return errors::InvalidArgument("OpKernel used list-valued output name '",
661 name,
662 "' when single-valued output was "
663 "expected");
664 }
665 return allocate_output(start, shape, tensor);
666 }
667
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)668 Status OpKernelContext::allocate_output(StringPiece name,
669 const TensorShape& shape,
670 Tensor** tensor,
671 AllocatorAttributes attr) {
672 int start, stop;
673 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
674 if (stop != start + 1) {
675 return errors::InvalidArgument("OpKernel used list-valued output name '",
676 name,
677 "' when single-valued output was "
678 "expected");
679 }
680 return allocate_output(start, shape, tensor, attr);
681 }
682
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)683 Status OpKernelContext::allocate_tensor(
684 DataType type, const TensorShape& shape, Tensor* out_tensor,
685 AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
686 Allocator* a = get_allocator(attr);
687 Tensor new_tensor(
688 a, type, shape,
689 AllocationAttributes(
690 /*retry_on_failure=*/allocation_attr.retry_on_failure,
691 /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
692
693 if (!new_tensor.IsInitialized()) {
694 return errors::ResourceExhausted(
695 "OOM when allocating tensor with shape", shape.DebugString(),
696 " and type ", DataTypeString(type), " on ", params_->device->name(),
697 " by allocator ", a->Name());
698 }
699 if (params_->log_memory) {
700 LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
701 params_->step_id, new_tensor);
702 }
703 *out_tensor = std::move(new_tensor);
704 return Status::OK();
705 }
706
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)707 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
708 Tensor** output,
709 AllocatorAttributes attr) {
710 if (index < 0) {
711 return errors::Internal("allocate_output with bad index=", index,
712 " kernel=", params_->op_kernel->name());
713 }
714 if (index >= num_outputs()) {
715 return errors::Internal("allocate_output with bad index=", index,
716 " num_outputs=", outputs_.size(),
717 " kernel=", params_->op_kernel->name());
718 }
719 const DataType type = params_->op_kernel->output_type(index);
720 if (IsRefType(type)) {
721 return errors::Internal("allocate_output with ref type. index=", index,
722 " type=", type,
723 " kernel=", params_->op_kernel->name());
724 }
725 if (mutable_output(index) != nullptr) {
726 return errors::Internal("allocate_output on same index multiple times.",
727 " index = ", index,
728 " mutable_output(index) = ", mutable_output(index),
729 " kernel=", params_->op_kernel->name());
730 }
731 if (attr.scope_id > 0) {
732 maybe_initialize_scope_id_set();
733 if (!allocated_scope_ids_->insert(attr.scope_id).second) {
734 return errors::Internal(
735 "OpKernel ", params_->op_kernel->name(),
736 " called allocate_output at index ", index, " with scope_id ",
737 attr.scope_id,
738 " more than once. Try turning off the ScopedAllocator optimizer.");
739 }
740 }
741 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
742 step_id(), "output", type, &shape);
743 auto output_tensor = MakeUnique<Tensor>();
744 Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
745 if (s.ok()) {
746 outputs_[index] = TensorValue(output_tensor.release());
747 *output = outputs_[index].tensor;
748 }
749 return s;
750 }
751
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)752 Status OpKernelContext::allocate_temp(
753 DataType type, const TensorShape& shape, Tensor* out_temp,
754 AllocatorAttributes allocator_attr,
755 const AllocationAttributes& allocation_attr) {
756 if (allocator_attr.scope_id > 0) {
757 // We do not allow ScopedAllocator calls from allocate_temp.
758 // Here we clear the scope_id and return a temporary buffer.
759 // This is because it is legal for a kernel to call allocate_temp
760 // and then set_output with the temp tensor.
761 //
762 // We achieve memory correctness by forcing an allocation in set_output and
763 // copying over the tensor from the temp buffer. Kernels which would like
764 // to avoid this performance penalty should switch to calling
765 // allocate_output.
766 VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
767 << " called allocate_temp with scope_id " << allocator_attr.scope_id
768 << ". Switch to allocate_output to avoid performance penalty.";
769 allocator_attr.scope_id = -1;
770 }
771 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
772 step_id(), "temp", type, &shape);
773 Status s =
774 allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
775 if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
776 Allocator* a = get_allocator(allocator_attr);
777 if (a->TracksAllocationSizes()) {
778 int64_t alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
779 record_temp_memory_allocation(alloc_size, *out_temp);
780 }
781 } else if (record_memory_consumption_) {
782 DCHECK(tracking_state_);
783 mutex_lock l(tracking_state_->stats_mu);
784 tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
785 }
786 return s;
787 }
788
get_input_index(StringPiece name,int * out_index) const789 Status OpKernelContext::get_input_index(StringPiece name,
790 int* out_index) const {
791 int start, stop;
792 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
793 if (stop != start + 1) {
794 return errors::InvalidArgument("OpKernel used list-valued input name '",
795 name,
796 "' when single-valued input was "
797 "expected");
798 }
799 *out_index = start;
800 return Status::OK();
801 }
802
get_output_index(StringPiece name,int * out_index) const803 Status OpKernelContext::get_output_index(StringPiece name,
804 int* out_index) const {
805 int start, stop;
806 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
807 if (stop != start + 1) {
808 return errors::InvalidArgument("OpKernel used list-valued output name '",
809 name,
810 "' when single-valued output was "
811 "expected");
812 }
813 *out_index = start;
814 return Status::OK();
815 }
816
set_output(StringPiece name,const Tensor & tensor)817 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
818 int index;
819 TF_RETURN_IF_ERROR(get_output_index(name, &index));
820 set_output(index, tensor);
821 return Status::OK();
822 }
823
set_output(StringPiece name,Tensor && tensor)824 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
825 int index;
826 TF_RETURN_IF_ERROR(get_output_index(name, &index));
827 set_output(index, std::move(tensor));
828 return Status::OK();
829 }
830
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)831 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
832 int index, const Tensor& tensor) {
833 bool allocate_and_copy = false;
834 const bool never_forward =
835 (params_->forward_from_array != nullptr &&
836 params_->forward_from_array[index] == Params::kNeverForward);
837 if (TF_PREDICT_FALSE(never_forward)) {
838 maybe_initialize_scope_id_set();
839 if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
840 allocated_scope_ids_->end()) {
841 allocate_and_copy = true;
842 } else {
843 // The output at `index` must have been previously allocated via a call to
844 // `allocate_output(index, ...)`. That call would ensure that we return
845 // the correct slice of the ScopedAllocated buffer, so we do not
846 // re-allocate and copy here.
847 LOG(WARNING)
848 << "OpKernel " << params_->op_kernel->name()
849 << " called both allocate_output and set_output with scope_id "
850 << output_alloc_attr(index).scope_id;
851 }
852 }
853
854 if (TF_PREDICT_FALSE(allocate_and_copy)) {
855 // This output was marked to not be forwarded either during graph
856 // construction or grappler passes. Force an allocation and copy input to
857 // output.
858 VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
859 << tensor.DebugString() << " never_forward " << never_forward
860 << " params_->forward_from_array[index] "
861 << params_->forward_from_array[index] << " alloc_attr.scope_id "
862 << output_alloc_attr(index).scope_id;
863 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
864 step_id(), "output",
865 tensor.dtype(), &tensor.shape());
866 auto new_tensor = MakeUnique<Tensor>();
867 Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
868 output_alloc_attr(index));
869 TF_CHECK_OK(s);
870 device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
871 op_device_context(), [](const Status&) {});
872 outputs_[index] = TensorValue(new_tensor.release());
873 }
874 return allocate_and_copy;
875 }
876
maybe_track_allocations_for_set_output(const Tensor & tensor)877 void OpKernelContext::maybe_track_allocations_for_set_output(
878 const Tensor& tensor) {
879 if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
880 DCHECK(tracking_state_);
881 mutex_lock l(tracking_state_->stats_mu);
882 const auto it = std::find_if(
883 tracking_state_->temp_tensor_buffer_and_size.begin(),
884 tracking_state_->temp_tensor_buffer_and_size.end(),
885 [&tensor](const std::pair<const void*, int64>& e) {
886 return e.first ==
887 static_cast<const void*>(tensor.tensor_data().data());
888 });
889 if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
890 tracking_state_->temp_memory_allocated -= it->second;
891 tracking_state_->temp_tensor_buffer_and_size.erase(it);
892 }
893 }
894 }
895
set_output(int index,const Tensor & tensor)896 void OpKernelContext::set_output(int index, const Tensor& tensor) {
897 CHECK_GE(index, 0);
898 CHECK_LT(index, outputs_.size());
899 const DataType type = params_->op_kernel->output_type(index);
900 CHECK(!IsRefType(type));
901 CHECK_EQ(outputs_[index].tensor, nullptr);
902 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
903 // Input can be forwarded to output; incref on `tensor` and set output at
904 // `index` to this tensor.
905 outputs_[index] = TensorValue(new Tensor(tensor));
906 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
907 }
908 }
909
set_output(int index,Tensor && tensor)910 void OpKernelContext::set_output(int index, Tensor&& tensor) {
911 CHECK_GE(index, 0);
912 CHECK_LT(index, outputs_.size());
913 const DataType type = params_->op_kernel->output_type(index);
914 CHECK(!IsRefType(type));
915 CHECK_EQ(outputs_[index].tensor, nullptr);
916 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
917 // Input can be forwarded to output; set output at `index` to this tensor.
918 outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
919 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
920 }
921 }
922
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)923 void OpKernelContext::set_output_ref(int index, mutex* mu,
924 Tensor* tensor_for_ref) {
925 CHECK_GE(index, 0);
926 CHECK_LT(index, outputs_.size());
927 CHECK(IsRefType(params_->op_kernel->output_type(index)));
928 outputs_[index] = TensorValue(mu, tensor_for_ref);
929 }
930
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)931 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
932 Tensor* tensor_for_ref) {
933 int index;
934 TF_RETURN_IF_ERROR(get_output_index(name, &index));
935 set_output_ref(index, mu, tensor_for_ref);
936 return Status::OK();
937 }
938
mutable_output(StringPiece name,Tensor ** tensor)939 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
940 int index;
941 TF_RETURN_IF_ERROR(get_output_index(name, &index));
942 *tensor = mutable_output(index);
943 return Status::OK();
944 }
945
ValidateInputsAreSameShape(OpKernel * op)946 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
947 const auto& inputs = *params_->inputs;
948 for (size_t i = 1; i < inputs.size(); ++i) {
949 if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
950 SetStatus(errors::InvalidArgument(
951 "Inputs to operation ", op->name(), " of type ", op->type_string(),
952 " must have the same size and shape. Input 0: ",
953 inputs[0]->shape().DebugString(), " != input ", i, ": ",
954 inputs[i]->shape().DebugString()));
955 return false;
956 }
957 }
958 return true;
959 }
960
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)961 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
962 const DataTypeSlice expected_outputs) {
963 DataTypeVector inputs;
964 for (const TensorValue& t : *params_->inputs) {
965 inputs.push_back(t.dtype());
966 }
967 DataTypeVector outputs = params_->op_kernel->output_types();
968 return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
969 outputs);
970 }
971
record_temp_memory_allocation(int64_t size,const Tensor & t)972 void OpKernelContext::record_temp_memory_allocation(int64_t size,
973 const Tensor& t) {
974 if (tracking_state_) {
975 mutex_lock l(tracking_state_->stats_mu);
976 tracking_state_->temp_memory_allocated += size;
977 tracking_state_->temp_tensor_buffer_and_size.emplace_back(
978 static_cast<const void*>(t.tensor_data().data()), size);
979 }
980 }
981
temp_memory_allocated() const982 int64 OpKernelContext::temp_memory_allocated() const {
983 if (tracking_state_) {
984 mutex_lock l(tracking_state_->stats_mu);
985 return tracking_state_->temp_memory_allocated;
986 } else {
987 return 0;
988 }
989 }
990
record_persistent_memory_allocation(int64_t size,int64_t alloc_id)991 void OpKernelContext::record_persistent_memory_allocation(int64_t size,
992 int64_t alloc_id) {
993 if (tracking_state_) {
994 mutex_lock l(tracking_state_->stats_mu);
995 tracking_state_->persistent_memory_allocated += size;
996 if (alloc_id >= 0) {
997 tracking_state_->persistent_alloc_ids.push_back(alloc_id);
998 }
999 }
1000 }
1001
persistent_memory_allocated() const1002 int64 OpKernelContext::persistent_memory_allocated() const {
1003 if (tracking_state_) {
1004 mutex_lock l(tracking_state_->stats_mu);
1005 return tracking_state_->persistent_memory_allocated;
1006 } else {
1007 return 0;
1008 }
1009 }
1010
persistent_alloc_ids() const1011 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
1012 if (tracking_state_) {
1013 mutex_lock l(tracking_state_->stats_mu);
1014 return std::vector<int64>(tracking_state_->persistent_alloc_ids.begin(),
1015 tracking_state_->persistent_alloc_ids.end());
1016 } else {
1017 return std::vector<int64>();
1018 }
1019 }
1020
clear_recorded_memory()1021 void OpKernelContext::clear_recorded_memory() {
1022 if (tracking_state_) {
1023 mutex_lock l(tracking_state_->stats_mu);
1024 tracking_state_->temp_memory_allocated = 0;
1025 tracking_state_->persistent_memory_allocated = 0;
1026 tracking_state_->temp_tensor_buffer_and_size.clear();
1027 tracking_state_->persistent_alloc_ids.clear();
1028 }
1029 }
1030
set_record_memory_consumption(bool v)1031 void OpKernelContext::set_record_memory_consumption(bool v) {
1032 record_memory_consumption_ = v;
1033 if (v && !tracking_state_) {
1034 tracking_state_ = absl::make_unique<TrackingState>();
1035 }
1036 }
1037
executor_type() const1038 const string& OpKernelContext::executor_type() const {
1039 if (params_->executor_type) {
1040 return *params_->executor_type;
1041 } else {
1042 static const string& kEmptyString = *new string("");
1043 return kEmptyString;
1044 }
1045 }
1046
1047 // OpKernel registration ------------------------------------------------------
1048
1049 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1050 KernelRegistration(const KernelDef& d, StringPiece c,
1051 std::unique_ptr<kernel_factory::OpKernelFactory> f)
1052 : def(d), kernel_class_name(c), factory(std::move(f)) {}
1053
1054 const KernelDef def;
1055 const string kernel_class_name;
1056 std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1057 };
1058
1059 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1060 // factory functions for instantiating the OpKernel that matches the
1061 // KernelDef.
1062 struct KernelRegistry {
1063 mutex mu;
1064 std::unordered_multimap<string, KernelRegistration> registry
1065 TF_GUARDED_BY(mu);
1066 };
1067
1068 #if defined(_WIN32)
1069 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1070 #elif defined(__APPLE__)
1071 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1072 #else
1073 static const char kKernelLibPattern[] = "libtfkernel*.so";
1074 #endif
1075
1076 #define FEATURE(x) \
1077 { x, #x }
1078
1079 // Returns Status::OK if the dynamic library at the given path is safe to
1080 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1081 static Status IsProbablySafeToLoad(const string& path) {
1082 // A map of platform string to required CPU feature.
1083 using port::CPUFeature;
1084 static const auto* feature_map =
1085 new std::map<string, std::pair<CPUFeature, string>>{
1086 {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1087 };
1088
1089 std::vector<std::string> platform_strings;
1090 int result = GetPlatformStrings(path, &platform_strings);
1091 if (result) {
1092 return Status(error::Code::UNKNOWN, strerror(result));
1093 }
1094 if (platform_strings.empty()) {
1095 return Status(error::Code::FAILED_PRECONDITION,
1096 "Didn't find any platform strings");
1097 }
1098 std::vector<std::string> missing_features;
1099 for (const auto& platform_string : platform_strings) {
1100 const auto& entry = feature_map->find(platform_string);
1101 if (entry != feature_map->end() &&
1102 !port::TestCPUFeature(entry->second.first)) {
1103 missing_features.emplace_back(entry->second.second);
1104 }
1105 }
1106 if (!missing_features.empty()) {
1107 string errmsg = "Missing CPU features: ";
1108 errmsg.append(absl::StrJoin(missing_features, ", "));
1109 return Status(errors::Code::FAILED_PRECONDITION, errmsg);
1110 }
1111 return Status::OK();
1112 }
1113
LoadDynamicKernelsInternal()1114 void LoadDynamicKernelsInternal() {
1115 Env* env = Env::Default();
1116
1117 // Override to allow loading unsafe packages for development.
1118 // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1119 char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1120 bool override_abi_check = false;
1121 if (_abi_check_env_var != nullptr) {
1122 override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1123 }
1124
1125 string bazel_kernel_dir =
1126 io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1127 std::vector<string> files;
1128 Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1129 if (s_kernel_dir.ok()) {
1130 string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1131 for (const auto& file : files) {
1132 string fullpath = io::JoinPath(bazel_kernel_dir, file);
1133 if (env->MatchPath(fullpath, dll_spec)) {
1134 Status s = IsProbablySafeToLoad(fullpath);
1135 if (!s.ok() && override_abi_check) {
1136 LOG(WARNING) << "Loading UNSAFE library " << fullpath
1137 << " because ABI check override is set: "
1138 << s.error_message();
1139 }
1140 if (s.ok() || override_abi_check) {
1141 // TODO(gunan): Store the handles to the opened files.
1142 void* unused_filehandle;
1143 TF_CHECK_OK(
1144 env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1145 } else {
1146 LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1147 << s.error_message();
1148 }
1149 }
1150 }
1151 }
1152 }
1153
1154 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1155 void LoadDynamicKernels() {
1156 // TODO(gunan): As more features are available, add intelligent kernel
1157 // selection, and dropping unsuitable kernel logic here.
1158 static absl::once_flag dll_loader_flag;
1159 absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1160 }
1161
GlobalKernelRegistry()1162 void* GlobalKernelRegistry() {
1163 static KernelRegistry* global_kernel_registry = []() {
1164 KernelRegistry* registry = new KernelRegistry;
1165 OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1166 return registry;
1167 }();
1168 return global_kernel_registry;
1169 }
1170
GlobalKernelRegistryTyped()1171 static KernelRegistry* GlobalKernelRegistryTyped() {
1172 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1173 LoadDynamicKernels();
1174 #endif // AUTOLOAD_DYNAMIC_KERNELS
1175 return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1176 }
1177
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1178 static string Key(StringPiece op_type, const DeviceType& device_type,
1179 StringPiece label) {
1180 return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1181 label);
1182 }
1183
1184 namespace kernel_factory {
1185
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1186 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1187 StringPiece kernel_class_name,
1188 std::unique_ptr<OpKernelFactory> factory) {
1189 const string key =
1190 Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1191 kernel_def->label());
1192
1193 // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1194 // here.
1195 // InitInternal gets called by static initializers, so it ends up executing
1196 // before main. This causes LoadKernelLibraries function to get called
1197 // before some file libraries can initialize, which in turn crashes the
1198 // program flakily. Until we get rid of static initializers in kernel
1199 // registration mechanism, we have this workaround here.
1200 auto global_registry =
1201 reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1202 mutex_lock l(global_registry->mu);
1203 global_registry->registry.emplace(
1204 key,
1205 KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1206 delete kernel_def;
1207 }
1208
Create(OpKernelConstruction * context)1209 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1210 OpKernelConstruction* context) {
1211 return (*create_func_)(context);
1212 }
1213
1214 } // namespace kernel_factory
1215
1216 namespace {
1217
1218 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1219 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1220 static const string& kKernelAttr = *new string("_kernel");
1221 static const string& kEmptyString = *new string("");
1222
1223 // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1224 // to use the `AttrSlice::FindByString()` overload, which does a more
1225 // efficient map lookup (instead of a linear scan) when the attribute name is
1226 // already a `const string&`.
1227 const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1228 if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1229 return kEmptyString;
1230 else
1231 return attr_value->s();
1232 }
1233
1234 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1235 Status FindKernelRegistration(
1236 const DeviceType& device_type, StringPiece node_name,
1237 bool has_experimental_debug_info,
1238 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1239 StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1240 bool* was_attr_mismatch) {
1241 *reg = nullptr;
1242 *was_attr_mismatch = false;
1243
1244 const string& label = GetKernelLabelAttr(node_attrs);
1245
1246 const string key = Key(node_op, device_type, label);
1247 auto typed_registry = GlobalKernelRegistryTyped();
1248 tf_shared_lock lock(typed_registry->mu);
1249 auto regs = typed_registry->registry.equal_range(key);
1250 for (auto iter = regs.first; iter != regs.second; ++iter) {
1251 // If there is a kernel registered for the op and device_type,
1252 // check that the attrs match.
1253 bool match;
1254 TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1255 if (match) {
1256 if (*reg != nullptr) {
1257 if ((*reg)->def.priority() == iter->second.def.priority()) {
1258 return errors::InvalidArgument(
1259 "Multiple OpKernel registrations match NodeDef at the same "
1260 "priority '",
1261 FormatNodeDefForError(node_name, has_experimental_debug_info,
1262 experimental_debug_info),
1263 "': '", (*reg)->def.ShortDebugString(), "' and '",
1264 iter->second.def.ShortDebugString(), "'");
1265 } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1266 continue;
1267 }
1268 // iter->second's priority is higher than *reg.
1269 }
1270 *reg = &iter->second;
1271 } else {
1272 *was_attr_mismatch = true;
1273 }
1274 }
1275 // Check if no device specific registrations found. If not, try finding a
1276 // default kernel.
1277 if (*reg == nullptr &&
1278 !IsSymbolicExecutionDevice(device_type.type_string())) {
1279 const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1280 auto regs = typed_registry->registry.equal_range(default_key);
1281 for (auto iter = regs.first; iter != regs.second; ++iter) {
1282 // If there is a kernel registered for the op and device_type,
1283 // check that the attrs match.
1284 bool match;
1285 TF_RETURN_IF_ERROR(
1286 KernelAttrsMatch(iter->second.def, node_attrs, &match));
1287 if (match) {
1288 if (*reg != nullptr) {
1289 return errors::InvalidArgument(
1290 "Multiple Default OpKernel registrations match NodeDef '",
1291 FormatNodeDefForError(node_name, has_experimental_debug_info,
1292 experimental_debug_info),
1293 "': '", (*reg)->def.ShortDebugString(), "' and '",
1294 iter->second.def.ShortDebugString(), "'");
1295 }
1296 *reg = &iter->second;
1297 } else {
1298 *was_attr_mismatch = true;
1299 }
1300 }
1301
1302 if (*reg != nullptr) {
1303 VLOG(1) << "No device-specific kernels found for NodeDef '"
1304 << FormatNodeDefForError(node_name, has_experimental_debug_info,
1305 experimental_debug_info)
1306 << "'"
1307 << "Will fall back to a default kernel." << std::endl;
1308 }
1309 }
1310
1311 return Status::OK();
1312 }
1313
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1314 Status FindKernelRegistration(const DeviceType& device_type,
1315 const NodeDef& node_def,
1316 const KernelRegistration** reg,
1317 bool* was_attr_mismatch) {
1318 return FindKernelRegistration(
1319 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1320 node_def.experimental_debug_info(), node_def.op(),
1321 AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1322 }
1323
1324 } // namespace
1325
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1326 bool KernelDefAvailable(const DeviceType& device_type,
1327 const NodeDef& node_def) {
1328 const KernelRegistration* reg = nullptr;
1329 bool was_attr_mismatch;
1330 Status result =
1331 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch);
1332 return result.ok() && reg != nullptr;
1333 }
1334
1335 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1336 Status FindKernelDef(
1337 const DeviceType& device_type, StringPiece node_name,
1338 bool has_experimental_debug_info,
1339 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1340 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1341 const KernelDef** def, string* kernel_class_name) {
1342 const KernelRegistration* reg = nullptr;
1343 bool was_attr_mismatch;
1344 TF_RETURN_IF_ERROR(FindKernelRegistration(
1345 device_type, node_name, has_experimental_debug_info,
1346 experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch));
1347 if (reg == nullptr) {
1348 const std::string device_str = DeviceTypeString(device_type);
1349 Status s = errors::NotFound(
1350 "No registered '", node_op, "' OpKernel for ", device_str,
1351 " devices compatible with node ",
1352 FormatNodeDefForError(node_name, has_experimental_debug_info,
1353 experimental_debug_info));
1354 if (was_attr_mismatch) {
1355 errors::AppendToMessage(
1356 &s, " (OpKernel was found, but attributes didn't match) ",
1357 "Requested Attributes: ",
1358 SummarizeAttrsHelper(node_attrs, node_device));
1359 }
1360
1361 // Do not print kernel registrations for other devices when using _JIT
1362 // devices for compilation.
1363 if (!absl::StrContains(device_str, "JIT")) {
1364 errors::AppendToMessage(
1365 &s, ". Registered:", KernelsRegisteredForOp(node_op));
1366 }
1367
1368 return s;
1369 }
1370 if (def != nullptr) *def = ®->def;
1371 if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1372 return Status::OK();
1373 }
1374
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1375 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1376 const KernelDef** def, string* kernel_class_name) {
1377 return FindKernelDef(
1378 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1379 node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1380 AttrSlice(&node_def.attr()), def, kernel_class_name);
1381 }
1382
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1383 Status SupportedDeviceTypesForNode(
1384 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1385 PrioritizedDeviceTypeVector* prioritized_device_types,
1386 const DeviceNameUtils::ParsedName* local_address_spec) {
1387 // TODO(zhifengc): Changes the callers (SimplePlacer and
1388 // DynamicPlacer) to consider the possibility that 'def' is call to
1389 // a user-defined function and only calls this
1390 // SupportedDeviceTypesForNode for primitive ops.
1391 const OpRegistrationData* op_reg_data;
1392 const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1393 if (s.ok()) {
1394 bool exists_attr_mismatch = false;
1395 for (const DeviceType& device_type : prioritized_types) {
1396 const KernelRegistration* reg = nullptr;
1397 bool was_attr_mismatch = false;
1398 TF_RETURN_IF_ERROR(
1399 FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
1400 exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1401 if (reg != nullptr) {
1402 int32_t priority = reg->def.priority();
1403 prioritized_device_types->emplace_back(device_type, priority);
1404 }
1405 }
1406 // Add extra supported device types if the following conditions are
1407 // satisfied:
1408 // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1409 // 2) A device is requested for this node which specifies job/replica/task
1410 // 3) A local device is provided which specifies job/replica/task
1411 // 4) The local device does not have the same (job, replica, task) as the
1412 // requested device
1413 //
1414 // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1415 // whose kernel is known to a remote process but not to the current process.
1416 if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1417 local_address_spec != nullptr) {
1418 DeviceNameUtils::ParsedName requested_device_name;
1419 DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1420 if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1421 requested_device_name)) {
1422 if (requested_device_name.has_type) {
1423 prioritized_device_types->push_back(
1424 std::make_pair(DeviceType(requested_device_name.type), 0));
1425 } else {
1426 for (const DeviceType& device_type : prioritized_types) {
1427 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1428 }
1429 }
1430 }
1431 }
1432
1433 // If we were unable to find any valid devices let's validate if the node is
1434 // even valid.
1435 if (prioritized_device_types->empty()) {
1436 TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1437 }
1438
1439 std::stable_sort(prioritized_device_types->begin(),
1440 prioritized_device_types->end(),
1441 [](const std::pair<DeviceType, int32>& a,
1442 const std::pair<DeviceType, int32>& b) {
1443 return a.second > b.second;
1444 });
1445 } else {
1446 // Assumes that all device types support this node.
1447 for (const DeviceType& device_type : prioritized_types) {
1448 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1449 }
1450 }
1451 return Status::OK();
1452 }
1453
LogAllRegisteredKernels()1454 void LogAllRegisteredKernels() {
1455 KernelList kernel_list = GetAllRegisteredKernels();
1456 for (const auto& kernel_def : kernel_list.kernel()) {
1457 LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1458 }
1459 }
1460
GetAllRegisteredKernels()1461 KernelList GetAllRegisteredKernels() {
1462 return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1463 }
1464
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1465 KernelList GetFilteredRegisteredKernels(
1466 const std::function<bool(const KernelDef&)>& predicate) {
1467 KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1468 KernelList kernel_list;
1469 tf_shared_lock lock(typed_registry->mu);
1470 kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1471 for (const auto& p : typed_registry->registry) {
1472 const KernelDef& kernel_def = p.second.def;
1473 if (predicate(kernel_def)) {
1474 *kernel_list.add_kernel() = kernel_def;
1475 }
1476 }
1477 return kernel_list;
1478 }
1479
GetRegisteredKernelsForOp(StringPiece op_name)1480 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1481 auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1482 return GetFilteredRegisteredKernels(op_pred);
1483 }
1484
KernelsRegisteredForOp(StringPiece op_name)1485 string KernelsRegisteredForOp(StringPiece op_name) {
1486 KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1487 if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
1488 string ret;
1489 for (const auto& kernel_def : kernel_list.kernel()) {
1490 strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
1491 if (!kernel_def.label().empty()) {
1492 strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1493 }
1494 for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1495 strings::StrAppend(
1496 &ret, "; ", kernel_def.constraint(i).name(), " in ",
1497 SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1498 }
1499 strings::StrAppend(&ret, "\n");
1500 }
1501 return ret;
1502 }
1503
1504 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1505 * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1506 std::unique_ptr<OpKernel> CreateOpKernel(
1507 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1508 const NodeDef& node_def, int graph_def_version, Status* status) {
1509 // Look up the Op registered for this op name.
1510 std::shared_ptr<const NodeProperties> props;
1511 status->Update(NodeProperties::CreateFromNodeDef(
1512 node_def, OpRegistry::Global(), &props));
1513 if (!status->ok()) {
1514 errors::AppendToMessage(status,
1515 " for node: ", FormatNodeDefForError(node_def));
1516 return nullptr;
1517 }
1518 return CreateOpKernel(device_type, device, allocator, props,
1519 graph_def_version, status);
1520 }
1521
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1522 std::unique_ptr<OpKernel> CreateOpKernel(
1523 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1524 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1525 Status* status) {
1526 OpKernel* kernel = nullptr;
1527 *status = CreateOpKernel(std::move(device_type), device, allocator,
1528 /*flib=*/nullptr, props, graph_def_version, &kernel);
1529 return std::unique_ptr<OpKernel>(kernel);
1530 }
1531
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1532 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1533 Allocator* allocator, FunctionLibraryRuntime* flib,
1534 const std::shared_ptr<const NodeProperties>& props,
1535 int graph_def_version, OpKernel** kernel) {
1536 return CreateOpKernel(std::move(device_type), device, allocator, flib,
1537 /* resource_mgr= */ nullptr, props, graph_def_version,
1538 kernel);
1539 }
1540
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1541 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1542 Allocator* allocator, FunctionLibraryRuntime* flib,
1543 ResourceMgr* resource_mgr,
1544 const std::shared_ptr<const NodeProperties>& props,
1545 int graph_def_version, OpKernel** kernel) {
1546 const NodeDef& node_def = props->node_def;
1547 bool was_attr_mismatch;
1548 const KernelRegistration* registration = nullptr;
1549 Status s;
1550 if (props != nullptr) {
1551 VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1552
1553 // Validate node_def against OpDef.
1554 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1555
1556 // Look up kernel registration.
1557 s = FindKernelRegistration(device_type, node_def, ®istration,
1558 &was_attr_mismatch);
1559 if (!s.ok()) {
1560 errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1561 return s;
1562 }
1563 }
1564 if (registration == nullptr) {
1565 s.Update(errors::NotFound("No registered '", node_def.op(),
1566 "' OpKernel for '", DeviceTypeString(device_type),
1567 "' devices compatible with node ",
1568 FormatNodeDefForError(node_def)));
1569 if (was_attr_mismatch) {
1570 errors::AppendToMessage(
1571 &s, " (OpKernel was found, but attributes didn't match) ",
1572 "Requested Attributes: ", SummarizeAttrs(node_def));
1573 }
1574 errors::AppendToMessage(
1575 &s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
1576 return s;
1577 }
1578
1579 // We are creating a kernel for an op registered in
1580 // OpRegistry::Global(), we consult the kernel registry to decide
1581 // the kernel's input and output memory types.
1582 MemoryTypeVector input_memory_types;
1583 MemoryTypeVector output_memory_types;
1584 TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1585 node_def, &input_memory_types,
1586 &output_memory_types));
1587
1588 // Everything needed for OpKernel construction.
1589 OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1590 resource_mgr, props, input_memory_types,
1591 output_memory_types, graph_def_version, &s);
1592 *kernel = registration->factory->Create(&context);
1593 if (!s.ok()) {
1594 delete *kernel;
1595 *kernel = nullptr;
1596 }
1597 return s;
1598 }
1599
1600 namespace {
1601
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1602 bool FindArgInOp(StringPiece arg_name,
1603 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1604 for (const auto& arg : args) {
1605 if (arg_name == arg.name()) {
1606 return true;
1607 }
1608 }
1609 return false;
1610 }
1611
1612 } // namespace
1613
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1614 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1615 auto typed_registry = GlobalKernelRegistryTyped();
1616 tf_shared_lock lock(typed_registry->mu);
1617 for (const auto& key_registration : typed_registry->registry) {
1618 const KernelDef& kernel_def(key_registration.second.def);
1619 const OpRegistrationData* op_reg_data;
1620 const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1621 if (!status.ok()) {
1622 // TODO(josh11b): Make this a hard error.
1623 LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1624 << "') for unknown op: " << kernel_def.op();
1625 continue;
1626 }
1627 const OpDef& op_def = op_reg_data->op_def;
1628 for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1629 if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1630 !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1631 return errors::InvalidArgument(
1632 "HostMemory arg '", host_memory_arg,
1633 "' not found in OpDef: ", SummarizeOpDef(op_def));
1634 }
1635 }
1636 }
1637 return Status::OK();
1638 }
1639
1640 template <>
eigen_device() const1641 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1642 return eigen_cpu_device();
1643 }
1644
1645 template <>
eigen_device() const1646 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1647 return eigen_gpu_device();
1648 }
1649
CtxFailure(const Status & s)1650 void OpKernelConstruction::CtxFailure(const Status& s) {
1651 VLOG(1) << s;
1652 SetStatus(s);
1653 }
1654
CtxFailureWithWarning(const Status & s)1655 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1656 LOG(WARNING) << s;
1657 SetStatus(s);
1658 }
1659
CtxFailure(const char * file,int line,const Status & s)1660 void OpKernelConstruction::CtxFailure(const char* file, int line,
1661 const Status& s) {
1662 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1663 << " : " << s;
1664 SetStatus(s);
1665 }
1666
CtxFailureWithWarning(const char * file,int line,const Status & s)1667 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1668 const Status& s) {
1669 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1670 << " : " << s;
1671 SetStatus(s);
1672 }
1673
CtxFailure(const Status & s)1674 void OpKernelContext::CtxFailure(const Status& s) {
1675 VLOG(1) << s;
1676 SetStatus(s);
1677 }
1678
CtxFailureWithWarning(const Status & s)1679 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1680 LOG(WARNING) << s;
1681 SetStatus(s);
1682 }
1683
CtxFailure(const char * file,int line,const Status & s)1684 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1685 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1686 << " : " << s;
1687 SetStatus(s);
1688 }
1689
CtxFailureWithWarning(const char * file,int line,const Status & s)1690 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1691 const Status& s) {
1692 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1693 << " : " << s;
1694 SetStatus(s);
1695 }
1696
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1697 void CheckNotInComputeAsync(OpKernelContext* ctx,
1698 const char* correct_macro_name) {
1699 CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1700 << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1701 }
1702
1703 } // namespace tensorflow
1704