1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex> // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/base/call_once.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/match.h"
29 #include "tensorflow/core/framework/allocation_description.pb.h"
30 #include "tensorflow/core/framework/attr_value.pb.h"
31 #include "tensorflow/core/framework/attr_value_util.h"
32 #include "tensorflow/core/framework/device_attributes.pb.h"
33 #include "tensorflow/core/framework/device_factory.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/kernel_def.pb.h"
36 #include "tensorflow/core/framework/kernel_def_util.h"
37 #include "tensorflow/core/framework/log_memory.h"
38 #include "tensorflow/core/framework/memory_types.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/node_properties.h"
42 #include "tensorflow/core/framework/op_def_util.h"
43 #include "tensorflow/core/framework/tensor_reference.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/notification.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/io/path.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/platform/cpu_info.h"
53 #include "tensorflow/core/platform/env.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/mutex.h"
56 #include "tensorflow/core/platform/platform_strings.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/util/ptr_util.h"
61
62 namespace tensorflow {
63
64 const char* kJitKernelLabel = "JITCompiledKernel";
65 const char* kDisableJitKernelsEnvVar = "TF_DISABLE_JIT_KERNELS";
66
67 namespace {
68
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)69 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
70 const DataTypeSlice expected_outputs,
71 const DataTypeSlice inputs,
72 const DataTypeSlice outputs) {
73 bool signature_mismatch = false;
74
75 if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
76 for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
77 if (!TypesCompatible(expected_inputs[i], inputs[i])) {
78 signature_mismatch = true;
79 }
80 }
81
82 if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
83 for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
84 if (!TypesCompatible(expected_outputs[i], outputs[i])) {
85 signature_mismatch = true;
86 }
87 }
88
89 if (signature_mismatch) {
90 return errors::InvalidArgument(
91 "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
92 DataTypeSliceString(outputs),
93 " expected: ", DataTypeSliceString(expected_inputs), "->",
94 DataTypeSliceString(expected_outputs));
95 }
96 return OkStatus();
97 }
98
GetOpNodeDefsToLogFromEnv()99 const absl::flat_hash_set<std::string>* GetOpNodeDefsToLogFromEnv() {
100 auto* result = new absl::flat_hash_set<std::string>;
101 const char* env = getenv("TF_DEBUG_OPS_TO_LOG_NODEDEFS");
102 if (!env) {
103 return result;
104 }
105
106 std::vector<absl::string_view> ops = absl::StrSplit(env, ',');
107 LOG(INFO) << "Will log NodeDefs from the following ops: ";
108 for (absl::string_view op : ops) {
109 result->insert(std::string(op));
110 LOG(INFO) << " |" << op << "|";
111 }
112
113 return result;
114 }
115
116 // Returns true if the NodeDef for the OpKernel should be logged. The
117 // envionrmental variable TF_DEBUG_OPS_TO_LOG_NODEDEFS can be set to a
118 // comma-separated list of op types. The NodeDef for each is printed, which is
119 // useful for debugging purposes.
ShouldLogNodeDef(OpKernel * op_kernel)120 bool ShouldLogNodeDef(OpKernel* op_kernel) {
121 static const absl::flat_hash_set<std::string>& ops_to_log_nodedefs =
122 *GetOpNodeDefsToLogFromEnv();
123 return ops_to_log_nodedefs.count(op_kernel->type_string());
124 }
125
126 } // namespace
127
128 // OpKernel ------------------------------------------------------------------
129
OpKernel(OpKernelConstruction * context)130 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
131
OpKernel(OpKernelConstruction * context,bool is_deferred)132 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
133 : props_(context->props_),
134 input_memory_types_(context->input_memory_types().begin(),
135 context->input_memory_types().end()),
136 output_memory_types_(context->output_memory_types().begin(),
137 context->output_memory_types().end()),
138 input_name_map_(context->num_inputs()),
139 output_name_map_(context->num_outputs()),
140 name_view_(props_->node_def.name()),
141 type_string_view_(props_->node_def.op()),
142 graph_def_version_(context->graph_def_version()),
143 is_deferred_(is_deferred) {
144 OP_REQUIRES_OK(context,
145 NameRangesForNode(props_->node_def, *props_->op_def,
146 &input_name_map_, &output_name_map_));
147 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
148 context->graph_def_version()));
149
150 // Kernels executing on GPU tie very few resources on the CPU where the
151 // scheduler runs: we consider them as inexpensive.
152 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
153 !DeviceFactory::IsPluggableDevice(
154 DeviceTypeString(context->device_type()));
155
156 if (ShouldLogNodeDef(this)) {
157 LOG(INFO) << "NodeDef for " << name() << ":\n" << def().ShortDebugString();
158 }
159 }
160
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)161 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
162 bool is_deferred)
163 : props_(std::make_shared<const NodeProperties>(
164 context->props_->op_def, std::move(custom_def),
165 context->props_->input_types, context->props_->output_types)),
166 input_memory_types_(context->input_memory_types().begin(),
167 context->input_memory_types().end()),
168 output_memory_types_(context->output_memory_types().begin(),
169 context->output_memory_types().end()),
170 input_name_map_(context->num_inputs()),
171 output_name_map_(context->num_outputs()),
172 name_view_(props_->node_def.name()),
173 type_string_view_(props_->node_def.op()),
174 graph_def_version_(context->graph_def_version()),
175 is_deferred_(is_deferred) {
176 OP_REQUIRES_OK(context,
177 NameRangesForNode(props_->node_def, *props_->op_def,
178 &input_name_map_, &output_name_map_));
179 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
180 context->graph_def_version()));
181
182 // Kernels executing on GPU tie very few resources on the CPU where the
183 // scheduler runs: we consider them as inexpensive.
184 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
185 !DeviceFactory::IsPluggableDevice(
186 DeviceTypeString(context->device_type()));
187 }
188
~OpKernel()189 OpKernel::~OpKernel() {}
190
InputRange(StringPiece input_name,int * start,int * stop) const191 Status OpKernel::InputRange(StringPiece input_name, int* start,
192 int* stop) const {
193 const auto result = input_name_map_.find(input_name);
194 if (result == input_name_map_.end()) {
195 return errors::InvalidArgument("Unknown input name: ", input_name);
196 } else {
197 *start = result->second.first;
198 *stop = result->second.second;
199 return OkStatus();
200 }
201 }
202
OutputRange(StringPiece output_name,int * start,int * stop) const203 Status OpKernel::OutputRange(StringPiece output_name, int* start,
204 int* stop) const {
205 const auto result = output_name_map_.find(output_name);
206 if (result == output_name_map_.end()) {
207 return errors::InvalidArgument("Unknown output name: ", output_name);
208 } else {
209 *start = result->second.first;
210 *stop = result->second.second;
211 return OkStatus();
212 }
213 }
214
ShapeTraceString(const OpKernelContext & ctx) const215 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
216 int num_inputs = ctx.num_inputs();
217 if (num_inputs == 0) return "";
218 std::vector<string> tensor_shapes;
219 tensor_shapes.reserve(num_inputs);
220 for (int i = 0; i < num_inputs; i++) {
221 if (!ctx.has_input(i)) {
222 tensor_shapes.emplace_back(); // Placeholder
223 continue;
224 }
225 DataType input_dtype = ctx.input_dtype(i);
226 if (input_dtype == DataType::DT_RESOURCE ||
227 input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
228 tensor_shapes.emplace_back(); // Placeholder
229 continue;
230 }
231 tensor_shapes.emplace_back(strings::StrCat(
232 DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
233 }
234 return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
235 }
236
TraceString(const OpKernelContext & ctx,bool verbose) const237 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
238 string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
239 if (verbose) {
240 string shape = ShapeTraceString(ctx);
241 if (!shape.empty()) {
242 trace_string =
243 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
244 }
245 }
246 return trace_string;
247 }
248
Compute(OpKernelContext * context)249 void AsyncOpKernel::Compute(OpKernelContext* context) {
250 Notification n;
251 ComputeAsync(context, [&n]() { n.Notify(); });
252 n.WaitForNotification();
253 }
254
255 // OpKernelConstruction ------------------------------------------------------
256
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)257 OpKernelConstruction::OpKernelConstruction(
258 DeviceType device_type, DeviceBase* device, Allocator* allocator,
259 FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
260 const std::shared_ptr<const NodeProperties>& props,
261 const MemoryTypeSlice& input_memory_types,
262 const MemoryTypeSlice& output_memory_types, int graph_def_version,
263 Status* status)
264 : device_type_(std::move(device_type)),
265 device_(device),
266 allocator_(allocator),
267 flib_(flib),
268 resource_mgr_(resource_mgr),
269 props_(props),
270 input_memory_types_(input_memory_types),
271 output_memory_types_(output_memory_types),
272 graph_def_version_(graph_def_version),
273 status_(status) {}
274
HasAttr(StringPiece attr_name) const275 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
276 return HasNodeAttr(def(), attr_name);
277 }
278
SetStatus(const Status & status)279 void OpKernelConstruction::SetStatus(const Status& status) {
280 status_->Update(status);
281 }
282
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)283 Status OpKernelConstruction::MatchSignature(
284 const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
285 return MatchSignatureHelper(expected_inputs, expected_outputs,
286 props_->input_types, props_->output_types);
287 }
288
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)289 Status OpKernelConstruction::allocate_temp(DataType type,
290 const TensorShape& shape,
291 Tensor* out_temp) {
292 AllocationAttributes attr;
293 attr.allocation_will_be_logged = true;
294 Tensor new_temp(allocator_, type, shape, attr);
295
296 if (!new_temp.IsInitialized()) {
297 return errors::ResourceExhausted(
298 "OOM when allocating temporary tensor with shape", shape.DebugString());
299 }
300 if (LogMemory::IsEnabled()) {
301 LogMemory::RecordTensorAllocation(
302 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
303 }
304 *out_temp = new_temp;
305 return OkStatus();
306 }
307
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)308 Status OpKernelConstruction::allocate_temp(DataType type,
309 const TensorShape& shape,
310 Tensor* out_temp,
311 AllocatorAttributes allocator_attr) {
312 if (allocator_attr.scope_id != 0) {
313 return errors::InvalidArgument(
314 "ScopedAllocator cannot be used via OpKernelConstruction.");
315 }
316 Allocator* a = device_->GetAllocator(allocator_attr);
317 AllocationAttributes attr;
318 attr.allocation_will_be_logged = true;
319 Tensor new_temp(a, type, shape, attr);
320
321 if (!new_temp.IsInitialized()) {
322 return errors::ResourceExhausted(
323 "OOM when allocating temporary tensor with shape", shape.DebugString());
324 }
325 if (LogMemory::IsEnabled()) {
326 LogMemory::RecordTensorAllocation(
327 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
328 }
329 *out_temp = new_temp;
330 return OkStatus();
331 }
332
333 // OpKernelContext -----------------------------------------------------------
334
335 const int OpKernelContext::Params::kNeverForward;
336 const int OpKernelContext::Params::kNoReservation;
337
OpKernelContext(Params * params)338 OpKernelContext::OpKernelContext(Params* params)
339 : OpKernelContext(
340 params, static_cast<int>(params->op_kernel->output_types().size())) {}
341
OpKernelContext(Params * params,int num_outputs)342 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
343 : params_(params), outputs_(num_outputs) {
344 if (params_->track_allocations) {
345 tracking_state_ = absl::make_unique<TrackingState>();
346 }
347
348 params_->ensure_eigen_gpu_device();
349 if (params_->eigen_gpu_device != nullptr) {
350 Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
351 Status s = params_->device->ReinitializeGpuDevice(
352 this, params_->eigen_gpu_device, params_->op_device_context,
353 eigen_gpu_allocator);
354 if (!s.ok()) {
355 SetStatus(s);
356 }
357 }
358 }
359
~OpKernelContext()360 OpKernelContext::~OpKernelContext() {
361 for (TensorValue& value : outputs_) {
362 if (!value.is_ref()) {
363 delete value.tensor;
364 }
365 }
366 if (params_->track_allocations &&
367 !tracking_state_->wrapped_allocators.empty()) {
368 LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
369 << "being consumed by the StepStatsCollector.";
370 for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
371 wrapped_allocator.second->GetRecordsAndUnRef();
372 }
373 }
374 }
375
get_allocator(AllocatorAttributes attr)376 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
377 Allocator* allocator = nullptr;
378 if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
379 allocator = params_->device->GetScopedAllocator(attr, step_id());
380 CHECK(allocator);
381 } else {
382 allocator = params_->device->GetAllocator(attr);
383 }
384 if (TF_PREDICT_FALSE(track_allocations())) {
385 DCHECK(tracking_state_);
386 mutex_lock lock(tracking_state_->mu);
387 for (const auto& wrapped : tracking_state_->wrapped_allocators) {
388 if (wrapped.first == allocator) {
389 return wrapped.second;
390 }
391 }
392 TrackingAllocator* wrapped_allocator =
393 new TrackingAllocator(allocator, params_->track_allocations);
394 tracking_state_->wrapped_allocators.push_back(
395 std::make_pair(allocator, wrapped_allocator));
396 return wrapped_allocator;
397 } else {
398 return allocator;
399 }
400 }
401
SetStatus(const Status & status)402 void OpKernelContext::SetStatus(const Status& status) {
403 status_.Update(status);
404 }
405
input(StringPiece name,const Tensor ** tensor)406 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
407 int index;
408 TF_RETURN_IF_ERROR(get_input_index(name, &index));
409 if (input_is_ref(index)) {
410 return errors::InvalidArgument("OpKernel used ref input name '", name,
411 "' when non-ref input was expected");
412 }
413 *tensor = params_->inputs[index].tensor;
414 return OkStatus();
415 }
416
input_dtype(StringPiece name,DataType * dtype) const417 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
418 int index;
419 TF_RETURN_IF_ERROR(get_input_index(name, &index));
420 const TensorValue& value(params_->inputs[index]);
421 *dtype = value.dtype();
422 return OkStatus();
423 }
424
input_ref_mutex(StringPiece name,mutex ** out_mutex)425 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
426 int index;
427 TF_RETURN_IF_ERROR(get_input_index(name, &index));
428 *out_mutex = input_ref_mutex(index);
429 return OkStatus();
430 }
431
input(int index) const432 const Tensor& OpKernelContext::input(int index) const {
433 CHECK_GE(index, 0);
434 CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
435 CHECK(!input_is_ref(index));
436 const Tensor& tensor = *params_->inputs[index].tensor;
437 return tensor;
438 }
439
mutable_input(int index,bool lock_held)440 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
441 CHECK_GE(index, 0);
442 CHECK_LT(index, num_inputs());
443 CHECK(input_is_ref(index));
444 // return a copy of the Ref acquired while holding the mutex
445 if (lock_held) {
446 Tensor& tensor = *params_->inputs[index].tensor;
447 return tensor;
448 } else {
449 tf_shared_lock l(*input_ref_mutex(index));
450 Tensor& tensor = *params_->inputs[index].tensor;
451 return tensor;
452 }
453 }
454
replace_ref_input(int index,const Tensor & tensor,bool lock_held)455 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
456 bool lock_held) {
457 CHECK_GE(index, 0);
458 CHECK_LT(index, num_inputs());
459 CHECK(input_is_ref(index));
460 // should only modify the tensor while holding the mutex
461 if (lock_held) {
462 *params_->inputs[index].tensor = tensor;
463 } else {
464 mutex_lock l(*input_ref_mutex(index));
465 *params_->inputs[index].tensor = tensor;
466 }
467 }
468
forward_ref_input_to_ref_output(int input_index,int output_index)469 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
470 int output_index) {
471 CHECK_GE(input_index, 0);
472 CHECK_LT(input_index, num_inputs());
473 CHECK(input_is_ref(input_index));
474 set_output_ref(output_index, params_->inputs[input_index].mutex_if_ref,
475 params_->inputs[input_index].tensor);
476 }
477
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)478 bool OpKernelContext::forward_input_to_output_with_shape(
479 int input_index, int output_index, const TensorShape& output_shape,
480 Tensor** output) {
481 const auto output_attr = params_->output_attr_array == nullptr
482 ? AllocatorAttributes()
483 : output_alloc_attr(output_index);
484 std::unique_ptr<Tensor> new_tensor = forward_input(
485 input_index, output_index, expected_output_dtype(output_index),
486 output_shape, output_memory_type(output_index), output_attr);
487 if (new_tensor != nullptr) {
488 // Transfer ownership to the output slot in OpKernelContext.
489 outputs_[output_index] = TensorValue(new_tensor.release());
490 *output = outputs_[output_index].tensor;
491 return true;
492 } else {
493 return false;
494 }
495 }
496
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)497 Status OpKernelContext::forward_input_to_output_with_shape(
498 StringPiece input_name, StringPiece output_name,
499 const TensorShape& output_shape, Tensor** output) {
500 int input_index, output_index;
501 TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
502 TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
503 if (!forward_input_to_output_with_shape(input_index, output_index,
504 output_shape, output)) {
505 return errors::FailedPrecondition("OpKernel could not forward input '",
506 input_name, "' to output '", output_name);
507 }
508 return OkStatus();
509 }
510
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)511 std::unique_ptr<Tensor> OpKernelContext::forward_input(
512 int input_index, int output_index, DataType output_dtype,
513 const TensorShape& output_shape, MemoryType output_memory_type,
514 const AllocatorAttributes& output_attr) {
515 CHECK_GE(input_index, 0);
516 CHECK_LT(input_index, num_inputs());
517 const TensorValue& input = params_->inputs[input_index];
518 // Check whether at graph construction time this output was marked
519 // either for no forwarding or with a reservation for this input.
520 // If it's reserved for this input we'll skip the refcount and
521 // AllocatorAttribute checks.
522 // TODO(tucker): Maybe we should skip all of the checks?
523 bool never_forward =
524 (params_->forward_from_array != nullptr && output_index >= 0 &&
525 params_->forward_from_array[output_index] == Params::kNeverForward);
526 if (never_forward) return nullptr;
527 bool forward_expected =
528 (params_->forward_from_array != nullptr && output_index >= 0 &&
529 params_->forward_from_array[output_index] == input_index);
530 if (!forward_expected && params_->forward_from_array != nullptr) {
531 // Check for possibly conflicting forward.
532 for (int i = 0; i < num_outputs(); ++i) {
533 if (params_->forward_from_array[i] == input_index) {
534 // This input is reserved for output i.
535 return nullptr;
536 }
537 }
538 }
539 // Check that input tensor exists and is not a ref.
540 if (input.tensor == nullptr || input.is_ref()) {
541 CHECK(!forward_expected);
542 return nullptr;
543 }
544 // Check that input type matches.
545 if (input_dtype(input_index) != output_dtype) {
546 CHECK(!forward_expected);
547 return nullptr;
548 }
549 // Check that the input and output sizes are compatible.
550 if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
551 CHECK(!forward_expected);
552 return nullptr;
553 }
554 // Check that input and output memory types match, i.e.
555 // that they either both live in host or both live in device memory.
556 if (input_memory_type(input_index) != output_memory_type) {
557 CHECK(!forward_expected);
558 return nullptr;
559 }
560 if (!forward_expected) {
561 if (!input->RefCountIsOne()) {
562 return nullptr;
563 }
564 // Check that output allocator attributes are not more restrictive than
565 // input allocator attributes.
566 const auto input_attr = params_->input_alloc_attrs.empty()
567 ? AllocatorAttributes()
568 : input_alloc_attr(input_index);
569 if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
570 return nullptr;
571 }
572 }
573
574 auto output_tensor = MakeUnique<Tensor>();
575 CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
576 return output_tensor;
577 }
578
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)579 Status OpKernelContext::forward_input_or_allocate_temp(
580 gtl::ArraySlice<int> candidate_input_indices, DataType type,
581 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
582 Tensor* out_temp) {
583 for (int input_index : candidate_input_indices) {
584 std::unique_ptr<Tensor> new_tensor =
585 forward_input(input_index, Params::kNoReservation /*output_index*/,
586 type, shape, DEVICE_MEMORY, allocator_attr);
587 if (new_tensor != nullptr) {
588 *out_temp = std::move(*new_tensor);
589 return OkStatus();
590 }
591 }
592 return allocate_temp(type, shape, out_temp, allocator_attr);
593 }
594
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)595 Status OpKernelContext::forward_input_or_allocate_output(
596 gtl::ArraySlice<int> candidate_input_indices, int output_index,
597 const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
598 for (int input_index : candidate_input_indices) {
599 if (forward_input_to_output_with_shape(input_index, output_index,
600 output_shape, output)) {
601 if (forwarded_input != nullptr) {
602 *forwarded_input = input_index;
603 }
604 return OkStatus();
605 }
606 }
607 if (forwarded_input != nullptr) {
608 *forwarded_input = -1;
609 }
610 return allocate_output(output_index, output_shape, output);
611 }
612
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)613 Status OpKernelContext::forward_input_or_allocate_output(
614 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
615 const TensorShape& output_shape, Tensor** output) {
616 for (const StringPiece& input_name : candidate_input_names) {
617 if (forward_input_to_output_with_shape(input_name, output_name,
618 output_shape, output)
619 .ok()) {
620 return OkStatus();
621 }
622 }
623 return allocate_output(output_name, output_shape, output);
624 }
625
delete_ref_input(int index,bool lock_held)626 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
627 CHECK_GE(index, 0);
628 CHECK_LT(index, num_inputs());
629 CHECK(input_is_ref(index));
630 // should only modify the tensor while holding the mutex
631 if (lock_held) {
632 delete params_->inputs[index].tensor;
633 } else {
634 mutex_lock l(*input_ref_mutex(index));
635 delete params_->inputs[index].tensor;
636 }
637 }
638
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)639 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
640 bool lock_held) {
641 int index;
642 TF_RETURN_IF_ERROR(get_input_index(name, &index));
643 if (!input_is_ref(index)) {
644 return errors::InvalidArgument("OpKernel used non-ref input name '", name,
645 "' when ref input was expected");
646 }
647 // return a copy of the Ref acquired while holding the mutex
648 if (lock_held) {
649 *tensor = *params_->inputs[index].tensor;
650 } else {
651 tf_shared_lock l(*input_ref_mutex(index));
652 *tensor = *params_->inputs[index].tensor;
653 }
654 return OkStatus();
655 }
656
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)657 Status OpKernelContext::replace_ref_input(StringPiece name,
658 const Tensor& tensor,
659 bool lock_held) {
660 int index;
661 TF_RETURN_IF_ERROR(get_input_index(name, &index));
662 if (!input_is_ref(index)) {
663 return errors::InvalidArgument("OpKernel used immutable input name '", name,
664 "' when ref input was expected");
665 }
666 replace_ref_input(index, tensor, lock_held);
667 return OkStatus();
668 }
669
input_list(StringPiece name,OpInputList * list)670 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
671 int start, stop;
672 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
673 *list = OpInputList(this, start, stop);
674 return OkStatus();
675 }
676
mutable_input_list(StringPiece name,OpMutableInputList * list)677 Status OpKernelContext::mutable_input_list(StringPiece name,
678 OpMutableInputList* list) {
679 int start, stop;
680 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
681 *list = OpMutableInputList(this, start, stop);
682 return OkStatus();
683 }
684
output_list(StringPiece name,OpOutputList * list)685 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
686 int start, stop;
687 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
688 *list = OpOutputList(this, start, stop);
689 return OkStatus();
690 }
691
maybe_initialize_scope_id_set()692 void OpKernelContext::maybe_initialize_scope_id_set() {
693 if (allocated_scope_ids_ == nullptr) {
694 allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
695 }
696 }
697
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)698 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
699 Tensor** tensor) {
700 if (index < 0) {
701 return errors::Internal("allocate_output with bad index=", index,
702 " kernel=", params_->op_kernel->name());
703 }
704 if (index >= num_outputs()) {
705 return errors::Internal("allocate_output with bad index=", index,
706 " num_outputs=", num_outputs(),
707 " kernel=", params_->op_kernel->name());
708 }
709 bool forward_expected =
710 (params_->forward_from_array != nullptr && index >= 0 &&
711 params_->forward_from_array[index] >= 0);
712 if (forward_expected) {
713 return errors::Internal(
714 "Explicit allocate_output call where input forwarding required. Try "
715 "turning off the ScopedAllocator optimizer.");
716 }
717 AllocatorAttributes attr = output_alloc_attr(index);
718 return allocate_output(index, shape, tensor, attr);
719 }
720
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)721 Status OpKernelContext::allocate_output(StringPiece name,
722 const TensorShape& shape,
723 Tensor** tensor) {
724 int start, stop;
725 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
726 if (stop != start + 1) {
727 return errors::InvalidArgument("OpKernel used list-valued output name '",
728 name,
729 "' when single-valued output was "
730 "expected");
731 }
732 return allocate_output(start, shape, tensor);
733 }
734
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)735 Status OpKernelContext::allocate_output(StringPiece name,
736 const TensorShape& shape,
737 Tensor** tensor,
738 AllocatorAttributes attr) {
739 int start, stop;
740 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
741 if (stop != start + 1) {
742 return errors::InvalidArgument("OpKernel used list-valued output name '",
743 name,
744 "' when single-valued output was "
745 "expected");
746 }
747 return allocate_output(start, shape, tensor, attr);
748 }
749
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)750 Status OpKernelContext::allocate_tensor(
751 DataType type, const TensorShape& shape, Tensor* out_tensor,
752 AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
753 Allocator* a = get_allocator(attr);
754 Tensor new_tensor(
755 a, type, shape,
756 AllocationAttributes(
757 /*retry_on_failure=*/allocation_attr.retry_on_failure,
758 /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
759
760 if (!new_tensor.IsInitialized()) {
761 return errors::ResourceExhausted(
762 "OOM when allocating tensor with shape", shape.DebugString(),
763 " and type ", DataTypeString(type), " on ", params_->device->name(),
764 " by allocator ", a->Name());
765 }
766 if (params_->log_memory) {
767 LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
768 params_->step_id, new_tensor);
769 }
770 *out_tensor = std::move(new_tensor);
771 return OkStatus();
772 }
773
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)774 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
775 Tensor** output,
776 AllocatorAttributes attr) {
777 if (index < 0) {
778 return errors::Internal("allocate_output with bad index=", index,
779 " kernel=", params_->op_kernel->name());
780 }
781 if (index >= num_outputs()) {
782 return errors::Internal("allocate_output with bad index=", index,
783 " num_outputs=", outputs_.size(),
784 " kernel=", params_->op_kernel->name());
785 }
786 const DataType type = params_->op_kernel->output_type(index);
787 if (IsRefType(type)) {
788 return errors::Internal("allocate_output with ref type. index=", index,
789 " type=", type,
790 " kernel=", params_->op_kernel->name());
791 }
792 if (mutable_output(index) != nullptr) {
793 return errors::Internal("allocate_output on same index multiple times.",
794 " index = ", index,
795 " mutable_output(index) = ", mutable_output(index),
796 " kernel=", params_->op_kernel->name());
797 }
798 if (attr.scope_id > 0) {
799 maybe_initialize_scope_id_set();
800 if (!allocated_scope_ids_->insert(attr.scope_id).second) {
801 return errors::Internal(
802 "OpKernel ", params_->op_kernel->name(),
803 " called allocate_output at index ", index, " with scope_id ",
804 attr.scope_id,
805 " more than once. Try turning off the ScopedAllocator optimizer.");
806 }
807 }
808 profiler::ScopedMemoryDebugAnnotation op_annotation(
809 op_kernel().name_view().data(), step_id(), "output", type,
810 [&shape]() { return shape.DebugString(); });
811 auto output_tensor = MakeUnique<Tensor>();
812 Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
813 if (s.ok()) {
814 outputs_[index] = TensorValue(output_tensor.release());
815 *output = outputs_[index].tensor;
816 }
817 return s;
818 }
819
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)820 Status OpKernelContext::allocate_temp(
821 DataType type, const TensorShape& shape, Tensor* out_temp,
822 AllocatorAttributes allocator_attr,
823 const AllocationAttributes& allocation_attr) {
824 if (allocator_attr.scope_id > 0) {
825 // We do not allow ScopedAllocator calls from allocate_temp.
826 // Here we clear the scope_id and return a temporary buffer.
827 // This is because it is legal for a kernel to call allocate_temp
828 // and then set_output with the temp tensor.
829 //
830 // We achieve memory correctness by forcing an allocation in set_output and
831 // copying over the tensor from the temp buffer. Kernels which would like
832 // to avoid this performance penalty should switch to calling
833 // allocate_output.
834 VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
835 << " called allocate_temp with scope_id " << allocator_attr.scope_id
836 << ". Switch to allocate_output to avoid performance penalty.";
837 allocator_attr.scope_id = -1;
838 }
839 profiler::ScopedMemoryDebugAnnotation op_annotation(
840 op_kernel().name_view().data(), step_id(), "temp", type,
841 [&shape]() { return shape.DebugString(); });
842 Status s =
843 allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
844 if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
845 Allocator* a = get_allocator(allocator_attr);
846 if (a->TracksAllocationSizes()) {
847 int64_t alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
848 record_temp_memory_allocation(alloc_size, *out_temp);
849 }
850 } else if (record_memory_consumption_) {
851 DCHECK(tracking_state_);
852 mutex_lock l(tracking_state_->stats_mu);
853 tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
854 }
855 return s;
856 }
857
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)858 Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape,
859 Tensor* out_temp,
860 AllocatorAttributes allocator_attr) {
861 return allocate_temp(type, shape, out_temp, allocator_attr,
862 AllocationAttributes());
863 }
864
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)865 Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape,
866 Tensor* out_temp) {
867 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
868 }
869
get_input_index(StringPiece name,int * out_index) const870 Status OpKernelContext::get_input_index(StringPiece name,
871 int* out_index) const {
872 int start, stop;
873 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
874 if (stop != start + 1) {
875 return errors::InvalidArgument("OpKernel used list-valued input name '",
876 name,
877 "' when single-valued input was "
878 "expected");
879 }
880 *out_index = start;
881 return OkStatus();
882 }
883
get_output_index(StringPiece name,int * out_index) const884 Status OpKernelContext::get_output_index(StringPiece name,
885 int* out_index) const {
886 int start, stop;
887 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
888 if (stop != start + 1) {
889 return errors::InvalidArgument("OpKernel used list-valued output name '",
890 name,
891 "' when single-valued output was "
892 "expected");
893 }
894 *out_index = start;
895 return OkStatus();
896 }
897
set_output(StringPiece name,const Tensor & tensor)898 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
899 int index;
900 TF_RETURN_IF_ERROR(get_output_index(name, &index));
901 set_output(index, tensor);
902 return OkStatus();
903 }
904
set_output(StringPiece name,Tensor && tensor)905 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
906 int index;
907 TF_RETURN_IF_ERROR(get_output_index(name, &index));
908 set_output(index, std::move(tensor));
909 return OkStatus();
910 }
911
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)912 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
913 int index, const Tensor& tensor) {
914 bool allocate_and_copy = false;
915 const bool never_forward =
916 (params_->forward_from_array != nullptr &&
917 params_->forward_from_array[index] == Params::kNeverForward);
918 if (TF_PREDICT_FALSE(never_forward)) {
919 maybe_initialize_scope_id_set();
920 if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
921 allocated_scope_ids_->end()) {
922 allocate_and_copy = true;
923 } else {
924 // The output at `index` must have been previously allocated via a call to
925 // `allocate_output(index, ...)`. That call would ensure that we return
926 // the correct slice of the ScopedAllocated buffer, so we do not
927 // re-allocate and copy here.
928 LOG(WARNING)
929 << "OpKernel " << params_->op_kernel->name()
930 << " called both allocate_output and set_output with scope_id "
931 << output_alloc_attr(index).scope_id;
932 }
933 }
934
935 if (TF_PREDICT_FALSE(allocate_and_copy)) {
936 // This output was marked to not be forwarded either during graph
937 // construction or grappler passes. Force an allocation and copy input to
938 // output.
939 VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
940 << tensor.DebugString() << " never_forward " << never_forward
941 << " params_->forward_from_array[index] "
942 << params_->forward_from_array[index] << " alloc_attr.scope_id "
943 << output_alloc_attr(index).scope_id;
944 profiler::ScopedMemoryDebugAnnotation op_annotation(
945 op_kernel().name_view().data(), step_id(), "output", tensor.dtype(),
946 [&tensor]() { return tensor.shape().DebugString(); });
947 auto new_tensor = MakeUnique<Tensor>();
948 Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
949 output_alloc_attr(index));
950 TF_CHECK_OK(s);
951 device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
952 op_device_context(), [](const Status&) {});
953 outputs_[index] = TensorValue(new_tensor.release());
954 }
955 return allocate_and_copy;
956 }
957
maybe_track_allocations_for_set_output(const Tensor & tensor)958 void OpKernelContext::maybe_track_allocations_for_set_output(
959 const Tensor& tensor) {
960 if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
961 DCHECK(tracking_state_);
962 mutex_lock l(tracking_state_->stats_mu);
963 const auto it = std::find_if(
964 tracking_state_->temp_tensor_buffer_and_size.begin(),
965 tracking_state_->temp_tensor_buffer_and_size.end(),
966 [&tensor](const std::pair<const void*, int64>& e) {
967 return e.first ==
968 static_cast<const void*>(tensor.tensor_data().data());
969 });
970 if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
971 tracking_state_->temp_memory_allocated -= it->second;
972 tracking_state_->temp_tensor_buffer_and_size.erase(it);
973 }
974 }
975 }
976
set_output(int index,const Tensor & tensor)977 void OpKernelContext::set_output(int index, const Tensor& tensor) {
978 CHECK_GE(index, 0);
979 CHECK_LT(index, outputs_.size());
980 const DataType type = params_->op_kernel->output_type(index);
981 CHECK(!IsRefType(type));
982 CHECK_EQ(outputs_[index].tensor, nullptr);
983 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
984 // Input can be forwarded to output; incref on `tensor` and set output at
985 // `index` to this tensor.
986 outputs_[index] = TensorValue(new Tensor(tensor));
987 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
988 }
989 }
990
set_output(int index,Tensor && tensor)991 void OpKernelContext::set_output(int index, Tensor&& tensor) {
992 CHECK_GE(index, 0);
993 CHECK_LT(index, outputs_.size());
994 const DataType type = params_->op_kernel->output_type(index);
995 CHECK(!IsRefType(type));
996 CHECK_EQ(outputs_[index].tensor, nullptr);
997 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
998 // Input can be forwarded to output; set output at `index` to this tensor.
999 outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
1000 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
1001 }
1002 }
1003
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)1004 void OpKernelContext::set_output_ref(int index, mutex* mu,
1005 Tensor* tensor_for_ref) {
1006 CHECK_GE(index, 0);
1007 CHECK_LT(index, outputs_.size());
1008 CHECK(IsRefType(params_->op_kernel->output_type(index)));
1009 outputs_[index] = TensorValue(mu, tensor_for_ref);
1010 }
1011
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)1012 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
1013 Tensor* tensor_for_ref) {
1014 int index;
1015 TF_RETURN_IF_ERROR(get_output_index(name, &index));
1016 set_output_ref(index, mu, tensor_for_ref);
1017 return OkStatus();
1018 }
1019
mutable_output(StringPiece name,Tensor ** tensor)1020 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
1021 int index;
1022 TF_RETURN_IF_ERROR(get_output_index(name, &index));
1023 *tensor = mutable_output(index);
1024 return OkStatus();
1025 }
1026
ValidateInputsAreSameShape(OpKernel * op)1027 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
1028 const auto& inputs = params_->inputs;
1029 for (size_t i = 1; i < inputs.size(); ++i) {
1030 if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
1031 SetStatus(errors::InvalidArgument(
1032 "Inputs to operation ", op->name(), " of type ", op->type_string(),
1033 " must have the same size and shape. Input 0: ",
1034 inputs[0]->shape().DebugString(), " != input ", i, ": ",
1035 inputs[i]->shape().DebugString()));
1036 return false;
1037 }
1038 }
1039 return true;
1040 }
1041
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)1042 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
1043 const DataTypeSlice expected_outputs) {
1044 DataTypeVector inputs;
1045 for (const TensorValue& t : params_->inputs) {
1046 inputs.push_back(t.dtype());
1047 }
1048 DataTypeVector outputs = params_->op_kernel->output_types();
1049 return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
1050 outputs);
1051 }
1052
record_temp_memory_allocation(int64_t size,const Tensor & t)1053 void OpKernelContext::record_temp_memory_allocation(int64_t size,
1054 const Tensor& t) {
1055 if (tracking_state_) {
1056 mutex_lock l(tracking_state_->stats_mu);
1057 tracking_state_->temp_memory_allocated += size;
1058 tracking_state_->temp_tensor_buffer_and_size.emplace_back(
1059 static_cast<const void*>(t.tensor_data().data()), size);
1060 }
1061 }
1062
temp_memory_allocated() const1063 int64_t OpKernelContext::temp_memory_allocated() const {
1064 if (tracking_state_) {
1065 mutex_lock l(tracking_state_->stats_mu);
1066 return tracking_state_->temp_memory_allocated;
1067 } else {
1068 return 0;
1069 }
1070 }
1071
record_persistent_memory_allocation(int64_t size,int64_t alloc_id)1072 void OpKernelContext::record_persistent_memory_allocation(int64_t size,
1073 int64_t alloc_id) {
1074 if (tracking_state_) {
1075 mutex_lock l(tracking_state_->stats_mu);
1076 tracking_state_->persistent_memory_allocated += size;
1077 if (alloc_id >= 0) {
1078 tracking_state_->persistent_alloc_ids.push_back(alloc_id);
1079 }
1080 }
1081 }
1082
persistent_memory_allocated() const1083 int64_t OpKernelContext::persistent_memory_allocated() const {
1084 if (tracking_state_) {
1085 mutex_lock l(tracking_state_->stats_mu);
1086 return tracking_state_->persistent_memory_allocated;
1087 } else {
1088 return 0;
1089 }
1090 }
1091
persistent_alloc_ids() const1092 std::vector<int64_t> OpKernelContext::persistent_alloc_ids() const {
1093 if (tracking_state_) {
1094 mutex_lock l(tracking_state_->stats_mu);
1095 return std::vector<int64_t>(tracking_state_->persistent_alloc_ids.begin(),
1096 tracking_state_->persistent_alloc_ids.end());
1097 } else {
1098 return std::vector<int64_t>();
1099 }
1100 }
1101
clear_recorded_memory()1102 void OpKernelContext::clear_recorded_memory() {
1103 if (tracking_state_) {
1104 mutex_lock l(tracking_state_->stats_mu);
1105 tracking_state_->temp_memory_allocated = 0;
1106 tracking_state_->persistent_memory_allocated = 0;
1107 tracking_state_->temp_tensor_buffer_and_size.clear();
1108 tracking_state_->persistent_alloc_ids.clear();
1109 }
1110 }
1111
set_record_memory_consumption(bool v)1112 void OpKernelContext::set_record_memory_consumption(bool v) {
1113 record_memory_consumption_ = v;
1114 if (v && !tracking_state_) {
1115 tracking_state_ = absl::make_unique<TrackingState>();
1116 }
1117 }
1118
executor_type() const1119 const string& OpKernelContext::executor_type() const {
1120 if (params_->executor_type) {
1121 return *params_->executor_type;
1122 } else {
1123 static const string& kEmptyString = *new string("");
1124 return kEmptyString;
1125 }
1126 }
1127
1128 // OpKernel registration ------------------------------------------------------
1129
1130 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1131 KernelRegistration(const KernelDef& d, StringPiece c,
1132 std::unique_ptr<kernel_factory::OpKernelFactory> f)
1133 : def(d), kernel_class_name(c), factory(std::move(f)) {}
1134
1135 const KernelDef def;
1136 const string kernel_class_name;
1137 std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1138 };
1139
1140 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1141 // factory functions for instantiating the OpKernel that matches the
1142 // KernelDef.
1143 struct KernelRegistry {
1144 mutex mu;
1145 std::unordered_multimap<string, KernelRegistration> registry
1146 TF_GUARDED_BY(mu);
1147 };
1148
1149 #if defined(_WIN32)
1150 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1151 #elif defined(__APPLE__)
1152 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1153 #else
1154 static const char kKernelLibPattern[] = "libtfkernel*.so";
1155 #endif
1156
1157 #define FEATURE(x) \
1158 { x, #x }
1159
1160 // Returns Status::OK if the dynamic library at the given path is safe to
1161 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1162 static Status IsProbablySafeToLoad(const string& path) {
1163 // A map of platform string to required CPU feature.
1164 using port::CPUFeature;
1165 static const auto* feature_map =
1166 new std::map<string, std::pair<CPUFeature, string>>{
1167 {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1168 };
1169
1170 std::vector<std::string> platform_strings;
1171 int result = GetPlatformStrings(path, &platform_strings);
1172 if (result) {
1173 return Status(error::Code::UNKNOWN, strerror(result));
1174 }
1175 if (platform_strings.empty()) {
1176 return Status(error::Code::FAILED_PRECONDITION,
1177 "Didn't find any platform strings");
1178 }
1179 std::vector<std::string> missing_features;
1180 for (const auto& platform_string : platform_strings) {
1181 const auto& entry = feature_map->find(platform_string);
1182 if (entry != feature_map->end() &&
1183 !port::TestCPUFeature(entry->second.first)) {
1184 missing_features.emplace_back(entry->second.second);
1185 }
1186 }
1187 if (!missing_features.empty()) {
1188 string errmsg = "Missing CPU features: ";
1189 errmsg.append(absl::StrJoin(missing_features, ", "));
1190 return errors::FailedPrecondition(errmsg);
1191 }
1192 return OkStatus();
1193 }
1194
LoadDynamicKernelsInternal()1195 void LoadDynamicKernelsInternal() {
1196 Env* env = Env::Default();
1197
1198 // Override to allow loading unsafe packages for development.
1199 // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1200 char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1201 bool override_abi_check = false;
1202 if (_abi_check_env_var != nullptr) {
1203 override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1204 }
1205
1206 string bazel_kernel_dir =
1207 io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1208 std::vector<string> files;
1209 Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1210 if (s_kernel_dir.ok()) {
1211 string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1212 for (const auto& file : files) {
1213 string fullpath = io::JoinPath(bazel_kernel_dir, file);
1214 if (env->MatchPath(fullpath, dll_spec)) {
1215 Status s = IsProbablySafeToLoad(fullpath);
1216 if (!s.ok() && override_abi_check) {
1217 LOG(WARNING) << "Loading UNSAFE library " << fullpath
1218 << " because ABI check override is set: "
1219 << s.error_message();
1220 }
1221 if (s.ok() || override_abi_check) {
1222 // TODO(gunan): Store the handles to the opened files.
1223 void* unused_filehandle;
1224 TF_CHECK_OK(
1225 env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1226 } else {
1227 LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1228 << s.error_message();
1229 }
1230 }
1231 }
1232 }
1233 }
1234
1235 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1236 void LoadDynamicKernels() {
1237 // TODO(gunan): As more features are available, add intelligent kernel
1238 // selection, and dropping unsuitable kernel logic here.
1239 static absl::once_flag dll_loader_flag;
1240 absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1241 }
1242
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1243 static string Key(StringPiece op_type, const DeviceType& device_type,
1244 StringPiece label) {
1245 return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1246 label);
1247 }
1248
1249 // Provide a way for users to disable JIT kernels for a transitional period.
1250 // Until this is removed, this function also removes the JIT label that is added
1251 // to JIT kernels during the static registration, to allow them to be found
1252 // during lookup as normal kernels.
SetupOrDisableJit(KernelRegistry * registry)1253 void SetupOrDisableJit(KernelRegistry* registry) {
1254 std::unordered_multimap<string, KernelRegistration> jit_kernels;
1255 bool remove_jit_kernels = absl::StrContains(
1256 absl::NullSafeStringView(getenv(kDisableJitKernelsEnvVar)), "1");
1257
1258 mutex_lock l(registry->mu);
1259 std::unordered_multimap<string, KernelRegistration>& all_kernels =
1260 registry->registry;
1261 auto it = all_kernels.begin();
1262 while (it != all_kernels.end()) {
1263 if (absl::StrContains(it->second.def.label(), kJitKernelLabel)) {
1264 // Remove all kernels that have the jit label. They will be added back
1265 // without the label if they are not to be disabled.
1266 KernelDef def_without_label = it->second.def;
1267 def_without_label.set_label("");
1268
1269 if (!remove_jit_kernels) {
1270 jit_kernels.emplace(
1271 Key(def_without_label.op(),
1272 DeviceType(def_without_label.device_type()),
1273 def_without_label.label()),
1274 KernelRegistration(def_without_label, it->second.kernel_class_name,
1275 std::move(it->second.factory)));
1276 }
1277
1278 it = all_kernels.erase(it);
1279 } else {
1280 it++;
1281 }
1282 }
1283
1284 // Add back kernels if they are not disabled. This new key-value pair have all
1285 // references to the label removed.
1286 for (auto& jit_kernel : jit_kernels) {
1287 all_kernels.insert(std::move(jit_kernel));
1288 }
1289 }
1290
1291 namespace register_kernel {
1292
1293 // Defined out of line to save code space
Name(const char * op)1294 Name::Name(const char* op) : KernelDefBuilder(op) {}
1295
1296 } // namespace register_kernel
1297
GlobalKernelRegistry()1298 void* GlobalKernelRegistry() {
1299 static KernelRegistry* global_kernel_registry = []() {
1300 KernelRegistry* registry = new KernelRegistry;
1301 OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1302 return registry;
1303 }();
1304 return global_kernel_registry;
1305 }
1306
GlobalKernelRegistryTyped()1307 static KernelRegistry* GlobalKernelRegistryTyped() {
1308 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1309 LoadDynamicKernels();
1310 #endif // AUTOLOAD_DYNAMIC_KERNELS
1311 auto* registry = reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1312 // Update or disable JIT kernels based on user configuration. This is a
1313 // temporary fallback as part of the initial release of JIT kernels.
1314 static absl::once_flag setup_or_disable_jit;
1315 absl::call_once(setup_or_disable_jit, SetupOrDisableJit, registry);
1316 return registry;
1317 }
1318
1319 namespace kernel_factory {
1320
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1321 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1322 StringPiece kernel_class_name,
1323 std::unique_ptr<OpKernelFactory> factory) {
1324 const string key =
1325 Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1326 kernel_def->label());
1327
1328 // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1329 // here.
1330 // InitInternal gets called by static initializers, so it ends up executing
1331 // before main. This causes LoadKernelLibraries function to get called
1332 // before some file libraries can initialize, which in turn crashes the
1333 // program flakily. Until we get rid of static initializers in kernel
1334 // registration mechanism, we have this workaround here.
1335 auto global_registry =
1336 reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1337 mutex_lock l(global_registry->mu);
1338 global_registry->registry.emplace(
1339 key,
1340 KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1341 delete kernel_def;
1342 }
1343
Create(OpKernelConstruction * context)1344 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1345 OpKernelConstruction* context) {
1346 return (*create_func_)(context);
1347 }
1348
1349 } // namespace kernel_factory
1350
1351 namespace {
1352
1353 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1354 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1355 static const string& kKernelAttr = *new string("_kernel");
1356 static const string& kEmptyString = *new string("");
1357
1358 // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1359 // to use the `AttrSlice::FindByString()` overload, which does a more
1360 // efficient map lookup (instead of a linear scan) when the attribute name is
1361 // already a `const string&`.
1362 const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1363 if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1364 return kEmptyString;
1365 else
1366 return attr_value->s();
1367 }
1368
1369 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1370 Status FindKernelRegistration(
1371 const DeviceType& device_type, StringPiece node_name,
1372 bool has_experimental_debug_info,
1373 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1374 StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1375 bool* was_attr_mismatch) {
1376 *reg = nullptr;
1377 *was_attr_mismatch = false;
1378
1379 const string& label = GetKernelLabelAttr(node_attrs);
1380
1381 const string key = Key(node_op, device_type, label);
1382 auto typed_registry = GlobalKernelRegistryTyped();
1383 tf_shared_lock lock(typed_registry->mu);
1384 auto regs = typed_registry->registry.equal_range(key);
1385 for (auto iter = regs.first; iter != regs.second; ++iter) {
1386 // If there is a kernel registered for the op and device_type,
1387 // check that the attrs match.
1388 bool match;
1389 TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1390 if (match) {
1391 if (*reg != nullptr) {
1392 if ((*reg)->def.priority() == iter->second.def.priority()) {
1393 return errors::InvalidArgument(
1394 "Multiple OpKernel registrations match NodeDef at the same "
1395 "priority '",
1396 FormatNodeDefForError(node_name, has_experimental_debug_info,
1397 experimental_debug_info),
1398 "': '", (*reg)->def.ShortDebugString(), "' and '",
1399 iter->second.def.ShortDebugString(), "'");
1400 } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1401 continue;
1402 }
1403 // iter->second's priority is higher than *reg.
1404 }
1405 *reg = &iter->second;
1406 } else {
1407 *was_attr_mismatch = true;
1408 }
1409 }
1410 // Check if no device specific registrations found. If not, try finding a
1411 // default kernel.
1412 if (*reg == nullptr &&
1413 !IsSymbolicExecutionDevice(device_type.type_string())) {
1414 const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1415 auto regs = typed_registry->registry.equal_range(default_key);
1416 for (auto iter = regs.first; iter != regs.second; ++iter) {
1417 // If there is a kernel registered for the op and device_type,
1418 // check that the attrs match.
1419 bool match;
1420 TF_RETURN_IF_ERROR(
1421 KernelAttrsMatch(iter->second.def, node_attrs, &match));
1422 if (match) {
1423 if (*reg != nullptr) {
1424 return errors::InvalidArgument(
1425 "Multiple Default OpKernel registrations match NodeDef '",
1426 FormatNodeDefForError(node_name, has_experimental_debug_info,
1427 experimental_debug_info),
1428 "': '", (*reg)->def.ShortDebugString(), "' and '",
1429 iter->second.def.ShortDebugString(), "'");
1430 }
1431 *reg = &iter->second;
1432 } else {
1433 *was_attr_mismatch = true;
1434 }
1435 }
1436
1437 if (*reg != nullptr) {
1438 VLOG(1) << "No device-specific kernels found for NodeDef '"
1439 << FormatNodeDefForError(node_name, has_experimental_debug_info,
1440 experimental_debug_info)
1441 << "'"
1442 << "Will fall back to a default kernel." << std::endl;
1443 }
1444 }
1445
1446 return OkStatus();
1447 }
1448
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1449 Status FindKernelRegistration(const DeviceType& device_type,
1450 const NodeDef& node_def,
1451 const KernelRegistration** reg,
1452 bool* was_attr_mismatch) {
1453 return FindKernelRegistration(
1454 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1455 node_def.experimental_debug_info(), node_def.op(),
1456 AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1457 }
1458
1459 } // namespace
1460
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1461 bool KernelDefAvailable(const DeviceType& device_type,
1462 const NodeDef& node_def) {
1463 const KernelRegistration* reg = nullptr;
1464 bool was_attr_mismatch;
1465 Status result =
1466 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch);
1467 return result.ok() && reg != nullptr;
1468 }
1469
1470 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1471 Status FindKernelDef(
1472 const DeviceType& device_type, StringPiece node_name,
1473 bool has_experimental_debug_info,
1474 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1475 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1476 const KernelDef** def, string* kernel_class_name) {
1477 const KernelRegistration* reg = nullptr;
1478 bool was_attr_mismatch;
1479 TF_RETURN_IF_ERROR(FindKernelRegistration(
1480 device_type, node_name, has_experimental_debug_info,
1481 experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch));
1482 if (reg == nullptr) {
1483 const std::string device_str = DeviceTypeString(device_type);
1484 Status s = errors::NotFound(
1485 "No registered '", node_op, "' OpKernel for ", device_str,
1486 " devices compatible with node ",
1487 FormatNodeDefForError(node_name, has_experimental_debug_info,
1488 experimental_debug_info));
1489 if (was_attr_mismatch) {
1490 errors::AppendToMessage(
1491 &s, " (OpKernel was found, but attributes didn't match) ",
1492 "Requested Attributes: ",
1493 SummarizeAttrsHelper(node_attrs, node_device));
1494 }
1495
1496 // Do not print kernel registrations for other devices when using _JIT
1497 // devices for compilation or for MKL ops.
1498 // TODO (intel-tf) : Remove the check for MKL ops when support for
1499 // block format is removed.
1500 if (!absl::StrContains(device_str, "JIT") &&
1501 !absl::StartsWith(node_name, "_Mkl")) {
1502 errors::AppendToMessage(
1503 &s, ". Registered:", KernelsRegisteredForOp(node_op));
1504 }
1505
1506 return s;
1507 }
1508 if (def != nullptr) *def = ®->def;
1509 if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1510 return OkStatus();
1511 }
1512
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1513 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1514 const KernelDef** def, string* kernel_class_name) {
1515 return FindKernelDef(
1516 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1517 node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1518 AttrSlice(&node_def.attr()), def, kernel_class_name);
1519 }
1520
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1521 Status SupportedDeviceTypesForNode(
1522 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1523 PrioritizedDeviceTypeVector* prioritized_device_types,
1524 const DeviceNameUtils::ParsedName* local_address_spec) {
1525 // TODO(zhifengc): Changes the callers (SimplePlacer and
1526 // DynamicPlacer) to consider the possibility that 'def' is call to
1527 // a user-defined function and only calls this
1528 // SupportedDeviceTypesForNode for primitive ops.
1529 const OpRegistrationData* op_reg_data;
1530 const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1531 if (s.ok()) {
1532 bool exists_attr_mismatch = false;
1533 for (const DeviceType& device_type : prioritized_types) {
1534 const KernelRegistration* reg = nullptr;
1535 bool was_attr_mismatch = false;
1536 TF_RETURN_IF_ERROR(
1537 FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
1538 exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1539 if (reg != nullptr) {
1540 int32_t priority = reg->def.priority();
1541 prioritized_device_types->emplace_back(device_type, priority);
1542 }
1543 }
1544 // Add extra supported device types if the following conditions are
1545 // satisfied:
1546 // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1547 // 2) A device is requested for this node which specifies job/replica/task
1548 // 3) A local device is provided which specifies job/replica/task
1549 // 4) The local device does not have the same (job, replica, task) as the
1550 // requested device
1551 //
1552 // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1553 // whose kernel is known to a remote process but not to the current process.
1554 if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1555 local_address_spec != nullptr) {
1556 DeviceNameUtils::ParsedName requested_device_name;
1557 DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1558 if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1559 requested_device_name)) {
1560 if (requested_device_name.has_type) {
1561 prioritized_device_types->push_back(
1562 std::make_pair(DeviceType(requested_device_name.type), 0));
1563 } else {
1564 for (const DeviceType& device_type : prioritized_types) {
1565 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1566 }
1567 }
1568 }
1569 }
1570
1571 // If we were unable to find any valid devices let's validate if the node is
1572 // even valid.
1573 if (prioritized_device_types->empty()) {
1574 TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1575 }
1576
1577 std::stable_sort(prioritized_device_types->begin(),
1578 prioritized_device_types->end(),
1579 [](const std::pair<DeviceType, int32>& a,
1580 const std::pair<DeviceType, int32>& b) {
1581 return a.second > b.second;
1582 });
1583 } else {
1584 // Assumes that all device types support this node.
1585 for (const DeviceType& device_type : prioritized_types) {
1586 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1587 }
1588 }
1589 return OkStatus();
1590 }
1591
LogAllRegisteredKernels()1592 void LogAllRegisteredKernels() {
1593 KernelList kernel_list = GetAllRegisteredKernels();
1594 for (const auto& kernel_def : kernel_list.kernel()) {
1595 LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1596 }
1597 }
1598
GetAllRegisteredKernels()1599 KernelList GetAllRegisteredKernels() {
1600 return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1601 }
1602
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1603 KernelList GetFilteredRegisteredKernels(
1604 const std::function<bool(const KernelDef&)>& predicate) {
1605 KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1606 KernelList kernel_list;
1607 tf_shared_lock lock(typed_registry->mu);
1608 kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1609 for (const auto& p : typed_registry->registry) {
1610 const KernelDef& kernel_def = p.second.def;
1611 if (predicate(kernel_def)) {
1612 *kernel_list.add_kernel() = kernel_def;
1613 }
1614 }
1615 return kernel_list;
1616 }
1617
GetRegisteredKernelsForOp(StringPiece op_name)1618 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1619 auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1620 return GetFilteredRegisteredKernels(op_pred);
1621 }
1622
KernelsRegisteredForOp(StringPiece op_name)1623 string KernelsRegisteredForOp(StringPiece op_name) {
1624 KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1625 if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
1626 string ret;
1627 for (const auto& kernel_def : kernel_list.kernel()) {
1628 strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
1629 if (!kernel_def.label().empty()) {
1630 strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1631 }
1632 for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1633 strings::StrAppend(
1634 &ret, "; ", kernel_def.constraint(i).name(), " in ",
1635 SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1636 }
1637 strings::StrAppend(&ret, "\n");
1638 }
1639 return ret;
1640 }
1641
1642 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1643 * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1644 std::unique_ptr<OpKernel> CreateOpKernel(
1645 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1646 const NodeDef& node_def, int graph_def_version, Status* status) {
1647 // Look up the Op registered for this op name.
1648 std::shared_ptr<const NodeProperties> props;
1649 status->Update(NodeProperties::CreateFromNodeDef(
1650 node_def, OpRegistry::Global(), &props));
1651 if (!status->ok()) {
1652 errors::AppendToMessage(status,
1653 " for node: ", FormatNodeDefForError(node_def));
1654 return nullptr;
1655 }
1656 return CreateOpKernel(device_type, device, allocator, props,
1657 graph_def_version, status);
1658 }
1659
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1660 std::unique_ptr<OpKernel> CreateOpKernel(
1661 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1662 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1663 Status* status) {
1664 OpKernel* kernel = nullptr;
1665 *status = CreateOpKernel(std::move(device_type), device, allocator,
1666 /*flib=*/nullptr, props, graph_def_version, &kernel);
1667 return std::unique_ptr<OpKernel>(kernel);
1668 }
1669
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1670 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1671 Allocator* allocator, FunctionLibraryRuntime* flib,
1672 const std::shared_ptr<const NodeProperties>& props,
1673 int graph_def_version, OpKernel** kernel) {
1674 return CreateOpKernel(std::move(device_type), device, allocator, flib,
1675 /* resource_mgr= */ nullptr, props, graph_def_version,
1676 kernel);
1677 }
1678
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1679 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1680 Allocator* allocator, FunctionLibraryRuntime* flib,
1681 ResourceMgr* resource_mgr,
1682 const std::shared_ptr<const NodeProperties>& props,
1683 int graph_def_version, OpKernel** kernel) {
1684 const NodeDef& node_def = props->node_def;
1685 bool was_attr_mismatch;
1686 const KernelRegistration* registration = nullptr;
1687 Status s;
1688 if (props != nullptr) {
1689 VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1690
1691 // Validate node_def against OpDef.
1692 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1693
1694 // Look up kernel registration.
1695 s = FindKernelRegistration(device_type, node_def, ®istration,
1696 &was_attr_mismatch);
1697 if (!s.ok()) {
1698 errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1699 return s;
1700 }
1701 }
1702 if (registration == nullptr) {
1703 s.Update(errors::NotFound("No registered '", node_def.op(),
1704 "' OpKernel for '", DeviceTypeString(device_type),
1705 "' devices compatible with node ",
1706 FormatNodeDefForError(node_def)));
1707 if (was_attr_mismatch) {
1708 errors::AppendToMessage(
1709 &s, " (OpKernel was found, but attributes didn't match) ",
1710 "Requested Attributes: ", SummarizeAttrs(node_def));
1711 }
1712 errors::AppendToMessage(
1713 &s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
1714 return s;
1715 }
1716
1717 // We are creating a kernel for an op registered in
1718 // OpRegistry::Global(), we consult the kernel registry to decide
1719 // the kernel's input and output memory types.
1720 MemoryTypeVector input_memory_types;
1721 MemoryTypeVector output_memory_types;
1722 TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1723 node_def, &input_memory_types,
1724 &output_memory_types));
1725
1726 // Everything needed for OpKernel construction.
1727 OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1728 resource_mgr, props, input_memory_types,
1729 output_memory_types, graph_def_version, &s);
1730 *kernel = registration->factory->Create(&context);
1731 if (!s.ok()) {
1732 delete *kernel;
1733 *kernel = nullptr;
1734 }
1735 return s;
1736 }
1737
1738 namespace {
1739
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1740 bool FindArgInOp(StringPiece arg_name,
1741 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1742 for (const auto& arg : args) {
1743 if (arg_name == arg.name()) {
1744 return true;
1745 }
1746 }
1747 return false;
1748 }
1749
1750 } // namespace
1751
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1752 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1753 auto typed_registry = GlobalKernelRegistryTyped();
1754 tf_shared_lock lock(typed_registry->mu);
1755 for (const auto& key_registration : typed_registry->registry) {
1756 const KernelDef& kernel_def(key_registration.second.def);
1757 const OpRegistrationData* op_reg_data;
1758 const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1759 if (!status.ok()) {
1760 // TODO(josh11b): Make this a hard error.
1761 LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1762 << "') for unknown op: " << kernel_def.op();
1763 continue;
1764 }
1765 const OpDef& op_def = op_reg_data->op_def;
1766 for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1767 if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1768 !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1769 return errors::InvalidArgument(
1770 "HostMemory arg '", host_memory_arg,
1771 "' not found in OpDef: ", SummarizeOpDef(op_def));
1772 }
1773 }
1774 }
1775 return OkStatus();
1776 }
1777
1778 template <>
eigen_device() const1779 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1780 return eigen_cpu_device();
1781 }
1782
1783 template <>
eigen_device() const1784 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1785 return eigen_gpu_device();
1786 }
1787
CtxFailure(const Status & s)1788 void OpKernelConstruction::CtxFailure(const Status& s) {
1789 VLOG(1) << s;
1790 SetStatus(s);
1791 }
1792
CtxFailureWithWarning(const Status & s)1793 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1794 LOG(WARNING) << s;
1795 SetStatus(s);
1796 }
1797
CtxFailure(const char * file,int line,const Status & s)1798 void OpKernelConstruction::CtxFailure(const char* file, int line,
1799 const Status& s) {
1800 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1801 << " : " << s;
1802 SetStatus(s);
1803 }
1804
CtxFailureWithWarning(const char * file,int line,const Status & s)1805 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1806 const Status& s) {
1807 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1808 << " : " << s;
1809 SetStatus(s);
1810 }
1811
CtxFailure(const Status & s)1812 void OpKernelContext::CtxFailure(const Status& s) {
1813 VLOG(1) << s;
1814 SetStatus(s);
1815 }
1816
CtxFailureWithWarning(const Status & s)1817 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1818 LOG(WARNING) << s;
1819 SetStatus(s);
1820 }
1821
CtxFailure(const char * file,int line,const Status & s)1822 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1823 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1824 << " : " << s;
1825 SetStatus(s);
1826 }
1827
CtxFailureWithWarning(const char * file,int line,const Status & s)1828 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1829 const Status& s) {
1830 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1831 << " : " << s;
1832 SetStatus(s);
1833 }
1834
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1835 void CheckNotInComputeAsync(OpKernelContext* ctx,
1836 const char* correct_macro_name) {
1837 CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1838 << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1839 }
1840
1841 } // namespace tensorflow
1842