• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
18 
19 #include <type_traits>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Region.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Support/LLVM.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 
32 namespace mlir {
33 namespace TFDevice {
34 
35 // -------------------------------------------------------------------------- //
36 // ValueConstraint.
37 // -------------------------------------------------------------------------- //
38 
39 // In order to be clustered operation can require its operands to satisfy
40 // some constraints (e.g. reduction operation can require reduction dimension
41 // operand to be a constant value).
42 enum class ValueConstraint {
43   // Operand must have statically known rank.
44   kRank = 0,
45   // Operand must have statically known shape (all dimensions are known at
46   // compile time).
47   kShape = 1,
48   // Operand must have statically known value (operand must be defined by a
49   // constant operation).
50   kValue = 2,
51 };
52 
53 // Returns the more restrictive constraint of `a` and `b`:
54 //
55 //    Value >> Shape >> Rank
56 //
57 // If you know the value, you always know the shape and the rank. If you know
58 // the shape, you always know the rank.
59 ValueConstraint Merge(ValueConstraint a, ValueConstraint b);
60 
61 // Returns success if constraint can be resolved statically based on the value
62 // type, e.g. `shape` constraint can be resolved if the value is a tensor of
63 // statically known shape.
64 LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint);
65 
66 raw_ostream& operator<<(raw_ostream& os, const ValueConstraint& constraint);
67 
68 // -------------------------------------------------------------------------- //
69 // ValuesConstraintSet.
70 // -------------------------------------------------------------------------- //
71 
72 // A set of constraints for values, that either operation results or operands.
73 class ValuesConstraintSet {
74   using ConstraintsMap = llvm::SmallDenseMap<Value, ValueConstraint>;
75   using ConstIterator = typename ConstraintsMap::const_iterator;
76 
77  public:
78   ValuesConstraintSet() = default;
79 
80   // Inserts a new constraint for the `value`. If the `value` already has some
81   // constraint, it will merge it with a new one, and will return a new
82   // constraint value. Returned pair has a constraint value that was set for
83   // a value, and a boolean flag that is true if the constraint was updated.
84   std::pair<ValueConstraint, bool> Insert(Value value,
85                                           ValueConstraint constraint);
86 
87   // Inserts constraints for multiple values.
88   void Insert(ValueRange value, ValueConstraint constraint);
89 
90   // Walk all the constraints owned by this set.
91   void Walk(llvm::function_ref<void(Value, ValueConstraint)> walk) const;
92 
93   // Returns the constraint of the value if it exists, or None otherwise.
94   Optional<ValueConstraint> GetConstraint(Value value) const;
95   bool HasConstraint(Value value) const;
96 
97   // Merges all constrains from the other constraints set into this one.
98   void MergeAll(const ValuesConstraintSet& other);
99 
100   // Remove constraints that can be statically resolved from the type of the
101   // constrained value (see `IsStaticallyResolved` defined above).
102   ValuesConstraintSet& Resolve();
103 
104   // Reset all constraints.
105   ValuesConstraintSet& Reset();
106 
107   // Return the number of constrained values in the set.
108   size_t Size() const;
109 
110   // Returns true if the constraint set is empty.
111   bool Empty() const;
112 
begin()113   ConstIterator begin() const { return constraints_.begin(); }
end()114   ConstIterator end() const { return constraints_.end(); }
115 
116  private:
117   llvm::SmallDenseMap<Value, ValueConstraint> constraints_;
118 };
119 
120 // -------------------------------------------------------------------------- //
121 // ClusteringPolicy.
122 // -------------------------------------------------------------------------- //
123 
124 // Clustering policy specifies if the operation can be clustered (in practice it
125 // usually means that operation can be added to a cluster that will be later
126 // compiled) given the set of constraints on its results, and might propagate or
127 // create new constraints on the operation operands.
128 //
129 // Clustering policy must make a local decision just for a single operation. It
130 // is the responsibility of a clustering pass to combine all these individual
131 // operations constraints to form a valid cluster.
132 //
133 // Example: compilation using XLA (MHLO) lowering
134 //
135 //   %0 = "tf.Transpose"(%input, %perm)
136 //        : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
137 //
138 //   XLAs `mhlo.transpose` operation requires permutation to be an attribute
139 //   (compile time value), so it means that if we want to put `tf.Transpose`
140 //   into a cluster that will be compiled with XLA, the `%perm` operand must
141 //   be a known compiled time value, e.g. result of a `tf.Const` operation.
142 //
143 class ClusteringPolicy {
144  public:
145   virtual ~ClusteringPolicy() = default;
146 
147   // Returns success if an operation can be clustered given the constraints on
148   // the operation results. Updates operands constraits to satisfy all the
149   // results constraints.
150   virtual LogicalResult MatchAndUpdateConstraints(
151       Operation* operation, const ValuesConstraintSet& results,
152       ValuesConstraintSet& operands) const = 0;
153 };
154 
155 // Clustering policy for a specific operation type.
156 template <typename OpTy>
157 class OpClusteringPolicy : public ClusteringPolicy {
158  public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands)159   LogicalResult MatchAndUpdateConstraints(
160       Operation* operation, const ValuesConstraintSet& results,
161       ValuesConstraintSet& operands) const final {
162     if (auto op = dyn_cast<OpTy>(operation))
163       return MatchAndUpdateConstraints(op, results, operands);
164     return failure();
165   }
166 
167   virtual LogicalResult MatchAndUpdateConstraints(
168       OpTy op, const ValuesConstraintSet& results,
169       ValuesConstraintSet& operands) const = 0;
170 };
171 
172 // -------------------------------------------------------------------------- //
173 // ClusteringPolicySet.
174 // -------------------------------------------------------------------------- //
175 
176 // A set of clustering policies for different operations.
177 class ClusteringPolicySet {
178  public:
179   using Policies = std::vector<std::unique_ptr<ClusteringPolicy>>;
180 
policies()181   const Policies& policies() const { return policies_; }
182 
183   // Add an instance of each of the policy types 'Ts'. Return a reference to
184   // `this` for chaining insertions.
185   template <typename... Ts>
Add()186   ClusteringPolicySet& Add() {
187     (void)std::initializer_list<int>{0, (AddImpl<Ts>(), 0)...};
188     return *this;
189   }
190 
191   // ClusteringPolicySet is move only type.
192   ClusteringPolicySet() = default;
193   ClusteringPolicySet(const ClusteringPolicySet&) = delete;
194   ClusteringPolicySet(ClusteringPolicySet&&) = default;
195   ClusteringPolicySet& operator=(const ClusteringPolicySet&) = delete;
196   ClusteringPolicySet& operator=(ClusteringPolicySet&&) = default;
197 
198  private:
199   template <typename T, typename... Args>
AddImpl(Args &&...args)200   void AddImpl(Args&&... args) {
201     static_assert(std::is_base_of<ClusteringPolicy, T>::value,
202                   "T must implement ClusteringPolicy");
203     policies_.emplace_back(std::make_unique<T>(std::forward<Args>(args)...));
204   }
205 
206   std::vector<std::unique_ptr<ClusteringPolicy>> policies_;
207 };
208 
209 // -------------------------------------------------------------------------- //
210 // Discovering clusters of operations based on the policy.
211 // -------------------------------------------------------------------------- //
212 
213 // Cluster groups together operations in the single basic block based on the
214 // given clustering policy set. Clusters can be outlined into nested modules
215 // later device specific compilation (e.g. for TFRT JIT compiler).
216 struct Cluster {
217   llvm::SmallVector<Operation*> operations;
218   ValuesConstraintSet constraints;
219 };
220 
221 // Returns clusters of operations in the given `block` based on the provided
222 // clustering policy. If `filter` is defined, it will be used to filter
223 // operations that can be considered for clustering based on the policy.
224 //
225 // TODO(ezhulenev): Additional filter function is a workaround for customizing
226 // clustering policies at runtime for experimentation. In the long term,
227 // clustering policy should be enough.
228 llvm::SmallVector<Cluster> FindClustersInTheBlock(
229     Block* block, const ClusteringPolicySet& policies,
230     std::function<bool(Operation* op)> filter = {});
231 
232 // Creates a `tf_device.cluster` operation from the clustered operations.
233 tf_device::ClusterOp CreateClusterOp(Cluster& cluster, StringAttr policy = {});
234 
235 // -------------------------------------------------------------------------- //
236 // Helper functions for value constraints propagations and analysis.
237 // -------------------------------------------------------------------------- //
238 
239 // Propagates initial constraints on the values defined by the `constraints` set
240 // with operations in the `root` as a starting point, using user provided set of
241 // clustering policies.
242 //
243 // Filter predicate specifies if constraints should be propagated across the
244 // given operation. Operations in the root set will be also filtered using
245 // the `filter` predicate.
246 //
247 // Optionally resolve constraints that can be statically satisfied by the
248 // value type, and stop constraints propagation early.
249 //
250 // Returns failure if constraints can't be propagated through some of the
251 // operations accepted by the filter (there is no clustering policy for an
252 // operation, or constraints can't be satisfied by the policy), and attaches
253 // error diagnostics to the operation that prevented constraints propagation.
254 mlir::LogicalResult PropagateValuesConstraints(
255     llvm::ArrayRef<Operation*> root, std::function<bool(Operation*)> filter,
256     const ClusteringPolicySet& policies, ValuesConstraintSet& constraints,
257     bool resolve = false);
258 
259 // Propagates initial constraints on the values in the `region` to the other
260 // values in the same region, using user provided set of clustering policies.
261 mlir::LogicalResult PropagateValuesConstraints(
262     mlir::Region& region, const ClusteringPolicySet& policies,
263     ValuesConstraintSet& constraints, bool resolve = false);
264 
265 // Emits constraints remarks for all operations that use constrained values.
266 void EmitValueConstraintsRemarks(const ValuesConstraintSet& constraints);
267 
268 // Emits constraints remarks for function inputs that are in the constraints
269 // set (entry block arguments have constraints).
270 void EmitInputsConstraintsRemarks(FuncOp func,
271                                   const ValuesConstraintSet& constraints);
272 
273 // Infers constraints for the values in the function body from the function
274 // results attributes.
275 //
276 // Example:
277 //   func @test(...) -> (tensor<?x?xf32> {tf.constraint = "shape"}) {
278 //     .....
279 //     %v = "some_operation"() : () -> tensor<?x?xf32>
280 //     return %v : tensor<?x?xf32>
281 //   }
282 LogicalResult InferFunctionBodyValuesConstraints(
283     FuncOp func, ValuesConstraintSet& constraints);
284 
285 }  // namespace TFDevice
286 }  // namespace mlir
287 
288 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_
289