1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 18 19 #include <array> 20 #include <memory> 21 22 #include "mlir/IR/Attributes.h" // from @llvm-project 23 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 24 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" 25 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime 26 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime 27 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime 28 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime 29 30 namespace tensorflow { 31 32 struct ParseDeviceNameResult { 33 std::string device_type; 34 std::string device_name; 35 std::string op_handler_name; 36 }; 37 38 // A helper class for converting CoreRT types and attributes. 39 class CoreRTConverter : public mlir::TypeConverter { 40 public: 41 CoreRTConverter( 42 mlir::MLIRContext *context, 43 const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis); 44 // Materialize all derived attributes. Note that this is only needed by 45 // CoreRT ops and fallback ops. 46 void MaterializeDerivedAttributes(mlir::Operation *op); 47 48 bool IsSupportedNumericDType(mlir::Type type) const; 49 50 // Create a single attribute that contains the named attribute lists. It is an 51 // array of pairs. The key must be a string attribute, and the value can be 52 // any attribute that is supported by CoreRuntime. 53 mlir::ArrayAttr CreateOpAttrs(llvm::ArrayRef<mlir::NamedAttribute> attrs); 54 55 // Similar to CreateOpAttrs, create a single attribute that contains the 56 // named attribute lists, which is an array of pairs, with keys and values 57 // both being string attributes. The values represent function names. 58 // This method also populates a vector of attribute keys to be removed. 59 mlir::ArrayAttr CreateOpFuncAttrs( 60 llvm::ArrayRef<mlir::NamedAttribute> attrs, 61 llvm::SmallVector<mlir::StringAttr, 4> *func_attr_keys); 62 63 // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0" 64 // will be parsed as "cpu". Return None if no device is assigned. 65 llvm::Optional<ParseDeviceNameResult> ParseDeviceName( 66 llvm::StringRef device_name) const; 67 llvm::Optional<ParseDeviceNameResult> ParseDeviceName( 68 mlir::Operation *op) const; 69 70 // Convert the device name in a TF op to a op_handler value produced by the 71 // corresponding GetOpHandler in the current block. If there does not exist 72 // one, insert a GetOpHandler to the beginning of the block and return the 73 // device value. 74 mlir::Value ConvertOpHandler(mlir::Operation *op, llvm::StringRef device_name, 75 mlir::ConversionPatternRewriter *rewriter); 76 77 // Get a DistributedContext value to be used by the given op. The 78 // DistributedContext value should be shared by all operations in the body 79 // of the same FuncOp. If there does not exist one, insert a 80 // GetDistributedContext op right before the given op and return the result 81 // value. 82 mlir::Value GetDistributedContext(mlir::Operation *op, 83 mlir::ConversionPatternRewriter *rewriter); 84 85 // Get a RemoteChainManager value to be used by the given op. The 86 // RemoteChainManager value should be shared by all operations in the body 87 // of the same FuncOp. If there does not exist one, insert a 88 // tfrt_dist.test_create_remote_chain_manager op right before the given op and 89 // return the result value. 90 mlir::Value GetRemoteChainManager(mlir::Operation *op, 91 mlir::ConversionPatternRewriter *rewriter); 92 93 // Get a TaskHandle value with the given task name. If the TaskHandle value 94 // has already been created for the given task name within the same FuncOp, 95 // return this TaskHandle value. Otherwise, insert a tfrt_dist.get_task_handle 96 // op right before the given op and return the result value. 97 mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name, 98 mlir::ConversionPatternRewriter *rewriter); 99 100 // Any local operation which uses any result of the `op` should depend on the 101 // given `chain`. RegisterLocalSideEffectChain(mlir::Operation * op,mlir::Value chain)102 void RegisterLocalSideEffectChain(mlir::Operation *op, mlir::Value chain) { 103 local_side_effect_chains_[op] = chain; 104 } 105 106 // Return a local chain for side effects for `op`. If there are multiple 107 // chains, a merge_chains kernel will be inserted and the merged chain will be 108 // returned. 109 mlir::Value GetLocalSideEffectChain( 110 mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); 111 112 // Return a remote chain for side effects for `op`. 113 mlir::Value GetRemoteSideEffectChain( 114 mlir::Operation *op, StringRef remote_host, 115 mlir::ConversionPatternRewriter *rewriter); 116 op_handler_type()117 mlir::Type op_handler_type() { 118 return builder_.getType<::tfrt::corert::OpHandlerType>(); 119 } 120 tensor_handle_type()121 mlir::Type tensor_handle_type() { 122 return builder_.getType<::tfrt::corert::TensorHandleType>(); 123 } 124 chain_type()125 mlir::Type chain_type() { 126 return builder_.getType<::tfrt::compiler::ChainType>(); 127 } 128 distributed_context_type()129 mlir::Type distributed_context_type() { 130 return builder_.getType<::tfrt::dist::DistributedContextType>(); 131 } 132 builder()133 mlir::Builder &builder() { return builder_; } 134 135 private: 136 // TODO(chky): attributes "_output_shapes" should be removed by any tool that 137 // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this 138 // filtering logic once unused attributes are cleaned up in the upper layer. IsUnusedAttribute(llvm::StringRef name)139 bool IsUnusedAttribute(llvm::StringRef name) const { 140 // NOTE: attributes "f.*" are function attribute related and 141 // are added during importing graph to MLIR TF Executor dialect. These 142 // attributes are not actually used by TF ops with function attributes. 143 // TODO(b/180399811): Re-evaluate the usage of these attributes. 144 static const char *const kUnusedAttributes[] = { 145 "_output_shapes", 146 "result_segment_sizes", 147 "operand_segment_sizes", 148 }; 149 150 for (auto attr : kUnusedAttributes) { 151 if (name == attr) { 152 return true; 153 } 154 } 155 156 return name.contains("f."); 157 } 158 159 // Returns the converted attribute in TFRT dialect. If the conversion fails, 160 // returns a null attribute instead. 161 mlir::Attribute ConvertAttribute(mlir::Attribute attr); 162 163 mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr); 164 165 mlir::StringAttr ConvertSymbolAttrToStringAttr( 166 mlir::FlatSymbolRefAttr symbol_attr); 167 168 mlir::Builder builder_; 169 170 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis_; 171 172 llvm::DenseMap<mlir::Operation *, mlir::Value> local_side_effect_chains_; 173 llvm::DenseMap<mlir::Operation *, mlir::Value> distributed_context_by_func_; 174 llvm::DenseMap<mlir::Operation *, mlir::Value> remote_chain_mgr_by_func_; 175 llvm::DenseMap<mlir::Operation *, llvm::StringMap<mlir::Value>> 176 task_handles_by_func_; 177 llvm::StringMap<mlir::Value> op_handler_by_name_; 178 }; 179 180 } // namespace tensorflow 181 182 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ 183