1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 18 19 #include <type_traits> 20 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "mlir/IR/Attributes.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/Operation.h" // from @llvm-project 26 #include "mlir/IR/Region.h" // from @llvm-project 27 #include "mlir/IR/Value.h" // from @llvm-project 28 #include "mlir/Support/LLVM.h" // from @llvm-project 29 #include "mlir/Support/LogicalResult.h" // from @llvm-project 30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 31 32 namespace mlir { 33 namespace TFDevice { 34 35 // -------------------------------------------------------------------------- // 36 // ValueConstraint. 37 // -------------------------------------------------------------------------- // 38 39 // In order to be clustered operation can require its operands to satisfy 40 // some constraints (e.g. reduction operation can require reduction dimension 41 // operand to be a constant value). 42 enum class ValueConstraint { 43 // Operand must have statically known rank. 44 kRank = 0, 45 // Operand must have statically known shape (all dimensions are known at 46 // compile time). 47 kShape = 1, 48 // Operand must have statically known value (operand must be defined by a 49 // constant operation). 50 kValue = 2, 51 }; 52 53 // Returns the more restrictive constraint of `a` and `b`: 54 // 55 // Value >> Shape >> Rank 56 // 57 // If you know the value, you always know the shape and the rank. If you know 58 // the shape, you always know the rank. 59 ValueConstraint Merge(ValueConstraint a, ValueConstraint b); 60 61 // Returns success if constraint can be resolved statically based on the value 62 // type, e.g. `shape` constraint can be resolved if the value is a tensor of 63 // statically known shape. 64 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint); 65 66 raw_ostream& operator<<(raw_ostream& os, const ValueConstraint& constraint); 67 68 // -------------------------------------------------------------------------- // 69 // ValuesConstraintSet. 70 // -------------------------------------------------------------------------- // 71 72 // A set of constraints for values, that either operation results or operands. 73 class ValuesConstraintSet { 74 using ConstraintsMap = llvm::SmallDenseMap<Value, ValueConstraint>; 75 using ConstIterator = typename ConstraintsMap::const_iterator; 76 77 public: 78 ValuesConstraintSet() = default; 79 80 // Inserts a new constraint for the `value`. If the `value` already has some 81 // constraint, it will merge it with a new one, and will return a new 82 // constraint value. Returned pair has a constraint value that was set for 83 // a value, and a boolean flag that is true if the constraint was updated. 84 std::pair<ValueConstraint, bool> Insert(Value value, 85 ValueConstraint constraint); 86 87 // Inserts constraints for multiple values. 88 void Insert(ValueRange value, ValueConstraint constraint); 89 90 // Walk all the constraints owned by this set. 91 void Walk(llvm::function_ref<void(Value, ValueConstraint)> walk) const; 92 93 // Returns the constraint of the value if it exists, or None otherwise. 94 Optional<ValueConstraint> GetConstraint(Value value) const; 95 bool HasConstraint(Value value) const; 96 97 // Merges all constrains from the other constraints set into this one. 98 void MergeAll(const ValuesConstraintSet& other); 99 100 // Remove constraints that can be statically resolved from the type of the 101 // constrained value (see `IsStaticallyResolved` defined above). 102 ValuesConstraintSet& Resolve(); 103 104 // Reset all constraints. 105 ValuesConstraintSet& Reset(); 106 107 // Return the number of constrained values in the set. 108 size_t Size() const; 109 110 // Returns true if the constraint set is empty. 111 bool Empty() const; 112 begin()113 ConstIterator begin() const { return constraints_.begin(); } end()114 ConstIterator end() const { return constraints_.end(); } 115 116 private: 117 llvm::SmallDenseMap<Value, ValueConstraint> constraints_; 118 }; 119 120 // -------------------------------------------------------------------------- // 121 // ClusteringPolicy. 122 // -------------------------------------------------------------------------- // 123 124 // Clustering policy specifies if the operation can be clustered (in practice it 125 // usually means that operation can be added to a cluster that will be later 126 // compiled) given the set of constraints on its results, and might propagate or 127 // create new constraints on the operation operands. 128 // 129 // Clustering policy must make a local decision just for a single operation. It 130 // is the responsibility of a clustering pass to combine all these individual 131 // operations constraints to form a valid cluster. 132 // 133 // Example: compilation using XLA (MHLO) lowering 134 // 135 // %0 = "tf.Transpose"(%input, %perm) 136 // : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 137 // 138 // XLAs `mhlo.transpose` operation requires permutation to be an attribute 139 // (compile time value), so it means that if we want to put `tf.Transpose` 140 // into a cluster that will be compiled with XLA, the `%perm` operand must 141 // be a known compiled time value, e.g. result of a `tf.Const` operation. 142 // 143 class ClusteringPolicy { 144 public: 145 virtual ~ClusteringPolicy() = default; 146 147 // Returns success if an operation can be clustered given the constraints on 148 // the operation results. Updates operands constraits to satisfy all the 149 // results constraints. 150 virtual LogicalResult MatchAndUpdateConstraints( 151 Operation* operation, const ValuesConstraintSet& results, 152 ValuesConstraintSet& operands) const = 0; 153 }; 154 155 // Clustering policy for a specific operation type. 156 template <typename OpTy> 157 class OpClusteringPolicy : public ClusteringPolicy { 158 public: MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands)159 LogicalResult MatchAndUpdateConstraints( 160 Operation* operation, const ValuesConstraintSet& results, 161 ValuesConstraintSet& operands) const final { 162 if (auto op = dyn_cast<OpTy>(operation)) 163 return MatchAndUpdateConstraints(op, results, operands); 164 return failure(); 165 } 166 167 virtual LogicalResult MatchAndUpdateConstraints( 168 OpTy op, const ValuesConstraintSet& results, 169 ValuesConstraintSet& operands) const = 0; 170 }; 171 172 // -------------------------------------------------------------------------- // 173 // ClusteringPolicySet. 174 // -------------------------------------------------------------------------- // 175 176 // A set of clustering policies for different operations. 177 class ClusteringPolicySet { 178 public: 179 using Policies = std::vector<std::unique_ptr<ClusteringPolicy>>; 180 policies()181 const Policies& policies() const { return policies_; } 182 183 // Add an instance of each of the policy types 'Ts'. Return a reference to 184 // `this` for chaining insertions. 185 template <typename... Ts> Add()186 ClusteringPolicySet& Add() { 187 (void)std::initializer_list<int>{0, (AddImpl<Ts>(), 0)...}; 188 return *this; 189 } 190 191 // ClusteringPolicySet is move only type. 192 ClusteringPolicySet() = default; 193 ClusteringPolicySet(const ClusteringPolicySet&) = delete; 194 ClusteringPolicySet(ClusteringPolicySet&&) = default; 195 ClusteringPolicySet& operator=(const ClusteringPolicySet&) = delete; 196 ClusteringPolicySet& operator=(ClusteringPolicySet&&) = default; 197 198 private: 199 template <typename T, typename... Args> AddImpl(Args &&...args)200 void AddImpl(Args&&... args) { 201 static_assert(std::is_base_of<ClusteringPolicy, T>::value, 202 "T must implement ClusteringPolicy"); 203 policies_.emplace_back(std::make_unique<T>(std::forward<Args>(args)...)); 204 } 205 206 std::vector<std::unique_ptr<ClusteringPolicy>> policies_; 207 }; 208 209 // -------------------------------------------------------------------------- // 210 // Discovering clusters of operations based on the policy. 211 // -------------------------------------------------------------------------- // 212 213 // Cluster groups together operations in the single basic block based on the 214 // given clustering policy set. Clusters can be outlined into nested modules 215 // later device specific compilation (e.g. for TFRT JIT compiler). 216 struct Cluster { 217 llvm::SmallVector<Operation*> operations; 218 ValuesConstraintSet constraints; 219 }; 220 221 // Returns clusters of operations in the given `block` based on the provided 222 // clustering policy. If `filter` is defined, it will be used to filter 223 // operations that can be considered for clustering based on the policy. 224 // 225 // TODO(ezhulenev): Additional filter function is a workaround for customizing 226 // clustering policies at runtime for experimentation. In the long term, 227 // clustering policy should be enough. 228 llvm::SmallVector<Cluster> FindClustersInTheBlock( 229 Block* block, const ClusteringPolicySet& policies, 230 std::function<bool(Operation* op)> filter = {}); 231 232 // Creates a `tf_device.cluster` operation from the clustered operations. 233 tf_device::ClusterOp CreateClusterOp(Cluster& cluster, StringAttr policy = {}); 234 235 // -------------------------------------------------------------------------- // 236 // Helper functions for value constraints propagations and analysis. 237 // -------------------------------------------------------------------------- // 238 239 // Propagates initial constraints on the values defined by the `constraints` set 240 // with operations in the `root` as a starting point, using user provided set of 241 // clustering policies. 242 // 243 // Filter predicate specifies if constraints should be propagated across the 244 // given operation. Operations in the root set will be also filtered using 245 // the `filter` predicate. 246 // 247 // Optionally resolve constraints that can be statically satisfied by the 248 // value type, and stop constraints propagation early. 249 // 250 // Returns failure if constraints can't be propagated through some of the 251 // operations accepted by the filter (there is no clustering policy for an 252 // operation, or constraints can't be satisfied by the policy), and attaches 253 // error diagnostics to the operation that prevented constraints propagation. 254 mlir::LogicalResult PropagateValuesConstraints( 255 llvm::ArrayRef<Operation*> root, std::function<bool(Operation*)> filter, 256 const ClusteringPolicySet& policies, ValuesConstraintSet& constraints, 257 bool resolve = false); 258 259 // Propagates initial constraints on the values in the `region` to the other 260 // values in the same region, using user provided set of clustering policies. 261 mlir::LogicalResult PropagateValuesConstraints( 262 mlir::Region& region, const ClusteringPolicySet& policies, 263 ValuesConstraintSet& constraints, bool resolve = false); 264 265 // Emits constraints remarks for all operations that use constrained values. 266 void EmitValueConstraintsRemarks(const ValuesConstraintSet& constraints); 267 268 // Emits constraints remarks for function inputs that are in the constraints 269 // set (entry block arguments have constraints). 270 void EmitInputsConstraintsRemarks(FuncOp func, 271 const ValuesConstraintSet& constraints); 272 273 // Infers constraints for the values in the function body from the function 274 // results attributes. 275 // 276 // Example: 277 // func @test(...) -> (tensor<?x?xf32> {tf.constraint = "shape"}) { 278 // ..... 279 // %v = "some_operation"() : () -> tensor<?x?xf32> 280 // return %v : tensor<?x?xf32> 281 // } 282 LogicalResult InferFunctionBodyValuesConstraints( 283 FuncOp func, ValuesConstraintSet& constraints); 284 285 } // namespace TFDevice 286 } // namespace mlir 287 288 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ 289