1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 18 19 #include <memory> 20 21 #include "mlir/IR/Attributes.h" // from @llvm-project 22 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 23 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" 24 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime 25 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime 26 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime 27 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime 28 29 namespace tensorflow { 30 31 struct ParseDeviceNameResult { 32 std::string device_type; 33 std::string device_name; 34 std::string op_handler_name; 35 }; 36 37 // A helper class for converting CoreRT types and attributes. 38 class CoreRTConverter : public mlir::TypeConverter { 39 public: 40 CoreRTConverter( 41 mlir::MLIRContext *context, 42 const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis); 43 // Materialize all derived attributes. Note that this is only needed by 44 // CoreRT ops and fallback ops. 45 void MaterializeDerivedAttributes(mlir::Operation *op); 46 47 bool IsSupportedNumericDType(mlir::Type type) const; 48 49 // Create a single attribute that contains the named attribute lists. It is an 50 // array of pairs. The key must be a string attribute, and the value can be 51 // any attribute that is supported by CoreRuntime. 52 mlir::ArrayAttr CreateOpAttrs(llvm::ArrayRef<mlir::NamedAttribute> attrs); 53 54 // Similar to CreateOpAttrs, create a single attribute that contains the 55 // named attribute lists, which is an array of pairs, with keys and values 56 // both being string attributes. The values represent function names. 57 // This method also populates a vector of attribute keys to be removed. 58 mlir::ArrayAttr CreateOpFuncAttrs( 59 llvm::ArrayRef<mlir::NamedAttribute> attrs, 60 llvm::SmallVector<mlir::Identifier, 4> *func_attr_keys); 61 62 // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0" 63 // will be parsed as "cpu". Return None if no device is assigned. 64 llvm::Optional<ParseDeviceNameResult> ParseDeviceName( 65 llvm::StringRef device_name) const; 66 llvm::Optional<ParseDeviceNameResult> ParseDeviceName( 67 mlir::Operation *op) const; 68 69 // Convert the device name in a TF op to a op_handler value produced by the 70 // corresponding GetOpHandler in the current block. If there does not exist 71 // one, insert a GetOpHandler to the beginning of the block and return the 72 // device value. 73 mlir::Value ConvertOpHandler(mlir::Operation *op, llvm::StringRef device_name, 74 mlir::ConversionPatternRewriter *rewriter); 75 76 // Get a DistributedContext value to be used by the given op. The 77 // DistributedContext value should be shared by all operations in the body 78 // of the same FuncOp. If there does not exist one, insert a 79 // GetDistributedContext op right before the given op and return the result 80 // value. 81 mlir::Value GetDistributedContext(mlir::Operation *op, 82 mlir::ConversionPatternRewriter *rewriter); 83 84 // Get a RemoteChainManager value to be used by the given op. The 85 // RemoteChainManager value should be shared by all operations in the body 86 // of the same FuncOp. If there does not exist one, insert a 87 // tfrt_dist.test_create_remote_chain_manager op right before the given op and 88 // return the result value. 89 mlir::Value GetRemoteChainManager(mlir::Operation *op, 90 mlir::ConversionPatternRewriter *rewriter); 91 92 // Get a TaskHandle value with the given task name. If the TaskHandle value 93 // has already been created for the given task name within the same FuncOp, 94 // return this TaskHandle value. Otherwise, insert a tfrt_dist.get_task_handle 95 // op right before the given op and return the result value. 96 mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name, 97 mlir::ConversionPatternRewriter *rewriter); 98 99 // Any local operation which uses any result of the `op` should depend on the 100 // given `chain`. RegisterLocalSideEffectChain(mlir::Operation * op,mlir::Value chain)101 void RegisterLocalSideEffectChain(mlir::Operation *op, mlir::Value chain) { 102 local_side_effect_chains_[op] = chain; 103 } 104 105 // Return a local chain for side effects for `op`. If there are multiple 106 // chains, a merge_chains kernel will be inserted and the merged chain will be 107 // returned. 108 mlir::Value GetLocalSideEffectChain( 109 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); 110 111 // Return a remote chain for side effects for `op`. 112 mlir::Value GetRemoteSideEffectChain( 113 mlir::Operation *op, StringRef remote_host, 114 mlir::ConversionPatternRewriter *rewriter); 115 op_handler_type()116 mlir::Type op_handler_type() { 117 return builder_.getType<::tfrt::corert::OpHandlerType>(); 118 } 119 tensor_handle_type()120 mlir::Type tensor_handle_type() { 121 return builder_.getType<::tfrt::corert::TensorHandleType>(); 122 } 123 chain_type()124 mlir::Type chain_type() { 125 return builder_.getType<::tfrt::compiler::ChainType>(); 126 } 127 distributed_context_type()128 mlir::Type distributed_context_type() { 129 return builder_.getType<::tfrt::dist::DistributedContextType>(); 130 } 131 builder()132 mlir::Builder &builder() { return builder_; } 133 134 private: 135 // TODO(chky): attributes "_output_shapes" should be removed by any tool that 136 // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this 137 // filtering logic once unused attributes are cleaned up in the upper layer. IsUnusedAttribute(llvm::StringRef name)138 bool IsUnusedAttribute(llvm::StringRef name) const { 139 // NOTE: attributes "f.*" are function attribute related and 140 // are added during importing graph to MLIR TF Executor dialect. These 141 // attributes are not actually used by TF ops with function attributes. 142 // TODO(b/180399811): Re-evaluate the usage of these attributes. 143 return name == "_output_shapes" || name.contains("f."); 144 } 145 146 // Returns the converted attribute in TFRT dialect. If the conversion fails, 147 // returns a null attribute instead. 148 mlir::Attribute ConvertAttribute(mlir::Attribute attr); 149 150 mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr); 151 152 mlir::StringAttr ConvertSymbolAttrToStringAttr( 153 mlir::FlatSymbolRefAttr symbol_attr); 154 155 mlir::Builder builder_; 156 157 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis_; 158 159 llvm::DenseMap<mlir::Operation *, mlir::Value> local_side_effect_chains_; 160 llvm::DenseMap<mlir::Operation *, mlir::Value> distributed_context_by_func_; 161 llvm::DenseMap<mlir::Operation *, mlir::Value> remote_chain_mgr_by_func_; 162 llvm::DenseMap<mlir::Operation *, llvm::StringMap<mlir::Value>> 163 task_handles_by_func_; 164 llvm::StringMap<mlir::Value> op_handler_by_name_; 165 }; 166 167 } // namespace tensorflow 168 169 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 170