1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h"
16
17 #include "tensorflow/core/platform/errors.h"
18
19 namespace tensorflow {
20 namespace tfd {
21 namespace {
22
CheckOpDefCompatibility(const tensorflow::OpDef & op_def)23 llvm::Error CheckOpDefCompatibility(const tensorflow::OpDef& op_def) {
24 auto check_arg_def = [&](const auto& arg_def) -> llvm::Error {
25 if (arg_def.is_ref())
26 return tfrt::MakeStringError(
27 "TFRT kernel fallback error: Unsupported ref args in ",
28 op_def.name());
29 return llvm::Error::success();
30 };
31
32 for (const auto& arg_def : op_def.input_arg())
33 if (auto error = check_arg_def(arg_def)) return error;
34 for (const auto& arg_def : op_def.output_arg())
35 if (auto error = check_arg_def(arg_def)) return error;
36
37 return llvm::Error::success();
38 }
39
40 // Create a tensorflow::NodeDef from the tensorflow::OpDef and the attributes.
BuildNodeDef(const tensorflow::OpDef & op_def,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder)41 tfrt::StatusOr<tensorflow::NodeDef> BuildNodeDef(
42 const tensorflow::OpDef& op_def, int num_args,
43 const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder) {
44 tensorflow::NodeDef node_def;
45 node_def.set_name(op_def.name());
46 node_def.set_op(op_def.name());
47 for (int i = 0; i < num_args; ++i) {
48 node_def.add_input("dummy_input");
49 }
50
51 auto* attr_value_map = node_def.mutable_attr();
52 if (auto error = attr_builder(attr_value_map)) {
53 return tensorflow::errors::InvalidArgument(tfrt::StrCat(error));
54 }
55
56 // For any attr-value pairs that exist in the op def (from op registry)
57 // but not in `attr_value_map`, fill them into `attr_value_map`, so that we
58 // can run a TFE_Op without having to specify all the default attr values
59 // (e.g. for matmul, the `transpose_a` attr defaults to false).
60 for (const auto& attr_def : op_def.attr()) {
61 if (attr_def.has_default_value()) {
62 // Insertion will fail if this attribute already has a value.
63 attr_value_map->insert({attr_def.name(), attr_def.default_value()});
64 }
65 }
66 return node_def;
67 }
68
CreateOpKernel(tensorflow::FunctionLibraryRuntime * flr,tensorflow::NodeDef ndef,std::unique_ptr<tensorflow::OpKernel> * result)69 tensorflow::Status CreateOpKernel(
70 tensorflow::FunctionLibraryRuntime* flr, tensorflow::NodeDef ndef,
71 std::unique_ptr<tensorflow::OpKernel>* result) {
72 std::shared_ptr<const tensorflow::NodeProperties> props;
73 TF_RETURN_IF_ERROR(tensorflow::NodeProperties::CreateFromNodeDef(
74 ndef, flr->GetFunctionLibraryDefinition(), &props));
75 tensorflow::OpKernel* k = nullptr;
76 TF_RETURN_IF_ERROR(flr->CreateKernel(props, &k));
77 result->reset(k);
78 return Status::OK();
79 }
80
81 } // namespace
82
Create(absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder,const KernelFallbackCompatRequestState & fallback_request_state)83 tfrt::StatusOr<OpKernelRunner> OpKernelRunner::Create(
84 absl::string_view op_name, absl::string_view device_name, int num_args,
85 const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
86 const KernelFallbackCompatRequestState& fallback_request_state) {
87 const OpDef* op_def = nullptr;
88 TF_RETURN_IF_ERROR(tensorflow::OpDefForOp(std::string(op_name), &op_def));
89 if (auto error = CheckOpDefCompatibility(*op_def)) {
90 return tensorflow::errors::Internal(tfrt::StrCat(error));
91 }
92 VLOG(1) << "KernelFallbackExecuteCompat creating op from OpDef: "
93 << op_def->DebugString();
94
95 TF_ASSIGN_OR_RETURN(auto node_def,
96 BuildNodeDef(*op_def, num_args, attr_builder));
97
98 VLOG(1) << "KernelFallbackExecuteCompat created NodeDef: "
99 << node_def.DebugString();
100
101 tensorflow::Device* device = nullptr;
102 tensorflow::FunctionLibraryRuntime* function_library_runtime = nullptr;
103
104 // TODO(b/176451036): For device names that are not in tensorflow format, we
105 // handle it specially. This is a workaround as the compiler lowering does not
106 // use tensorflow format in some cases. Ideally, we should always use device
107 // name in tensorflow format in fallback code.
108 Status s = fallback_request_state.device_manager().LookupDevice(device_name,
109 &device);
110
111 // Fall back to host device if it fails to find the specified device.
112 if (!s.ok()) {
113 LOG(ERROR) << "Failed to find device " << device_name
114 << " when creating OpKernel: " << s;
115 LOG(ERROR) << "Fallback to host device instead";
116 device = fallback_request_state.device_manager().HostCPU();
117 }
118
119 function_library_runtime =
120 fallback_request_state.process_function_library_runtime().GetFLR(
121 device->name());
122
123 std::unique_ptr<OpKernel> op_kernel;
124 TF_RETURN_IF_ERROR(CreateOpKernel(function_library_runtime,
125 std::move(node_def), &op_kernel));
126 return OpKernelRunner(device, function_library_runtime, std::move(op_kernel));
127 }
128
OpKernelRunner(tensorflow::Device * device,tensorflow::FunctionLibraryRuntime * function_library_runtime,std::unique_ptr<tensorflow::OpKernel> op_kernel)129 OpKernelRunner::OpKernelRunner(
130 tensorflow::Device* device,
131 tensorflow::FunctionLibraryRuntime* function_library_runtime,
132 std::unique_ptr<tensorflow::OpKernel> op_kernel)
133 : device_(device),
134 function_library_runtime_(function_library_runtime),
135 resource_manager_(device->resource_manager()),
136 op_kernel_(std::move(op_kernel)),
137 is_async_(op_kernel_->AsAsync() != nullptr) {
138 DCHECK(device_);
139 DCHECK(function_library_runtime_);
140
141 const auto& input_memory_types = op_kernel_->input_memory_types();
142 input_alloc_attrs_.resize(op_kernel_->num_inputs());
143 for (size_t i = 0, e = op_kernel_->num_inputs(); i < e; ++i) {
144 input_alloc_attrs_[i].set_on_host(input_memory_types[i] ==
145 tensorflow::HOST_MEMORY);
146 }
147 const auto& output_memory_types = op_kernel_->output_memory_types();
148 output_alloc_attrs_.resize(op_kernel_->num_outputs());
149 for (size_t i = 0, e = output_alloc_attrs_.size(); i < e; ++i) {
150 output_alloc_attrs_[i].set_on_host(output_memory_types[i] ==
151 tensorflow::HOST_MEMORY);
152 }
153 }
154
RunAsync(OpKernelContext * context,AsyncOpKernel::DoneCallback done_callback) const155 void OpKernelRunner::RunAsync(OpKernelContext* context,
156 AsyncOpKernel::DoneCallback done_callback) const {
157 DVLOG(1) << "KernelFallbackExecuteCompat Running Async Op: "
158 << op_kernel_->def().DebugString()
159 << ", on Device: " << device_->name();
160
161 AsyncOpKernel* async = op_kernel_->AsAsync();
162 DCHECK(async);
163
164 async->ComputeAsync(context, std::move(done_callback));
165 }
166
OpKernelRunnerCache()167 OpKernelRunnerCache::OpKernelRunnerCache() {}
168
GetOrCreate(tfrt::Location loc,absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder,const KernelFallbackCompatRequestState & fallback_request_state)169 tfrt::StatusOr<OpKernelRunner*> OpKernelRunnerCache::GetOrCreate(
170 tfrt::Location loc, absl::string_view op_name,
171 absl::string_view device_name, int num_args,
172 const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
173 const KernelFallbackCompatRequestState& fallback_request_state) {
174 OpLocationKey key(loc);
175 {
176 tf_shared_lock lock(mu_);
177 auto it = map_.find(key);
178 if (it != map_.end()) {
179 DCHECK_EQ(it->second->op_kernel()->name(), op_name);
180 return it->second.get();
181 }
182 }
183
184 mutex_lock lock(mu_);
185
186 auto it = map_.find(key);
187 if (it != map_.end()) {
188 DCHECK_EQ(it->second->op_kernel()->name(), op_name);
189 return it->second.get();
190 }
191
192 VLOG(1) << "KernelFallbackExecuteCompat creating op " << op_name
193 << " at location " << loc.data << " on device " << device_name;
194
195 TF_ASSIGN_OR_RETURN(auto runner, OpKernelRunner::Create(
196 op_name, device_name, num_args,
197 attr_builder, fallback_request_state));
198
199 auto runner_uptr = std::make_unique<OpKernelRunner>(std::move(runner));
200
201 auto* runner_ptr = runner_uptr.get();
202 auto r = map_.emplace(key, std::move(runner_uptr)).second;
203 DCHECK(r);
204
205 return runner_ptr;
206 }
207
208 } // namespace tfd
209 } // namespace tensorflow
210