1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
17
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Identifier.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/IR/Types.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
32 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
33 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
34 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
35
36 namespace tensorflow {
37
CoreRTConverter(mlir::MLIRContext * context,const mlir::TF::SideEffectAnalysis::Info * side_effect_analysis)38 CoreRTConverter::CoreRTConverter(
39 mlir::MLIRContext *context,
40 const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis)
41 : builder_(context), side_effect_analysis_(*side_effect_analysis) {
42 addConversion([](tfrt::compiler::ChainType type) { return type; });
43 addConversion([](tfrt::corert::OpHandlerType type) { return type; });
44 addConversion([](tfrt::dist::DistributedContextType type) { return type; });
45 addConversion([](tfrt::corert::TensorHandleType type) { return type; });
46 addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
47 // Ref types are not supported in both compiler and runtime.
48 if (type.getElementType().isa<mlir::TF::TensorFlowRefType>())
49 return llvm::None;
50 return tensor_handle_type();
51 });
52 addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
53 if (type == builder_.getI1Type()) return type;
54 return llvm::None;
55 });
56 }
57
MaterializeDerivedAttributes(mlir::Operation * op)58 void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) {
59 if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
60 auto derived_attrs = interface.materializeDerivedAttributes();
61 for (auto named_attr : derived_attrs) {
62 op->setAttr(named_attr.first, named_attr.second);
63 }
64 }
65 }
66
IsSupportedNumericDType(mlir::Type type) const67 bool CoreRTConverter::IsSupportedNumericDType(mlir::Type type) const {
68 // Most of the tensorflow data types (eg. f32, i64) are supported and they
69 // are standard MLIR types that need no conversion here.
70 if (type.isBF16() || type.isF16() || type.isF32() || type.isF64() ||
71 type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
72 type.isInteger(32) || type.isInteger(64) || type.isUnsignedInteger(8) ||
73 type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
74 type.isUnsignedInteger(64))
75 return true;
76
77 if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
78 auto element_type = complex_type.getElementType();
79 if (element_type.isF32() || element_type.isF64()) return true;
80 }
81
82 return false;
83 }
84
CreateOpAttrs(ArrayRef<NamedAttribute> attrs)85 mlir::ArrayAttr CoreRTConverter::CreateOpAttrs(ArrayRef<NamedAttribute> attrs) {
86 llvm::SmallVector<mlir::Attribute, 4> attr_array;
87 for (auto key_and_value : attrs) {
88 if (!IsUnusedAttribute(key_and_value.first)) {
89 auto converted = ConvertAttribute(key_and_value.second);
90 if (!converted) return {};
91
92 mlir::StringAttr key =
93 builder_.getStringAttr(key_and_value.first.strref());
94 attr_array.push_back(builder_.getArrayAttr({key, converted}));
95 }
96 }
97 return builder_.getArrayAttr(attr_array);
98 }
99
CreateOpFuncAttrs(ArrayRef<NamedAttribute> attrs,llvm::SmallVector<mlir::Identifier,4> * func_attr_keys)100 mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs(
101 ArrayRef<NamedAttribute> attrs,
102 llvm::SmallVector<mlir::Identifier, 4> *func_attr_keys) {
103 llvm::SmallVector<mlir::Attribute, 4> attr_array;
104 for (auto key_and_value : attrs) {
105 auto attr_key = key_and_value.first;
106 auto attr_value = key_and_value.second;
107 if (!IsUnusedAttribute(attr_key) &&
108 attr_value.isa<mlir::FlatSymbolRefAttr, mlir::SymbolRefAttr>()) {
109 auto func_attr = attr_value.dyn_cast<mlir::FlatSymbolRefAttr>();
110 auto converted = ConvertSymbolAttrToStringAttr(func_attr);
111 mlir::StringAttr key = builder_.getStringAttr(attr_key.strref());
112 attr_array.push_back(builder_.getArrayAttr({key, converted}));
113
114 // Remove the attribute to avoid being converted again.
115 func_attr_keys->push_back(attr_key);
116 }
117 }
118 return builder_.getArrayAttr(attr_array);
119 }
120
121 // TODO(chky): Add support for multiple device instances.
ParseDeviceName(llvm::StringRef device_name) const122 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
123 llvm::StringRef device_name) const {
124 std::string tf_device_name = device_name.str();
125
126 if (tf_device_name.empty()) {
127 return llvm::None;
128 }
129
130 ParseDeviceNameResult result;
131 result.device_name = tf_device_name;
132
133 // Parse the device name in format of the current tensorflow.
134 DeviceNameUtils::ParsedName parsed_name;
135 if (!DeviceNameUtils::ParseFullName(result.device_name, &parsed_name)) {
136 return llvm::None;
137 }
138 if (!parsed_name.has_type) {
139 return llvm::None;
140 }
141 result.device_type = parsed_name.type;
142
143 result.op_handler_name = tf_device_name;
144
145 return result;
146 }
147
ParseDeviceName(mlir::Operation * op) const148 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
149 mlir::Operation *op) const {
150 auto device_attr = op->getAttr("device");
151 if (!device_attr) {
152 return llvm::None;
153 }
154
155 auto parsed_device_name =
156 ParseDeviceName(device_attr.cast<mlir::StringAttr>().getValue());
157 if (!parsed_device_name) op->emitWarning("failed to parse device name.");
158 return parsed_device_name;
159 }
160
ConvertOpHandler(mlir::Operation * op,llvm::StringRef op_handler_name,ConversionPatternRewriter * rewriter)161 mlir::Value CoreRTConverter::ConvertOpHandler(
162 mlir::Operation *op, llvm::StringRef op_handler_name,
163 ConversionPatternRewriter *rewriter) {
164 auto iter = op_handler_by_name_.find(op_handler_name);
165 if (iter != op_handler_by_name_.end()) return iter->second;
166
167 mlir::Block *block = op->getBlock();
168 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
169 rewriter->setInsertionPointToStart(block);
170
171 FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
172 mlir::Value in_chain = func_op.getArgument(0);
173 auto get_op_handler_op = rewriter->create<tfrt::corert::GetOpHandler>(
174 block->getParent()->getLoc(), op_handler_type(), in_chain,
175 op_handler_name);
176 op_handler_by_name_[op_handler_name] = get_op_handler_op.getResult();
177 return get_op_handler_op.getResult();
178 }
179
GetDistributedContext(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)180 mlir::Value CoreRTConverter::GetDistributedContext(
181 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
182 mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
183 auto iter = distributed_context_by_func_.find(func_op.getOperation());
184 if (iter != distributed_context_by_func_.end()) {
185 return iter->second;
186 }
187 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
188 rewriter->setInsertionPoint(op);
189 auto get_dist_ctx_op = rewriter->create<tfrt::dist::GetDistributedContextOp>(
190 op->getLoc(), distributed_context_type());
191
192 mlir::Value result = get_dist_ctx_op.result();
193 distributed_context_by_func_[func_op.getOperation()] = result;
194 return result;
195 }
196
GetRemoteChainManager(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)197 mlir::Value CoreRTConverter::GetRemoteChainManager(
198 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
199 mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
200 auto iter = remote_chain_mgr_by_func_.find(func_op.getOperation());
201 if (iter != remote_chain_mgr_by_func_.end()) {
202 return iter->second;
203 }
204 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
205 rewriter->setInsertionPoint(op);
206
207 mlir::Type remote_chain_mgr_type =
208 builder_.getType<::tfrt::dist::RemoteChainManagerType>();
209 mlir::Value dist_ctx = GetDistributedContext(op, rewriter);
210 auto create_mgr_op = rewriter->create<tfrt::dist::CreateRemoteChainManager>(
211 op->getLoc(), remote_chain_mgr_type, dist_ctx);
212
213 mlir::Value result = create_mgr_op.result();
214 remote_chain_mgr_by_func_[func_op.getOperation()] = result;
215 return result;
216 }
217
GetLocalSideEffectChain(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)218 mlir::Value CoreRTConverter::GetLocalSideEffectChain(
219 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
220 auto func_op = op->getParentOfType<mlir::FuncOp>();
221 auto predecessors = side_effect_analysis_.DirectControlPredecessors(op);
222
223 // If there is no side-effect predecessor, then the input side-effect chain
224 // is used.
225 if (predecessors.empty()) return func_op.getArgument(0);
226
227 llvm::SmallVector<mlir::Value, 2> chains;
228 for (auto *pred : predecessors) {
229 // TODO(chky): ReadVariableOp is removed in the pass and not converted.
230 // Ideally, every side-effecting op should be converted to a
231 // tfrt_fallback.executeop.seq op. The special rewrite logic of
232 // ReadVariableOp should be done in a previous pass.
233 if (auto chain = local_side_effect_chains_.lookup(pred))
234 chains.push_back(chain);
235 }
236
237 if (chains.empty()) return func_op.getArgument(0);
238
239 if (chains.size() == 1) return chains[0];
240
241 // If there are multiple side-effect predecessors, insert a merge_chains
242 // kernel and return the merged chain.
243 ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
244 rewriter->setInsertionPoint(op);
245 return rewriter->create<tfrt::compiler::MergeChainsOp>(op->getLoc(),
246 chain_type(), chains);
247 }
248
GetTaskHandle(mlir::Operation * op,StringRef task_name,mlir::ConversionPatternRewriter * rewriter)249 mlir::Value CoreRTConverter::GetTaskHandle(
250 mlir::Operation *op, StringRef task_name,
251 mlir::ConversionPatternRewriter *rewriter) {
252 mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
253 llvm::StringMap<mlir::Value> &task_handle_by_name =
254 task_handles_by_func_[func_op.getOperation()];
255 auto iter = task_handle_by_name.find(task_name);
256 if (iter != task_handle_by_name.end()) {
257 return iter->second;
258 }
259
260 mlir::Value distributed_context = GetDistributedContext(op, rewriter);
261 auto task_handle_op = rewriter->create<tfrt::dist::GetTaskHandleOp>(
262 op->getLoc(), rewriter->getType<tfrt::dist::TaskHandleType>(),
263 distributed_context, task_name);
264
265 task_handle_by_name[task_name] = task_handle_op.getResult();
266 return task_handle_op.getResult();
267 }
268
GetRemoteSideEffectChain(mlir::Operation * op,StringRef remote_host,mlir::ConversionPatternRewriter * rewriter)269 mlir::Value CoreRTConverter::GetRemoteSideEffectChain(
270 mlir::Operation *op, StringRef remote_host,
271 mlir::ConversionPatternRewriter *rewriter) {
272 mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter);
273 mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter);
274 mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter);
275 mlir::Type remote_obj_id_ty =
276 rewriter->getType<tfrt::dist::RemoteObjectIdType>();
277
278 // Get the remote chain using the tfrt_dist.get_chain_for_task_handle op.
279 auto get_chain_op = rewriter->create<tfrt::dist::GetChainForTaskHandleOp>(
280 op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr,
281 task_handle);
282 return get_chain_op.getResult();
283 }
284
ConvertAttribute(mlir::Attribute attr)285 mlir::Attribute CoreRTConverter::ConvertAttribute(mlir::Attribute attr) {
286 // The supported attributes here should be kept consistent with
287 // //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h
288 //
289 // Currently, not all tensorflow data types are supported. Unranked shape
290 // attributes are not supported yet.
291
292 // Return directly if the attribute is already supported.
293 if (attr.isa<mlir::IntegerAttr, mlir::FloatAttr, mlir::BoolAttr,
294 mlir::StringAttr, mlir::DenseIntOrFPElementsAttr>())
295 return attr;
296
297 // For type attributes, we convert non-standard MLIR types to corresponding
298 // corert types.
299 if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
300 if (auto shape_type = type_attr.getValue().dyn_cast<mlir::TensorType>()) {
301 if (!shape_type.hasRank())
302 return tfrt::corert::ShapeAttr::get(builder_.getContext());
303
304 return tfrt::corert::ShapeAttr::get(builder_.getContext(),
305 shape_type.getShape());
306 }
307
308 return ConvertTypeAttribute(type_attr);
309 }
310
311 // Convert the attribute to the corresponding format in TFRT dialect if
312 // needed.
313 if (auto shape_attr = attr.dyn_cast<mlir::TF::ShapeAttr>()) {
314 if (!shape_attr.hasRank())
315 return tfrt::corert::ShapeAttr::get(builder_.getContext());
316 return tfrt::corert::ShapeAttr::get(builder_.getContext(),
317 shape_attr.getShape());
318 }
319
320 // For arrays, we recursively convert the elements.
321 if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
322 llvm::SmallVector<mlir::Attribute, 8> attrs;
323 attrs.reserve(array_attr.size());
324 for (auto attr : array_attr) {
325 auto converted = ConvertAttribute(attr);
326 if (!converted) return {};
327 attrs.push_back(converted);
328 }
329 return builder_.getArrayAttr(attrs);
330 }
331
332 return {};
333 }
334
ConvertSymbolAttrToStringAttr(mlir::FlatSymbolRefAttr symbol_attr)335 mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr(
336 mlir::FlatSymbolRefAttr symbol_attr) {
337 // Currently in TF graph to MLIR importing, a "0" is appended to the original
338 // function name, so we pop it here. The renaming is for TF/XLA v1 bridge
339 // use cases. Refer to b/142268695, b/141617294 for more context.
340 //
341 // In TFRT use cases, in almost every case "0" is the only literal
342 // appended since TF Graph already guarantee function name uniqueness.
343 // TODO(b/172092902): Investigate a better way to make the tf_func_name to
344 // mlir_tf_func_name conversion reversible.
345 auto func_name = symbol_attr.getValue().drop_back().str();
346
347 return mlir::StringAttr::get(builder_.getContext(), func_name);
348 }
349
ConvertTypeAttribute(mlir::TypeAttr type_attr)350 mlir::TypeAttr CoreRTConverter::ConvertTypeAttribute(mlir::TypeAttr type_attr) {
351 auto type = type_attr.getValue();
352
353 if (IsSupportedNumericDType(type)) return type_attr;
354
355 // For TF custom types, we convert it to custom corert types.
356 if (type.isa<mlir::TF::StringType>())
357 return mlir::TypeAttr::get(
358 tfrt::corert::StringType::get(builder_.getContext()));
359
360 if (type.isa<mlir::TF::ResourceType>())
361 return mlir::TypeAttr::get(
362 tfrt::corert::ResourceType::get(builder_.getContext()));
363
364 if (type.isa<mlir::TF::VariantType>())
365 return mlir::TypeAttr::get(
366 tfrt::corert::VariantType::get(builder_.getContext()));
367
368 if (type.isa<mlir::TF::Quint8Type>()) {
369 return mlir::TypeAttr::get(
370 tfrt::corert::Quint8Type::get(builder_.getContext()));
371 }
372
373 if (type.isa<mlir::TF::Quint16Type>()) {
374 return mlir::TypeAttr::get(
375 tfrt::corert::Quint16Type::get(builder_.getContext()));
376 }
377
378 if (type.isa<mlir::TF::Qint8Type>()) {
379 return mlir::TypeAttr::get(
380 tfrt::corert::Qint8Type::get(builder_.getContext()));
381 }
382
383 if (type.isa<mlir::TF::Qint16Type>()) {
384 return mlir::TypeAttr::get(
385 tfrt::corert::Qint16Type::get(builder_.getContext()));
386 }
387
388 if (type.isa<mlir::TF::Qint32Type>()) {
389 return mlir::TypeAttr::get(
390 tfrt::corert::Qint32Type::get(builder_.getContext()));
391 }
392
393 // Return invalid results to emit error for unsupported types.
394 return {};
395 }
396
397 } // namespace tensorflow
398