1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 18 19 #include <type_traits> 20 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 24 #include "mlir/IR/Attributes.h" // from @llvm-project 25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 26 #include "mlir/IR/Operation.h" // from @llvm-project 27 #include "mlir/IR/Region.h" // from @llvm-project 28 #include "mlir/IR/Value.h" // from @llvm-project 29 #include "mlir/Support/LLVM.h" // from @llvm-project 30 #include "mlir/Support/LogicalResult.h" // from @llvm-project 31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 32 33 namespace mlir { 34 namespace TFDevice { 35 36 // -------------------------------------------------------------------------- // 37 // ValueConstraint. 38 // -------------------------------------------------------------------------- // 39 40 // In order to be clustered operation can require its operands to satisfy 41 // some constraints (e.g. reduction operation can require reduction dimension 42 // operand to be a constant value). 43 enum class ValueConstraint { 44 // Operand must have statically known rank. 45 kRank = 0, 46 // Operand must have statically known shape (all dimensions are known at 47 // compile time). 48 kShape = 1, 49 // Operand must have statically known value (operand must be defined by a 50 // constant operation). 51 kValue = 2, 52 }; 53 54 // Returns the more restrictive constraint of `a` and `b`: 55 // 56 // Value >> Shape >> Rank 57 // 58 // If you know the value, you always know the shape and the rank. If you know 59 // the shape, you always know the rank. 60 ValueConstraint Merge(ValueConstraint a, ValueConstraint b); 61 62 // Returns success if constraint can be resolved statically based on the value 63 // type, e.g. `shape` constraint can be resolved if the value is a tensor of 64 // statically known shape. 65 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint); 66 67 raw_ostream& operator<<(raw_ostream& os, const ValueConstraint& constraint); 68 69 // -------------------------------------------------------------------------- // 70 // ValuesConstraintSet. 71 // -------------------------------------------------------------------------- // 72 73 // A set of constraints for values, that either operation results or operands. 74 class ValuesConstraintSet { 75 using ConstraintsMap = llvm::SmallDenseMap<Value, ValueConstraint>; 76 using ConstIterator = typename ConstraintsMap::const_iterator; 77 78 public: 79 ValuesConstraintSet() = default; 80 81 // Inserts a new constraint for the `value`. If the `value` already has some 82 // constraint, it will merge it with a new one, and will return a new 83 // constraint value. Returned pair has a constraint value that was set for 84 // a value, and a boolean flag that is true if the constraint was updated. 85 std::pair<ValueConstraint, bool> Insert(Value value, 86 ValueConstraint constraint); 87 88 // Inserts constraints for multiple values. 89 void Insert(ValueRange value, ValueConstraint constraint); 90 91 // Walk all the constraints owned by this set. 92 void Walk(llvm::function_ref<void(Value, ValueConstraint)> walk) const; 93 94 // Returns the constraint of the value if it exists, or None otherwise. 95 Optional<ValueConstraint> GetConstraint(Value value) const; 96 bool HasConstraint(Value value) const; 97 98 // Merges all constrains from the other constraints set into this one. 99 void MergeAll(const ValuesConstraintSet& other); 100 101 // Remove constraints that can be statically resolved from the type of the 102 // constrained value (see `IsStaticallyResolved` defined above). 103 ValuesConstraintSet& Resolve(); 104 105 // Reset all constraints. 106 ValuesConstraintSet& Reset(); 107 108 // Return the number of constrained values in the set. 109 size_t Size() const; 110 111 // Returns true if the constraint set is empty. 112 bool Empty() const; 113 begin()114 ConstIterator begin() const { return constraints_.begin(); } end()115 ConstIterator end() const { return constraints_.end(); } 116 117 private: 118 llvm::SmallDenseMap<Value, ValueConstraint> constraints_; 119 }; 120 121 // -------------------------------------------------------------------------- // 122 // ClusteringPolicy. 123 // -------------------------------------------------------------------------- // 124 125 // Clustering policy specifies if the operation can be clustered (in practice it 126 // usually means that operation can be added to a cluster that will be later 127 // compiled) given the set of constraints on its results, and might propagate or 128 // create new constraints on the operation operands. 129 // 130 // Clustering policy must make a local decision just for a single operation. It 131 // is the responsibility of a clustering pass to combine all these individual 132 // operations constraints to form a valid cluster. 133 // 134 // Example: compilation using XLA (MHLO) lowering 135 // 136 // %0 = "tf.Transpose"(%input, %perm) 137 // : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 138 // 139 // XLAs `mhlo.transpose` operation requires permutation to be an attribute 140 // (compile time value), so it means that if we want to put `tf.Transpose` 141 // into a cluster that will be compiled with XLA, the `%perm` operand must 142 // be a known compiled time value, e.g. result of a `tf.Const` operation. 143 // 144 class ClusteringPolicy { 145 public: 146 virtual ~ClusteringPolicy() = default; 147 148 // Returns success if an operation can be clustered given the constraints on 149 // the operation results. Updates operands constraits to satisfy all the 150 // results constraints. 151 virtual LogicalResult MatchAndUpdateConstraints( 152 Operation* operation, const ValuesConstraintSet& results, 153 ValuesConstraintSet& operands) const = 0; 154 }; 155 156 // Clustering policy for a specific operation type. 157 template <typename OpTy> 158 class OpClusteringPolicy : public ClusteringPolicy { 159 public: MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands)160 LogicalResult MatchAndUpdateConstraints( 161 Operation* operation, const ValuesConstraintSet& results, 162 ValuesConstraintSet& operands) const final { 163 if (auto op = dyn_cast<OpTy>(operation)) 164 return MatchAndUpdateConstraints(op, results, operands); 165 return failure(); 166 } 167 168 virtual LogicalResult MatchAndUpdateConstraints( 169 OpTy op, const ValuesConstraintSet& results, 170 ValuesConstraintSet& operands) const = 0; 171 }; 172 173 // -------------------------------------------------------------------------- // 174 // ClusteringPolicySet. 175 // -------------------------------------------------------------------------- // 176 177 // A set of clustering policies for different operations. 178 class ClusteringPolicySet { 179 public: 180 using Policies = std::vector<std::unique_ptr<ClusteringPolicy>>; 181 policies()182 const Policies& policies() const { return policies_; } 183 184 // Add an instance of each of the policy types 'Ts'. Return a reference to 185 // `this` for chaining insertions. 186 template <typename... Ts> Add()187 ClusteringPolicySet& Add() { 188 (void)std::initializer_list<int>{0, (AddImpl<Ts>(), 0)...}; 189 return *this; 190 } 191 192 // ClusteringPolicySet is move only type. 193 ClusteringPolicySet() = default; 194 ClusteringPolicySet(const ClusteringPolicySet&) = delete; 195 ClusteringPolicySet(ClusteringPolicySet&&) = default; 196 ClusteringPolicySet& operator=(const ClusteringPolicySet&) = delete; 197 ClusteringPolicySet& operator=(ClusteringPolicySet&&) = default; 198 199 private: 200 template <typename T, typename... Args> AddImpl(Args &&...args)201 void AddImpl(Args&&... args) { 202 static_assert(std::is_base_of<ClusteringPolicy, T>::value, 203 "T must implement ClusteringPolicy"); 204 policies_.emplace_back(std::make_unique<T>(std::forward<Args>(args)...)); 205 } 206 207 std::vector<std::unique_ptr<ClusteringPolicy>> policies_; 208 }; 209 210 // -------------------------------------------------------------------------- // 211 // Discovering clusters of operations based on the policy. 212 // -------------------------------------------------------------------------- // 213 214 // Cluster groups together operations in the single basic block based on the 215 // given clustering policy set. Clusters can be outlined into nested modules 216 // later device specific compilation (e.g. for TFRT JIT compiler). 217 struct Cluster { 218 llvm::SmallVector<Operation*> operations; 219 ValuesConstraintSet constraints; 220 }; 221 222 // Returns clusters of operations in the given `block` based on the provided 223 // clustering policy. If `filter` is defined, it will be used to filter 224 // operations that can be considered for clustering based on the policy. 225 // 226 // TODO(ezhulenev): Additional filter function is a workaround for customizing 227 // clustering policies at runtime for experimentation. In the long term, 228 // clustering policy should be enough. 229 llvm::SmallVector<Cluster> FindClustersInTheBlock( 230 Block* block, const ClusteringPolicySet& policies, 231 std::function<bool(Operation* op)> filter = {}); 232 233 // Creates a `tf_device.cluster` operation from the clustered operations. 234 tf_device::ClusterOp CreateClusterOp(Cluster& cluster, StringAttr policy = {}); 235 236 // -------------------------------------------------------------------------- // 237 // Helper functions for value constraints propagations and analysis. 238 // -------------------------------------------------------------------------- // 239 240 // Propagates initial constraints on the values defined by the `constraints` set 241 // with operations in the `root` as a starting point, using user provided set of 242 // clustering policies. 243 // 244 // Filter predicate specifies if constraints should be propagated across the 245 // given operation. Operations in the root set will be also filtered using 246 // the `filter` predicate. 247 // 248 // Optionally resolve constraints that can be statically satisfied by the 249 // value type, and stop constraints propagation early. 250 // 251 // Optionally emits remarks attached to operation that failed to propagate 252 // results constraints to its operands (for testing purpose). 253 // 254 // Returns failure if constraints can't be propagated through some of the 255 // operations accepted by the filter (there is no clustering policy for an 256 // operation, or constraints can't be satisfied by the policy), and attaches 257 // error diagnostics to the operation that prevented constraints propagation. 258 mlir::LogicalResult PropagateValuesConstraints( 259 llvm::ArrayRef<Operation*> root, std::function<bool(Operation*)> filter, 260 const ClusteringPolicySet& policies, ValuesConstraintSet& constraints, 261 bool resolve = false, bool emit_remarks = false); 262 263 // Propagates initial constraints on the values in the `region` to the other 264 // values in the same region, using user provided set of clustering policies. 265 mlir::LogicalResult PropagateValuesConstraints( 266 mlir::Region& region, const ClusteringPolicySet& policies, 267 ValuesConstraintSet& constraints, bool resolve = false, 268 bool emit_remarks = false); 269 270 // Emits constraints remarks for all operations that use constrained values. 271 void EmitValueConstraintsRemarks(const ValuesConstraintSet& constraints); 272 273 // Emits constraints remarks for function inputs that are in the constraints 274 // set (entry block arguments have constraints). 275 void EmitInputsConstraintsRemarks(func::FuncOp func, 276 const ValuesConstraintSet& constraints); 277 278 // Infers constraints for the values in the function body from the function 279 // results attributes. 280 // 281 // Example: 282 // func @test(...) -> (tensor<?x?xf32> {tf.constraint = "shape"}) { 283 // ..... 284 // %v = "some_operation"() : () -> tensor<?x?xf32> 285 // return %v : tensor<?x?xf32> 286 // } 287 LogicalResult InferFunctionBodyValuesConstraints( 288 func::FuncOp func, ValuesConstraintSet& constraints); 289 290 } // namespace TFDevice 291 } // namespace mlir 292 293 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 294