• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/op_kernel_runner.h"
16 
17 #include "tensorflow/core/platform/errors.h"
18 
19 namespace tensorflow {
20 namespace tfd {
21 namespace {
22 
CheckOpDefCompatibility(const tensorflow::OpDef & op_def)23 llvm::Error CheckOpDefCompatibility(const tensorflow::OpDef& op_def) {
24   auto check_arg_def = [&](const auto& arg_def) -> llvm::Error {
25     if (arg_def.is_ref())
26       return tfrt::MakeStringError(
27           "TFRT kernel fallback error: Unsupported ref args in ",
28           op_def.name());
29     return llvm::Error::success();
30   };
31 
32   for (const auto& arg_def : op_def.input_arg())
33     if (auto error = check_arg_def(arg_def)) return error;
34   for (const auto& arg_def : op_def.output_arg())
35     if (auto error = check_arg_def(arg_def)) return error;
36 
37   return llvm::Error::success();
38 }
39 
40 // Create a tensorflow::NodeDef from the tensorflow::OpDef and the attributes.
BuildNodeDef(const tensorflow::OpDef & op_def,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder)41 tfrt::StatusOr<tensorflow::NodeDef> BuildNodeDef(
42     const tensorflow::OpDef& op_def, int num_args,
43     const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder) {
44   tensorflow::NodeDef node_def;
45   node_def.set_name(op_def.name());
46   node_def.set_op(op_def.name());
47   for (int i = 0; i < num_args; ++i) {
48     node_def.add_input("dummy_input");
49   }
50 
51   auto* attr_value_map = node_def.mutable_attr();
52   if (auto error = attr_builder(attr_value_map)) {
53     return tensorflow::errors::InvalidArgument(tfrt::StrCat(error));
54   }
55 
56   // For any attr-value pairs that exist in the op def (from op registry)
57   // but not in `attr_value_map`, fill them into `attr_value_map`, so that we
58   // can run a TFE_Op without having to specify all the default attr values
59   // (e.g. for matmul, the `transpose_a` attr defaults to false).
60   for (const auto& attr_def : op_def.attr()) {
61     if (attr_def.has_default_value()) {
62       // Insertion will fail if this attribute already has a value.
63       attr_value_map->insert({attr_def.name(), attr_def.default_value()});
64     }
65   }
66   return node_def;
67 }
68 
CreateOpKernel(tensorflow::FunctionLibraryRuntime * flr,tensorflow::NodeDef ndef,std::unique_ptr<tensorflow::OpKernel> * result)69 tensorflow::Status CreateOpKernel(
70     tensorflow::FunctionLibraryRuntime* flr, tensorflow::NodeDef ndef,
71     std::unique_ptr<tensorflow::OpKernel>* result) {
72   std::shared_ptr<const tensorflow::NodeProperties> props;
73   TF_RETURN_IF_ERROR(tensorflow::NodeProperties::CreateFromNodeDef(
74       ndef, flr->GetFunctionLibraryDefinition(), &props));
75   tensorflow::OpKernel* k = nullptr;
76   TF_RETURN_IF_ERROR(flr->CreateKernel(props, &k));
77   result->reset(k);
78   return Status::OK();
79 }
80 
81 }  // namespace
82 
Create(absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder,const KernelFallbackCompatRequestState & fallback_request_state)83 tfrt::StatusOr<OpKernelRunner> OpKernelRunner::Create(
84     absl::string_view op_name, absl::string_view device_name, int num_args,
85     const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
86     const KernelFallbackCompatRequestState& fallback_request_state) {
87   const OpDef* op_def = nullptr;
88   TF_RETURN_IF_ERROR(tensorflow::OpDefForOp(std::string(op_name), &op_def));
89   if (auto error = CheckOpDefCompatibility(*op_def)) {
90     return tensorflow::errors::Internal(tfrt::StrCat(error));
91   }
92   VLOG(1) << "KernelFallbackExecuteCompat creating op from OpDef: "
93           << op_def->DebugString();
94 
95   TF_ASSIGN_OR_RETURN(auto node_def,
96                       BuildNodeDef(*op_def, num_args, attr_builder));
97 
98   VLOG(1) << "KernelFallbackExecuteCompat created NodeDef: "
99           << node_def.DebugString();
100 
101   tensorflow::Device* device = nullptr;
102   tensorflow::FunctionLibraryRuntime* function_library_runtime = nullptr;
103 
104   // TODO(b/176451036): For device names that are not in tensorflow format, we
105   // handle it specially. This is a workaround as the compiler lowering does not
106   // use tensorflow format in some cases. Ideally, we should always use device
107   // name in tensorflow format in fallback code.
108   Status s = fallback_request_state.device_manager().LookupDevice(device_name,
109                                                                   &device);
110 
111   // Fall back to host device if it fails to find the specified device.
112   if (!s.ok()) {
113     LOG(ERROR) << "Failed to find device " << device_name
114                << " when creating OpKernel: " << s;
115     LOG(ERROR) << "Fallback to host device instead";
116     device = fallback_request_state.device_manager().HostCPU();
117   }
118 
119   function_library_runtime =
120       fallback_request_state.process_function_library_runtime().GetFLR(
121           device->name());
122 
123   std::unique_ptr<OpKernel> op_kernel;
124   TF_RETURN_IF_ERROR(CreateOpKernel(function_library_runtime,
125                                     std::move(node_def), &op_kernel));
126   return OpKernelRunner(device, function_library_runtime, std::move(op_kernel));
127 }
128 
OpKernelRunner(tensorflow::Device * device,tensorflow::FunctionLibraryRuntime * function_library_runtime,std::unique_ptr<tensorflow::OpKernel> op_kernel)129 OpKernelRunner::OpKernelRunner(
130     tensorflow::Device* device,
131     tensorflow::FunctionLibraryRuntime* function_library_runtime,
132     std::unique_ptr<tensorflow::OpKernel> op_kernel)
133     : device_(device),
134       function_library_runtime_(function_library_runtime),
135       resource_manager_(device->resource_manager()),
136       op_kernel_(std::move(op_kernel)),
137       is_async_(op_kernel_->AsAsync() != nullptr) {
138   DCHECK(device_);
139   DCHECK(function_library_runtime_);
140 
141   const auto& input_memory_types = op_kernel_->input_memory_types();
142   input_alloc_attrs_.resize(op_kernel_->num_inputs());
143   for (size_t i = 0, e = op_kernel_->num_inputs(); i < e; ++i) {
144     input_alloc_attrs_[i].set_on_host(input_memory_types[i] ==
145                                       tensorflow::HOST_MEMORY);
146   }
147   const auto& output_memory_types = op_kernel_->output_memory_types();
148   output_alloc_attrs_.resize(op_kernel_->num_outputs());
149   for (size_t i = 0, e = output_alloc_attrs_.size(); i < e; ++i) {
150     output_alloc_attrs_[i].set_on_host(output_memory_types[i] ==
151                                        tensorflow::HOST_MEMORY);
152   }
153 }
154 
RunAsync(OpKernelContext * context,AsyncOpKernel::DoneCallback done_callback) const155 void OpKernelRunner::RunAsync(OpKernelContext* context,
156                               AsyncOpKernel::DoneCallback done_callback) const {
157   DVLOG(1) << "KernelFallbackExecuteCompat Running Async Op: "
158            << op_kernel_->def().DebugString()
159            << ", on Device: " << device_->name();
160 
161   AsyncOpKernel* async = op_kernel_->AsAsync();
162   DCHECK(async);
163 
164   async->ComputeAsync(context, std::move(done_callback));
165 }
166 
OpKernelRunnerCache()167 OpKernelRunnerCache::OpKernelRunnerCache() {}
168 
GetOrCreate(tfrt::Location loc,absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<llvm::Error (tensorflow::AttrValueMap *)> & attr_builder,const KernelFallbackCompatRequestState & fallback_request_state)169 tfrt::StatusOr<OpKernelRunner*> OpKernelRunnerCache::GetOrCreate(
170     tfrt::Location loc, absl::string_view op_name,
171     absl::string_view device_name, int num_args,
172     const std::function<llvm::Error(tensorflow::AttrValueMap*)>& attr_builder,
173     const KernelFallbackCompatRequestState& fallback_request_state) {
174   OpLocationKey key(loc);
175   {
176     tf_shared_lock lock(mu_);
177     auto it = map_.find(key);
178     if (it != map_.end()) {
179       DCHECK_EQ(it->second->op_kernel()->name(), op_name);
180       return it->second.get();
181     }
182   }
183 
184   mutex_lock lock(mu_);
185 
186   auto it = map_.find(key);
187   if (it != map_.end()) {
188     DCHECK_EQ(it->second->op_kernel()->name(), op_name);
189     return it->second.get();
190   }
191 
192   VLOG(1) << "KernelFallbackExecuteCompat creating op " << op_name
193           << " at location " << loc.data << " on device " << device_name;
194 
195   TF_ASSIGN_OR_RETURN(auto runner, OpKernelRunner::Create(
196                                        op_name, device_name, num_args,
197                                        attr_builder, fallback_request_state));
198 
199   auto runner_uptr = std::make_unique<OpKernelRunner>(std::move(runner));
200 
201   auto* runner_ptr = runner_uptr.get();
202   auto r = map_.emplace(key, std::move(runner_uptr)).second;
203   DCHECK(r);
204 
205   return runner_ptr;
206 }
207 
208 }  // namespace tfd
209 }  // namespace tensorflow
210