• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
18 
19 #include <memory>
20 
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
24 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
25 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
26 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
27 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
28 
29 namespace tensorflow {
30 
31 struct ParseDeviceNameResult {
32   std::string device_type;
33   std::string device_name;
34   std::string op_handler_name;
35 };
36 
37 // A helper class for converting CoreRT types and attributes.
38 class CoreRTConverter : public mlir::TypeConverter {
39  public:
40   CoreRTConverter(
41       mlir::MLIRContext *context,
42       const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis);
43   // Materialize all derived attributes. Note that this is only needed by
44   // CoreRT ops and fallback ops.
45   void MaterializeDerivedAttributes(mlir::Operation *op);
46 
47   bool IsSupportedNumericDType(mlir::Type type) const;
48 
49   // Create a single attribute that contains the named attribute lists. It is an
50   // array of pairs. The key must be a string attribute, and the value can be
51   // any attribute that is supported by CoreRuntime.
52   mlir::ArrayAttr CreateOpAttrs(llvm::ArrayRef<mlir::NamedAttribute> attrs);
53 
54   // Similar to CreateOpAttrs, create a single attribute that contains the
55   // named attribute lists, which is an array of pairs, with keys and values
56   // both being string attributes. The values represent function names.
57   // This method also populates a vector of attribute keys to be removed.
58   mlir::ArrayAttr CreateOpFuncAttrs(
59       llvm::ArrayRef<mlir::NamedAttribute> attrs,
60       llvm::SmallVector<mlir::Identifier, 4> *func_attr_keys);
61 
62   // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0"
63   // will be parsed as "cpu". Return None if no device is assigned.
64   llvm::Optional<ParseDeviceNameResult> ParseDeviceName(
65       llvm::StringRef device_name) const;
66   llvm::Optional<ParseDeviceNameResult> ParseDeviceName(
67       mlir::Operation *op) const;
68 
69   // Convert the device name in a TF op to a op_handler value produced by the
70   // corresponding GetOpHandler in the current block. If there does not exist
71   // one, insert a GetOpHandler to the beginning of the block and return the
72   // device value.
73   mlir::Value ConvertOpHandler(mlir::Operation *op, llvm::StringRef device_name,
74                                mlir::ConversionPatternRewriter *rewriter);
75 
76   // Get a DistributedContext value to be used by the given op. The
77   // DistributedContext value should be shared by all operations in the body
78   // of the same FuncOp. If there does not exist one, insert a
79   // GetDistributedContext op right before the given op and return the result
80   // value.
81   mlir::Value GetDistributedContext(mlir::Operation *op,
82                                     mlir::ConversionPatternRewriter *rewriter);
83 
84   // Get a RemoteChainManager value to be used by the given op. The
85   // RemoteChainManager value should be shared by all operations in the body
86   // of the same FuncOp. If there does not exist one, insert a
87   // tfrt_dist.test_create_remote_chain_manager op right before the given op and
88   // return the result value.
89   mlir::Value GetRemoteChainManager(mlir::Operation *op,
90                                     mlir::ConversionPatternRewriter *rewriter);
91 
92   // Get a TaskHandle value with the given task name. If the TaskHandle value
93   // has already been created for the given task name within the same FuncOp,
94   // return this TaskHandle value. Otherwise, insert a tfrt_dist.get_task_handle
95   // op right before the given op and return the result value.
96   mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name,
97                             mlir::ConversionPatternRewriter *rewriter);
98 
99   // Any local operation which uses any result of the `op` should depend on the
100   // given `chain`.
RegisterLocalSideEffectChain(mlir::Operation * op,mlir::Value chain)101   void RegisterLocalSideEffectChain(mlir::Operation *op, mlir::Value chain) {
102     local_side_effect_chains_[op] = chain;
103   }
104 
105   // Return a local chain for side effects for `op`. If there are multiple
106   // chains, a merge_chains kernel will be inserted and the merged chain will be
107   // returned.
108   mlir::Value GetLocalSideEffectChain(
109       mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter);
110 
111   // Return a remote chain for side effects for `op`.
112   mlir::Value GetRemoteSideEffectChain(
113       mlir::Operation *op, StringRef remote_host,
114       mlir::ConversionPatternRewriter *rewriter);
115 
op_handler_type()116   mlir::Type op_handler_type() {
117     return builder_.getType<::tfrt::corert::OpHandlerType>();
118   }
119 
tensor_handle_type()120   mlir::Type tensor_handle_type() {
121     return builder_.getType<::tfrt::corert::TensorHandleType>();
122   }
123 
chain_type()124   mlir::Type chain_type() {
125     return builder_.getType<::tfrt::compiler::ChainType>();
126   }
127 
distributed_context_type()128   mlir::Type distributed_context_type() {
129     return builder_.getType<::tfrt::dist::DistributedContextType>();
130   }
131 
builder()132   mlir::Builder &builder() { return builder_; }
133 
134  private:
135   // TODO(chky): attributes "_output_shapes" should be removed by any tool that
136   // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this
137   // filtering logic once unused attributes are cleaned up in the upper layer.
IsUnusedAttribute(llvm::StringRef name)138   bool IsUnusedAttribute(llvm::StringRef name) const {
139     // NOTE: attributes "f.*" are function attribute related and
140     // are added during importing graph to MLIR TF Executor dialect. These
141     // attributes are not actually used by TF ops with function attributes.
142     // TODO(b/180399811): Re-evaluate the usage of these attributes.
143     return name == "_output_shapes" || name.contains("f.");
144   }
145 
146   // Returns the converted attribute in TFRT dialect. If the conversion fails,
147   // returns a null attribute instead.
148   mlir::Attribute ConvertAttribute(mlir::Attribute attr);
149 
150   mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr);
151 
152   mlir::StringAttr ConvertSymbolAttrToStringAttr(
153       mlir::FlatSymbolRefAttr symbol_attr);
154 
155   mlir::Builder builder_;
156 
157   const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis_;
158 
159   llvm::DenseMap<mlir::Operation *, mlir::Value> local_side_effect_chains_;
160   llvm::DenseMap<mlir::Operation *, mlir::Value> distributed_context_by_func_;
161   llvm::DenseMap<mlir::Operation *, mlir::Value> remote_chain_mgr_by_func_;
162   llvm::DenseMap<mlir::Operation *, llvm::StringMap<mlir::Value>>
163       task_handles_by_func_;
164   llvm::StringMap<mlir::Value> op_handler_by_name_;
165 };
166 
167 }  // namespace tensorflow
168 
169 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
170