1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
17
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OperationSupport.h"
22 #include "mlir/IR/Types.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/core/util/device_name_utils.h"
29 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
32 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
33 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
34
35 namespace tensorflow {
36
CoreRTConverter(mlir::MLIRContext * context,const mlir::TF::SideEffectAnalysis::Info * side_effect_analysis)37 CoreRTConverter::CoreRTConverter(
38 mlir::MLIRContext *context,
39 const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis)
40 : builder_(context), side_effect_analysis_(*side_effect_analysis) {
41 addConversion([](tfrt::compiler::ChainType type) { return type; });
42 addConversion([](tfrt::corert::OpHandlerType type) { return type; });
43 addConversion([](tfrt::dist::DistributedContextType type) { return type; });
44 addConversion([](tfrt::corert::TensorHandleType type) { return type; });
45 addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
46 // Ref types are not supported in both compiler and runtime.
47 if (type.getElementType().isa<mlir::TF::TensorFlowRefType>())
48 return llvm::None;
49 return tensor_handle_type();
50 });
51 addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
52 if (type == builder_.getI1Type()) return type;
53 return llvm::None;
54 });
55 }
56
MaterializeDerivedAttributes(mlir::Operation * op)57 void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) {
58 if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
59 auto derived_attrs = interface.materializeDerivedAttributes();
60 for (auto named_attr : derived_attrs) {
61 op->setAttr(named_attr.getName(), named_attr.getValue());
62 }
63 }
64 }
65
IsSupportedNumericDType(mlir::Type type) const66 bool CoreRTConverter::IsSupportedNumericDType(mlir::Type type) const {
67 // Most of the tensorflow data types (eg. f32, i64) are supported and they
68 // are standard MLIR types that need no conversion here.
69 if (type.isBF16() || type.isF16() || type.isF32() || type.isF64() ||
70 type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
71 type.isInteger(32) || type.isInteger(64) || type.isUnsignedInteger(8) ||
72 type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
73 type.isUnsignedInteger(64))
74 return true;
75
76 if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
77 auto element_type = complex_type.getElementType();
78 if (element_type.isF32() || element_type.isF64()) return true;
79 }
80
81 return false;
82 }
83
CreateOpAttrs(ArrayRef<NamedAttribute> attrs)84 mlir::ArrayAttr CoreRTConverter::CreateOpAttrs(ArrayRef<NamedAttribute> attrs) {
85 llvm::SmallVector<mlir::Attribute, 4> attr_array;
86 for (auto key_and_value : attrs) {
87 if (!IsUnusedAttribute(key_and_value.getName())) {
88 auto converted = ConvertAttribute(key_and_value.getValue());
89 if (!converted) return {};
90
91 mlir::StringAttr key =
92 builder_.getStringAttr(key_and_value.getName().strref());
93 attr_array.push_back(builder_.getArrayAttr({key, converted}));
94 }
95 }
96 return builder_.getArrayAttr(attr_array);
97 }
98
CreateOpFuncAttrs(ArrayRef<NamedAttribute> attrs,llvm::SmallVector<mlir::StringAttr,4> * func_attr_keys)99 mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs(
100 ArrayRef<NamedAttribute> attrs,
101 llvm::SmallVector<mlir::StringAttr, 4> *func_attr_keys) {
102 llvm::SmallVector<mlir::Attribute, 4> attr_array;
103 for (auto key_and_value : attrs) {
104 auto attr_key = key_and_value.getName();
105 auto attr_value = key_and_value.getValue();
106 if (!IsUnusedAttribute(attr_key) &&
107 attr_value.isa<mlir::FlatSymbolRefAttr, mlir::SymbolRefAttr>()) {
108 auto func_attr = attr_value.dyn_cast<mlir::FlatSymbolRefAttr>();
109 auto converted = ConvertSymbolAttrToStringAttr(func_attr);
110 mlir::StringAttr key = builder_.getStringAttr(attr_key.strref());
111 attr_array.push_back(builder_.getArrayAttr({key, converted}));
112
113 // Remove the attribute to avoid being converted again.
114 func_attr_keys->push_back(attr_key);
115 }
116 }
117 return builder_.getArrayAttr(attr_array);
118 }
119
120 // TODO(chky): Add support for multiple device instances.
ParseDeviceName(llvm::StringRef device_name) const121 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
122 llvm::StringRef device_name) const {
123 std::string tf_device_name = device_name.str();
124
125 if (tf_device_name.empty()) {
126 return llvm::None;
127 }
128
129 ParseDeviceNameResult result;
130 result.device_name = tf_device_name;
131
132 // Parse the device name in format of the current tensorflow.
133 DeviceNameUtils::ParsedName parsed_name;
134 if (!DeviceNameUtils::ParseFullName(result.device_name, &parsed_name)) {
135 return llvm::None;
136 }
137 if (!parsed_name.has_type) {
138 return llvm::None;
139 }
140 result.device_type = parsed_name.type;
141
142 result.op_handler_name = tf_device_name;
143
144 return result;
145 }
146
ParseDeviceName(mlir::Operation * op) const147 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
148 mlir::Operation *op) const {
149 auto device_attr = op->getAttr("device");
150 if (!device_attr) {
151 return llvm::None;
152 }
153
154 auto parsed_device_name =
155 ParseDeviceName(device_attr.cast<mlir::StringAttr>().getValue());
156 if (!parsed_device_name) op->emitWarning("failed to parse device name.");
157 return parsed_device_name;
158 }
159
ConvertOpHandler(mlir::Operation * op,llvm::StringRef op_handler_name,ConversionPatternRewriter * rewriter)160 mlir::Value CoreRTConverter::ConvertOpHandler(
161 mlir::Operation *op, llvm::StringRef op_handler_name,
162 ConversionPatternRewriter *rewriter) {
163 auto iter = op_handler_by_name_.find(op_handler_name);
164 if (iter != op_handler_by_name_.end()) return iter->second;
165
166 mlir::Block *block = op->getBlock();
167 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
168 rewriter->setInsertionPointToStart(block);
169
170 func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
171 mlir::Value in_chain = func_op.getArgument(0);
172 auto get_op_handler_op = rewriter->create<tfrt::corert::GetOpHandler>(
173 block->getParent()->getLoc(), op_handler_type(), in_chain,
174 op_handler_name);
175 op_handler_by_name_[op_handler_name] = get_op_handler_op.getResult();
176 return get_op_handler_op.getResult();
177 }
178
GetDistributedContext(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)179 mlir::Value CoreRTConverter::GetDistributedContext(
180 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
181 mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
182 auto iter = distributed_context_by_func_.find(func_op.getOperation());
183 if (iter != distributed_context_by_func_.end()) {
184 return iter->second;
185 }
186 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
187 rewriter->setInsertionPoint(op);
188 auto get_dist_ctx_op = rewriter->create<tfrt::dist::GetDistributedContextOp>(
189 op->getLoc(), distributed_context_type());
190
191 mlir::Value result = get_dist_ctx_op.result();
192 distributed_context_by_func_[func_op.getOperation()] = result;
193 return result;
194 }
195
GetRemoteChainManager(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)196 mlir::Value CoreRTConverter::GetRemoteChainManager(
197 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
198 mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
199 auto iter = remote_chain_mgr_by_func_.find(func_op.getOperation());
200 if (iter != remote_chain_mgr_by_func_.end()) {
201 return iter->second;
202 }
203 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
204 rewriter->setInsertionPoint(op);
205
206 mlir::Type remote_chain_mgr_type =
207 builder_.getType<::tfrt::dist::RemoteChainManagerType>();
208 mlir::Value dist_ctx = GetDistributedContext(op, rewriter);
209 auto create_mgr_op = rewriter->create<tfrt::dist::CreateRemoteChainManager>(
210 op->getLoc(), remote_chain_mgr_type, dist_ctx);
211
212 mlir::Value result = create_mgr_op.result();
213 remote_chain_mgr_by_func_[func_op.getOperation()] = result;
214 return result;
215 }
216
GetLocalSideEffectChain(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)217 mlir::Value CoreRTConverter::GetLocalSideEffectChain(
218 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
219 auto func_op = op->getParentOfType<mlir::func::FuncOp>();
220
221 llvm::SmallVector<mlir::Operation *, 4> predecessors;
222 if (llvm::isa<mlir::func::ReturnOp>(op)) {
223 auto sinks = side_effect_analysis_.ControlSinks();
224 predecessors.assign(sinks.begin(), sinks.end());
225 } else {
226 predecessors = side_effect_analysis_.DirectControlPredecessors(op);
227 }
228
229 llvm::SmallVector<mlir::Value, 2> chains;
230 for (auto *pred : predecessors) {
231 // TODO(chky): ReadVariableOp is removed in the pass and not converted.
232 // Ideally, every side-effecting op should be converted to a
233 // tfrt_fallback.executeop.seq op. The special rewrite logic of
234 // ReadVariableOp should be done in a previous pass.
235 if (auto chain = local_side_effect_chains_.lookup(pred))
236 chains.push_back(chain);
237 }
238
239 // If there is no side-effect predecessor, then the input side-effect chain
240 // is used.
241 if (chains.empty()) return func_op.getArgument(0);
242
243 if (chains.size() == 1) return chains[0];
244
245 // If there are multiple side-effect predecessors, insert a merge_chains
246 // kernel and return the merged chain.
247 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
248 rewriter->setInsertionPoint(op);
249 return rewriter->create<tfrt::compiler::MergeChainsOp>(op->getLoc(),
250 chain_type(), chains);
251 }
252
GetTaskHandle(mlir::Operation * op,StringRef task_name,mlir::ConversionPatternRewriter * rewriter)253 mlir::Value CoreRTConverter::GetTaskHandle(
254 mlir::Operation *op, StringRef task_name,
255 mlir::ConversionPatternRewriter *rewriter) {
256 mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
257 llvm::StringMap<mlir::Value> &task_handle_by_name =
258 task_handles_by_func_[func_op.getOperation()];
259 auto iter = task_handle_by_name.find(task_name);
260 if (iter != task_handle_by_name.end()) {
261 return iter->second;
262 }
263
264 mlir::Value distributed_context = GetDistributedContext(op, rewriter);
265 auto task_handle_op = rewriter->create<tfrt::dist::GetTaskHandleOp>(
266 op->getLoc(), rewriter->getType<tfrt::dist::TaskHandleType>(),
267 distributed_context, task_name);
268
269 task_handle_by_name[task_name] = task_handle_op.getResult();
270 return task_handle_op.getResult();
271 }
272
GetRemoteSideEffectChain(mlir::Operation * op,StringRef remote_host,mlir::ConversionPatternRewriter * rewriter)273 mlir::Value CoreRTConverter::GetRemoteSideEffectChain(
274 mlir::Operation *op, StringRef remote_host,
275 mlir::ConversionPatternRewriter *rewriter) {
276 mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter);
277 mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter);
278 mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter);
279 mlir::Type remote_obj_id_ty =
280 rewriter->getType<tfrt::dist::RemoteObjectIdType>();
281
282 // Get the remote chain using the tfrt_dist.get_chain_for_task_handle op.
283 auto get_chain_op = rewriter->create<tfrt::dist::GetChainForTaskHandleOp>(
284 op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr,
285 task_handle);
286 return get_chain_op.getResult();
287 }
288
ConvertAttribute(mlir::Attribute attr)289 mlir::Attribute CoreRTConverter::ConvertAttribute(mlir::Attribute attr) {
290 // The supported attributes here should be kept consistent with
291 // //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h
292 //
293 // Currently, not all tensorflow data types are supported. Unranked shape
294 // attributes are not supported yet.
295
296 // Return directly if the attribute is already supported.
297 if (attr.isa<mlir::IntegerAttr, mlir::FloatAttr, mlir::BoolAttr,
298 mlir::StringAttr, mlir::DenseIntOrFPElementsAttr, mlir::DenseI32ArrayAttr>())
299 return attr;
300
301 // For type attributes, we convert non-standard MLIR types to corresponding
302 // corert types.
303 if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
304 if (auto shape_type = type_attr.getValue().dyn_cast<mlir::TensorType>()) {
305 if (!shape_type.hasRank())
306 return tfrt::corert::ShapeAttr::get(builder_.getContext());
307
308 return tfrt::corert::ShapeAttr::get(builder_.getContext(),
309 shape_type.getShape());
310 }
311
312 return ConvertTypeAttribute(type_attr);
313 }
314
315 // Convert the attribute to the corresponding format in TFRT dialect if
316 // needed.
317 if (auto shape_attr = attr.dyn_cast<mlir::TF::ShapeAttr>()) {
318 if (!shape_attr.hasRank())
319 return tfrt::corert::ShapeAttr::get(builder_.getContext());
320 return tfrt::corert::ShapeAttr::get(builder_.getContext(),
321 shape_attr.getShape());
322 }
323
324 // For arrays, we recursively convert the elements.
325 if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
326 llvm::SmallVector<mlir::Attribute, 8> attrs;
327 attrs.reserve(array_attr.size());
328 for (auto attr : array_attr) {
329 auto converted = ConvertAttribute(attr);
330 if (!converted) return {};
331 attrs.push_back(converted);
332 }
333 return builder_.getArrayAttr(attrs);
334 }
335
336 return {};
337 }
338
ConvertSymbolAttrToStringAttr(mlir::FlatSymbolRefAttr symbol_attr)339 mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr(
340 mlir::FlatSymbolRefAttr symbol_attr) {
341 // Currently in TF graph to MLIR importing, a "0" is appended to the original
342 // function name, so we pop it here. The renaming is for TF/XLA v1 bridge
343 // use cases. Refer to b/142268695, b/141617294 for more context.
344 //
345 // In TFRT use cases, in almost every case "0" is the only literal
346 // appended since TF Graph already guarantee function name uniqueness.
347 // TODO(b/172092902): Investigate a better way to make the tf_func_name to
348 // mlir_tf_func_name conversion reversible.
349 auto func_name = symbol_attr.getValue().drop_back().str();
350
351 return mlir::StringAttr::get(builder_.getContext(), func_name);
352 }
353
ConvertTypeAttribute(mlir::TypeAttr type_attr)354 mlir::TypeAttr CoreRTConverter::ConvertTypeAttribute(mlir::TypeAttr type_attr) {
355 auto type = type_attr.getValue();
356
357 if (IsSupportedNumericDType(type)) return type_attr;
358
359 // For TF custom types, we convert it to custom corert types.
360 if (type.isa<mlir::TF::StringType>())
361 return mlir::TypeAttr::get(
362 tfrt::corert::StringType::get(builder_.getContext()));
363
364 if (type.isa<mlir::TF::ResourceType>())
365 return mlir::TypeAttr::get(
366 tfrt::corert::ResourceType::get(builder_.getContext()));
367
368 if (type.isa<mlir::TF::VariantType>())
369 return mlir::TypeAttr::get(
370 tfrt::corert::VariantType::get(builder_.getContext()));
371
372 if (type.isa<mlir::TF::Quint8Type>()) {
373 return mlir::TypeAttr::get(
374 tfrt::corert::Quint8Type::get(builder_.getContext()));
375 }
376
377 if (type.isa<mlir::TF::Quint16Type>()) {
378 return mlir::TypeAttr::get(
379 tfrt::corert::Quint16Type::get(builder_.getContext()));
380 }
381
382 if (type.isa<mlir::TF::Qint8Type>()) {
383 return mlir::TypeAttr::get(
384 tfrt::corert::Qint8Type::get(builder_.getContext()));
385 }
386
387 if (type.isa<mlir::TF::Qint16Type>()) {
388 return mlir::TypeAttr::get(
389 tfrt::corert::Qint16Type::get(builder_.getContext()));
390 }
391
392 if (type.isa<mlir::TF::Qint32Type>()) {
393 return mlir::TypeAttr::get(
394 tfrt::corert::Qint32Type::get(builder_.getContext()));
395 }
396
397 // Return invalid results to emit error for unsupported types.
398 return {};
399 }
400
401 } // namespace tensorflow
402