• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
18 
19 #include <type_traits>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 
32 namespace xla {
33 
34 class HloComputation;
35 class HloInstruction;
36 
37 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
38 // instruction, all its operands were already visited. User code can subclass
39 // this to iterate over an HloInstruction DAG. The Handle* routines have
40 // operands / data unpacked for ease of use in the visitor subclass.
41 //
42 // No instruction will ever be visited twice; however, the root instruction will
43 // be reported again when the traversal is done via a call to FinishVisit.
44 //
45 // A subclass must override at least
46 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
47 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
48 // The default Handle methods for (unary, binary) ops call
49 // (HandleElementwiseUnary, HandleElementwiseBinary).
50 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
51 // "unimplemented" error status.
52 //
53 // Note: this may change to an iterator in the future for flexibility purposes.
54 //
55 // Users should not use this class directly, but use the type-aliases
56 // DfsHloVisitor/ConstDfsHloVisitor instead.
57 template <typename HloInstructionPtr>
58 class DfsHloVisitorBase {
59   static_assert(
60       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
61           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
62       "Template argument expected to be HloInstruction* or const "
63       "HloInstruction*");
64 
65  public:
DfsHloVisitorBase()66   DfsHloVisitorBase() {}
~DfsHloVisitorBase()67   virtual ~DfsHloVisitorBase() {}
68 
69   // These routines are self-descriptive, see class comment for usage
70   // information.
71 
72   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
73   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
74 
75   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
76   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
HandleMaximum(HloInstructionPtr hlo)77   virtual Status HandleMaximum(HloInstructionPtr hlo) {
78     return HandleElementwiseBinary(hlo);
79   }
HandleMinimum(HloInstructionPtr hlo)80   virtual Status HandleMinimum(HloInstructionPtr hlo) {
81     return HandleElementwiseBinary(hlo);
82   }
83   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
HandleConvert(HloInstructionPtr hlo)84   virtual Status HandleConvert(HloInstructionPtr hlo) {
85     return HandleElementwiseUnary(hlo);
86   }
HandleBitcastConvert(HloInstructionPtr hlo)87   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
88     return HandleElementwiseUnary(hlo);
89   }
HandleCopy(HloInstructionPtr hlo)90   virtual Status HandleCopy(HloInstructionPtr hlo) {
91     return HandleElementwiseUnary(hlo);
92   }
HandleComplex(HloInstructionPtr hlo)93   virtual Status HandleComplex(HloInstructionPtr hlo) {
94     return HandleElementwiseBinary(hlo);
95   }
HandleMultiply(HloInstructionPtr hlo)96   virtual Status HandleMultiply(HloInstructionPtr hlo) {
97     return HandleElementwiseBinary(hlo);
98   }
99   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
HandlePower(HloInstructionPtr hlo)100   virtual Status HandlePower(HloInstructionPtr hlo) {
101     return HandleElementwiseBinary(hlo);
102   }
HandleSqrt(HloInstructionPtr hlo)103   virtual Status HandleSqrt(HloInstructionPtr hlo) {
104     return HandleElementwiseUnary(hlo);
105   }
HandleRsqrt(HloInstructionPtr hlo)106   virtual Status HandleRsqrt(HloInstructionPtr hlo) {
107     return HandleElementwiseUnary(hlo);
108   }
HandleCbrt(HloInstructionPtr hlo)109   virtual Status HandleCbrt(HloInstructionPtr hlo) {
110     return HandleElementwiseUnary(hlo);
111   }
112   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
113   virtual Status HandleFft(HloInstructionPtr fft) = 0;
114   virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
115   virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
116   virtual Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0;
117   virtual Status HandleAllGather(HloInstructionPtr hlo) = 0;
118   virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0;
119   virtual Status HandleAllGatherDone(HloInstructionPtr hlo) = 0;
120   virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
121   virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0;
122   virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0;
123   virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0;
124   virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
125   virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
126   virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
127   virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
128   virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
129   virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
130   virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
131   virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0;
HandleCompare(HloInstructionPtr hlo)132   virtual Status HandleCompare(HloInstructionPtr hlo) {
133     return HandleElementwiseBinary(hlo);
134   }
HandleAdd(HloInstructionPtr hlo)135   virtual Status HandleAdd(HloInstructionPtr hlo) {
136     return HandleElementwiseBinary(hlo);
137   }
HandleDivide(HloInstructionPtr hlo)138   virtual Status HandleDivide(HloInstructionPtr hlo) {
139     return HandleElementwiseBinary(hlo);
140   }
HandleRemainder(HloInstructionPtr hlo)141   virtual Status HandleRemainder(HloInstructionPtr hlo) {
142     return HandleElementwiseBinary(hlo);
143   }
HandleSubtract(HloInstructionPtr hlo)144   virtual Status HandleSubtract(HloInstructionPtr hlo) {
145     return HandleElementwiseBinary(hlo);
146   }
HandleAbs(HloInstructionPtr hlo)147   virtual Status HandleAbs(HloInstructionPtr hlo) {
148     return HandleElementwiseUnary(hlo);
149   }
HandleAtan2(HloInstructionPtr hlo)150   virtual Status HandleAtan2(HloInstructionPtr hlo) {
151     return HandleElementwiseBinary(hlo);
152   }
HandleRound(HloInstructionPtr hlo)153   virtual Status HandleRound(HloInstructionPtr hlo) {
154     return HandleElementwiseUnary(hlo);
155   }
HandleRoundNearestEven(HloInstructionPtr hlo)156   virtual Status HandleRoundNearestEven(HloInstructionPtr hlo) {
157     return HandleElementwiseUnary(hlo);
158   }
HandleLogistic(HloInstructionPtr hlo)159   virtual Status HandleLogistic(HloInstructionPtr hlo) {
160     return HandleElementwiseUnary(hlo);
161   }
HandleSign(HloInstructionPtr hlo)162   virtual Status HandleSign(HloInstructionPtr hlo) {
163     return HandleElementwiseUnary(hlo);
164   }
HandleNegate(HloInstructionPtr hlo)165   virtual Status HandleNegate(HloInstructionPtr hlo) {
166     return HandleElementwiseUnary(hlo);
167   }
HandleExp(HloInstructionPtr hlo)168   virtual Status HandleExp(HloInstructionPtr hlo) {
169     return HandleElementwiseUnary(hlo);
170   }
HandleExpm1(HloInstructionPtr hlo)171   virtual Status HandleExpm1(HloInstructionPtr hlo) {
172     return HandleElementwiseUnary(hlo);
173   }
HandleFloor(HloInstructionPtr hlo)174   virtual Status HandleFloor(HloInstructionPtr hlo) {
175     return HandleElementwiseUnary(hlo);
176   }
HandleCeil(HloInstructionPtr hlo)177   virtual Status HandleCeil(HloInstructionPtr hlo) {
178     return HandleElementwiseUnary(hlo);
179   }
HandleLog(HloInstructionPtr hlo)180   virtual Status HandleLog(HloInstructionPtr hlo) {
181     return HandleElementwiseUnary(hlo);
182   }
HandleClz(HloInstructionPtr hlo)183   virtual Status HandleClz(HloInstructionPtr hlo) {
184     return HandleElementwiseUnary(hlo);
185   }
HandleLog1p(HloInstructionPtr hlo)186   virtual Status HandleLog1p(HloInstructionPtr hlo) {
187     return HandleElementwiseUnary(hlo);
188   }
HandleCos(HloInstructionPtr hlo)189   virtual Status HandleCos(HloInstructionPtr hlo) {
190     return HandleElementwiseUnary(hlo);
191   }
HandleSin(HloInstructionPtr hlo)192   virtual Status HandleSin(HloInstructionPtr hlo) {
193     return HandleElementwiseUnary(hlo);
194   }
HandleTanh(HloInstructionPtr hlo)195   virtual Status HandleTanh(HloInstructionPtr hlo) {
196     return HandleElementwiseUnary(hlo);
197   }
HandleReal(HloInstructionPtr hlo)198   virtual Status HandleReal(HloInstructionPtr hlo) {
199     return HandleElementwiseUnary(hlo);
200   }
HandleImag(HloInstructionPtr hlo)201   virtual Status HandleImag(HloInstructionPtr hlo) {
202     return HandleElementwiseUnary(hlo);
203   }
HandleIsFinite(HloInstructionPtr hlo)204   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
205     return HandleElementwiseUnary(hlo);
206   }
HandleAnd(HloInstructionPtr hlo)207   virtual Status HandleAnd(HloInstructionPtr hlo) {
208     return HandleElementwiseBinary(hlo);
209   }
HandleNot(HloInstructionPtr hlo)210   virtual Status HandleNot(HloInstructionPtr hlo) {
211     return HandleElementwiseUnary(hlo);
212   }
HandleOr(HloInstructionPtr hlo)213   virtual Status HandleOr(HloInstructionPtr hlo) {
214     return HandleElementwiseBinary(hlo);
215   }
HandleXor(HloInstructionPtr hlo)216   virtual Status HandleXor(HloInstructionPtr hlo) {
217     return HandleElementwiseBinary(hlo);
218   }
HandlePopulationCount(HloInstructionPtr hlo)219   virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
220     return HandleElementwiseUnary(hlo);
221   }
HandleShiftLeft(HloInstructionPtr hlo)222   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
223     return HandleElementwiseBinary(hlo);
224   }
HandleShiftRightArithmetic(HloInstructionPtr hlo)225   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
226     return HandleElementwiseBinary(hlo);
227   }
HandleShiftRightLogical(HloInstructionPtr hlo)228   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
229     return HandleElementwiseBinary(hlo);
230   }
231 
HandleReducePrecision(HloInstructionPtr hlo)232   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
233     return HandleElementwiseUnary(hlo);
234   }
235 
HandleDomain(HloInstructionPtr hlo)236   virtual Status HandleDomain(HloInstructionPtr hlo) {
237     return HandleElementwiseUnary(hlo);
238   }
239 
240   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
241   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
242   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
243   virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0;
244   virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0;
245   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
246   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
247   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
248   virtual Status HandleIota(HloInstructionPtr hlo) = 0;
249   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
250   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
251   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
252   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
253   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
254   virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0;
255   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
256   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
257   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
258   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
259   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
260   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
261   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
262   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
263   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
264   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
265   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
266   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
267   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
268   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
269   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
270   virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
271 
272   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
273 
274   virtual Status HandleAsyncStart(HloInstructionPtr hlo) = 0;
275   virtual Status HandleAsyncUpdate(HloInstructionPtr hlo) = 0;
276   virtual Status HandleAsyncDone(HloInstructionPtr hlo) = 0;
277 
278   virtual Status HandleCopyStart(HloInstructionPtr copy_start) = 0;
279   virtual Status HandleCopyDone(HloInstructionPtr copy_done) = 0;
280 
281   virtual Status HandleSend(HloInstructionPtr send) = 0;
282   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
283 
284   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
285   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
286 
287   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
288 
289   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
290 
291   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
292 
293   virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
294   virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
295 
296   // Invoked to inform the visitor that the traversal has completed, and that
297   // the root was "root".
298   virtual Status FinishVisit(HloInstructionPtr root) = 0;
299 
300   // 3 possible visitation states of HLO instructions. Each instruction's
301   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
302   enum VisitState {
303     kNotVisited = 0,
304     kVisiting = 1,
305     kVisited = 2,
306   };
307 
GetVisitState(int id)308   VisitState GetVisitState(int id) {
309     auto iter = visit_state_.find(id);
310     if (iter == visit_state_.end()) {
311       return VisitState::kNotVisited;
312     }
313     return iter->second;
314   }
315   VisitState GetVisitState(const HloInstruction& instruction);
316 
317   // Resize internal state if necessary to hold state for ids <= num.
318   // This call is purely a performance hint and can be omitted without
319   // affecting correctness.
ReserveVisitStates(int num)320   void ReserveVisitStates(int num) { visit_state_.reserve(num); }
VisitStateCapacity()321   size_t VisitStateCapacity() const { return visit_state_.capacity(); }
322 
323   // Useful when we want to visit the same computation more than once with the
324   // same visitor.
ResetVisitStates()325   void ResetVisitStates() {
326     // Clear the map, but don't resize the capacity across uses -- Calculating
327     // and reserving space could be expensive, and we always use the same
328     // module->instruction_count() as the capacity.
329     visit_state_.erase(visit_state_.begin(), visit_state_.end());
330   }
331 
332   // Useful when we want to free up the memory used by the visit state without
333   // destroying the actual visitor subclass.
DestroyVisitState()334   void DestroyVisitState() {
335     visit_state_ = absl::flat_hash_map<int, VisitState>{};
336   }
337 
SetVisitState(int id,VisitState state)338   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
339 
340   // Sets the visitation state of the given instruction as kVisiting.
341   //
342   // Precondition: current state must be kNotVisited.
343   void SetVisiting(const HloInstruction& instruction);
344 
345   // Sets the visitation state of the given instruction as kVisited.
346   //
347   // Precondition: current state must be either kNotVisited or kVisiting.
348   void SetVisited(const HloInstruction& instruction);
349 
350   // Returns whether the state of the given instruction is kVisiting.
IsVisiting(const HloInstruction & instruction)351   bool IsVisiting(const HloInstruction& instruction) {
352     return GetVisitState(instruction) == kVisiting;
353   }
354 
355   // Returns whether the state of the given instruction is kVisited.
DidVisit(const HloInstruction & instruction)356   bool DidVisit(const HloInstruction& instruction) {
357     return GetVisitState(instruction) == kVisited;
358   }
359 
360   // Returns whether the state of the given instruction is kNotVisited.
NotVisited(const HloInstruction & instruction)361   bool NotVisited(const HloInstruction& instruction) {
362     return GetVisitState(instruction) == kNotVisited;
363   }
364 
365   // This method should be overridden by subclasses that wish to run some
366   // operation on an op before its Handle* visitor method is called.
367   //
368   // For any HLO op, the order of calls is:
369   //
370   //   Preprocess(op);
371   //   Handle/OpType/(op);
372   //   Postprocess(op);
373   //
374   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
375   // own preprocessing.
376   virtual Status Preprocess(HloInstructionPtr hlo);
377 
378   // This method should be overridden by subclasses that wish to run some
379   // operation on an op after its Handle* visitor method is called. See
380   // Preprocess for more details.
381   //
382   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
383   // own postprocessing.
384   virtual Status Postprocess(HloInstructionPtr hlo);
385 
386  private:
387   absl::flat_hash_map<int, VisitState> visit_state_;
388 
389   DfsHloVisitorBase(const DfsHloVisitorBase&) = delete;
390   DfsHloVisitorBase& operator=(const DfsHloVisitorBase&) = delete;
391 };
392 
393 // Explicit instantiations in dfs_hlo_visitor.cc.
394 extern template class DfsHloVisitorBase<HloInstruction*>;
395 extern template class DfsHloVisitorBase<const HloInstruction*>;
396 
397 // Users should use one of these two type aliases, which are the only two valid
398 // instantiations of DfsHloVisitorBase.
399 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
400 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
401 
402 }  // namespace xla
403 
404 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
405