• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
18 
19 #include <type_traits>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/Region.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/Support/LLVM.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 
33 namespace mlir {
34 namespace TFDevice {
35 
36 // -------------------------------------------------------------------------- //
37 // ValueConstraint.
38 // -------------------------------------------------------------------------- //
39 
40 // In order to be clustered operation can require its operands to satisfy
41 // some constraints (e.g. reduction operation can require reduction dimension
42 // operand to be a constant value).
43 enum class ValueConstraint {
44   // Operand must have statically known rank.
45   kRank = 0,
46   // Operand must have statically known shape (all dimensions are known at
47   // compile time).
48   kShape = 1,
49   // Operand must have statically known value (operand must be defined by a
50   // constant operation).
51   kValue = 2,
52 };
53 
54 // Returns the more restrictive constraint of `a` and `b`:
55 //
56 //    Value >> Shape >> Rank
57 //
58 // If you know the value, you always know the shape and the rank. If you know
59 // the shape, you always know the rank.
60 ValueConstraint Merge(ValueConstraint a, ValueConstraint b);
61 
62 // Returns success if constraint can be resolved statically based on the value
63 // type, e.g. `shape` constraint can be resolved if the value is a tensor of
64 // statically known shape.
65 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint);
66 
67 raw_ostream& operator<<(raw_ostream& os, const ValueConstraint& constraint);
68 
69 // -------------------------------------------------------------------------- //
70 // ValuesConstraintSet.
71 // -------------------------------------------------------------------------- //
72 
73 // A set of constraints for values, that either operation results or operands.
74 class ValuesConstraintSet {
75   using ConstraintsMap = llvm::SmallDenseMap<Value, ValueConstraint>;
76   using ConstIterator = typename ConstraintsMap::const_iterator;
77 
78  public:
79   ValuesConstraintSet() = default;
80 
81   // Inserts a new constraint for the `value`. If the `value` already has some
82   // constraint, it will merge it with a new one, and will return a new
83   // constraint value. Returned pair has a constraint value that was set for
84   // a value, and a boolean flag that is true if the constraint was updated.
85   std::pair<ValueConstraint, bool> Insert(Value value,
86                                           ValueConstraint constraint);
87 
88   // Inserts constraints for multiple values.
89   void Insert(ValueRange value, ValueConstraint constraint);
90 
91   // Walk all the constraints owned by this set.
92   void Walk(llvm::function_ref<void(Value, ValueConstraint)> walk) const;
93 
94   // Returns the constraint of the value if it exists, or None otherwise.
95   Optional<ValueConstraint> GetConstraint(Value value) const;
96   bool HasConstraint(Value value) const;
97 
98   // Merges all constrains from the other constraints set into this one.
99   void MergeAll(const ValuesConstraintSet& other);
100 
101   // Remove constraints that can be statically resolved from the type of the
102   // constrained value (see `IsStaticallyResolved` defined above).
103   ValuesConstraintSet& Resolve();
104 
105   // Reset all constraints.
106   ValuesConstraintSet& Reset();
107 
108   // Return the number of constrained values in the set.
109   size_t Size() const;
110 
111   // Returns true if the constraint set is empty.
112   bool Empty() const;
113 
begin()114   ConstIterator begin() const { return constraints_.begin(); }
end()115   ConstIterator end() const { return constraints_.end(); }
116 
117  private:
118   llvm::SmallDenseMap<Value, ValueConstraint> constraints_;
119 };
120 
121 // -------------------------------------------------------------------------- //
122 // ClusteringPolicy.
123 // -------------------------------------------------------------------------- //
124 
125 // Clustering policy specifies if the operation can be clustered (in practice it
126 // usually means that operation can be added to a cluster that will be later
127 // compiled) given the set of constraints on its results, and might propagate or
128 // create new constraints on the operation operands.
129 //
130 // Clustering policy must make a local decision just for a single operation. It
131 // is the responsibility of a clustering pass to combine all these individual
132 // operations constraints to form a valid cluster.
133 //
134 // Example: compilation using XLA (MHLO) lowering
135 //
136 //   %0 = "tf.Transpose"(%input, %perm)
137 //        : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
138 //
139 //   XLAs `mhlo.transpose` operation requires permutation to be an attribute
140 //   (compile time value), so it means that if we want to put `tf.Transpose`
141 //   into a cluster that will be compiled with XLA, the `%perm` operand must
142 //   be a known compiled time value, e.g. result of a `tf.Const` operation.
143 //
144 class ClusteringPolicy {
145  public:
146   virtual ~ClusteringPolicy() = default;
147 
148   // Returns success if an operation can be clustered given the constraints on
149   // the operation results. Updates operands constraits to satisfy all the
150   // results constraints.
151   virtual LogicalResult MatchAndUpdateConstraints(
152       Operation* operation, const ValuesConstraintSet& results,
153       ValuesConstraintSet& operands) const = 0;
154 };
155 
156 // Clustering policy for a specific operation type.
157 template <typename OpTy>
158 class OpClusteringPolicy : public ClusteringPolicy {
159  public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands)160   LogicalResult MatchAndUpdateConstraints(
161       Operation* operation, const ValuesConstraintSet& results,
162       ValuesConstraintSet& operands) const final {
163     if (auto op = dyn_cast<OpTy>(operation))
164       return MatchAndUpdateConstraints(op, results, operands);
165     return failure();
166   }
167 
168   virtual LogicalResult MatchAndUpdateConstraints(
169       OpTy op, const ValuesConstraintSet& results,
170       ValuesConstraintSet& operands) const = 0;
171 };
172 
173 // -------------------------------------------------------------------------- //
174 // ClusteringPolicySet.
175 // -------------------------------------------------------------------------- //
176 
177 // A set of clustering policies for different operations.
178 class ClusteringPolicySet {
179  public:
180   using Policies = std::vector<std::unique_ptr<ClusteringPolicy>>;
181 
policies()182   const Policies& policies() const { return policies_; }
183 
184   // Add an instance of each of the policy types 'Ts'. Return a reference to
185   // `this` for chaining insertions.
186   template <typename... Ts>
Add()187   ClusteringPolicySet& Add() {
188     (void)std::initializer_list<int>{0, (AddImpl<Ts>(), 0)...};
189     return *this;
190   }
191 
192   // ClusteringPolicySet is move only type.
193   ClusteringPolicySet() = default;
194   ClusteringPolicySet(const ClusteringPolicySet&) = delete;
195   ClusteringPolicySet(ClusteringPolicySet&&) = default;
196   ClusteringPolicySet& operator=(const ClusteringPolicySet&) = delete;
197   ClusteringPolicySet& operator=(ClusteringPolicySet&&) = default;
198 
199  private:
200   template <typename T, typename... Args>
AddImpl(Args &&...args)201   void AddImpl(Args&&... args) {
202     static_assert(std::is_base_of<ClusteringPolicy, T>::value,
203                   "T must implement ClusteringPolicy");
204     policies_.emplace_back(std::make_unique<T>(std::forward<Args>(args)...));
205   }
206 
207   std::vector<std::unique_ptr<ClusteringPolicy>> policies_;
208 };
209 
210 // -------------------------------------------------------------------------- //
211 // Discovering clusters of operations based on the policy.
212 // -------------------------------------------------------------------------- //
213 
214 // Cluster groups together operations in the single basic block based on the
215 // given clustering policy set. Clusters can be outlined into nested modules
216 // later device specific compilation (e.g. for TFRT JIT compiler).
217 struct Cluster {
218   llvm::SmallVector<Operation*> operations;
219   ValuesConstraintSet constraints;
220 };
221 
222 // Returns clusters of operations in the given `block` based on the provided
223 // clustering policy. If `filter` is defined, it will be used to filter
224 // operations that can be considered for clustering based on the policy.
225 //
226 // TODO(ezhulenev): Additional filter function is a workaround for customizing
227 // clustering policies at runtime for experimentation. In the long term,
228 // clustering policy should be enough.
229 llvm::SmallVector<Cluster> FindClustersInTheBlock(
230     Block* block, const ClusteringPolicySet& policies,
231     std::function<bool(Operation* op)> filter = {});
232 
233 // Creates a `tf_device.cluster` operation from the clustered operations.
234 tf_device::ClusterOp CreateClusterOp(Cluster& cluster, StringAttr policy = {});
235 
236 // -------------------------------------------------------------------------- //
237 // Helper functions for value constraints propagations and analysis.
238 // -------------------------------------------------------------------------- //
239 
240 // Propagates initial constraints on the values defined by the `constraints` set
241 // with operations in the `root` as a starting point, using user provided set of
242 // clustering policies.
243 //
244 // Filter predicate specifies if constraints should be propagated across the
245 // given operation. Operations in the root set will be also filtered using
246 // the `filter` predicate.
247 //
248 // Optionally resolve constraints that can be statically satisfied by the
249 // value type, and stop constraints propagation early.
250 //
251 // Optionally emits remarks attached to operation that failed to propagate
252 // results constraints to its operands (for testing purpose).
253 //
254 // Returns failure if constraints can't be propagated through some of the
255 // operations accepted by the filter (there is no clustering policy for an
256 // operation, or constraints can't be satisfied by the policy), and attaches
257 // error diagnostics to the operation that prevented constraints propagation.
258 mlir::LogicalResult PropagateValuesConstraints(
259     llvm::ArrayRef<Operation*> root, std::function<bool(Operation*)> filter,
260     const ClusteringPolicySet& policies, ValuesConstraintSet& constraints,
261     bool resolve = false, bool emit_remarks = false);
262 
263 // Propagates initial constraints on the values in the `region` to the other
264 // values in the same region, using user provided set of clustering policies.
265 mlir::LogicalResult PropagateValuesConstraints(
266     mlir::Region& region, const ClusteringPolicySet& policies,
267     ValuesConstraintSet& constraints, bool resolve = false,
268     bool emit_remarks = false);
269 
270 // Emits constraints remarks for all operations that use constrained values.
271 void EmitValueConstraintsRemarks(const ValuesConstraintSet& constraints);
272 
273 // Emits constraints remarks for function inputs that are in the constraints
274 // set (entry block arguments have constraints).
275 void EmitInputsConstraintsRemarks(func::FuncOp func,
276                                   const ValuesConstraintSet& constraints);
277 
278 // Infers constraints for the values in the function body from the function
279 // results attributes.
280 //
281 // Example:
282 //   func @test(...) -> (tensor<?x?xf32> {tf.constraint = "shape"}) {
283 //     .....
284 //     %v = "some_operation"() : () -> tensor<?x?xf32>
285 //     return %v : tensor<?x?xf32>
286 //   }
287 LogicalResult InferFunctionBodyValuesConstraints(
288     func::FuncOp func, ValuesConstraintSet& constraints);
289 
290 }  // namespace TFDevice
291 }  // namespace mlir
292 
293 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
294