• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
17 
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Identifier.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "mlir/IR/Types.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/attributes.h"  // from @tf_runtime
32 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
33 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
34 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
35 
36 namespace tensorflow {
37 
CoreRTConverter(mlir::MLIRContext * context,const mlir::TF::SideEffectAnalysis::Info * side_effect_analysis)38 CoreRTConverter::CoreRTConverter(
39     mlir::MLIRContext *context,
40     const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis)
41     : builder_(context), side_effect_analysis_(*side_effect_analysis) {
42   addConversion([](tfrt::compiler::ChainType type) { return type; });
43   addConversion([](tfrt::corert::OpHandlerType type) { return type; });
44   addConversion([](tfrt::dist::DistributedContextType type) { return type; });
45   addConversion([](tfrt::corert::TensorHandleType type) { return type; });
46   addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
47     // Ref types are not supported in both compiler and runtime.
48     if (type.getElementType().isa<mlir::TF::TensorFlowRefType>())
49       return llvm::None;
50     return tensor_handle_type();
51   });
52   addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
53     if (type == builder_.getI1Type()) return type;
54     return llvm::None;
55   });
56 }
57 
MaterializeDerivedAttributes(mlir::Operation * op)58 void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) {
59   if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
60     auto derived_attrs = interface.materializeDerivedAttributes();
61     for (auto named_attr : derived_attrs) {
62       op->setAttr(named_attr.first, named_attr.second);
63     }
64   }
65 }
66 
IsSupportedNumericDType(mlir::Type type) const67 bool CoreRTConverter::IsSupportedNumericDType(mlir::Type type) const {
68   // Most of the tensorflow data types (eg. f32, i64) are supported and they
69   // are standard MLIR types that need no conversion here.
70   if (type.isBF16() || type.isF16() || type.isF32() || type.isF64() ||
71       type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
72       type.isInteger(32) || type.isInteger(64) || type.isUnsignedInteger(8) ||
73       type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
74       type.isUnsignedInteger(64))
75     return true;
76 
77   if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
78     auto element_type = complex_type.getElementType();
79     if (element_type.isF32() || element_type.isF64()) return true;
80   }
81 
82   return false;
83 }
84 
CreateOpAttrs(ArrayRef<NamedAttribute> attrs)85 mlir::ArrayAttr CoreRTConverter::CreateOpAttrs(ArrayRef<NamedAttribute> attrs) {
86   llvm::SmallVector<mlir::Attribute, 4> attr_array;
87   for (auto key_and_value : attrs) {
88     if (!IsUnusedAttribute(key_and_value.first)) {
89       auto converted = ConvertAttribute(key_and_value.second);
90       if (!converted) return {};
91 
92       mlir::StringAttr key =
93           builder_.getStringAttr(key_and_value.first.strref());
94       attr_array.push_back(builder_.getArrayAttr({key, converted}));
95     }
96   }
97   return builder_.getArrayAttr(attr_array);
98 }
99 
CreateOpFuncAttrs(ArrayRef<NamedAttribute> attrs,llvm::SmallVector<mlir::Identifier,4> * func_attr_keys)100 mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs(
101     ArrayRef<NamedAttribute> attrs,
102     llvm::SmallVector<mlir::Identifier, 4> *func_attr_keys) {
103   llvm::SmallVector<mlir::Attribute, 4> attr_array;
104   for (auto key_and_value : attrs) {
105     auto attr_key = key_and_value.first;
106     auto attr_value = key_and_value.second;
107     if (!IsUnusedAttribute(attr_key) &&
108         attr_value.isa<mlir::FlatSymbolRefAttr, mlir::SymbolRefAttr>()) {
109       auto func_attr = attr_value.dyn_cast<mlir::FlatSymbolRefAttr>();
110       auto converted = ConvertSymbolAttrToStringAttr(func_attr);
111       mlir::StringAttr key = builder_.getStringAttr(attr_key.strref());
112       attr_array.push_back(builder_.getArrayAttr({key, converted}));
113 
114       // Remove the attribute to avoid being converted again.
115       func_attr_keys->push_back(attr_key);
116     }
117   }
118   return builder_.getArrayAttr(attr_array);
119 }
120 
121 // TODO(chky): Add support for multiple device instances.
ParseDeviceName(llvm::StringRef device_name) const122 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
123     llvm::StringRef device_name) const {
124   std::string tf_device_name = device_name.str();
125 
126   if (tf_device_name.empty()) {
127     return llvm::None;
128   }
129 
130   ParseDeviceNameResult result;
131   result.device_name = tf_device_name;
132 
133   // Parse the device name in format of the current tensorflow.
134   DeviceNameUtils::ParsedName parsed_name;
135   if (!DeviceNameUtils::ParseFullName(result.device_name, &parsed_name)) {
136     return llvm::None;
137   }
138   if (!parsed_name.has_type) {
139     return llvm::None;
140   }
141   result.device_type = parsed_name.type;
142 
143   result.op_handler_name = tf_device_name;
144 
145   return result;
146 }
147 
ParseDeviceName(mlir::Operation * op) const148 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
149     mlir::Operation *op) const {
150   auto device_attr = op->getAttr("device");
151   if (!device_attr) {
152     return llvm::None;
153   }
154 
155   auto parsed_device_name =
156       ParseDeviceName(device_attr.cast<mlir::StringAttr>().getValue());
157   if (!parsed_device_name) op->emitWarning("failed to parse device name.");
158   return parsed_device_name;
159 }
160 
ConvertOpHandler(mlir::Operation * op,llvm::StringRef op_handler_name,ConversionPatternRewriter * rewriter)161 mlir::Value CoreRTConverter::ConvertOpHandler(
162     mlir::Operation *op, llvm::StringRef op_handler_name,
163     ConversionPatternRewriter *rewriter) {
164   auto iter = op_handler_by_name_.find(op_handler_name);
165   if (iter != op_handler_by_name_.end()) return iter->second;
166 
167   mlir::Block *block = op->getBlock();
168   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
169   rewriter->setInsertionPointToStart(block);
170 
171   FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
172   mlir::Value in_chain = func_op.getArgument(0);
173   auto get_op_handler_op = rewriter->create<tfrt::corert::GetOpHandler>(
174       block->getParent()->getLoc(), op_handler_type(), in_chain,
175       op_handler_name);
176   op_handler_by_name_[op_handler_name] = get_op_handler_op.getResult();
177   return get_op_handler_op.getResult();
178 }
179 
GetDistributedContext(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)180 mlir::Value CoreRTConverter::GetDistributedContext(
181     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
182   mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
183   auto iter = distributed_context_by_func_.find(func_op.getOperation());
184   if (iter != distributed_context_by_func_.end()) {
185     return iter->second;
186   }
187   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
188   rewriter->setInsertionPoint(op);
189   auto get_dist_ctx_op = rewriter->create<tfrt::dist::GetDistributedContextOp>(
190       op->getLoc(), distributed_context_type());
191 
192   mlir::Value result = get_dist_ctx_op.result();
193   distributed_context_by_func_[func_op.getOperation()] = result;
194   return result;
195 }
196 
GetRemoteChainManager(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)197 mlir::Value CoreRTConverter::GetRemoteChainManager(
198     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
199   mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
200   auto iter = remote_chain_mgr_by_func_.find(func_op.getOperation());
201   if (iter != remote_chain_mgr_by_func_.end()) {
202     return iter->second;
203   }
204   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
205   rewriter->setInsertionPoint(op);
206 
207   mlir::Type remote_chain_mgr_type =
208       builder_.getType<::tfrt::dist::RemoteChainManagerType>();
209   mlir::Value dist_ctx = GetDistributedContext(op, rewriter);
210   auto create_mgr_op = rewriter->create<tfrt::dist::CreateRemoteChainManager>(
211       op->getLoc(), remote_chain_mgr_type, dist_ctx);
212 
213   mlir::Value result = create_mgr_op.result();
214   remote_chain_mgr_by_func_[func_op.getOperation()] = result;
215   return result;
216 }
217 
GetLocalSideEffectChain(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)218 mlir::Value CoreRTConverter::GetLocalSideEffectChain(
219     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
220   auto func_op = op->getParentOfType<mlir::FuncOp>();
221   auto predecessors = side_effect_analysis_.DirectControlPredecessors(op);
222 
223   // If there is no side-effect predecessor, then the input side-effect chain
224   // is used.
225   if (predecessors.empty()) return func_op.getArgument(0);
226 
227   llvm::SmallVector<mlir::Value, 2> chains;
228   for (auto *pred : predecessors) {
229     // TODO(chky): ReadVariableOp is removed in the pass and not converted.
230     // Ideally, every side-effecting op should be converted to a
231     // tfrt_fallback.executeop.seq op. The special rewrite logic of
232     // ReadVariableOp should be done in a previous pass.
233     if (auto chain = local_side_effect_chains_.lookup(pred))
234       chains.push_back(chain);
235   }
236 
237   if (chains.empty()) return func_op.getArgument(0);
238 
239   if (chains.size() == 1) return chains[0];
240 
241   // If there are multiple side-effect predecessors, insert a merge_chains
242   // kernel and return the merged chain.
243   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
244   rewriter->setInsertionPoint(op);
245   return rewriter->create<tfrt::compiler::MergeChainsOp>(op->getLoc(),
246                                                          chain_type(), chains);
247 }
248 
GetTaskHandle(mlir::Operation * op,StringRef task_name,mlir::ConversionPatternRewriter * rewriter)249 mlir::Value CoreRTConverter::GetTaskHandle(
250     mlir::Operation *op, StringRef task_name,
251     mlir::ConversionPatternRewriter *rewriter) {
252   mlir::FuncOp func_op = op->getParentOfType<mlir::FuncOp>();
253   llvm::StringMap<mlir::Value> &task_handle_by_name =
254       task_handles_by_func_[func_op.getOperation()];
255   auto iter = task_handle_by_name.find(task_name);
256   if (iter != task_handle_by_name.end()) {
257     return iter->second;
258   }
259 
260   mlir::Value distributed_context = GetDistributedContext(op, rewriter);
261   auto task_handle_op = rewriter->create<tfrt::dist::GetTaskHandleOp>(
262       op->getLoc(), rewriter->getType<tfrt::dist::TaskHandleType>(),
263       distributed_context, task_name);
264 
265   task_handle_by_name[task_name] = task_handle_op.getResult();
266   return task_handle_op.getResult();
267 }
268 
GetRemoteSideEffectChain(mlir::Operation * op,StringRef remote_host,mlir::ConversionPatternRewriter * rewriter)269 mlir::Value CoreRTConverter::GetRemoteSideEffectChain(
270     mlir::Operation *op, StringRef remote_host,
271     mlir::ConversionPatternRewriter *rewriter) {
272   mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter);
273   mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter);
274   mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter);
275   mlir::Type remote_obj_id_ty =
276       rewriter->getType<tfrt::dist::RemoteObjectIdType>();
277 
278   // Get the remote chain using the tfrt_dist.get_chain_for_task_handle op.
279   auto get_chain_op = rewriter->create<tfrt::dist::GetChainForTaskHandleOp>(
280       op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr,
281       task_handle);
282   return get_chain_op.getResult();
283 }
284 
ConvertAttribute(mlir::Attribute attr)285 mlir::Attribute CoreRTConverter::ConvertAttribute(mlir::Attribute attr) {
286   // The supported attributes here should be kept consistent with
287   // //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h
288   //
289   // Currently, not all tensorflow data types are supported. Unranked shape
290   // attributes are not supported yet.
291 
292   // Return directly if the attribute is already supported.
293   if (attr.isa<mlir::IntegerAttr, mlir::FloatAttr, mlir::BoolAttr,
294                mlir::StringAttr, mlir::DenseIntOrFPElementsAttr>())
295     return attr;
296 
297   // For type attributes, we convert non-standard MLIR types to corresponding
298   // corert types.
299   if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
300     if (auto shape_type = type_attr.getValue().dyn_cast<mlir::TensorType>()) {
301       if (!shape_type.hasRank())
302         return tfrt::corert::ShapeAttr::get(builder_.getContext());
303 
304       return tfrt::corert::ShapeAttr::get(builder_.getContext(),
305                                           shape_type.getShape());
306     }
307 
308     return ConvertTypeAttribute(type_attr);
309   }
310 
311   // Convert the attribute to the corresponding format in TFRT dialect if
312   // needed.
313   if (auto shape_attr = attr.dyn_cast<mlir::TF::ShapeAttr>()) {
314     if (!shape_attr.hasRank())
315       return tfrt::corert::ShapeAttr::get(builder_.getContext());
316     return tfrt::corert::ShapeAttr::get(builder_.getContext(),
317                                         shape_attr.getShape());
318   }
319 
320   // For arrays, we recursively convert the elements.
321   if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
322     llvm::SmallVector<mlir::Attribute, 8> attrs;
323     attrs.reserve(array_attr.size());
324     for (auto attr : array_attr) {
325       auto converted = ConvertAttribute(attr);
326       if (!converted) return {};
327       attrs.push_back(converted);
328     }
329     return builder_.getArrayAttr(attrs);
330   }
331 
332   return {};
333 }
334 
ConvertSymbolAttrToStringAttr(mlir::FlatSymbolRefAttr symbol_attr)335 mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr(
336     mlir::FlatSymbolRefAttr symbol_attr) {
337   // Currently in TF graph to MLIR importing, a "0" is appended to the original
338   // function name, so we pop it here. The renaming is for TF/XLA v1 bridge
339   // use cases. Refer to b/142268695, b/141617294 for more context.
340   //
341   // In TFRT use cases, in almost every case "0" is the only literal
342   // appended since TF Graph already guarantee function name uniqueness.
343   // TODO(b/172092902): Investigate a better way to make the tf_func_name to
344   // mlir_tf_func_name conversion reversible.
345   auto func_name = symbol_attr.getValue().drop_back().str();
346 
347   return mlir::StringAttr::get(builder_.getContext(), func_name);
348 }
349 
ConvertTypeAttribute(mlir::TypeAttr type_attr)350 mlir::TypeAttr CoreRTConverter::ConvertTypeAttribute(mlir::TypeAttr type_attr) {
351   auto type = type_attr.getValue();
352 
353   if (IsSupportedNumericDType(type)) return type_attr;
354 
355   // For TF custom types, we convert it to custom corert types.
356   if (type.isa<mlir::TF::StringType>())
357     return mlir::TypeAttr::get(
358         tfrt::corert::StringType::get(builder_.getContext()));
359 
360   if (type.isa<mlir::TF::ResourceType>())
361     return mlir::TypeAttr::get(
362         tfrt::corert::ResourceType::get(builder_.getContext()));
363 
364   if (type.isa<mlir::TF::VariantType>())
365     return mlir::TypeAttr::get(
366         tfrt::corert::VariantType::get(builder_.getContext()));
367 
368   if (type.isa<mlir::TF::Quint8Type>()) {
369     return mlir::TypeAttr::get(
370         tfrt::corert::Quint8Type::get(builder_.getContext()));
371   }
372 
373   if (type.isa<mlir::TF::Quint16Type>()) {
374     return mlir::TypeAttr::get(
375         tfrt::corert::Quint16Type::get(builder_.getContext()));
376   }
377 
378   if (type.isa<mlir::TF::Qint8Type>()) {
379     return mlir::TypeAttr::get(
380         tfrt::corert::Qint8Type::get(builder_.getContext()));
381   }
382 
383   if (type.isa<mlir::TF::Qint16Type>()) {
384     return mlir::TypeAttr::get(
385         tfrt::corert::Qint16Type::get(builder_.getContext()));
386   }
387 
388   if (type.isa<mlir::TF::Qint32Type>()) {
389     return mlir::TypeAttr::get(
390         tfrt::corert::Qint32Type::get(builder_.getContext()));
391   }
392 
393   // Return invalid results to emit error for unsupported types.
394   return {};
395 }
396 
397 }  // namespace tensorflow
398