1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <cstdlib>
19 #include <cstring>
20 #include <mutex> // NOLINT
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/base/call_once.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/framework/allocation_description.pb.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/attr_value_util.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/kernel_def.pb.h"
34 #include "tensorflow/core/framework/kernel_def_util.h"
35 #include "tensorflow/core/framework/log_memory.h"
36 #include "tensorflow/core/framework/memory_types.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/node_properties.h"
40 #include "tensorflow/core/framework/op_def_util.h"
41 #include "tensorflow/core/framework/tensor_reference.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/core/stringpiece.h"
46 #include "tensorflow/core/lib/gtl/map_util.h"
47 #include "tensorflow/core/lib/io/path.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/cpu_info.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/platform_strings.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/profiler/lib/traceme.h"
57 #include "tensorflow/core/util/ptr_util.h"
58
59 namespace tensorflow {
60
61 namespace {
62
MatchSignatureHelper(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs,const DataTypeSlice inputs,const DataTypeSlice outputs)63 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
64 const DataTypeSlice expected_outputs,
65 const DataTypeSlice inputs,
66 const DataTypeSlice outputs) {
67 bool signature_mismatch = false;
68
69 if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
70 for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
71 if (!TypesCompatible(expected_inputs[i], inputs[i])) {
72 signature_mismatch = true;
73 }
74 }
75
76 if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
77 for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
78 if (!TypesCompatible(expected_outputs[i], outputs[i])) {
79 signature_mismatch = true;
80 }
81 }
82
83 if (signature_mismatch) {
84 return errors::InvalidArgument(
85 "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
86 DataTypeSliceString(outputs),
87 " expected: ", DataTypeSliceString(expected_inputs), "->",
88 DataTypeSliceString(expected_outputs));
89 }
90 return Status::OK();
91 }
92
93 } // namespace
94
95 // OpKernel ------------------------------------------------------------------
96
OpKernel(OpKernelConstruction * context)97 OpKernel::OpKernel(OpKernelConstruction* context) : OpKernel(context, false) {}
98
OpKernel(OpKernelConstruction * context,bool is_deferred)99 OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
100 : props_(context->props_),
101 input_memory_types_(context->input_memory_types().begin(),
102 context->input_memory_types().end()),
103 output_memory_types_(context->output_memory_types().begin(),
104 context->output_memory_types().end()),
105 input_name_map_(context->num_inputs()),
106 output_name_map_(context->num_outputs()),
107 name_view_(props_->node_def.name()),
108 type_string_view_(props_->node_def.op()),
109 graph_def_version_(context->graph_def_version()),
110 is_deferred_(is_deferred) {
111 OP_REQUIRES_OK(context,
112 NameRangesForNode(props_->node_def, *props_->op_def,
113 &input_name_map_, &output_name_map_));
114 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
115 context->graph_def_version()));
116
117 // Kernels executing on GPU tie very few resources on the CPU where the
118 // scheduler runs: we consider them as inexpensive.
119 expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
120 }
121
OpKernel(OpKernelConstruction * context,NodeDef && custom_def,bool is_deferred)122 OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
123 bool is_deferred)
124 : props_(std::make_shared<const NodeProperties>(
125 context->props_->op_def, std::move(custom_def),
126 context->props_->input_types, context->props_->output_types)),
127 input_memory_types_(context->input_memory_types().begin(),
128 context->input_memory_types().end()),
129 output_memory_types_(context->output_memory_types().begin(),
130 context->output_memory_types().end()),
131 input_name_map_(context->num_inputs()),
132 output_name_map_(context->num_outputs()),
133 name_view_(props_->node_def.name()),
134 type_string_view_(props_->node_def.op()),
135 graph_def_version_(context->graph_def_version()),
136 is_deferred_(is_deferred) {
137 OP_REQUIRES_OK(context,
138 NameRangesForNode(props_->node_def, *props_->op_def,
139 &input_name_map_, &output_name_map_));
140 OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
141 context->graph_def_version()));
142
143 // Kernels executing on GPU tie very few resources on the CPU where the
144 // scheduler runs: we consider them as inexpensive.
145 expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
146 }
147
~OpKernel()148 OpKernel::~OpKernel() {}
149
InputRange(StringPiece input_name,int * start,int * stop) const150 Status OpKernel::InputRange(StringPiece input_name, int* start,
151 int* stop) const {
152 const auto result = input_name_map_.find(input_name);
153 if (result == input_name_map_.end()) {
154 return errors::InvalidArgument("Unknown input name: ", input_name);
155 } else {
156 *start = result->second.first;
157 *stop = result->second.second;
158 return Status::OK();
159 }
160 }
161
OutputRange(StringPiece output_name,int * start,int * stop) const162 Status OpKernel::OutputRange(StringPiece output_name, int* start,
163 int* stop) const {
164 const auto result = output_name_map_.find(output_name);
165 if (result == output_name_map_.end()) {
166 return errors::InvalidArgument("Unknown output name: ", output_name);
167 } else {
168 *start = result->second.first;
169 *stop = result->second.second;
170 return Status::OK();
171 }
172 }
173
ShapeTraceString(const OpKernelContext & ctx) const174 string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const {
175 int num_inputs = ctx.num_inputs();
176 if (num_inputs == 0) return "";
177 std::vector<string> tensor_shapes;
178 tensor_shapes.reserve(num_inputs);
179 for (int i = 0; i < num_inputs; i++) {
180 if (!ctx.has_input(i)) {
181 tensor_shapes.emplace_back(); // Placeholder
182 continue;
183 }
184 DataType input_dtype = ctx.input_dtype(i);
185 if (input_dtype == DataType::DT_RESOURCE ||
186 input_dtype == DataType::DT_VARIANT || IsRefType(input_dtype)) {
187 tensor_shapes.emplace_back(); // Placeholder
188 continue;
189 }
190 tensor_shapes.emplace_back(strings::StrCat(
191 DataTypeString(input_dtype), ctx.input(i).shape().DebugString()));
192 }
193 return strings::StrCat("(", absl::StrJoin(tensor_shapes, ";"), ")");
194 }
195
TraceString(const OpKernelContext & ctx,bool verbose) const196 string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const {
197 string trace_string = profiler::TraceMeOp(name_view(), type_string_view());
198 if (verbose) {
199 string shape = ShapeTraceString(ctx);
200 if (!shape.empty()) {
201 trace_string =
202 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
203 }
204 }
205 return trace_string;
206 }
207
Compute(OpKernelContext * context)208 void AsyncOpKernel::Compute(OpKernelContext* context) {
209 Notification n;
210 ComputeAsync(context, [&n]() { n.Notify(); });
211 n.WaitForNotification();
212 }
213
214 // PersistentTensor ----------------------------------------------------------
215
AccessTensor(OpKernelConstruction * context)216 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
217 // the caller has to have a valid context
218 CHECK(context);
219 return &tensor_;
220 }
221
AccessTensor(OpKernelContext * context)222 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
223 return &tensor_;
224 }
225
226 // OpKernelConstruction ------------------------------------------------------
227
OpKernelConstruction(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,const MemoryTypeSlice & input_memory_types,const MemoryTypeSlice & output_memory_types,int graph_def_version,Status * status)228 OpKernelConstruction::OpKernelConstruction(
229 DeviceType device_type, DeviceBase* device, Allocator* allocator,
230 FunctionLibraryRuntime* flib, ResourceMgr* resource_mgr,
231 const std::shared_ptr<const NodeProperties>& props,
232 const MemoryTypeSlice& input_memory_types,
233 const MemoryTypeSlice& output_memory_types, int graph_def_version,
234 Status* status)
235 : device_type_(std::move(device_type)),
236 device_(device),
237 allocator_(allocator),
238 flib_(flib),
239 resource_mgr_(resource_mgr),
240 props_(props),
241 input_memory_types_(input_memory_types),
242 output_memory_types_(output_memory_types),
243 graph_def_version_(graph_def_version),
244 status_(status) {}
245
HasAttr(StringPiece attr_name) const246 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
247 return HasNodeAttr(def(), attr_name);
248 }
249
SetStatus(const Status & status)250 void OpKernelConstruction::SetStatus(const Status& status) {
251 status_->Update(status);
252 }
253
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)254 Status OpKernelConstruction::MatchSignature(
255 const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
256 return MatchSignatureHelper(expected_inputs, expected_outputs,
257 props_->input_types, props_->output_types);
258 }
259
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)260 Status OpKernelConstruction::allocate_temp(DataType type,
261 const TensorShape& shape,
262 Tensor* out_temp) {
263 AllocationAttributes attr;
264 attr.allocation_will_be_logged = true;
265 Tensor new_temp(allocator_, type, shape, attr);
266
267 if (!new_temp.IsInitialized()) {
268 return errors::ResourceExhausted(
269 "OOM when allocating temporary tensor with shape", shape.DebugString());
270 }
271 if (LogMemory::IsEnabled()) {
272 LogMemory::RecordTensorAllocation(
273 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
274 }
275 *out_temp = new_temp;
276 return Status::OK();
277 }
278
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)279 Status OpKernelConstruction::allocate_temp(DataType type,
280 const TensorShape& shape,
281 Tensor* out_temp,
282 AllocatorAttributes allocator_attr) {
283 if (allocator_attr.scope_id != 0) {
284 return errors::InvalidArgument(
285 "ScopedAllocator cannot be used via OpKernelConstruction.");
286 }
287 Allocator* a = device_->GetAllocator(allocator_attr);
288 AllocationAttributes attr;
289 attr.allocation_will_be_logged = true;
290 Tensor new_temp(a, type, shape, attr);
291
292 if (!new_temp.IsInitialized()) {
293 return errors::ResourceExhausted(
294 "OOM when allocating temporary tensor with shape", shape.DebugString());
295 }
296 if (LogMemory::IsEnabled()) {
297 LogMemory::RecordTensorAllocation(
298 def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
299 }
300 *out_temp = new_temp;
301 return Status::OK();
302 }
303
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)304 Status OpKernelConstruction::allocate_persistent(
305 DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
306 Tensor** out_tensor) {
307 // for now just do the same thing as allocate_temp
308 // TODO(misard) add specific memory tracking for persistent tensors
309 Tensor persistent;
310 TF_RETURN_IF_ERROR(allocate_temp(type, shape, &persistent));
311
312 *out_persistent = PersistentTensor(persistent);
313 Tensor* allocated = out_persistent->AccessTensor(this);
314 if (out_tensor) {
315 *out_tensor = allocated;
316 }
317 return Status::OK();
318 }
319
320 // OpKernelContext -----------------------------------------------------------
321
322 const int OpKernelContext::Params::kNeverForward;
323 const int OpKernelContext::Params::kNoReservation;
324
OpKernelContext(Params * params)325 OpKernelContext::OpKernelContext(Params* params)
326 : OpKernelContext(
327 params, static_cast<int>(params->op_kernel->output_types().size())) {}
328
OpKernelContext(Params * params,int num_outputs)329 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
330 : params_(params), outputs_(num_outputs) {
331 if (params_->track_allocations) {
332 tracking_state_ = absl::make_unique<TrackingState>();
333 }
334
335 params_->ensure_eigen_gpu_device();
336 if (params_->eigen_gpu_device != nullptr) {
337 Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
338 Status s = params_->device->ReinitializeGpuDevice(
339 this, params_->eigen_gpu_device, params_->op_device_context,
340 eigen_gpu_allocator);
341 if (!s.ok()) {
342 SetStatus(s);
343 }
344 }
345 }
346
~OpKernelContext()347 OpKernelContext::~OpKernelContext() {
348 for (TensorValue& value : outputs_) {
349 if (!value.is_ref()) {
350 delete value.tensor;
351 }
352 }
353 if (params_->track_allocations &&
354 !tracking_state_->wrapped_allocators.empty()) {
355 LOG(WARNING) << "OpKernelContext is tracking allocations but they are not "
356 << "being consumed by the StepStatsCollector.";
357 for (auto& wrapped_allocator : tracking_state_->wrapped_allocators) {
358 wrapped_allocator.second->GetRecordsAndUnRef();
359 }
360 }
361 }
362
get_allocator(AllocatorAttributes attr)363 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
364 Allocator* allocator = nullptr;
365 if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
366 allocator = params_->device->GetScopedAllocator(attr, step_id());
367 CHECK(allocator);
368 } else {
369 allocator = params_->device->GetAllocator(attr);
370 }
371 if (TF_PREDICT_FALSE(track_allocations())) {
372 DCHECK(tracking_state_);
373 mutex_lock lock(tracking_state_->mu);
374 for (const auto& wrapped : tracking_state_->wrapped_allocators) {
375 if (wrapped.first == allocator) {
376 return wrapped.second;
377 }
378 }
379 TrackingAllocator* wrapped_allocator =
380 new TrackingAllocator(allocator, params_->track_allocations);
381 tracking_state_->wrapped_allocators.push_back(
382 std::make_pair(allocator, wrapped_allocator));
383 return wrapped_allocator;
384 } else {
385 return allocator;
386 }
387 }
388
SetStatus(const Status & status)389 void OpKernelContext::SetStatus(const Status& status) {
390 status_.Update(status);
391 }
392
input(StringPiece name,const Tensor ** tensor)393 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
394 int index;
395 TF_RETURN_IF_ERROR(get_input_index(name, &index));
396 if (input_is_ref(index)) {
397 return errors::InvalidArgument("OpKernel used ref input name '", name,
398 "' when non-ref input was expected");
399 }
400 *tensor = (*params_->inputs)[index].tensor;
401 return Status::OK();
402 }
403
input_dtype(StringPiece name,DataType * dtype) const404 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
405 int index;
406 TF_RETURN_IF_ERROR(get_input_index(name, &index));
407 const TensorValue& value((*params_->inputs)[index]);
408 *dtype = value.dtype();
409 return Status::OK();
410 }
411
input_ref_mutex(StringPiece name,mutex ** out_mutex)412 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
413 int index;
414 TF_RETURN_IF_ERROR(get_input_index(name, &index));
415 *out_mutex = input_ref_mutex(index);
416 return Status::OK();
417 }
418
input(int index) const419 const Tensor& OpKernelContext::input(int index) const {
420 CHECK_GE(index, 0);
421 CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
422 CHECK(!input_is_ref(index));
423 const Tensor& tensor = *((*params_->inputs)[index].tensor);
424 return tensor;
425 }
426
mutable_input(int index,bool lock_held)427 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
428 CHECK_GE(index, 0);
429 CHECK_LT(index, num_inputs());
430 CHECK(input_is_ref(index));
431 // return a copy of the Ref acquired while holding the mutex
432 if (lock_held) {
433 Tensor& tensor = *((*params_->inputs)[index].tensor);
434 return tensor;
435 } else {
436 tf_shared_lock l(*input_ref_mutex(index));
437 Tensor& tensor = *((*params_->inputs)[index].tensor);
438 return tensor;
439 }
440 }
441
replace_ref_input(int index,const Tensor & tensor,bool lock_held)442 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
443 bool lock_held) {
444 CHECK_GE(index, 0);
445 CHECK_LT(index, num_inputs());
446 CHECK(input_is_ref(index));
447 // should only modify the tensor while holding the mutex
448 if (lock_held) {
449 *(*params_->inputs)[index].tensor = tensor;
450 } else {
451 mutex_lock l(*input_ref_mutex(index));
452 *(*params_->inputs)[index].tensor = tensor;
453 }
454 }
455
forward_ref_input_to_ref_output(int input_index,int output_index)456 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
457 int output_index) {
458 CHECK_GE(input_index, 0);
459 CHECK_LT(input_index, num_inputs());
460 CHECK(input_is_ref(input_index));
461 set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
462 (*params_->inputs)[input_index].tensor);
463 }
464
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)465 bool OpKernelContext::forward_input_to_output_with_shape(
466 int input_index, int output_index, const TensorShape& output_shape,
467 Tensor** output) {
468 const auto output_attr = params_->output_attr_array == nullptr
469 ? AllocatorAttributes()
470 : output_alloc_attr(output_index);
471 std::unique_ptr<Tensor> new_tensor = forward_input(
472 input_index, output_index, expected_output_dtype(output_index),
473 output_shape, output_memory_type(output_index), output_attr);
474 if (new_tensor != nullptr) {
475 // Transfer ownership to the output slot in OpKernelContext.
476 outputs_[output_index] = TensorValue(new_tensor.release());
477 *output = outputs_[output_index].tensor;
478 return true;
479 } else {
480 return false;
481 }
482 }
483
forward_input_to_output_with_shape(StringPiece input_name,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)484 Status OpKernelContext::forward_input_to_output_with_shape(
485 StringPiece input_name, StringPiece output_name,
486 const TensorShape& output_shape, Tensor** output) {
487 int input_index, output_index;
488 TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
489 TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
490 if (!forward_input_to_output_with_shape(input_index, output_index,
491 output_shape, output)) {
492 return errors::FailedPrecondition("OpKernel could not forward input '",
493 input_name, "' to output '", output_name);
494 }
495 return Status::OK();
496 }
497
forward_input(int input_index,int output_index,DataType output_dtype,const TensorShape & output_shape,MemoryType output_memory_type,const AllocatorAttributes & output_attr)498 std::unique_ptr<Tensor> OpKernelContext::forward_input(
499 int input_index, int output_index, DataType output_dtype,
500 const TensorShape& output_shape, MemoryType output_memory_type,
501 const AllocatorAttributes& output_attr) {
502 CHECK_GE(input_index, 0);
503 CHECK_LT(input_index, num_inputs());
504 const TensorValue& input = (*params_->inputs)[input_index];
505 // Check whether at graph construction time this output was marked
506 // either for no forwarding or with a reservation for this input.
507 // If it's reserved for this input we'll skip the refcount and
508 // AllocatorAttribute checks.
509 // TODO(tucker): Maybe we should skip all of the checks?
510 bool never_forward =
511 (params_->forward_from_array != nullptr && output_index >= 0 &&
512 params_->forward_from_array[output_index] == Params::kNeverForward);
513 if (never_forward) return nullptr;
514 bool forward_expected =
515 (params_->forward_from_array != nullptr && output_index >= 0 &&
516 params_->forward_from_array[output_index] == input_index);
517 if (!forward_expected && params_->forward_from_array != nullptr) {
518 // Check for possibly conflicting forward.
519 for (int i = 0; i < num_outputs(); ++i) {
520 if (params_->forward_from_array[i] == input_index) {
521 // This input is reserved for output i.
522 return nullptr;
523 }
524 }
525 }
526 // Check that input tensor exists and is not a ref.
527 if (input.tensor == nullptr || input.is_ref()) {
528 CHECK(!forward_expected);
529 return nullptr;
530 }
531 // Check that input type matches.
532 if (input_dtype(input_index) != output_dtype) {
533 CHECK(!forward_expected);
534 return nullptr;
535 }
536 // Check that the input and output sizes are compatible.
537 if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
538 CHECK(!forward_expected);
539 return nullptr;
540 }
541 // Check that input and output memory types match, i.e.
542 // that they either both live in host or both live in device memory.
543 if (input_memory_type(input_index) != output_memory_type) {
544 CHECK(!forward_expected);
545 return nullptr;
546 }
547 if (!forward_expected) {
548 if (!input->RefCountIsOne()) {
549 return nullptr;
550 }
551 // Check that output allocator attributes are not more restrictive than
552 // input allocator attributes.
553 const auto input_attr = params_->input_alloc_attrs == nullptr
554 ? AllocatorAttributes()
555 : input_alloc_attr(input_index);
556 if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
557 return nullptr;
558 }
559 }
560
561 auto output_tensor = MakeUnique<Tensor>();
562 CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
563 return output_tensor;
564 }
565
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,const AllocatorAttributes & allocator_attr,Tensor * out_temp)566 Status OpKernelContext::forward_input_or_allocate_temp(
567 gtl::ArraySlice<int> candidate_input_indices, DataType type,
568 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
569 Tensor* out_temp) {
570 for (int input_index : candidate_input_indices) {
571 std::unique_ptr<Tensor> new_tensor =
572 forward_input(input_index, Params::kNoReservation /*output_index*/,
573 type, shape, DEVICE_MEMORY, allocator_attr);
574 if (new_tensor != nullptr) {
575 *out_temp = std::move(*new_tensor);
576 return Status::OK();
577 }
578 }
579 return allocate_temp(type, shape, out_temp, allocator_attr);
580 }
581
delete_ref_input(int index,bool lock_held)582 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
583 CHECK_GE(index, 0);
584 CHECK_LT(index, num_inputs());
585 CHECK(input_is_ref(index));
586 // should only modify the tensor while holding the mutex
587 if (lock_held) {
588 delete (*params_->inputs)[index].tensor;
589 } else {
590 mutex_lock l(*input_ref_mutex(index));
591 delete (*params_->inputs)[index].tensor;
592 }
593 }
594
mutable_input(StringPiece name,Tensor * tensor,bool lock_held)595 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
596 bool lock_held) {
597 int index;
598 TF_RETURN_IF_ERROR(get_input_index(name, &index));
599 if (!input_is_ref(index)) {
600 return errors::InvalidArgument("OpKernel used non-ref input name '", name,
601 "' when ref input was expected");
602 }
603 // return a copy of the Ref acquired while holding the mutex
604 if (lock_held) {
605 *tensor = *(*params_->inputs)[index].tensor;
606 } else {
607 tf_shared_lock l(*input_ref_mutex(index));
608 *tensor = *(*params_->inputs)[index].tensor;
609 }
610 return Status::OK();
611 }
612
replace_ref_input(StringPiece name,const Tensor & tensor,bool lock_held)613 Status OpKernelContext::replace_ref_input(StringPiece name,
614 const Tensor& tensor,
615 bool lock_held) {
616 int index;
617 TF_RETURN_IF_ERROR(get_input_index(name, &index));
618 if (!input_is_ref(index)) {
619 return errors::InvalidArgument("OpKernel used immutable input name '", name,
620 "' when ref input was expected");
621 }
622 replace_ref_input(index, tensor, lock_held);
623 return Status::OK();
624 }
625
input_list(StringPiece name,OpInputList * list)626 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
627 int start, stop;
628 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
629 *list = OpInputList(this, start, stop);
630 return Status::OK();
631 }
632
mutable_input_list(StringPiece name,OpMutableInputList * list)633 Status OpKernelContext::mutable_input_list(StringPiece name,
634 OpMutableInputList* list) {
635 int start, stop;
636 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
637 *list = OpMutableInputList(this, start, stop);
638 return Status::OK();
639 }
640
output_list(StringPiece name,OpOutputList * list)641 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
642 int start, stop;
643 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
644 *list = OpOutputList(this, start, stop);
645 return Status::OK();
646 }
647
maybe_initialize_scope_id_set()648 void OpKernelContext::maybe_initialize_scope_id_set() {
649 if (allocated_scope_ids_ == nullptr) {
650 allocated_scope_ids_ = absl::make_unique<std::unordered_set<int32>>();
651 }
652 }
653
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)654 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
655 Tensor** tensor) {
656 if (index < 0) {
657 return errors::Internal("allocate_output with bad index=", index,
658 " kernel=", params_->op_kernel->name());
659 }
660 if (index >= num_outputs()) {
661 return errors::Internal("allocate_output with bad index=", index,
662 " num_outputs=", num_outputs(),
663 " kernel=", params_->op_kernel->name());
664 }
665 bool forward_expected =
666 (params_->forward_from_array != nullptr && index >= 0 &&
667 params_->forward_from_array[index] >= 0);
668 if (forward_expected) {
669 return errors::Internal(
670 "Explicit allocate_output call where input forwarding required. Try "
671 "turning off the ScopedAllocator optimizer.");
672 }
673 AllocatorAttributes attr = output_alloc_attr(index);
674 return allocate_output(index, shape, tensor, attr);
675 }
676
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor)677 Status OpKernelContext::allocate_output(StringPiece name,
678 const TensorShape& shape,
679 Tensor** tensor) {
680 int start, stop;
681 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
682 if (stop != start + 1) {
683 return errors::InvalidArgument("OpKernel used list-valued output name '",
684 name,
685 "' when single-valued output was "
686 "expected");
687 }
688 return allocate_output(start, shape, tensor);
689 }
690
allocate_output(StringPiece name,const TensorShape & shape,Tensor ** tensor,AllocatorAttributes attr)691 Status OpKernelContext::allocate_output(StringPiece name,
692 const TensorShape& shape,
693 Tensor** tensor,
694 AllocatorAttributes attr) {
695 int start, stop;
696 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
697 if (stop != start + 1) {
698 return errors::InvalidArgument("OpKernel used list-valued output name '",
699 name,
700 "' when single-valued output was "
701 "expected");
702 }
703 return allocate_output(start, shape, tensor, attr);
704 }
705
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes attr,const AllocationAttributes & allocation_attr)706 Status OpKernelContext::allocate_tensor(
707 DataType type, const TensorShape& shape, Tensor* out_tensor,
708 AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
709 Allocator* a = get_allocator(attr);
710 Tensor new_tensor(
711 a, type, shape,
712 AllocationAttributes(
713 /*retry_on_failure=*/allocation_attr.retry_on_failure,
714 /*allocation_will_be_logged=*/true, allocation_attr.freed_by_func));
715
716 if (!new_tensor.IsInitialized()) {
717 return errors::ResourceExhausted(
718 "OOM when allocating tensor with shape", shape.DebugString(),
719 " and type ", DataTypeString(type), " on ", params_->device->name(),
720 " by allocator ", a->Name());
721 }
722 if (params_->log_memory) {
723 LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
724 params_->step_id, new_tensor);
725 }
726 *out_tensor = std::move(new_tensor);
727 return Status::OK();
728 }
729
allocate_output(int index,const TensorShape & shape,Tensor ** output,AllocatorAttributes attr)730 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
731 Tensor** output,
732 AllocatorAttributes attr) {
733 if (index < 0) {
734 return errors::Internal("allocate_output with bad index=", index,
735 " kernel=", params_->op_kernel->name());
736 }
737 if (index >= num_outputs()) {
738 return errors::Internal("allocate_output with bad index=", index,
739 " num_outputs=", outputs_.size(),
740 " kernel=", params_->op_kernel->name());
741 }
742 const DataType type = params_->op_kernel->output_type(index);
743 if (IsRefType(type)) {
744 return errors::Internal("allocate_output with ref type. index=", index,
745 " type=", type,
746 " kernel=", params_->op_kernel->name());
747 }
748 if (mutable_output(index) != nullptr) {
749 return errors::Internal("allocate_output on same index multiple times.",
750 " index = ", index,
751 " mutable_output(index) = ", mutable_output(index),
752 " kernel=", params_->op_kernel->name());
753 }
754 if (attr.scope_id > 0) {
755 maybe_initialize_scope_id_set();
756 if (!allocated_scope_ids_->insert(attr.scope_id).second) {
757 return errors::Internal(
758 "OpKernel ", params_->op_kernel->name(),
759 " called allocate_output at index ", index, " with scope_id ",
760 attr.scope_id,
761 " more than once. Try turning off the ScopedAllocator optimizer.");
762 }
763 }
764 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
765 step_id(), "output", type, &shape);
766 auto output_tensor = MakeUnique<Tensor>();
767 Status s = allocate_tensor(type, shape, output_tensor.get(), attr);
768 if (s.ok()) {
769 outputs_[index] = TensorValue(output_tensor.release());
770 *output = outputs_[index].tensor;
771 }
772 return s;
773 }
774
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr,const AllocationAttributes & allocation_attr)775 Status OpKernelContext::allocate_temp(
776 DataType type, const TensorShape& shape, Tensor* out_temp,
777 AllocatorAttributes allocator_attr,
778 const AllocationAttributes& allocation_attr) {
779 if (allocator_attr.scope_id > 0) {
780 // We do not allow ScopedAllocator calls from allocate_temp. Unlike
781 // allocate_persistent where we return an error if a kernel provides a
782 // meaningful scope_id, here we clear the scope_id and return a temporary
783 // buffer. This is because it is legal for a kernel to call allocate_temp
784 // and then set_output with the temp tensor.
785 //
786 // We achieve memory correctness by forcing an allocation in set_output and
787 // copying over the tensor from the temp buffer. Kernels which would like
788 // to avoid this performance penalty should switch to calling
789 // allocate_output.
790 VLOG(2) << "Warning: OpKernel " << params_->op_kernel->name()
791 << " called allocate_temp with scope_id " << allocator_attr.scope_id
792 << ". Switch to allocate_output to avoid performance penalty.";
793 allocator_attr.scope_id = -1;
794 }
795 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
796 step_id(), "temp", type, &shape);
797 Status s =
798 allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
799 if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
800 Allocator* a = get_allocator(allocator_attr);
801 if (a->TracksAllocationSizes()) {
802 int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
803 record_temp_memory_allocation(alloc_size, *out_temp);
804 }
805 } else if (record_memory_consumption_) {
806 DCHECK(tracking_state_);
807 mutex_lock l(tracking_state_->stats_mu);
808 tracking_state_->temp_memory_allocated += out_temp->TotalBytes();
809 }
810 return s;
811 }
812
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor,AllocatorAttributes attr)813 Status OpKernelContext::allocate_persistent(DataType type,
814 const TensorShape& shape,
815 PersistentTensor* out_persistent,
816 Tensor** out_tensor,
817 AllocatorAttributes attr) {
818 if (attr.scope_id > 0) {
819 // ScopedAllocator cannot be used for persistent tensors, because these
820 // tensors may persist across kernel invocations/steps, whereas the backing
821 // tensor for the scoped allocator will be reallocated every step.
822 return errors::Internal(
823 "Unexpected call to allocate_persistent with scope_id ", attr.scope_id);
824 }
825 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
826 step_id(), "persist", type, &shape);
827 Tensor persistent;
828 Status s = allocate_tensor(type, shape, &persistent, attr);
829 if (s.ok()) {
830 *out_persistent = PersistentTensor(persistent);
831 Tensor* t = out_persistent->AccessTensor(this);
832
833 if (out_tensor) {
834 *out_tensor = t;
835 }
836
837 if (track_allocations()) {
838 Allocator* a = get_allocator(attr);
839 if (a->TracksAllocationSizes()) {
840 // Zero-byte Tensors don't use allocators: check and skip tracking.
841 AllocationDescription alloc_desc;
842 TensorReference tensor_ref(*t);
843 tensor_ref.FillDescription(&alloc_desc);
844 tensor_ref.Unref();
845
846 if (alloc_desc.allocated_bytes()) { // Non-zero sized tensor.
847 int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
848 int64 alloc_id = a->AllocationId(t->tensor_data().data());
849 record_persistent_memory_allocation(alloc_size, alloc_id);
850 }
851 }
852 } else if (record_memory_consumption_) {
853 record_persistent_memory_allocation(t->TotalBytes());
854 }
855 }
856 return s;
857 }
858
get_input_index(StringPiece name,int * out_index) const859 Status OpKernelContext::get_input_index(StringPiece name,
860 int* out_index) const {
861 int start, stop;
862 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
863 if (stop != start + 1) {
864 return errors::InvalidArgument("OpKernel used list-valued input name '",
865 name,
866 "' when single-valued input was "
867 "expected");
868 }
869 *out_index = start;
870 return Status::OK();
871 }
872
get_output_index(StringPiece name,int * out_index) const873 Status OpKernelContext::get_output_index(StringPiece name,
874 int* out_index) const {
875 int start, stop;
876 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
877 if (stop != start + 1) {
878 return errors::InvalidArgument("OpKernel used list-valued output name '",
879 name,
880 "' when single-valued output was "
881 "expected");
882 }
883 *out_index = start;
884 return Status::OK();
885 }
886
set_output(StringPiece name,const Tensor & tensor)887 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
888 int index;
889 TF_RETURN_IF_ERROR(get_output_index(name, &index));
890 set_output(index, tensor);
891 return Status::OK();
892 }
893
set_output(StringPiece name,Tensor && tensor)894 Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
895 int index;
896 TF_RETURN_IF_ERROR(get_output_index(name, &index));
897 set_output(index, std::move(tensor));
898 return Status::OK();
899 }
900
maybe_set_output_by_allocate_and_copy(int index,const Tensor & tensor)901 bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
902 int index, const Tensor& tensor) {
903 bool allocate_and_copy = false;
904 const bool never_forward =
905 (params_->forward_from_array != nullptr &&
906 params_->forward_from_array[index] == Params::kNeverForward);
907 if (TF_PREDICT_FALSE(never_forward)) {
908 maybe_initialize_scope_id_set();
909 if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
910 allocated_scope_ids_->end()) {
911 allocate_and_copy = true;
912 } else {
913 // The output at `index` must have been previously allocated via a call to
914 // `allocate_output(index, ...)`. That call would ensure that we return
915 // the correct slice of the ScopedAllocated buffer, so we do not
916 // re-allocate and copy here.
917 LOG(WARNING)
918 << "OpKernel " << params_->op_kernel->name()
919 << " called both allocate_output and set_output with scope_id "
920 << output_alloc_attr(index).scope_id;
921 }
922 }
923
924 if (TF_PREDICT_FALSE(allocate_and_copy)) {
925 // This output was marked to not be forwarded either during graph
926 // construction or grappler passes. Force an allocation and copy input to
927 // output.
928 VLOG(1) << "OpKernelContext set_output index " << index << " tensor "
929 << tensor.DebugString() << " never_forward " << never_forward
930 << " params_->forward_from_array[index] "
931 << params_->forward_from_array[index] << " alloc_attr.scope_id "
932 << output_alloc_attr(index).scope_id;
933 ScopedMemoryDebugAnnotation op_annotation(op_kernel().name_view().data(),
934 step_id(), "output",
935 tensor.dtype(), &tensor.shape());
936 auto new_tensor = MakeUnique<Tensor>();
937 Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
938 output_alloc_attr(index));
939 TF_CHECK_OK(s);
940 device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
941 op_device_context(), [](const Status&) {});
942 outputs_[index] = TensorValue(new_tensor.release());
943 }
944 return allocate_and_copy;
945 }
946
maybe_track_allocations_for_set_output(const Tensor & tensor)947 void OpKernelContext::maybe_track_allocations_for_set_output(
948 const Tensor& tensor) {
949 if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
950 DCHECK(tracking_state_);
951 mutex_lock l(tracking_state_->stats_mu);
952 const auto it = std::find_if(
953 tracking_state_->temp_tensor_buffer_and_size.begin(),
954 tracking_state_->temp_tensor_buffer_and_size.end(),
955 [&tensor](const std::pair<const void*, int64>& e) {
956 return e.first ==
957 static_cast<const void*>(tensor.tensor_data().data());
958 });
959 if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
960 tracking_state_->temp_memory_allocated -= it->second;
961 tracking_state_->temp_tensor_buffer_and_size.erase(it);
962 }
963 }
964 }
965
set_output(int index,const Tensor & tensor)966 void OpKernelContext::set_output(int index, const Tensor& tensor) {
967 CHECK_GE(index, 0);
968 CHECK_LT(index, outputs_.size());
969 const DataType type = params_->op_kernel->output_type(index);
970 CHECK(!IsRefType(type));
971 CHECK_EQ(outputs_[index].tensor, nullptr);
972 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
973 // Input can be forwarded to output; incref on `tensor` and set output at
974 // `index` to this tensor.
975 outputs_[index] = TensorValue(new Tensor(tensor));
976 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
977 }
978 }
979
set_output(int index,Tensor && tensor)980 void OpKernelContext::set_output(int index, Tensor&& tensor) {
981 CHECK_GE(index, 0);
982 CHECK_LT(index, outputs_.size());
983 const DataType type = params_->op_kernel->output_type(index);
984 CHECK(!IsRefType(type));
985 CHECK_EQ(outputs_[index].tensor, nullptr);
986 if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
987 // Input can be forwarded to output; set output at `index` to this tensor.
988 outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
989 maybe_track_allocations_for_set_output(*outputs_[index].tensor);
990 }
991 }
992
set_output_ref(int index,mutex * mu,Tensor * tensor_for_ref)993 void OpKernelContext::set_output_ref(int index, mutex* mu,
994 Tensor* tensor_for_ref) {
995 CHECK_GE(index, 0);
996 CHECK_LT(index, outputs_.size());
997 CHECK(IsRefType(params_->op_kernel->output_type(index)));
998 outputs_[index] = TensorValue(mu, tensor_for_ref);
999 }
1000
set_output_ref(StringPiece name,mutex * mu,Tensor * tensor_for_ref)1001 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
1002 Tensor* tensor_for_ref) {
1003 int index;
1004 TF_RETURN_IF_ERROR(get_output_index(name, &index));
1005 set_output_ref(index, mu, tensor_for_ref);
1006 return Status::OK();
1007 }
1008
mutable_output(StringPiece name,Tensor ** tensor)1009 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
1010 int index;
1011 TF_RETURN_IF_ERROR(get_output_index(name, &index));
1012 *tensor = mutable_output(index);
1013 return Status::OK();
1014 }
1015
ValidateInputsAreSameShape(OpKernel * op)1016 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
1017 const auto& inputs = *params_->inputs;
1018 for (size_t i = 1; i < inputs.size(); ++i) {
1019 if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
1020 SetStatus(errors::InvalidArgument(
1021 "Inputs to operation ", op->name(), " of type ", op->type_string(),
1022 " must have the same size and shape. Input 0: ",
1023 inputs[0]->shape().DebugString(), " != input ", i, ": ",
1024 inputs[i]->shape().DebugString()));
1025 return false;
1026 }
1027 }
1028 return true;
1029 }
1030
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)1031 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
1032 const DataTypeSlice expected_outputs) {
1033 DataTypeVector inputs;
1034 for (const TensorValue& t : *params_->inputs) {
1035 inputs.push_back(t.dtype());
1036 }
1037 DataTypeVector outputs = params_->op_kernel->output_types();
1038 return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
1039 outputs);
1040 }
1041
record_temp_memory_allocation(int64 size,const Tensor & t)1042 void OpKernelContext::record_temp_memory_allocation(int64 size,
1043 const Tensor& t) {
1044 if (tracking_state_) {
1045 mutex_lock l(tracking_state_->stats_mu);
1046 tracking_state_->temp_memory_allocated += size;
1047 tracking_state_->temp_tensor_buffer_and_size.emplace_back(
1048 static_cast<const void*>(t.tensor_data().data()), size);
1049 }
1050 }
1051
temp_memory_allocated() const1052 int64 OpKernelContext::temp_memory_allocated() const {
1053 if (tracking_state_) {
1054 mutex_lock l(tracking_state_->stats_mu);
1055 return tracking_state_->temp_memory_allocated;
1056 } else {
1057 return 0;
1058 }
1059 }
1060
record_persistent_memory_allocation(int64 size,int64 alloc_id)1061 void OpKernelContext::record_persistent_memory_allocation(int64 size,
1062 int64 alloc_id) {
1063 if (tracking_state_) {
1064 mutex_lock l(tracking_state_->stats_mu);
1065 tracking_state_->persistent_memory_allocated += size;
1066 if (alloc_id >= 0) {
1067 tracking_state_->persistent_alloc_ids.push_back(alloc_id);
1068 }
1069 }
1070 }
1071
persistent_memory_allocated() const1072 int64 OpKernelContext::persistent_memory_allocated() const {
1073 if (tracking_state_) {
1074 mutex_lock l(tracking_state_->stats_mu);
1075 return tracking_state_->persistent_memory_allocated;
1076 } else {
1077 return 0;
1078 }
1079 }
1080
persistent_alloc_ids() const1081 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
1082 if (tracking_state_) {
1083 mutex_lock l(tracking_state_->stats_mu);
1084 return std::vector<int64>(tracking_state_->persistent_alloc_ids.begin(),
1085 tracking_state_->persistent_alloc_ids.end());
1086 } else {
1087 return std::vector<int64>();
1088 }
1089 }
1090
clear_recorded_memory()1091 void OpKernelContext::clear_recorded_memory() {
1092 if (tracking_state_) {
1093 mutex_lock l(tracking_state_->stats_mu);
1094 tracking_state_->temp_memory_allocated = 0;
1095 tracking_state_->persistent_memory_allocated = 0;
1096 tracking_state_->temp_tensor_buffer_and_size.clear();
1097 tracking_state_->persistent_alloc_ids.clear();
1098 }
1099 }
1100
set_record_memory_consumption(bool v)1101 void OpKernelContext::set_record_memory_consumption(bool v) {
1102 record_memory_consumption_ = v;
1103 if (v && !tracking_state_) {
1104 tracking_state_ = absl::make_unique<TrackingState>();
1105 }
1106 }
1107
executor_type() const1108 const string& OpKernelContext::executor_type() const {
1109 if (params_->executor_type) {
1110 return *params_->executor_type;
1111 } else {
1112 static const string& kEmptyString = *new string("");
1113 return kEmptyString;
1114 }
1115 }
1116
1117 // OpKernel registration ------------------------------------------------------
1118
1119 struct KernelRegistration {
KernelRegistrationtensorflow::KernelRegistration1120 KernelRegistration(const KernelDef& d, StringPiece c,
1121 std::unique_ptr<kernel_factory::OpKernelFactory> f)
1122 : def(d), kernel_class_name(c), factory(std::move(f)) {}
1123
1124 const KernelDef def;
1125 const string kernel_class_name;
1126 std::unique_ptr<kernel_factory::OpKernelFactory> factory;
1127 };
1128
1129 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
1130 // factory functions for instantiating the OpKernel that matches the
1131 // KernelDef.
1132 struct KernelRegistry {
1133 mutex mu;
1134 std::unordered_multimap<string, KernelRegistration> registry
1135 TF_GUARDED_BY(mu);
1136 };
1137
1138 #if defined(_WIN32)
1139 static const char kKernelLibPattern[] = "libtfkernel*.dll";
1140 #elif defined(__APPLE__)
1141 static const char kKernelLibPattern[] = "libtfkernel*.dylib";
1142 #else
1143 static const char kKernelLibPattern[] = "libtfkernel*.so";
1144 #endif
1145
1146 #define FEATURE(x) \
1147 { x, #x }
1148
1149 // Returns Status::OK if the dynamic library at the given path is safe to
1150 // load with some level of confidence.
IsProbablySafeToLoad(const string & path)1151 static Status IsProbablySafeToLoad(const string& path) {
1152 // A map of platform string to required CPU feature.
1153 using port::CPUFeature;
1154 static const auto* feature_map =
1155 new std::map<string, std::pair<CPUFeature, string>>{
1156 {"__AVX512VL__=1", FEATURE(CPUFeature::AVX512VL)},
1157 };
1158
1159 std::vector<std::string> platform_strings;
1160 int result = GetPlatformStrings(path, &platform_strings);
1161 if (result) {
1162 return Status(error::Code::UNKNOWN, strerror(result));
1163 }
1164 if (platform_strings.empty()) {
1165 return Status(error::Code::FAILED_PRECONDITION,
1166 "Didn't find any platform strings");
1167 }
1168 std::vector<std::string> missing_features;
1169 for (const auto& platform_string : platform_strings) {
1170 const auto& entry = feature_map->find(platform_string);
1171 if (entry != feature_map->end() &&
1172 !port::TestCPUFeature(entry->second.first)) {
1173 missing_features.emplace_back(entry->second.second);
1174 }
1175 }
1176 if (!missing_features.empty()) {
1177 string errmsg = "Missing CPU features: ";
1178 errmsg.append(absl::StrJoin(missing_features, ", "));
1179 return Status(errors::Code::FAILED_PRECONDITION, errmsg);
1180 }
1181 return Status::OK();
1182 }
1183
LoadDynamicKernelsInternal()1184 void LoadDynamicKernelsInternal() {
1185 Env* env = Env::Default();
1186
1187 // Override to allow loading unsafe packages for development.
1188 // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
1189 char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
1190 bool override_abi_check = false;
1191 if (_abi_check_env_var != nullptr) {
1192 override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
1193 }
1194
1195 string bazel_kernel_dir =
1196 io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
1197 std::vector<string> files;
1198 Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
1199 if (s_kernel_dir.ok()) {
1200 string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
1201 for (const auto& file : files) {
1202 string fullpath = io::JoinPath(bazel_kernel_dir, file);
1203 if (env->MatchPath(fullpath, dll_spec)) {
1204 Status s = IsProbablySafeToLoad(fullpath);
1205 if (!s.ok() && override_abi_check) {
1206 LOG(WARNING) << "Loading UNSAFE library " << fullpath
1207 << " because ABI check override is set: "
1208 << s.error_message();
1209 }
1210 if (s.ok() || override_abi_check) {
1211 // TODO(gunan): Store the handles to the opened files.
1212 void* unused_filehandle;
1213 TF_CHECK_OK(
1214 env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
1215 } else {
1216 LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
1217 << s.error_message();
1218 }
1219 }
1220 }
1221 }
1222 }
1223
1224 // Mechanism for loading existing kernel libraries.
LoadDynamicKernels()1225 void LoadDynamicKernels() {
1226 // TODO(gunan): As more features are available, add intelligent kernel
1227 // selection, and dropping unsuitable kernel logic here.
1228 static absl::once_flag dll_loader_flag;
1229 absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
1230 }
1231
GlobalKernelRegistry()1232 void* GlobalKernelRegistry() {
1233 static KernelRegistry* global_kernel_registry = []() {
1234 KernelRegistry* registry = new KernelRegistry;
1235 OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
1236 return registry;
1237 }();
1238 return global_kernel_registry;
1239 }
1240
GlobalKernelRegistryTyped()1241 static KernelRegistry* GlobalKernelRegistryTyped() {
1242 #ifdef AUTOLOAD_DYNAMIC_KERNELS
1243 LoadDynamicKernels();
1244 #endif // AUTOLOAD_DYNAMIC_KERNELS
1245 return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1246 }
1247
Key(StringPiece op_type,const DeviceType & device_type,StringPiece label)1248 static string Key(StringPiece op_type, const DeviceType& device_type,
1249 StringPiece label) {
1250 return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
1251 label);
1252 }
1253
1254 namespace kernel_factory {
1255
InitInternal(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1256 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
1257 StringPiece kernel_class_name,
1258 std::unique_ptr<OpKernelFactory> factory) {
1259 const string key =
1260 Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
1261 kernel_def->label());
1262
1263 // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
1264 // here.
1265 // InitInternal gets called by static initializers, so it ends up executing
1266 // before main. This causes LoadKernelLibraries function to get called
1267 // before some file libraries can initialize, which in turn crashes the
1268 // program flakily. Until we get rid of static initializers in kernel
1269 // registration mechanism, we have this workaround here.
1270 auto global_registry =
1271 reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
1272 mutex_lock l(global_registry->mu);
1273 global_registry->registry.emplace(
1274 key,
1275 KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
1276 delete kernel_def;
1277 }
1278
Create(OpKernelConstruction * context)1279 OpKernel* OpKernelRegistrar::PtrOpKernelFactory::Create(
1280 OpKernelConstruction* context) {
1281 return (*create_func_)(context);
1282 }
1283
1284 } // namespace kernel_factory
1285
1286 namespace {
1287
1288 // Label defaults to empty if not found in NodeDef.
GetKernelLabelAttr(const AttrSlice & node_attrs)1289 const string& GetKernelLabelAttr(const AttrSlice& node_attrs) {
1290 static const string& kKernelAttr = *new string("_kernel");
1291 static const string& kEmptyString = *new string("");
1292
1293 // NOTE: We inline the implementation of `GetNodeAttrString()` here in order
1294 // to use the `AttrSlice::FindByString()` overload, which does a more
1295 // efficient map lookup (instead of a linear scan) when the attribute name is
1296 // already a `const string&`.
1297 const AttrValue* attr_value = node_attrs.FindByString(kKernelAttr);
1298 if (attr_value == nullptr || attr_value->value_case() != AttrValue::kS)
1299 return kEmptyString;
1300 else
1301 return attr_value->s();
1302 }
1303
1304 // TODO(irving): Replace with const Node& version below.
FindKernelRegistration(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,AttrSlice node_attrs,const KernelRegistration ** reg,bool * was_attr_mismatch)1305 Status FindKernelRegistration(
1306 const DeviceType& device_type, StringPiece node_name,
1307 bool has_experimental_debug_info,
1308 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1309 StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
1310 bool* was_attr_mismatch) {
1311 *reg = nullptr;
1312 *was_attr_mismatch = false;
1313
1314 const string& label = GetKernelLabelAttr(node_attrs);
1315
1316 const string key = Key(node_op, device_type, label);
1317 auto typed_registry = GlobalKernelRegistryTyped();
1318 tf_shared_lock lock(typed_registry->mu);
1319 auto regs = typed_registry->registry.equal_range(key);
1320 for (auto iter = regs.first; iter != regs.second; ++iter) {
1321 // If there is a kernel registered for the op and device_type,
1322 // check that the attrs match.
1323 bool match;
1324 TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
1325 if (match) {
1326 if (*reg != nullptr) {
1327 if ((*reg)->def.priority() == iter->second.def.priority()) {
1328 return errors::InvalidArgument(
1329 "Multiple OpKernel registrations match NodeDef at the same "
1330 "priority '",
1331 FormatNodeDefForError(node_name, has_experimental_debug_info,
1332 experimental_debug_info),
1333 "': '", (*reg)->def.ShortDebugString(), "' and '",
1334 iter->second.def.ShortDebugString(), "'");
1335 } else if ((*reg)->def.priority() > iter->second.def.priority()) {
1336 continue;
1337 }
1338 // iter->second's priority is higher than *reg.
1339 }
1340 *reg = &iter->second;
1341 } else {
1342 *was_attr_mismatch = true;
1343 }
1344 }
1345 // Check if no device specific registrations found. If not, try finding a
1346 // default kernel.
1347 if (*reg == nullptr &&
1348 !IsSymbolicExecutionDevice(device_type.type_string())) {
1349 const string default_key = Key(node_op, DEVICE_DEFAULT, label);
1350 auto regs = typed_registry->registry.equal_range(default_key);
1351 for (auto iter = regs.first; iter != regs.second; ++iter) {
1352 // If there is a kernel registered for the op and device_type,
1353 // check that the attrs match.
1354 bool match;
1355 TF_RETURN_IF_ERROR(
1356 KernelAttrsMatch(iter->second.def, node_attrs, &match));
1357 if (match) {
1358 if (*reg != nullptr) {
1359 return errors::InvalidArgument(
1360 "Multiple Default OpKernel registrations match NodeDef '",
1361 FormatNodeDefForError(node_name, has_experimental_debug_info,
1362 experimental_debug_info),
1363 "': '", (*reg)->def.ShortDebugString(), "' and '",
1364 iter->second.def.ShortDebugString(), "'");
1365 }
1366 *reg = &iter->second;
1367 } else {
1368 *was_attr_mismatch = true;
1369 }
1370 }
1371
1372 if (*reg != nullptr) {
1373 VLOG(1) << "No device-specific kernels found for NodeDef '"
1374 << FormatNodeDefForError(node_name, has_experimental_debug_info,
1375 experimental_debug_info)
1376 << "'"
1377 << "Will fall back to a default kernel." << std::endl;
1378 }
1379 }
1380
1381 return Status::OK();
1382 }
1383
FindKernelRegistration(const DeviceType & device_type,const NodeDef & node_def,const KernelRegistration ** reg,bool * was_attr_mismatch)1384 Status FindKernelRegistration(const DeviceType& device_type,
1385 const NodeDef& node_def,
1386 const KernelRegistration** reg,
1387 bool* was_attr_mismatch) {
1388 return FindKernelRegistration(
1389 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1390 node_def.experimental_debug_info(), node_def.op(),
1391 AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
1392 }
1393
1394 } // namespace
1395
KernelDefAvailable(const DeviceType & device_type,const NodeDef & node_def)1396 bool KernelDefAvailable(const DeviceType& device_type,
1397 const NodeDef& node_def) {
1398 const KernelRegistration* reg = nullptr;
1399 bool was_attr_mismatch;
1400 Status result =
1401 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch);
1402 return result.ok() && reg != nullptr;
1403 }
1404
1405 // TODO(irving): Change const NodeDef& to const Node&
FindKernelDef(const DeviceType & device_type,StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info,StringPiece node_op,StringPiece node_device,AttrSlice node_attrs,const KernelDef ** def,string * kernel_class_name)1406 Status FindKernelDef(
1407 const DeviceType& device_type, StringPiece node_name,
1408 bool has_experimental_debug_info,
1409 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1410 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1411 const KernelDef** def, string* kernel_class_name) {
1412 const KernelRegistration* reg = nullptr;
1413 bool was_attr_mismatch;
1414 TF_RETURN_IF_ERROR(FindKernelRegistration(
1415 device_type, node_name, has_experimental_debug_info,
1416 experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch));
1417 if (reg == nullptr) {
1418 const std::string device_str = DeviceTypeString(device_type);
1419 Status s = errors::NotFound(
1420 "No registered '", node_op, "' OpKernel for ", device_str,
1421 " devices compatible with node ",
1422 FormatNodeDefForError(node_name, has_experimental_debug_info,
1423 experimental_debug_info));
1424 if (was_attr_mismatch) {
1425 errors::AppendToMessage(
1426 &s, " (OpKernel was found, but attributes didn't match) ",
1427 "Requested Attributes: ",
1428 SummarizeAttrsHelper(node_attrs, node_device));
1429 }
1430
1431 // Do not print kernel registrations for other devices when using _JIT
1432 // devices for compilation.
1433 if (!absl::StrContains(device_str, "JIT")) {
1434 errors::AppendToMessage(
1435 &s, ". Registered:", KernelsRegisteredForOp(node_op));
1436 }
1437
1438 return s;
1439 }
1440 if (def != nullptr) *def = ®->def;
1441 if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
1442 return Status::OK();
1443 }
1444
FindKernelDef(const DeviceType & device_type,const NodeDef & node_def,const KernelDef ** def,string * kernel_class_name)1445 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1446 const KernelDef** def, string* kernel_class_name) {
1447 return FindKernelDef(
1448 device_type, node_def.name(), node_def.has_experimental_debug_info(),
1449 node_def.experimental_debug_info(), node_def.op(), node_def.device(),
1450 AttrSlice(&node_def.attr()), def, kernel_class_name);
1451 }
1452
SupportedDeviceTypesForNode(const std::vector<DeviceType> & prioritized_types,const NodeDef & def,PrioritizedDeviceTypeVector * prioritized_device_types,const DeviceNameUtils::ParsedName * local_address_spec)1453 Status SupportedDeviceTypesForNode(
1454 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1455 PrioritizedDeviceTypeVector* prioritized_device_types,
1456 const DeviceNameUtils::ParsedName* local_address_spec) {
1457 // TODO(zhifengc): Changes the callers (SimplePlacer and
1458 // DynamicPlacer) to consider the possibility that 'def' is call to
1459 // a user-defined function and only calls this
1460 // SupportedDeviceTypesForNode for primitive ops.
1461 const OpRegistrationData* op_reg_data;
1462 const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
1463 if (s.ok()) {
1464 bool exists_attr_mismatch = false;
1465 for (const DeviceType& device_type : prioritized_types) {
1466 const KernelRegistration* reg = nullptr;
1467 bool was_attr_mismatch = false;
1468 TF_RETURN_IF_ERROR(
1469 FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
1470 exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
1471 if (reg != nullptr) {
1472 int32 priority = reg->def.priority();
1473 prioritized_device_types->emplace_back(device_type, priority);
1474 }
1475 }
1476 // Add extra supported device types if the following conditions are
1477 // satisfied:
1478 // 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
1479 // 2) A device is requested for this node which specifies job/replica/task
1480 // 3) A local device is provided which specifies job/replica/task
1481 // 4) The local device does not have the same (job, replica, task) as the
1482 // requested device
1483 //
1484 // The goal is to address the issue where a graph includes op (e.g. PyFunc)
1485 // whose kernel is known to a remote process but not to the current process.
1486 if (prioritized_device_types->empty() && !exists_attr_mismatch &&
1487 local_address_spec != nullptr) {
1488 DeviceNameUtils::ParsedName requested_device_name;
1489 DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
1490 if (DeviceNameUtils::IsDifferentAddressSpace(*local_address_spec,
1491 requested_device_name)) {
1492 if (requested_device_name.has_type) {
1493 prioritized_device_types->push_back(
1494 std::make_pair(DeviceType(requested_device_name.type), 0));
1495 } else {
1496 for (const DeviceType& device_type : prioritized_types) {
1497 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1498 }
1499 }
1500 }
1501 }
1502
1503 // If we were unable to find any valid devices let's validate if the node is
1504 // even valid.
1505 if (prioritized_device_types->empty()) {
1506 TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
1507 }
1508
1509 std::sort(prioritized_device_types->begin(),
1510 prioritized_device_types->end(),
1511 [](const std::pair<DeviceType, int32>& a,
1512 const std::pair<DeviceType, int32>& b) {
1513 return a.second > b.second;
1514 });
1515 } else {
1516 // Assumes that all device types support this node.
1517 for (const DeviceType& device_type : prioritized_types) {
1518 prioritized_device_types->push_back(std::make_pair(device_type, 0));
1519 }
1520 }
1521 return Status::OK();
1522 }
1523
LogAllRegisteredKernels()1524 void LogAllRegisteredKernels() {
1525 KernelList kernel_list = GetAllRegisteredKernels();
1526 for (const auto& kernel_def : kernel_list.kernel()) {
1527 LOG(INFO) << "OpKernel ('" << kernel_def.ShortDebugString() << "')";
1528 }
1529 }
1530
GetAllRegisteredKernels()1531 KernelList GetAllRegisteredKernels() {
1532 return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
1533 }
1534
GetFilteredRegisteredKernels(const std::function<bool (const KernelDef &)> & predicate)1535 KernelList GetFilteredRegisteredKernels(
1536 const std::function<bool(const KernelDef&)>& predicate) {
1537 KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
1538 KernelList kernel_list;
1539 tf_shared_lock lock(typed_registry->mu);
1540 kernel_list.mutable_kernel()->Reserve(typed_registry->registry.size());
1541 for (const auto& p : typed_registry->registry) {
1542 const KernelDef& kernel_def = p.second.def;
1543 if (predicate(kernel_def)) {
1544 *kernel_list.add_kernel() = kernel_def;
1545 }
1546 }
1547 return kernel_list;
1548 }
1549
GetRegisteredKernelsForOp(StringPiece op_name)1550 KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
1551 auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
1552 return GetFilteredRegisteredKernels(op_pred);
1553 }
1554
KernelsRegisteredForOp(StringPiece op_name)1555 string KernelsRegisteredForOp(StringPiece op_name) {
1556 KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
1557 if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
1558 string ret;
1559 for (const auto& kernel_def : kernel_list.kernel()) {
1560 strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
1561 if (!kernel_def.label().empty()) {
1562 strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
1563 }
1564 for (int i = 0; i < kernel_def.constraint_size(); ++i) {
1565 strings::StrAppend(
1566 &ret, "; ", kernel_def.constraint(i).name(), " in ",
1567 SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
1568 }
1569 strings::StrAppend(&ret, "\n");
1570 }
1571 return ret;
1572 }
1573
1574 /* TODO(rmlarsen): This API is deprecated. Remove it if possible to avoid
1575 * copying the NodeDef. */
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const NodeDef & node_def,int graph_def_version,Status * status)1576 std::unique_ptr<OpKernel> CreateOpKernel(
1577 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1578 const NodeDef& node_def, int graph_def_version, Status* status) {
1579 // Look up the Op registered for this op name.
1580 std::shared_ptr<const NodeProperties> props;
1581 status->Update(NodeProperties::CreateFromNodeDef(
1582 node_def, OpRegistry::Global(), &props));
1583 if (!status->ok()) {
1584 errors::AppendToMessage(status,
1585 " for node: ", FormatNodeDefForError(node_def));
1586 return nullptr;
1587 }
1588 return CreateOpKernel(device_type, device, allocator, props,
1589 graph_def_version, status);
1590 }
1591
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,Status * status)1592 std::unique_ptr<OpKernel> CreateOpKernel(
1593 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1594 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1595 Status* status) {
1596 OpKernel* kernel = nullptr;
1597 *status = CreateOpKernel(std::move(device_type), device, allocator,
1598 /*flib=*/nullptr, props, graph_def_version, &kernel);
1599 return std::unique_ptr<OpKernel>(kernel);
1600 }
1601
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1602 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1603 Allocator* allocator, FunctionLibraryRuntime* flib,
1604 const std::shared_ptr<const NodeProperties>& props,
1605 int graph_def_version, OpKernel** kernel) {
1606 return CreateOpKernel(std::move(device_type), device, allocator, flib,
1607 /* resource_mgr= */ nullptr, props, graph_def_version,
1608 kernel);
1609 }
1610
CreateOpKernel(DeviceType device_type,DeviceBase * device,Allocator * allocator,FunctionLibraryRuntime * flib,ResourceMgr * resource_mgr,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1611 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1612 Allocator* allocator, FunctionLibraryRuntime* flib,
1613 ResourceMgr* resource_mgr,
1614 const std::shared_ptr<const NodeProperties>& props,
1615 int graph_def_version, OpKernel** kernel) {
1616 const NodeDef& node_def = props->node_def;
1617 bool was_attr_mismatch;
1618 const KernelRegistration* registration = nullptr;
1619 Status s;
1620 if (props != nullptr) {
1621 VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
1622
1623 // Validate node_def against OpDef.
1624 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *props->op_def));
1625
1626 // Look up kernel registration.
1627 s = FindKernelRegistration(device_type, node_def, ®istration,
1628 &was_attr_mismatch);
1629 if (!s.ok()) {
1630 errors::AppendToMessage(&s, " when instantiating ", node_def.op());
1631 return s;
1632 }
1633 }
1634 if (registration == nullptr) {
1635 s.Update(errors::NotFound("No registered '", node_def.op(),
1636 "' OpKernel for '", DeviceTypeString(device_type),
1637 "' devices compatible with node ",
1638 FormatNodeDefForError(node_def)));
1639 if (was_attr_mismatch) {
1640 errors::AppendToMessage(
1641 &s, " (OpKernel was found, but attributes didn't match) ",
1642 "Requested Attributes: ", SummarizeAttrs(node_def));
1643 }
1644 errors::AppendToMessage(
1645 &s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
1646 return s;
1647 }
1648
1649 // We are creating a kernel for an op registered in
1650 // OpRegistry::Global(), we consult the kernel registry to decide
1651 // the kernel's input and output memory types.
1652 MemoryTypeVector input_memory_types;
1653 MemoryTypeVector output_memory_types;
1654 TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
1655 node_def, &input_memory_types,
1656 &output_memory_types));
1657
1658 // Everything needed for OpKernel construction.
1659 OpKernelConstruction context(std::move(device_type), device, allocator, flib,
1660 resource_mgr, props, input_memory_types,
1661 output_memory_types, graph_def_version, &s);
1662 *kernel = registration->factory->Create(&context);
1663 if (!s.ok()) {
1664 delete *kernel;
1665 *kernel = nullptr;
1666 }
1667 return s;
1668 }
1669
1670 namespace {
1671
FindArgInOp(StringPiece arg_name,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)1672 bool FindArgInOp(StringPiece arg_name,
1673 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
1674 for (const auto& arg : args) {
1675 if (arg_name == arg.name()) {
1676 return true;
1677 }
1678 }
1679 return false;
1680 }
1681
1682 } // namespace
1683
ValidateKernelRegistrations(const OpRegistryInterface & op_registry)1684 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
1685 auto typed_registry = GlobalKernelRegistryTyped();
1686 tf_shared_lock lock(typed_registry->mu);
1687 for (const auto& key_registration : typed_registry->registry) {
1688 const KernelDef& kernel_def(key_registration.second.def);
1689 const OpRegistrationData* op_reg_data;
1690 const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
1691 if (!status.ok()) {
1692 // TODO(josh11b): Make this a hard error.
1693 LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
1694 << "') for unknown op: " << kernel_def.op();
1695 continue;
1696 }
1697 const OpDef& op_def = op_reg_data->op_def;
1698 for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
1699 if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
1700 !FindArgInOp(host_memory_arg, op_def.output_arg())) {
1701 return errors::InvalidArgument(
1702 "HostMemory arg '", host_memory_arg,
1703 "' not found in OpDef: ", SummarizeOpDef(op_def));
1704 }
1705 }
1706 }
1707 return Status::OK();
1708 }
1709
1710 template <>
eigen_device() const1711 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
1712 return eigen_cpu_device();
1713 }
1714
1715 template <>
eigen_device() const1716 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
1717 return eigen_gpu_device();
1718 }
1719
1720
CtxFailure(const Status & s)1721 void OpKernelConstruction::CtxFailure(const Status& s) {
1722 VLOG(1) << s;
1723 SetStatus(s);
1724 }
1725
CtxFailureWithWarning(const Status & s)1726 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
1727 LOG(WARNING) << s;
1728 SetStatus(s);
1729 }
1730
CtxFailure(const char * file,int line,const Status & s)1731 void OpKernelConstruction::CtxFailure(const char* file, int line,
1732 const Status& s) {
1733 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1734 << " : " << s;
1735 SetStatus(s);
1736 }
1737
CtxFailureWithWarning(const char * file,int line,const Status & s)1738 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
1739 const Status& s) {
1740 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1741 << " : " << s;
1742 SetStatus(s);
1743 }
1744
CtxFailure(const Status & s)1745 void OpKernelContext::CtxFailure(const Status& s) {
1746 VLOG(1) << s;
1747 SetStatus(s);
1748 }
1749
CtxFailureWithWarning(const Status & s)1750 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
1751 LOG(WARNING) << s;
1752 SetStatus(s);
1753 }
1754
CtxFailure(const char * file,int line,const Status & s)1755 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
1756 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1757 << " : " << s;
1758 SetStatus(s);
1759 }
1760
CtxFailureWithWarning(const char * file,int line,const Status & s)1761 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
1762 const Status& s) {
1763 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
1764 << " : " << s;
1765 SetStatus(s);
1766 }
1767
CheckNotInComputeAsync(OpKernelContext * ctx,const char * correct_macro_name)1768 void CheckNotInComputeAsync(OpKernelContext* ctx,
1769 const char* correct_macro_name) {
1770 CHECK_EQ(nullptr, ctx->params_->op_kernel->AsAsync())
1771 << "Use " << correct_macro_name << " in AsyncOpKernel implementations.";
1772 }
1773
1774 } // namespace tensorflow
1775