• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
18 
19 #include <type_traits>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/status.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/stringpiece.h"
29 #include "tensorflow/core/lib/gtl/array_slice.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 class HloComputation;
37 class HloInstruction;
38 
39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
40 // instruction, all its operands were already visited. User code can subclass
41 // this to iterate over an HloInstruction DAG. The Handle* routines have
42 // operands / data unpacked for ease of use in the visitor subclass.
43 //
44 // No instruction will ever be visited twice; however, the root instruction will
45 // be reported again when the traversal is done via a call to FinishVisit.
46 //
47 // A subclass must override at least
48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
50 // The default Handle methods for (unary, binary) ops call
51 // (HandleElementwiseUnary, HandleElementwiseBinary).
52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
53 // "unimplemented" error status.
54 //
55 // Note: this may change to an iterator in the future for flexibility purposes.
56 //
57 // Users should not use this class directly, but use the type-aliases
58 // DfsHloVisitor/ConstDfsHloVisitor instead.
59 template <typename HloInstructionPtr>
60 class DfsHloVisitorBase {
61   static_assert(
62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
64       "Template argument expected to be HloInstruction* or const "
65       "HloInstruction*");
66 
67  public:
DfsHloVisitorBase()68   DfsHloVisitorBase() {}
~DfsHloVisitorBase()69   virtual ~DfsHloVisitorBase() {}
70 
71   // These routines are self-descriptive, see class comment for usage
72   // information.
73 
74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
76 
77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
HandleMaximum(HloInstructionPtr hlo)79   virtual Status HandleMaximum(HloInstructionPtr hlo) {
80     return HandleElementwiseBinary(hlo);
81   }
HandleMinimum(HloInstructionPtr hlo)82   virtual Status HandleMinimum(HloInstructionPtr hlo) {
83     return HandleElementwiseBinary(hlo);
84   }
85   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
HandleConvert(HloInstructionPtr hlo)86   virtual Status HandleConvert(HloInstructionPtr hlo) {
87     return HandleElementwiseUnary(hlo);
88   }
HandleBitcastConvert(HloInstructionPtr hlo)89   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
90     return HandleElementwiseUnary(hlo);
91   }
HandleCopy(HloInstructionPtr hlo)92   virtual Status HandleCopy(HloInstructionPtr hlo) {
93     return HandleElementwiseUnary(hlo);
94   }
HandleComplex(HloInstructionPtr hlo)95   virtual Status HandleComplex(HloInstructionPtr hlo) {
96     return HandleElementwiseBinary(hlo);
97   }
HandleMultiply(HloInstructionPtr hlo)98   virtual Status HandleMultiply(HloInstructionPtr hlo) {
99     return HandleElementwiseBinary(hlo);
100   }
101   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
HandlePower(HloInstructionPtr hlo)102   virtual Status HandlePower(HloInstructionPtr hlo) {
103     return HandleElementwiseBinary(hlo);
104   }
105   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
106   virtual Status HandleFft(HloInstructionPtr fft) = 0;
107   virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
HandleCompare(HloInstructionPtr hlo)108   virtual Status HandleCompare(HloInstructionPtr hlo) {
109     return HandleElementwiseBinary(hlo);
110   }
HandleAdd(HloInstructionPtr hlo)111   virtual Status HandleAdd(HloInstructionPtr hlo) {
112     return HandleElementwiseBinary(hlo);
113   }
HandleDivide(HloInstructionPtr hlo)114   virtual Status HandleDivide(HloInstructionPtr hlo) {
115     return HandleElementwiseBinary(hlo);
116   }
HandleRemainder(HloInstructionPtr hlo)117   virtual Status HandleRemainder(HloInstructionPtr hlo) {
118     return HandleElementwiseBinary(hlo);
119   }
HandleSubtract(HloInstructionPtr hlo)120   virtual Status HandleSubtract(HloInstructionPtr hlo) {
121     return HandleElementwiseBinary(hlo);
122   }
HandleAbs(HloInstructionPtr hlo)123   virtual Status HandleAbs(HloInstructionPtr hlo) {
124     return HandleElementwiseUnary(hlo);
125   }
HandleAtan2(HloInstructionPtr hlo)126   virtual Status HandleAtan2(HloInstructionPtr hlo) {
127     return HandleElementwiseBinary(hlo);
128   }
HandleRound(HloInstructionPtr hlo)129   virtual Status HandleRound(HloInstructionPtr hlo) {
130     return HandleElementwiseUnary(hlo);
131   }
HandleSign(HloInstructionPtr hlo)132   virtual Status HandleSign(HloInstructionPtr hlo) {
133     return HandleElementwiseUnary(hlo);
134   }
HandleNegate(HloInstructionPtr hlo)135   virtual Status HandleNegate(HloInstructionPtr hlo) {
136     return HandleElementwiseUnary(hlo);
137   }
HandleExp(HloInstructionPtr hlo)138   virtual Status HandleExp(HloInstructionPtr hlo) {
139     return HandleElementwiseUnary(hlo);
140   }
HandleFloor(HloInstructionPtr hlo)141   virtual Status HandleFloor(HloInstructionPtr hlo) {
142     return HandleElementwiseUnary(hlo);
143   }
HandleCeil(HloInstructionPtr hlo)144   virtual Status HandleCeil(HloInstructionPtr hlo) {
145     return HandleElementwiseUnary(hlo);
146   }
HandleLog(HloInstructionPtr hlo)147   virtual Status HandleLog(HloInstructionPtr hlo) {
148     return HandleElementwiseUnary(hlo);
149   }
HandleCos(HloInstructionPtr hlo)150   virtual Status HandleCos(HloInstructionPtr hlo) {
151     return HandleElementwiseUnary(hlo);
152   }
HandleSin(HloInstructionPtr hlo)153   virtual Status HandleSin(HloInstructionPtr hlo) {
154     return HandleElementwiseUnary(hlo);
155   }
HandleTanh(HloInstructionPtr hlo)156   virtual Status HandleTanh(HloInstructionPtr hlo) {
157     return HandleElementwiseUnary(hlo);
158   }
HandleReal(HloInstructionPtr hlo)159   virtual Status HandleReal(HloInstructionPtr hlo) {
160     return HandleElementwiseUnary(hlo);
161   }
HandleImag(HloInstructionPtr hlo)162   virtual Status HandleImag(HloInstructionPtr hlo) {
163     return HandleElementwiseUnary(hlo);
164   }
HandleIsFinite(HloInstructionPtr hlo)165   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
166     return HandleElementwiseUnary(hlo);
167   }
HandleAnd(HloInstructionPtr hlo)168   virtual Status HandleAnd(HloInstructionPtr hlo) {
169     return HandleElementwiseBinary(hlo);
170   }
HandleNot(HloInstructionPtr hlo)171   virtual Status HandleNot(HloInstructionPtr hlo) {
172     return HandleElementwiseUnary(hlo);
173   }
HandleOr(HloInstructionPtr hlo)174   virtual Status HandleOr(HloInstructionPtr hlo) {
175     return HandleElementwiseBinary(hlo);
176   }
HandleShiftLeft(HloInstructionPtr hlo)177   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
178     return HandleElementwiseBinary(hlo);
179   }
HandleShiftRightArithmetic(HloInstructionPtr hlo)180   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
181     return HandleElementwiseBinary(hlo);
182   }
HandleShiftRightLogical(HloInstructionPtr hlo)183   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
184     return HandleElementwiseBinary(hlo);
185   }
186 
HandleReducePrecision(HloInstructionPtr hlo)187   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
188     return HandleElementwiseUnary(hlo);
189   }
190 
191   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
192   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
193   virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
194   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
195   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
196   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
197   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
198   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
199   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
200   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
201   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
202   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
203   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
204   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
205   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
206   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
207   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
208   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
209   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
210   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
211   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
212   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
213   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
214   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
215   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
216   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
217   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
218 
219   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
220 
221   virtual Status HandleSend(HloInstructionPtr send) = 0;
222   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
223 
224   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
225   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
226 
227   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
228 
229   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
230 
231   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
232 
233   // Invoked to inform the visitor that the traversal has completed, and that
234   // the root was "root".
235   virtual Status FinishVisit(HloInstructionPtr root) = 0;
236 
237   // 3 possible visitation states of HLO instructions. Each instruction's
238   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
239   enum VisitState {
240     kNotVisited = 0,
241     kVisiting = 1,
242     kVisited = 2,
243   };
244 
GetVisitState(int id)245   VisitState GetVisitState(int id) { return visit_state_.GetState(id); }
246   VisitState GetVisitState(const HloInstruction& instruction);
247 
248   // Resize internal state if necessary to hold state for ids <= num.
249   // This call is purely a performance hint and can be omitted without
250   // affecting correctness.
ReserveVisitStates(int num)251   void ReserveVisitStates(int num) { visit_state_.Reserve(num); }
252 
253   // Useful when we want to visit the same computation more than once with the
254   // same visitor.
ResetVisitStates()255   void ResetVisitStates() { visit_state_.Reset(); }
256 
SetVisitState(int id,VisitState state)257   void SetVisitState(int id, VisitState state) {
258     visit_state_.SetState(id, state);
259   }
260 
261   // Sets the visitation state of the given instruction as kVisiting.
262   //
263   // Precondition: current state must be kNotVisited.
264   void SetVisiting(const HloInstruction& instruction);
265 
266   // Sets the visitation state of the given instruction as kVisited.
267   //
268   // Precondition: current state must be either kNotVisited or kVisiting.
269   void SetVisited(const HloInstruction& instruction);
270 
271   // Returns whether the state of the given instruction is kVisiting.
IsVisiting(const HloInstruction & instruction)272   bool IsVisiting(const HloInstruction& instruction) {
273     return GetVisitState(instruction) == kVisiting;
274   }
275 
276   // Returns whether the state of the given instruction is kVisited.
DidVisit(const HloInstruction & instruction)277   bool DidVisit(const HloInstruction& instruction) {
278     return GetVisitState(instruction) == kVisited;
279   }
280 
281   // Returns whether the state of the given instruction is kNotVisited.
NotVisited(const HloInstruction & instruction)282   bool NotVisited(const HloInstruction& instruction) {
283     return GetVisitState(instruction) == kNotVisited;
284   }
285 
286   // This method should be overridden by subclasses that wish to run some
287   // operation on an op before its Handle* visitor method is called.
288   //
289   // For any HLO op, the order of calls is:
290   //
291   //   Preprocess(op);
292   //   Handle/OpType/(op);
293   //   Postprocess(op);
294   //
295   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
296   // own preprocessing.
297   virtual Status Preprocess(HloInstructionPtr hlo);
298 
299   // This method should be overridden by subclasses that wish to run some
300   // operation on an op after its Handle* visitor method is called. See
301   // Preprocess for more details.
302   //
303   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
304   // own postprocessing.
305   virtual Status Postprocess(HloInstructionPtr hlo);
306 
307  private:
308   class DFSVisitStates {
309    public:
DFSVisitStates()310     DFSVisitStates() {}
Reserve(uint64 num)311     void Reserve(uint64 num) {
312       states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord);
313     }
GetState(uint64 id)314     VisitState GetState(uint64 id) {
315       uint64 word_index = id / kStatesPerWord;
316       if (word_index >= states_.size()) {
317         return VisitState::kNotVisited;
318       }
319       static_assert(static_cast<int>(VisitState::kVisited) < 3,
320                     "VisitState must fit in two bits");
321       uint64 w = states_[word_index];
322       uint32 shift = 2 * (id % kStatesPerWord);  // 2 bits per state
323       return static_cast<VisitState>((w >> shift) & 0x3);
324     }
SetState(uint64 id,VisitState state)325     void SetState(uint64 id, VisitState state) {
326       uint64 word_index = id / kStatesPerWord;
327       if (word_index >= states_.size()) {
328         states_.resize(word_index + 1, 0);
329       }
330       uint64* w = &states_[word_index];
331       uint32 shift = 2 * (id % kStatesPerWord);  // 2 bits per state
332       uint64 mask = 0x3ull << shift;
333       *w = (*w & ~mask) | (static_cast<uint64>(state) << shift);
334       DCHECK_EQ(GetState(id), state);
335     }
Reset()336     void Reset() { states_.clear(); }
337 
338    private:
339     static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
340     // Map from id to two-bit states.  We store 32 such states per 64-bit
341     // value
342     std::vector<uint64> states_;
343   };
344 
345   DFSVisitStates visit_state_;
346 
347   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
348 };
349 
350 // Users should use one of these two type aliases, which are the only two valid
351 // instantiations of DfsHloVisitorBase.
352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
354 
355 }  // namespace xla
356 
357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
358