• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OperationSupport.h"
22 #include "mlir/IR/Types.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/core/util/device_name_utils.h"
29 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/attributes.h"  // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
32 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
33 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
34 
35 namespace tensorflow {
36 
CoreRTConverter(mlir::MLIRContext * context,const mlir::TF::SideEffectAnalysis::Info * side_effect_analysis)37 CoreRTConverter::CoreRTConverter(
38     mlir::MLIRContext *context,
39     const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis)
40     : builder_(context), side_effect_analysis_(*side_effect_analysis) {
41   addConversion([](tfrt::compiler::ChainType type) { return type; });
42   addConversion([](tfrt::corert::OpHandlerType type) { return type; });
43   addConversion([](tfrt::dist::DistributedContextType type) { return type; });
44   addConversion([](tfrt::corert::TensorHandleType type) { return type; });
45   addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
46     // Ref types are not supported in both compiler and runtime.
47     if (type.getElementType().isa<mlir::TF::TensorFlowRefType>())
48       return llvm::None;
49     return tensor_handle_type();
50   });
51   addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
52     if (type == builder_.getI1Type()) return type;
53     return llvm::None;
54   });
55 }
56 
MaterializeDerivedAttributes(mlir::Operation * op)57 void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) {
58   if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
59     auto derived_attrs = interface.materializeDerivedAttributes();
60     for (auto named_attr : derived_attrs) {
61       op->setAttr(named_attr.getName(), named_attr.getValue());
62     }
63   }
64 }
65 
IsSupportedNumericDType(mlir::Type type) const66 bool CoreRTConverter::IsSupportedNumericDType(mlir::Type type) const {
67   // Most of the tensorflow data types (eg. f32, i64) are supported and they
68   // are standard MLIR types that need no conversion here.
69   if (type.isBF16() || type.isF16() || type.isF32() || type.isF64() ||
70       type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
71       type.isInteger(32) || type.isInteger(64) || type.isUnsignedInteger(8) ||
72       type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
73       type.isUnsignedInteger(64))
74     return true;
75 
76   if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
77     auto element_type = complex_type.getElementType();
78     if (element_type.isF32() || element_type.isF64()) return true;
79   }
80 
81   return false;
82 }
83 
CreateOpAttrs(ArrayRef<NamedAttribute> attrs)84 mlir::ArrayAttr CoreRTConverter::CreateOpAttrs(ArrayRef<NamedAttribute> attrs) {
85   llvm::SmallVector<mlir::Attribute, 4> attr_array;
86   for (auto key_and_value : attrs) {
87     if (!IsUnusedAttribute(key_and_value.getName())) {
88       auto converted = ConvertAttribute(key_and_value.getValue());
89       if (!converted) return {};
90 
91       mlir::StringAttr key =
92           builder_.getStringAttr(key_and_value.getName().strref());
93       attr_array.push_back(builder_.getArrayAttr({key, converted}));
94     }
95   }
96   return builder_.getArrayAttr(attr_array);
97 }
98 
CreateOpFuncAttrs(ArrayRef<NamedAttribute> attrs,llvm::SmallVector<mlir::StringAttr,4> * func_attr_keys)99 mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs(
100     ArrayRef<NamedAttribute> attrs,
101     llvm::SmallVector<mlir::StringAttr, 4> *func_attr_keys) {
102   llvm::SmallVector<mlir::Attribute, 4> attr_array;
103   for (auto key_and_value : attrs) {
104     auto attr_key = key_and_value.getName();
105     auto attr_value = key_and_value.getValue();
106     if (!IsUnusedAttribute(attr_key) &&
107         attr_value.isa<mlir::FlatSymbolRefAttr, mlir::SymbolRefAttr>()) {
108       auto func_attr = attr_value.dyn_cast<mlir::FlatSymbolRefAttr>();
109       auto converted = ConvertSymbolAttrToStringAttr(func_attr);
110       mlir::StringAttr key = builder_.getStringAttr(attr_key.strref());
111       attr_array.push_back(builder_.getArrayAttr({key, converted}));
112 
113       // Remove the attribute to avoid being converted again.
114       func_attr_keys->push_back(attr_key);
115     }
116   }
117   return builder_.getArrayAttr(attr_array);
118 }
119 
120 // TODO(chky): Add support for multiple device instances.
ParseDeviceName(llvm::StringRef device_name) const121 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
122     llvm::StringRef device_name) const {
123   std::string tf_device_name = device_name.str();
124 
125   if (tf_device_name.empty()) {
126     return llvm::None;
127   }
128 
129   ParseDeviceNameResult result;
130   result.device_name = tf_device_name;
131 
132   // Parse the device name in format of the current tensorflow.
133   DeviceNameUtils::ParsedName parsed_name;
134   if (!DeviceNameUtils::ParseFullName(result.device_name, &parsed_name)) {
135     return llvm::None;
136   }
137   if (!parsed_name.has_type) {
138     return llvm::None;
139   }
140   result.device_type = parsed_name.type;
141 
142   result.op_handler_name = tf_device_name;
143 
144   return result;
145 }
146 
ParseDeviceName(mlir::Operation * op) const147 llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
148     mlir::Operation *op) const {
149   auto device_attr = op->getAttr("device");
150   if (!device_attr) {
151     return llvm::None;
152   }
153 
154   auto parsed_device_name =
155       ParseDeviceName(device_attr.cast<mlir::StringAttr>().getValue());
156   if (!parsed_device_name) op->emitWarning("failed to parse device name.");
157   return parsed_device_name;
158 }
159 
ConvertOpHandler(mlir::Operation * op,llvm::StringRef op_handler_name,ConversionPatternRewriter * rewriter)160 mlir::Value CoreRTConverter::ConvertOpHandler(
161     mlir::Operation *op, llvm::StringRef op_handler_name,
162     ConversionPatternRewriter *rewriter) {
163   auto iter = op_handler_by_name_.find(op_handler_name);
164   if (iter != op_handler_by_name_.end()) return iter->second;
165 
166   mlir::Block *block = op->getBlock();
167   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
168   rewriter->setInsertionPointToStart(block);
169 
170   func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
171   mlir::Value in_chain = func_op.getArgument(0);
172   auto get_op_handler_op = rewriter->create<tfrt::corert::GetOpHandler>(
173       block->getParent()->getLoc(), op_handler_type(), in_chain,
174       op_handler_name);
175   op_handler_by_name_[op_handler_name] = get_op_handler_op.getResult();
176   return get_op_handler_op.getResult();
177 }
178 
GetDistributedContext(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)179 mlir::Value CoreRTConverter::GetDistributedContext(
180     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
181   mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
182   auto iter = distributed_context_by_func_.find(func_op.getOperation());
183   if (iter != distributed_context_by_func_.end()) {
184     return iter->second;
185   }
186   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
187   rewriter->setInsertionPoint(op);
188   auto get_dist_ctx_op = rewriter->create<tfrt::dist::GetDistributedContextOp>(
189       op->getLoc(), distributed_context_type());
190 
191   mlir::Value result = get_dist_ctx_op.result();
192   distributed_context_by_func_[func_op.getOperation()] = result;
193   return result;
194 }
195 
GetRemoteChainManager(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)196 mlir::Value CoreRTConverter::GetRemoteChainManager(
197     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
198   mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
199   auto iter = remote_chain_mgr_by_func_.find(func_op.getOperation());
200   if (iter != remote_chain_mgr_by_func_.end()) {
201     return iter->second;
202   }
203   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
204   rewriter->setInsertionPoint(op);
205 
206   mlir::Type remote_chain_mgr_type =
207       builder_.getType<::tfrt::dist::RemoteChainManagerType>();
208   mlir::Value dist_ctx = GetDistributedContext(op, rewriter);
209   auto create_mgr_op = rewriter->create<tfrt::dist::CreateRemoteChainManager>(
210       op->getLoc(), remote_chain_mgr_type, dist_ctx);
211 
212   mlir::Value result = create_mgr_op.result();
213   remote_chain_mgr_by_func_[func_op.getOperation()] = result;
214   return result;
215 }
216 
GetLocalSideEffectChain(mlir::Operation * op,mlir::ConversionPatternRewriter * rewriter)217 mlir::Value CoreRTConverter::GetLocalSideEffectChain(
218     mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
219   auto func_op = op->getParentOfType<mlir::func::FuncOp>();
220 
221   llvm::SmallVector<mlir::Operation *, 4> predecessors;
222   if (llvm::isa<mlir::func::ReturnOp>(op)) {
223     auto sinks = side_effect_analysis_.ControlSinks();
224     predecessors.assign(sinks.begin(), sinks.end());
225   } else {
226     predecessors = side_effect_analysis_.DirectControlPredecessors(op);
227   }
228 
229   llvm::SmallVector<mlir::Value, 2> chains;
230   for (auto *pred : predecessors) {
231     // TODO(chky): ReadVariableOp is removed in the pass and not converted.
232     // Ideally, every side-effecting op should be converted to a
233     // tfrt_fallback.executeop.seq op. The special rewrite logic of
234     // ReadVariableOp should be done in a previous pass.
235     if (auto chain = local_side_effect_chains_.lookup(pred))
236       chains.push_back(chain);
237   }
238 
239   // If there is no side-effect predecessor, then the input side-effect chain
240   // is used.
241   if (chains.empty()) return func_op.getArgument(0);
242 
243   if (chains.size() == 1) return chains[0];
244 
245   // If there are multiple side-effect predecessors, insert a merge_chains
246   // kernel and return the merged chain.
247   ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
248   rewriter->setInsertionPoint(op);
249   return rewriter->create<tfrt::compiler::MergeChainsOp>(op->getLoc(),
250                                                          chain_type(), chains);
251 }
252 
GetTaskHandle(mlir::Operation * op,StringRef task_name,mlir::ConversionPatternRewriter * rewriter)253 mlir::Value CoreRTConverter::GetTaskHandle(
254     mlir::Operation *op, StringRef task_name,
255     mlir::ConversionPatternRewriter *rewriter) {
256   mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
257   llvm::StringMap<mlir::Value> &task_handle_by_name =
258       task_handles_by_func_[func_op.getOperation()];
259   auto iter = task_handle_by_name.find(task_name);
260   if (iter != task_handle_by_name.end()) {
261     return iter->second;
262   }
263 
264   mlir::Value distributed_context = GetDistributedContext(op, rewriter);
265   auto task_handle_op = rewriter->create<tfrt::dist::GetTaskHandleOp>(
266       op->getLoc(), rewriter->getType<tfrt::dist::TaskHandleType>(),
267       distributed_context, task_name);
268 
269   task_handle_by_name[task_name] = task_handle_op.getResult();
270   return task_handle_op.getResult();
271 }
272 
GetRemoteSideEffectChain(mlir::Operation * op,StringRef remote_host,mlir::ConversionPatternRewriter * rewriter)273 mlir::Value CoreRTConverter::GetRemoteSideEffectChain(
274     mlir::Operation *op, StringRef remote_host,
275     mlir::ConversionPatternRewriter *rewriter) {
276   mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter);
277   mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter);
278   mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter);
279   mlir::Type remote_obj_id_ty =
280       rewriter->getType<tfrt::dist::RemoteObjectIdType>();
281 
282   // Get the remote chain using the tfrt_dist.get_chain_for_task_handle op.
283   auto get_chain_op = rewriter->create<tfrt::dist::GetChainForTaskHandleOp>(
284       op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr,
285       task_handle);
286   return get_chain_op.getResult();
287 }
288 
ConvertAttribute(mlir::Attribute attr)289 mlir::Attribute CoreRTConverter::ConvertAttribute(mlir::Attribute attr) {
290   // The supported attributes here should be kept consistent with
291   // //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h
292   //
293   // Currently, not all tensorflow data types are supported. Unranked shape
294   // attributes are not supported yet.
295 
296   // Return directly if the attribute is already supported.
297   if (attr.isa<mlir::IntegerAttr, mlir::FloatAttr, mlir::BoolAttr,
298                mlir::StringAttr, mlir::DenseIntOrFPElementsAttr, mlir::DenseI32ArrayAttr>())
299     return attr;
300 
301   // For type attributes, we convert non-standard MLIR types to corresponding
302   // corert types.
303   if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
304     if (auto shape_type = type_attr.getValue().dyn_cast<mlir::TensorType>()) {
305       if (!shape_type.hasRank())
306         return tfrt::corert::ShapeAttr::get(builder_.getContext());
307 
308       return tfrt::corert::ShapeAttr::get(builder_.getContext(),
309                                           shape_type.getShape());
310     }
311 
312     return ConvertTypeAttribute(type_attr);
313   }
314 
315   // Convert the attribute to the corresponding format in TFRT dialect if
316   // needed.
317   if (auto shape_attr = attr.dyn_cast<mlir::TF::ShapeAttr>()) {
318     if (!shape_attr.hasRank())
319       return tfrt::corert::ShapeAttr::get(builder_.getContext());
320     return tfrt::corert::ShapeAttr::get(builder_.getContext(),
321                                         shape_attr.getShape());
322   }
323 
324   // For arrays, we recursively convert the elements.
325   if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
326     llvm::SmallVector<mlir::Attribute, 8> attrs;
327     attrs.reserve(array_attr.size());
328     for (auto attr : array_attr) {
329       auto converted = ConvertAttribute(attr);
330       if (!converted) return {};
331       attrs.push_back(converted);
332     }
333     return builder_.getArrayAttr(attrs);
334   }
335 
336   return {};
337 }
338 
ConvertSymbolAttrToStringAttr(mlir::FlatSymbolRefAttr symbol_attr)339 mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr(
340     mlir::FlatSymbolRefAttr symbol_attr) {
341   // Currently in TF graph to MLIR importing, a "0" is appended to the original
342   // function name, so we pop it here. The renaming is for TF/XLA v1 bridge
343   // use cases. Refer to b/142268695, b/141617294 for more context.
344   //
345   // In TFRT use cases, in almost every case "0" is the only literal
346   // appended since TF Graph already guarantee function name uniqueness.
347   // TODO(b/172092902): Investigate a better way to make the tf_func_name to
348   // mlir_tf_func_name conversion reversible.
349   auto func_name = symbol_attr.getValue().drop_back().str();
350 
351   return mlir::StringAttr::get(builder_.getContext(), func_name);
352 }
353 
ConvertTypeAttribute(mlir::TypeAttr type_attr)354 mlir::TypeAttr CoreRTConverter::ConvertTypeAttribute(mlir::TypeAttr type_attr) {
355   auto type = type_attr.getValue();
356 
357   if (IsSupportedNumericDType(type)) return type_attr;
358 
359   // For TF custom types, we convert it to custom corert types.
360   if (type.isa<mlir::TF::StringType>())
361     return mlir::TypeAttr::get(
362         tfrt::corert::StringType::get(builder_.getContext()));
363 
364   if (type.isa<mlir::TF::ResourceType>())
365     return mlir::TypeAttr::get(
366         tfrt::corert::ResourceType::get(builder_.getContext()));
367 
368   if (type.isa<mlir::TF::VariantType>())
369     return mlir::TypeAttr::get(
370         tfrt::corert::VariantType::get(builder_.getContext()));
371 
372   if (type.isa<mlir::TF::Quint8Type>()) {
373     return mlir::TypeAttr::get(
374         tfrt::corert::Quint8Type::get(builder_.getContext()));
375   }
376 
377   if (type.isa<mlir::TF::Quint16Type>()) {
378     return mlir::TypeAttr::get(
379         tfrt::corert::Quint16Type::get(builder_.getContext()));
380   }
381 
382   if (type.isa<mlir::TF::Qint8Type>()) {
383     return mlir::TypeAttr::get(
384         tfrt::corert::Qint8Type::get(builder_.getContext()));
385   }
386 
387   if (type.isa<mlir::TF::Qint16Type>()) {
388     return mlir::TypeAttr::get(
389         tfrt::corert::Qint16Type::get(builder_.getContext()));
390   }
391 
392   if (type.isa<mlir::TF::Qint32Type>()) {
393     return mlir::TypeAttr::get(
394         tfrt::corert::Qint32Type::get(builder_.getContext()));
395   }
396 
397   // Return invalid results to emit error for unsupported types.
398   return {};
399 }
400 
401 }  // namespace tensorflow
402