• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 class HloComputation;
33 class HloInstruction;
34 
35 // DfsHloVisitor with default action based on the HloInstruction being visited.
36 // Users should not use this class directly, but use the type aliases
37 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
38 template <typename HloInstructionPtr>
39 class DfsHloVisitorWithDefaultBase
40     : public DfsHloVisitorBase<HloInstructionPtr> {
41  public:
DfsHloVisitorWithDefaultBase()42   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()43   ~DfsHloVisitorWithDefaultBase() override {}
44 
45   // Default action performed on HloInstruction.
46   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
47 
HandleElementwiseUnary(HloInstructionPtr hlo)48   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
49     return DefaultAction(hlo);
50   }
HandleElementwiseBinary(HloInstructionPtr hlo)51   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
52     return DefaultAction(hlo);
53   }
54 
HandleBatchNormTraining(HloInstructionPtr hlo)55   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
56     return DefaultAction(hlo);
57   }
58 
HandleBatchNormInference(HloInstructionPtr hlo)59   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
60     return DefaultAction(hlo);
61   }
62 
HandleBatchNormGrad(HloInstructionPtr hlo)63   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
64     return DefaultAction(hlo);
65   }
66 
HandleClamp(HloInstructionPtr clamp)67   Status HandleClamp(HloInstructionPtr clamp) override {
68     return DefaultAction(clamp);
69   }
HandleConcatenate(HloInstructionPtr concatenate)70   Status HandleConcatenate(HloInstructionPtr concatenate) override {
71     return DefaultAction(concatenate);
72   }
HandleConvert(HloInstructionPtr convert)73   Status HandleConvert(HloInstructionPtr convert) override {
74     return DefaultAction(convert);
75   }
HandleCopy(HloInstructionPtr copy)76   Status HandleCopy(HloInstructionPtr copy) override {
77     return DefaultAction(copy);
78   }
HandleSelect(HloInstructionPtr select)79   Status HandleSelect(HloInstructionPtr select) override {
80     return DefaultAction(select);
81   }
HandleDot(HloInstructionPtr dot)82   Status HandleDot(HloInstructionPtr dot) override {
83     return DefaultAction(dot);
84   }
HandleConvolution(HloInstructionPtr convolution)85   Status HandleConvolution(HloInstructionPtr convolution) override {
86     return DefaultAction(convolution);
87   }
HandleFft(HloInstructionPtr fft)88   Status HandleFft(HloInstructionPtr fft) override {
89     return DefaultAction(fft);
90   }
HandleCrossReplicaSum(HloInstructionPtr crs)91   Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
92     return DefaultAction(crs);
93   }
HandleCompare(HloInstructionPtr compare)94   Status HandleCompare(HloInstructionPtr compare) override {
95     return DefaultAction(compare);
96   }
HandleRng(HloInstructionPtr random)97   Status HandleRng(HloInstructionPtr random) override {
98     return DefaultAction(random);
99   }
HandleInfeed(HloInstructionPtr infeed)100   Status HandleInfeed(HloInstructionPtr infeed) override {
101     return DefaultAction(infeed);
102   }
HandleOutfeed(HloInstructionPtr outfeed)103   Status HandleOutfeed(HloInstructionPtr outfeed) override {
104     return DefaultAction(outfeed);
105   }
HandleHostCompute(HloInstructionPtr host_compute)106   Status HandleHostCompute(HloInstructionPtr host_compute) override {
107     return DefaultAction(host_compute);
108   }
HandleReverse(HloInstructionPtr reverse)109   Status HandleReverse(HloInstructionPtr reverse) override {
110     return DefaultAction(reverse);
111   }
HandleSort(HloInstructionPtr sort)112   Status HandleSort(HloInstructionPtr sort) override {
113     return DefaultAction(sort);
114   }
HandleConstant(HloInstructionPtr constant)115   Status HandleConstant(HloInstructionPtr constant) override {
116     return DefaultAction(constant);
117   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)118   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
119     return DefaultAction(get_tuple_element);
120   }
HandleParameter(HloInstructionPtr parameter)121   Status HandleParameter(HloInstructionPtr parameter) override {
122     return DefaultAction(parameter);
123   }
HandleFusion(HloInstructionPtr fusion)124   Status HandleFusion(HloInstructionPtr fusion) override {
125     return DefaultAction(fusion);
126   }
HandleCall(HloInstructionPtr call)127   Status HandleCall(HloInstructionPtr call) override {
128     return DefaultAction(call);
129   }
HandleCustomCall(HloInstructionPtr custom_call)130   Status HandleCustomCall(HloInstructionPtr custom_call) override {
131     return DefaultAction(custom_call);
132   }
HandleSlice(HloInstructionPtr slice)133   Status HandleSlice(HloInstructionPtr slice) override {
134     return DefaultAction(slice);
135   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)136   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
137     return DefaultAction(dynamic_slice);
138   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)139   Status HandleDynamicUpdateSlice(
140       HloInstructionPtr dynamic_update_slice) override {
141     return DefaultAction(dynamic_update_slice);
142   }
HandleTuple(HloInstructionPtr tuple)143   Status HandleTuple(HloInstructionPtr tuple) override {
144     return DefaultAction(tuple);
145   }
HandleMap(HloInstructionPtr map)146   Status HandleMap(HloInstructionPtr map) override {
147     return DefaultAction(map);
148   }
HandleReduce(HloInstructionPtr reduce)149   Status HandleReduce(HloInstructionPtr reduce) override {
150     return DefaultAction(reduce);
151   }
HandleReduceWindow(HloInstructionPtr reduce_window)152   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
153     return DefaultAction(reduce_window);
154   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)155   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
156     return DefaultAction(select_and_scatter);
157   }
HandleBitcast(HloInstructionPtr bitcast)158   Status HandleBitcast(HloInstructionPtr bitcast) override {
159     return DefaultAction(bitcast);
160   }
HandleBroadcast(HloInstructionPtr broadcast)161   Status HandleBroadcast(HloInstructionPtr broadcast) override {
162     return DefaultAction(broadcast);
163   }
HandlePad(HloInstructionPtr pad)164   Status HandlePad(HloInstructionPtr pad) override {
165     return DefaultAction(pad);
166   }
HandleReshape(HloInstructionPtr reshape)167   Status HandleReshape(HloInstructionPtr reshape) override {
168     return DefaultAction(reshape);
169   }
HandleTranspose(HloInstructionPtr transpose)170   Status HandleTranspose(HloInstructionPtr transpose) override {
171     return DefaultAction(transpose);
172   }
HandleWhile(HloInstructionPtr xla_while)173   Status HandleWhile(HloInstructionPtr xla_while) override {
174     return DefaultAction(xla_while);
175   }
HandleConditional(HloInstructionPtr conditional)176   Status HandleConditional(HloInstructionPtr conditional) override {
177     return DefaultAction(conditional);
178   }
HandleRecv(HloInstructionPtr recv)179   Status HandleRecv(HloInstructionPtr recv) override {
180     return DefaultAction(recv);
181   }
HandleRecvDone(HloInstructionPtr recv_done)182   Status HandleRecvDone(HloInstructionPtr recv_done) override {
183     return DefaultAction(recv_done);
184   }
HandleSend(HloInstructionPtr send)185   Status HandleSend(HloInstructionPtr send) override {
186     return DefaultAction(send);
187   }
HandleSendDone(HloInstructionPtr send_done)188   Status HandleSendDone(HloInstructionPtr send_done) override {
189     return DefaultAction(send_done);
190   }
HandleGather(HloInstructionPtr gather)191   Status HandleGather(HloInstructionPtr gather) override {
192     return DefaultAction(gather);
193   }
194 
195   // Invoked to inform the visitor that the traversal has completed, and that
196   // the root was "root".
FinishVisit(HloInstructionPtr)197   Status FinishVisit(HloInstructionPtr /*root*/) override {
198     return Status::OK();
199   }
200 
201  private:
202   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
203 };
204 
205 // Users should use these type aliases which are only two valid instantiations.
206 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
207 using ConstDfsHloVisitorWithDefault =
208     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
209 
210 // (Const)FunctionVisitor lets you transform an
211 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
212 //
213 // This is useful if you have code that needs to handle visitors in the form of
214 // both std::function and DfsHloVisitor.  You can wrap the function in a
215 // FunctionVisitor and then treat it like any other DfsHloVisitor.
216 template <typename HloInstructionPtr>
217 class FunctionVisitorBase
218     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
219  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)220   explicit FunctionVisitorBase(
221       std::function<Status(HloInstructionPtr)> visitor_func)
222       : visitor_func_(std::move(visitor_func)) {}
223 
DefaultAction(HloInstructionPtr hlo_instruction)224   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
225     return visitor_func_(hlo_instruction);
226   }
227 
228  private:
229   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
230 
231   std::function<Status(HloInstructionPtr)> visitor_func_;
232 };
233 
234 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
235 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
236 
237 }  // namespace xla
238 
239 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
240