• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
18 
19 #include <type_traits>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 class HloComputation;
37 class HloInstruction;
38 
39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
40 // instruction, all its operands were already visited. User code can subclass
41 // this to iterate over an HloInstruction DAG. The Handle* routines have
42 // operands / data unpacked for ease of use in the visitor subclass.
43 //
44 // No instruction will ever be visited twice; however, the root instruction will
45 // be reported again when the traversal is done via a call to FinishVisit.
46 //
47 // A subclass must override at least
48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
50 // The default Handle methods for (unary, binary) ops call
51 // (HandleElementwiseUnary, HandleElementwiseBinary).
52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
53 // "unimplemented" error status.
54 //
55 // Note: this may change to an iterator in the future for flexibility purposes.
56 //
57 // Users should not use this class directly, but use the type-aliases
58 // DfsHloVisitor/ConstDfsHloVisitor instead.
59 template <typename HloInstructionPtr>
60 class DfsHloVisitorBase {
61   static_assert(
62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
64       "Template argument expected to be HloInstruction* or const "
65       "HloInstruction*");
66 
67  public:
DfsHloVisitorBase()68   DfsHloVisitorBase() {}
~DfsHloVisitorBase()69   virtual ~DfsHloVisitorBase() {}
70 
71   // These routines are self-descriptive, see class comment for usage
72   // information.
73 
74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
76 
77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
79   virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
HandleMaximum(HloInstructionPtr hlo)80   virtual Status HandleMaximum(HloInstructionPtr hlo) {
81     return HandleElementwiseBinary(hlo);
82   }
HandleMinimum(HloInstructionPtr hlo)83   virtual Status HandleMinimum(HloInstructionPtr hlo) {
84     return HandleElementwiseBinary(hlo);
85   }
86   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
HandleConvert(HloInstructionPtr hlo)87   virtual Status HandleConvert(HloInstructionPtr hlo) {
88     return HandleElementwiseUnary(hlo);
89   }
HandleBitcastConvert(HloInstructionPtr hlo)90   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
91     return HandleElementwiseUnary(hlo);
92   }
HandleCopy(HloInstructionPtr hlo)93   virtual Status HandleCopy(HloInstructionPtr hlo) {
94     return HandleElementwiseUnary(hlo);
95   }
HandleComplex(HloInstructionPtr hlo)96   virtual Status HandleComplex(HloInstructionPtr hlo) {
97     return HandleElementwiseBinary(hlo);
98   }
HandleMultiply(HloInstructionPtr hlo)99   virtual Status HandleMultiply(HloInstructionPtr hlo) {
100     return HandleElementwiseBinary(hlo);
101   }
102   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
HandlePower(HloInstructionPtr hlo)103   virtual Status HandlePower(HloInstructionPtr hlo) {
104     return HandleElementwiseBinary(hlo);
105   }
HandleSqrt(HloInstructionPtr hlo)106   virtual Status HandleSqrt(HloInstructionPtr hlo) {
107     return HandleElementwiseUnary(hlo);
108   }
HandleRsqrt(HloInstructionPtr hlo)109   virtual Status HandleRsqrt(HloInstructionPtr hlo) {
110     return HandleElementwiseUnary(hlo);
111   }
HandleCbrt(HloInstructionPtr hlo)112   virtual Status HandleCbrt(HloInstructionPtr hlo) {
113     return HandleElementwiseUnary(hlo);
114   }
115   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
116   virtual Status HandleFft(HloInstructionPtr fft) = 0;
117   virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
118   virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
119   virtual Status HandleAllGather(HloInstructionPtr hlo) = 0;
120   virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0;
121   virtual Status HandleAllGatherDone(HloInstructionPtr hlo) = 0;
122   virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
123   virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0;
124   virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0;
125   virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0;
126   virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
127   virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
128   virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0;
129   virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0;
130   virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
131   virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0;
132   virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
133   virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0;
HandleCompare(HloInstructionPtr hlo)134   virtual Status HandleCompare(HloInstructionPtr hlo) {
135     return HandleElementwiseBinary(hlo);
136   }
HandleAdd(HloInstructionPtr hlo)137   virtual Status HandleAdd(HloInstructionPtr hlo) {
138     return HandleElementwiseBinary(hlo);
139   }
HandleDivide(HloInstructionPtr hlo)140   virtual Status HandleDivide(HloInstructionPtr hlo) {
141     return HandleElementwiseBinary(hlo);
142   }
HandleRemainder(HloInstructionPtr hlo)143   virtual Status HandleRemainder(HloInstructionPtr hlo) {
144     return HandleElementwiseBinary(hlo);
145   }
HandleSubtract(HloInstructionPtr hlo)146   virtual Status HandleSubtract(HloInstructionPtr hlo) {
147     return HandleElementwiseBinary(hlo);
148   }
HandleAbs(HloInstructionPtr hlo)149   virtual Status HandleAbs(HloInstructionPtr hlo) {
150     return HandleElementwiseUnary(hlo);
151   }
HandleAtan2(HloInstructionPtr hlo)152   virtual Status HandleAtan2(HloInstructionPtr hlo) {
153     return HandleElementwiseBinary(hlo);
154   }
HandleRound(HloInstructionPtr hlo)155   virtual Status HandleRound(HloInstructionPtr hlo) {
156     return HandleElementwiseUnary(hlo);
157   }
HandleLogistic(HloInstructionPtr hlo)158   virtual Status HandleLogistic(HloInstructionPtr hlo) {
159     return HandleElementwiseUnary(hlo);
160   }
HandleSign(HloInstructionPtr hlo)161   virtual Status HandleSign(HloInstructionPtr hlo) {
162     return HandleElementwiseUnary(hlo);
163   }
HandleNegate(HloInstructionPtr hlo)164   virtual Status HandleNegate(HloInstructionPtr hlo) {
165     return HandleElementwiseUnary(hlo);
166   }
HandleExp(HloInstructionPtr hlo)167   virtual Status HandleExp(HloInstructionPtr hlo) {
168     return HandleElementwiseUnary(hlo);
169   }
HandleExpm1(HloInstructionPtr hlo)170   virtual Status HandleExpm1(HloInstructionPtr hlo) {
171     return HandleElementwiseUnary(hlo);
172   }
HandleFloor(HloInstructionPtr hlo)173   virtual Status HandleFloor(HloInstructionPtr hlo) {
174     return HandleElementwiseUnary(hlo);
175   }
HandleCeil(HloInstructionPtr hlo)176   virtual Status HandleCeil(HloInstructionPtr hlo) {
177     return HandleElementwiseUnary(hlo);
178   }
HandleLog(HloInstructionPtr hlo)179   virtual Status HandleLog(HloInstructionPtr hlo) {
180     return HandleElementwiseUnary(hlo);
181   }
HandleClz(HloInstructionPtr hlo)182   virtual Status HandleClz(HloInstructionPtr hlo) {
183     return HandleElementwiseUnary(hlo);
184   }
HandleLog1p(HloInstructionPtr hlo)185   virtual Status HandleLog1p(HloInstructionPtr hlo) {
186     return HandleElementwiseUnary(hlo);
187   }
HandleCos(HloInstructionPtr hlo)188   virtual Status HandleCos(HloInstructionPtr hlo) {
189     return HandleElementwiseUnary(hlo);
190   }
HandleSin(HloInstructionPtr hlo)191   virtual Status HandleSin(HloInstructionPtr hlo) {
192     return HandleElementwiseUnary(hlo);
193   }
HandleTanh(HloInstructionPtr hlo)194   virtual Status HandleTanh(HloInstructionPtr hlo) {
195     return HandleElementwiseUnary(hlo);
196   }
HandleReal(HloInstructionPtr hlo)197   virtual Status HandleReal(HloInstructionPtr hlo) {
198     return HandleElementwiseUnary(hlo);
199   }
HandleImag(HloInstructionPtr hlo)200   virtual Status HandleImag(HloInstructionPtr hlo) {
201     return HandleElementwiseUnary(hlo);
202   }
HandleIsFinite(HloInstructionPtr hlo)203   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
204     return HandleElementwiseUnary(hlo);
205   }
HandleAnd(HloInstructionPtr hlo)206   virtual Status HandleAnd(HloInstructionPtr hlo) {
207     return HandleElementwiseBinary(hlo);
208   }
HandleNot(HloInstructionPtr hlo)209   virtual Status HandleNot(HloInstructionPtr hlo) {
210     return HandleElementwiseUnary(hlo);
211   }
HandleOr(HloInstructionPtr hlo)212   virtual Status HandleOr(HloInstructionPtr hlo) {
213     return HandleElementwiseBinary(hlo);
214   }
HandleXor(HloInstructionPtr hlo)215   virtual Status HandleXor(HloInstructionPtr hlo) {
216     return HandleElementwiseBinary(hlo);
217   }
HandlePopulationCount(HloInstructionPtr hlo)218   virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
219     return HandleElementwiseUnary(hlo);
220   }
HandleShiftLeft(HloInstructionPtr hlo)221   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
222     return HandleElementwiseBinary(hlo);
223   }
HandleShiftRightArithmetic(HloInstructionPtr hlo)224   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
225     return HandleElementwiseBinary(hlo);
226   }
HandleShiftRightLogical(HloInstructionPtr hlo)227   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
228     return HandleElementwiseBinary(hlo);
229   }
230 
HandleReducePrecision(HloInstructionPtr hlo)231   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
232     return HandleElementwiseUnary(hlo);
233   }
234 
HandleDomain(HloInstructionPtr hlo)235   virtual Status HandleDomain(HloInstructionPtr hlo) {
236     return HandleElementwiseUnary(hlo);
237   }
238 
239   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
240   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
241   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
242   virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0;
243   virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0;
244   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
245   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
246   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
247   virtual Status HandleIota(HloInstructionPtr hlo) = 0;
248   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
249   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
250   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
251   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
252   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
253   virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0;
254   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
255   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
256   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
257   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
258   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
259   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
260   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
261   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
262   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
263   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
264   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
265   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
266   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
267   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
268   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
269   virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
270 
271   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
272 
273   virtual Status HandleCopyStart(HloInstructionPtr copy_start) = 0;
274   virtual Status HandleCopyDone(HloInstructionPtr copy_done) = 0;
275 
276   virtual Status HandleSend(HloInstructionPtr send) = 0;
277   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
278 
279   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
280   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
281 
282   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
283 
284   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
285 
286   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
287 
288   virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
289   virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
290 
291   // Invoked to inform the visitor that the traversal has completed, and that
292   // the root was "root".
293   virtual Status FinishVisit(HloInstructionPtr root) = 0;
294 
295   // 3 possible visitation states of HLO instructions. Each instruction's
296   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
297   enum VisitState {
298     kNotVisited = 0,
299     kVisiting = 1,
300     kVisited = 2,
301   };
302 
GetVisitState(int id)303   VisitState GetVisitState(int id) {
304     auto iter = visit_state_.find(id);
305     if (iter == visit_state_.end()) {
306       return VisitState::kNotVisited;
307     }
308     return iter->second;
309   }
310   VisitState GetVisitState(const HloInstruction& instruction);
311 
312   // Resize internal state if necessary to hold state for ids <= num.
313   // This call is purely a performance hint and can be omitted without
314   // affecting correctness.
ReserveVisitStates(int num)315   void ReserveVisitStates(int num) { visit_state_.reserve(num); }
VisitStateCapacity()316   size_t VisitStateCapacity() const { return visit_state_.capacity(); }
317 
318   // Useful when we want to visit the same computation more than once with the
319   // same visitor.
ResetVisitStates()320   void ResetVisitStates() {
321     // Clear the map, but don't resize the capacity across uses -- Calculating
322     // and reserving space could be expensive, and we always use the same
323     // module->instruction_count() as the capacity.
324     visit_state_.erase(visit_state_.begin(), visit_state_.end());
325   }
326 
327   // Useful when we want to free up the memory used by the visit state without
328   // destroying the actual visitor subclass.
DestroyVisitState()329   void DestroyVisitState() {
330     visit_state_ = absl::flat_hash_map<int, VisitState>{};
331   }
332 
SetVisitState(int id,VisitState state)333   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
334 
335   // Sets the visitation state of the given instruction as kVisiting.
336   //
337   // Precondition: current state must be kNotVisited.
338   void SetVisiting(const HloInstruction& instruction);
339 
340   // Sets the visitation state of the given instruction as kVisited.
341   //
342   // Precondition: current state must be either kNotVisited or kVisiting.
343   void SetVisited(const HloInstruction& instruction);
344 
345   // Returns whether the state of the given instruction is kVisiting.
IsVisiting(const HloInstruction & instruction)346   bool IsVisiting(const HloInstruction& instruction) {
347     return GetVisitState(instruction) == kVisiting;
348   }
349 
350   // Returns whether the state of the given instruction is kVisited.
DidVisit(const HloInstruction & instruction)351   bool DidVisit(const HloInstruction& instruction) {
352     return GetVisitState(instruction) == kVisited;
353   }
354 
355   // Returns whether the state of the given instruction is kNotVisited.
NotVisited(const HloInstruction & instruction)356   bool NotVisited(const HloInstruction& instruction) {
357     return GetVisitState(instruction) == kNotVisited;
358   }
359 
360   // This method should be overridden by subclasses that wish to run some
361   // operation on an op before its Handle* visitor method is called.
362   //
363   // For any HLO op, the order of calls is:
364   //
365   //   Preprocess(op);
366   //   Handle/OpType/(op);
367   //   Postprocess(op);
368   //
369   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
370   // own preprocessing.
371   virtual Status Preprocess(HloInstructionPtr hlo);
372 
373   // This method should be overridden by subclasses that wish to run some
374   // operation on an op after its Handle* visitor method is called. See
375   // Preprocess for more details.
376   //
377   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
378   // own postprocessing.
379   virtual Status Postprocess(HloInstructionPtr hlo);
380 
381  private:
382   absl::flat_hash_map<int, VisitState> visit_state_;
383 
384   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
385 };
386 
387 // Explicit instantiations in dfs_hlo_visitor.cc.
388 extern template class DfsHloVisitorBase<HloInstruction*>;
389 extern template class DfsHloVisitorBase<const HloInstruction*>;
390 
391 // Users should use one of these two type aliases, which are the only two valid
392 // instantiations of DfsHloVisitorBase.
393 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
394 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
395 
396 }  // namespace xla
397 
398 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
399