• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
18 
19 #include <array>
20 #include <memory>
21 
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
25 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
26 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
27 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
28 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
29 
30 namespace tensorflow {
31 
32 struct ParseDeviceNameResult {
33   std::string device_type;
34   std::string device_name;
35   std::string op_handler_name;
36 };
37 
38 // A helper class for converting CoreRT types and attributes.
39 class CoreRTConverter : public mlir::TypeConverter {
40  public:
41   CoreRTConverter(
42       mlir::MLIRContext *context,
43       const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis);
44   // Materialize all derived attributes. Note that this is only needed by
45   // CoreRT ops and fallback ops.
46   void MaterializeDerivedAttributes(mlir::Operation *op);
47 
48   bool IsSupportedNumericDType(mlir::Type type) const;
49 
50   // Create a single attribute that contains the named attribute lists. It is an
51   // array of pairs. The key must be a string attribute, and the value can be
52   // any attribute that is supported by CoreRuntime.
53   mlir::ArrayAttr CreateOpAttrs(llvm::ArrayRef<mlir::NamedAttribute> attrs);
54 
55   // Similar to CreateOpAttrs, create a single attribute that contains the
56   // named attribute lists, which is an array of pairs, with keys and values
57   // both being string attributes. The values represent function names.
58   // This method also populates a vector of attribute keys to be removed.
59   mlir::ArrayAttr CreateOpFuncAttrs(
60       llvm::ArrayRef<mlir::NamedAttribute> attrs,
61       llvm::SmallVector<mlir::StringAttr, 4> *func_attr_keys);
62 
63   // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0"
64   // will be parsed as "cpu". Return None if no device is assigned.
65   llvm::Optional<ParseDeviceNameResult> ParseDeviceName(
66       llvm::StringRef device_name) const;
67   llvm::Optional<ParseDeviceNameResult> ParseDeviceName(
68       mlir::Operation *op) const;
69 
70   // Convert the device name in a TF op to a op_handler value produced by the
71   // corresponding GetOpHandler in the current block. If there does not exist
72   // one, insert a GetOpHandler to the beginning of the block and return the
73   // device value.
74   mlir::Value ConvertOpHandler(mlir::Operation *op, llvm::StringRef device_name,
75                                mlir::ConversionPatternRewriter *rewriter);
76 
77   // Get a DistributedContext value to be used by the given op. The
78   // DistributedContext value should be shared by all operations in the body
79   // of the same FuncOp. If there does not exist one, insert a
80   // GetDistributedContext op right before the given op and return the result
81   // value.
82   mlir::Value GetDistributedContext(mlir::Operation *op,
83                                     mlir::ConversionPatternRewriter *rewriter);
84 
85   // Get a RemoteChainManager value to be used by the given op. The
86   // RemoteChainManager value should be shared by all operations in the body
87   // of the same FuncOp. If there does not exist one, insert a
88   // tfrt_dist.test_create_remote_chain_manager op right before the given op and
89   // return the result value.
90   mlir::Value GetRemoteChainManager(mlir::Operation *op,
91                                     mlir::ConversionPatternRewriter *rewriter);
92 
93   // Get a TaskHandle value with the given task name. If the TaskHandle value
94   // has already been created for the given task name within the same FuncOp,
95   // return this TaskHandle value. Otherwise, insert a tfrt_dist.get_task_handle
96   // op right before the given op and return the result value.
97   mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name,
98                             mlir::ConversionPatternRewriter *rewriter);
99 
100   // Any local operation which uses any result of the `op` should depend on the
101   // given `chain`.
RegisterLocalSideEffectChain(mlir::Operation * op,mlir::Value chain)102   void RegisterLocalSideEffectChain(mlir::Operation *op, mlir::Value chain) {
103     local_side_effect_chains_[op] = chain;
104   }
105 
106   // Return a local chain for side effects for `op`. If there are multiple
107   // chains, a merge_chains kernel will be inserted and the merged chain will be
108   // returned.
109   mlir::Value GetLocalSideEffectChain(
110       mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter);
111 
112   // Return a remote chain for side effects for `op`.
113   mlir::Value GetRemoteSideEffectChain(
114       mlir::Operation *op, StringRef remote_host,
115       mlir::ConversionPatternRewriter *rewriter);
116 
op_handler_type()117   mlir::Type op_handler_type() {
118     return builder_.getType<::tfrt::corert::OpHandlerType>();
119   }
120 
tensor_handle_type()121   mlir::Type tensor_handle_type() {
122     return builder_.getType<::tfrt::corert::TensorHandleType>();
123   }
124 
chain_type()125   mlir::Type chain_type() {
126     return builder_.getType<::tfrt::compiler::ChainType>();
127   }
128 
distributed_context_type()129   mlir::Type distributed_context_type() {
130     return builder_.getType<::tfrt::dist::DistributedContextType>();
131   }
132 
builder()133   mlir::Builder &builder() { return builder_; }
134 
135  private:
136   // TODO(chky): attributes "_output_shapes" should be removed by any tool that
137   // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this
138   // filtering logic once unused attributes are cleaned up in the upper layer.
IsUnusedAttribute(llvm::StringRef name)139   bool IsUnusedAttribute(llvm::StringRef name) const {
140     // NOTE: attributes "f.*" are function attribute related and
141     // are added during importing graph to MLIR TF Executor dialect. These
142     // attributes are not actually used by TF ops with function attributes.
143     // TODO(b/180399811): Re-evaluate the usage of these attributes.
144     static const char *const kUnusedAttributes[] = {
145         "_output_shapes",
146         "result_segment_sizes",
147         "operand_segment_sizes",
148     };
149 
150     for (auto attr : kUnusedAttributes) {
151       if (name == attr) {
152         return true;
153       }
154     }
155 
156     return name.contains("f.");
157   }
158 
159   // Returns the converted attribute in TFRT dialect. If the conversion fails,
160   // returns a null attribute instead.
161   mlir::Attribute ConvertAttribute(mlir::Attribute attr);
162 
163   mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr);
164 
165   mlir::StringAttr ConvertSymbolAttrToStringAttr(
166       mlir::FlatSymbolRefAttr symbol_attr);
167 
168   mlir::Builder builder_;
169 
170   const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis_;
171 
172   llvm::DenseMap<mlir::Operation *, mlir::Value> local_side_effect_chains_;
173   llvm::DenseMap<mlir::Operation *, mlir::Value> distributed_context_by_func_;
174   llvm::DenseMap<mlir::Operation *, mlir::Value> remote_chain_mgr_by_func_;
175   llvm::DenseMap<mlir::Operation *, llvm::StringMap<mlir::Value>>
176       task_handles_by_func_;
177   llvm::StringMap<mlir::Value> op_handler_by_name_;
178 };
179 
180 }  // namespace tensorflow
181 
182 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_
183