• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // DfsHloVisitor with default action based on the HloInstruction being visited.
37 // Users should not use this class directly, but use the type aliases
38 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
39 //
40 // Do *not* add an override to this class if the opcode is covered by
41 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
42 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
43 // here will break passes which rely on the HandleElementwiseUnary/Binary
44 // handling these opcodes.
45 template <typename HloInstructionPtr>
46 class DfsHloVisitorWithDefaultBase
47     : public DfsHloVisitorBase<HloInstructionPtr> {
48  public:
DfsHloVisitorWithDefaultBase()49   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()50   ~DfsHloVisitorWithDefaultBase() override {}
51 
52   // Default action performed on HloInstruction.
53   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
54 
HandleElementwiseUnary(HloInstructionPtr hlo)55   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
56     return DefaultAction(hlo);
57   }
HandleElementwiseBinary(HloInstructionPtr hlo)58   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
59     return DefaultAction(hlo);
60   }
61 
HandleBatchNormTraining(HloInstructionPtr hlo)62   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
63     return DefaultAction(hlo);
64   }
65 
HandleBatchNormInference(HloInstructionPtr hlo)66   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
67     return DefaultAction(hlo);
68   }
69 
HandleBatchNormGrad(HloInstructionPtr hlo)70   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
71     return DefaultAction(hlo);
72   }
73 
HandleClamp(HloInstructionPtr clamp)74   Status HandleClamp(HloInstructionPtr clamp) override {
75     return DefaultAction(clamp);
76   }
HandleConcatenate(HloInstructionPtr concatenate)77   Status HandleConcatenate(HloInstructionPtr concatenate) override {
78     return DefaultAction(concatenate);
79   }
HandleSelect(HloInstructionPtr select)80   Status HandleSelect(HloInstructionPtr select) override {
81     return DefaultAction(select);
82   }
HandleTupleSelect(HloInstructionPtr tuple_select)83   Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
84     return DefaultAction(tuple_select);
85   }
HandleDot(HloInstructionPtr dot)86   Status HandleDot(HloInstructionPtr dot) override {
87     return DefaultAction(dot);
88   }
HandleConvolution(HloInstructionPtr convolution)89   Status HandleConvolution(HloInstructionPtr convolution) override {
90     return DefaultAction(convolution);
91   }
HandleFft(HloInstructionPtr fft)92   Status HandleFft(HloInstructionPtr fft) override {
93     return DefaultAction(fft);
94   }
HandleTriangularSolve(HloInstructionPtr hlo)95   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
96     return DefaultAction(hlo);
97   }
HandleCholesky(HloInstructionPtr hlo)98   Status HandleCholesky(HloInstructionPtr hlo) override {
99     return DefaultAction(hlo);
100   }
HandleAllGather(HloInstructionPtr crs)101   Status HandleAllGather(HloInstructionPtr crs) override {
102     return DefaultAction(crs);
103   }
HandleAllGatherStart(HloInstructionPtr crs)104   Status HandleAllGatherStart(HloInstructionPtr crs) override {
105     return DefaultAction(crs);
106   }
HandleAllGatherDone(HloInstructionPtr crs)107   Status HandleAllGatherDone(HloInstructionPtr crs) override {
108     return DefaultAction(crs);
109   }
HandleAllReduce(HloInstructionPtr crs)110   Status HandleAllReduce(HloInstructionPtr crs) override {
111     return DefaultAction(crs);
112   }
HandleReduceScatter(HloInstructionPtr hlo)113   Status HandleReduceScatter(HloInstructionPtr hlo) override {
114     return DefaultAction(hlo);
115   }
HandleAllReduceStart(HloInstructionPtr hlo)116   Status HandleAllReduceStart(HloInstructionPtr hlo) override {
117     return DefaultAction(hlo);
118   }
HandleAllReduceDone(HloInstructionPtr hlo)119   Status HandleAllReduceDone(HloInstructionPtr hlo) override {
120     return DefaultAction(hlo);
121   }
HandleAllToAll(HloInstructionPtr hlo)122   Status HandleAllToAll(HloInstructionPtr hlo) override {
123     return DefaultAction(hlo);
124   }
HandleCollectivePermute(HloInstructionPtr hlo)125   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
126     return DefaultAction(hlo);
127   }
HandleCollectivePermuteStart(HloInstructionPtr hlo)128   Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
129     return DefaultAction(hlo);
130   }
HandleCollectivePermuteDone(HloInstructionPtr hlo)131   Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
132     return DefaultAction(hlo);
133   }
HandleReplicaId(HloInstructionPtr hlo)134   Status HandleReplicaId(HloInstructionPtr hlo) override {
135     return DefaultAction(hlo);
136   }
HandlePartitionId(HloInstructionPtr hlo)137   Status HandlePartitionId(HloInstructionPtr hlo) override {
138     return DefaultAction(hlo);
139   }
HandleRng(HloInstructionPtr random)140   Status HandleRng(HloInstructionPtr random) override {
141     return DefaultAction(random);
142   }
HandleRngBitGenerator(HloInstructionPtr random)143   Status HandleRngBitGenerator(HloInstructionPtr random) override {
144     return DefaultAction(random);
145   }
HandleRngGetAndUpdateState(HloInstructionPtr random)146   Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
147     return DefaultAction(random);
148   }
HandleInfeed(HloInstructionPtr infeed)149   Status HandleInfeed(HloInstructionPtr infeed) override {
150     return DefaultAction(infeed);
151   }
HandleOutfeed(HloInstructionPtr outfeed)152   Status HandleOutfeed(HloInstructionPtr outfeed) override {
153     return DefaultAction(outfeed);
154   }
HandleReverse(HloInstructionPtr reverse)155   Status HandleReverse(HloInstructionPtr reverse) override {
156     return DefaultAction(reverse);
157   }
HandleSort(HloInstructionPtr sort)158   Status HandleSort(HloInstructionPtr sort) override {
159     return DefaultAction(sort);
160   }
HandleConstant(HloInstructionPtr constant)161   Status HandleConstant(HloInstructionPtr constant) override {
162     return DefaultAction(constant);
163   }
HandleIota(HloInstructionPtr iota)164   Status HandleIota(HloInstructionPtr iota) override {
165     return DefaultAction(iota);
166   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)167   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
168     return DefaultAction(get_tuple_element);
169   }
HandleParameter(HloInstructionPtr parameter)170   Status HandleParameter(HloInstructionPtr parameter) override {
171     return DefaultAction(parameter);
172   }
HandleFusion(HloInstructionPtr fusion)173   Status HandleFusion(HloInstructionPtr fusion) override {
174     return DefaultAction(fusion);
175   }
HandleCall(HloInstructionPtr call)176   Status HandleCall(HloInstructionPtr call) override {
177     return DefaultAction(call);
178   }
HandleCustomCall(HloInstructionPtr custom_call)179   Status HandleCustomCall(HloInstructionPtr custom_call) override {
180     return DefaultAction(custom_call);
181   }
HandleSlice(HloInstructionPtr slice)182   Status HandleSlice(HloInstructionPtr slice) override {
183     return DefaultAction(slice);
184   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)185   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
186     return DefaultAction(dynamic_slice);
187   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)188   Status HandleDynamicUpdateSlice(
189       HloInstructionPtr dynamic_update_slice) override {
190     return DefaultAction(dynamic_update_slice);
191   }
HandleTuple(HloInstructionPtr tuple)192   Status HandleTuple(HloInstructionPtr tuple) override {
193     return DefaultAction(tuple);
194   }
HandleMap(HloInstructionPtr map)195   Status HandleMap(HloInstructionPtr map) override {
196     return DefaultAction(map);
197   }
HandleReduce(HloInstructionPtr reduce)198   Status HandleReduce(HloInstructionPtr reduce) override {
199     return DefaultAction(reduce);
200   }
HandleReduceWindow(HloInstructionPtr reduce_window)201   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
202     return DefaultAction(reduce_window);
203   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)204   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
205     return DefaultAction(select_and_scatter);
206   }
HandleBitcast(HloInstructionPtr bitcast)207   Status HandleBitcast(HloInstructionPtr bitcast) override {
208     return DefaultAction(bitcast);
209   }
HandleBroadcast(HloInstructionPtr broadcast)210   Status HandleBroadcast(HloInstructionPtr broadcast) override {
211     return DefaultAction(broadcast);
212   }
HandlePad(HloInstructionPtr pad)213   Status HandlePad(HloInstructionPtr pad) override {
214     return DefaultAction(pad);
215   }
HandleDynamicReshape(HloInstructionPtr dynamic_reshape)216   Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override {
217     return DefaultAction(dynamic_reshape);
218   }
HandleReshape(HloInstructionPtr reshape)219   Status HandleReshape(HloInstructionPtr reshape) override {
220     return DefaultAction(reshape);
221   }
HandleTranspose(HloInstructionPtr transpose)222   Status HandleTranspose(HloInstructionPtr transpose) override {
223     return DefaultAction(transpose);
224   }
HandleWhile(HloInstructionPtr xla_while)225   Status HandleWhile(HloInstructionPtr xla_while) override {
226     return DefaultAction(xla_while);
227   }
HandleConditional(HloInstructionPtr conditional)228   Status HandleConditional(HloInstructionPtr conditional) override {
229     return DefaultAction(conditional);
230   }
HandleCopyStart(HloInstructionPtr copy_start)231   Status HandleCopyStart(HloInstructionPtr copy_start) override {
232     return DefaultAction(copy_start);
233   }
HandleCopyDone(HloInstructionPtr copy_done)234   Status HandleCopyDone(HloInstructionPtr copy_done) override {
235     return DefaultAction(copy_done);
236   }
HandleRecv(HloInstructionPtr recv)237   Status HandleRecv(HloInstructionPtr recv) override {
238     return DefaultAction(recv);
239   }
HandleRecvDone(HloInstructionPtr recv_done)240   Status HandleRecvDone(HloInstructionPtr recv_done) override {
241     return DefaultAction(recv_done);
242   }
HandleSend(HloInstructionPtr send)243   Status HandleSend(HloInstructionPtr send) override {
244     return DefaultAction(send);
245   }
HandleSendDone(HloInstructionPtr send_done)246   Status HandleSendDone(HloInstructionPtr send_done) override {
247     return DefaultAction(send_done);
248   }
HandleGather(HloInstructionPtr gather)249   Status HandleGather(HloInstructionPtr gather) override {
250     return DefaultAction(gather);
251   }
HandleScatter(HloInstructionPtr scatter)252   Status HandleScatter(HloInstructionPtr scatter) override {
253     return DefaultAction(scatter);
254   }
HandleAfterAll(HloInstructionPtr token)255   Status HandleAfterAll(HloInstructionPtr token) override {
256     return DefaultAction(token);
257   }
HandleGetDimensionSize(HloInstructionPtr get_size)258   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
259     return DefaultAction(get_size);
260   }
HandleSetDimensionSize(HloInstructionPtr get_size)261   Status HandleSetDimensionSize(HloInstructionPtr get_size) override {
262     return DefaultAction(get_size);
263   }
HandleAddDependency(HloInstructionPtr add_dependency)264   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
265     return DefaultAction(add_dependency);
266   }
267 
268   // Invoked to inform the visitor that the traversal has completed, and that
269   // the root was "root".
FinishVisit(HloInstructionPtr)270   Status FinishVisit(HloInstructionPtr /*root*/) override {
271     return Status::OK();
272   }
273 
274  private:
275   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
276 };
277 
278 // Users should use these type aliases which are only two valid instantiations.
279 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
280 using ConstDfsHloVisitorWithDefault =
281     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
282 
283 // A common base class for visitors performing rewriting operation.
284 //
285 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
286 // visiting.
287 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
288  public:
289   // Runs a visitor on the module and returns whether the module has changed.
RunOnModule(HloModule * module)290   StatusOr<bool> RunOnModule(HloModule* module) {
291     bool is_changed = false;
292     for (const auto& computation : module->computations()) {
293       TF_RETURN_IF_ERROR(computation->Accept(this));
294       is_changed |= changed();
295     }
296     return is_changed;
297   }
298 
299   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)300   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
301     return Status::OK();
302   }
303 
changed()304   bool changed() const { return changed_; }
305 
306  protected:
307   // Replaces the existing HLO instruction old_instruction, with
308   // new_instruction, and marks the optimizer status as changed.
309   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)310   Status ReplaceWithNewInstruction(
311       HloInstruction* old_instruction,
312       std::unique_ptr<HloInstruction> new_instruction) {
313     VLOG(3) << "Replacing instruction:";
314     VLOG(3) << "  old: " << old_instruction->ToString();
315     VLOG(3) << "  new: " << new_instruction->ToString();
316     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
317         old_instruction, std::move(new_instruction)));
318     changed_ = true;
319     return Status::OK();
320   }
321 
322   // Replaces the existing HLO instruction old_instruction, with
323   // new_instruction, and marks the optimizer status as changed.
324   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)325   Status ReplaceInstruction(HloInstruction* old_instruction,
326                             HloInstruction* new_instruction) {
327     VLOG(3) << "Replacing instruction:";
328     VLOG(3) << "  old: " << old_instruction->ToString();
329     VLOG(3) << "  new: " << new_instruction->ToString();
330     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction(
331         old_instruction, new_instruction));
332     changed_ = true;
333     return Status::OK();
334   }
335 
336   bool changed_ = false;
337 };
338 
339 // (Const)FunctionVisitor lets you transform an
340 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
341 //
342 // This is useful if you have code that needs to handle visitors in the form of
343 // both std::function and DfsHloVisitor.  You can wrap the function in a
344 // FunctionVisitor and then treat it like any other DfsHloVisitor.
345 template <typename HloInstructionPtr>
346 class FunctionVisitorBase
347     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
348  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)349   explicit FunctionVisitorBase(
350       std::function<Status(HloInstructionPtr)> visitor_func)
351       : visitor_func_(std::move(visitor_func)) {}
352 
DefaultAction(HloInstructionPtr hlo_instruction)353   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
354     return visitor_func_(hlo_instruction);
355   }
356 
357  private:
358   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
359 
360   std::function<Status(HloInstructionPtr)> visitor_func_;
361 };
362 
363 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
364 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
365 
366 }  // namespace xla
367 
368 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
369