• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // DfsHloVisitor with default action based on the HloInstruction being visited.
37 // Users should not use this class directly, but use the type aliases
38 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
39 //
40 // Do *not* add an override to this class if the opcode is covered by
41 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
42 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
43 // here will break passes which rely on the HandleElementwiseUnary/Binary
44 // handling these opcodes.
45 template <typename HloInstructionPtr>
46 class DfsHloVisitorWithDefaultBase
47     : public DfsHloVisitorBase<HloInstructionPtr> {
48  public:
DfsHloVisitorWithDefaultBase()49   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()50   ~DfsHloVisitorWithDefaultBase() override {}
51 
52   // Default action performed on HloInstruction.
53   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
54 
HandleElementwiseUnary(HloInstructionPtr hlo)55   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
56     return DefaultAction(hlo);
57   }
HandleElementwiseBinary(HloInstructionPtr hlo)58   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
59     return DefaultAction(hlo);
60   }
61 
HandleBatchNormTraining(HloInstructionPtr hlo)62   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
63     return DefaultAction(hlo);
64   }
65 
HandleBatchNormInference(HloInstructionPtr hlo)66   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
67     return DefaultAction(hlo);
68   }
69 
HandleBatchNormGrad(HloInstructionPtr hlo)70   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
71     return DefaultAction(hlo);
72   }
73 
HandleClamp(HloInstructionPtr clamp)74   Status HandleClamp(HloInstructionPtr clamp) override {
75     return DefaultAction(clamp);
76   }
HandleConcatenate(HloInstructionPtr concatenate)77   Status HandleConcatenate(HloInstructionPtr concatenate) override {
78     return DefaultAction(concatenate);
79   }
HandleSelect(HloInstructionPtr select)80   Status HandleSelect(HloInstructionPtr select) override {
81     return DefaultAction(select);
82   }
HandleTupleSelect(HloInstructionPtr tuple_select)83   Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
84     return DefaultAction(tuple_select);
85   }
HandleDot(HloInstructionPtr dot)86   Status HandleDot(HloInstructionPtr dot) override {
87     return DefaultAction(dot);
88   }
HandleConvolution(HloInstructionPtr convolution)89   Status HandleConvolution(HloInstructionPtr convolution) override {
90     return DefaultAction(convolution);
91   }
HandleFft(HloInstructionPtr fft)92   Status HandleFft(HloInstructionPtr fft) override {
93     return DefaultAction(fft);
94   }
HandleTriangularSolve(HloInstructionPtr hlo)95   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
96     return DefaultAction(hlo);
97   }
HandleCholesky(HloInstructionPtr hlo)98   Status HandleCholesky(HloInstructionPtr hlo) override {
99     return DefaultAction(hlo);
100   }
HandleAllGather(HloInstructionPtr crs)101   Status HandleAllGather(HloInstructionPtr crs) override {
102     return DefaultAction(crs);
103   }
HandleAllReduce(HloInstructionPtr crs)104   Status HandleAllReduce(HloInstructionPtr crs) override {
105     return DefaultAction(crs);
106   }
HandleAllToAll(HloInstructionPtr hlo)107   Status HandleAllToAll(HloInstructionPtr hlo) override {
108     return DefaultAction(hlo);
109   }
HandleCollectivePermute(HloInstructionPtr hlo)110   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
111     return DefaultAction(hlo);
112   }
HandleCollectivePermuteStart(HloInstructionPtr hlo)113   Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
114     return DefaultAction(hlo);
115   }
HandleCollectivePermuteDone(HloInstructionPtr hlo)116   Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
117     return DefaultAction(hlo);
118   }
HandleReplicaId(HloInstructionPtr hlo)119   Status HandleReplicaId(HloInstructionPtr hlo) override {
120     return DefaultAction(hlo);
121   }
HandlePartitionId(HloInstructionPtr hlo)122   Status HandlePartitionId(HloInstructionPtr hlo) override {
123     return DefaultAction(hlo);
124   }
HandleRng(HloInstructionPtr random)125   Status HandleRng(HloInstructionPtr random) override {
126     return DefaultAction(random);
127   }
HandleRngBitGenerator(HloInstructionPtr random)128   Status HandleRngBitGenerator(HloInstructionPtr random) override {
129     return DefaultAction(random);
130   }
HandleRngGetAndUpdateState(HloInstructionPtr random)131   Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
132     return DefaultAction(random);
133   }
HandleInfeed(HloInstructionPtr infeed)134   Status HandleInfeed(HloInstructionPtr infeed) override {
135     return DefaultAction(infeed);
136   }
HandleOutfeed(HloInstructionPtr outfeed)137   Status HandleOutfeed(HloInstructionPtr outfeed) override {
138     return DefaultAction(outfeed);
139   }
HandleReverse(HloInstructionPtr reverse)140   Status HandleReverse(HloInstructionPtr reverse) override {
141     return DefaultAction(reverse);
142   }
HandleSort(HloInstructionPtr sort)143   Status HandleSort(HloInstructionPtr sort) override {
144     return DefaultAction(sort);
145   }
HandleConstant(HloInstructionPtr constant)146   Status HandleConstant(HloInstructionPtr constant) override {
147     return DefaultAction(constant);
148   }
HandleIota(HloInstructionPtr iota)149   Status HandleIota(HloInstructionPtr iota) override {
150     return DefaultAction(iota);
151   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)152   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
153     return DefaultAction(get_tuple_element);
154   }
HandleParameter(HloInstructionPtr parameter)155   Status HandleParameter(HloInstructionPtr parameter) override {
156     return DefaultAction(parameter);
157   }
HandleFusion(HloInstructionPtr fusion)158   Status HandleFusion(HloInstructionPtr fusion) override {
159     return DefaultAction(fusion);
160   }
HandleCall(HloInstructionPtr call)161   Status HandleCall(HloInstructionPtr call) override {
162     return DefaultAction(call);
163   }
HandleCustomCall(HloInstructionPtr custom_call)164   Status HandleCustomCall(HloInstructionPtr custom_call) override {
165     return DefaultAction(custom_call);
166   }
HandleSlice(HloInstructionPtr slice)167   Status HandleSlice(HloInstructionPtr slice) override {
168     return DefaultAction(slice);
169   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)170   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
171     return DefaultAction(dynamic_slice);
172   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)173   Status HandleDynamicUpdateSlice(
174       HloInstructionPtr dynamic_update_slice) override {
175     return DefaultAction(dynamic_update_slice);
176   }
HandleTuple(HloInstructionPtr tuple)177   Status HandleTuple(HloInstructionPtr tuple) override {
178     return DefaultAction(tuple);
179   }
HandleMap(HloInstructionPtr map)180   Status HandleMap(HloInstructionPtr map) override {
181     return DefaultAction(map);
182   }
HandleReduce(HloInstructionPtr reduce)183   Status HandleReduce(HloInstructionPtr reduce) override {
184     return DefaultAction(reduce);
185   }
HandleReduceWindow(HloInstructionPtr reduce_window)186   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
187     return DefaultAction(reduce_window);
188   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)189   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
190     return DefaultAction(select_and_scatter);
191   }
HandleBitcast(HloInstructionPtr bitcast)192   Status HandleBitcast(HloInstructionPtr bitcast) override {
193     return DefaultAction(bitcast);
194   }
HandleBroadcast(HloInstructionPtr broadcast)195   Status HandleBroadcast(HloInstructionPtr broadcast) override {
196     return DefaultAction(broadcast);
197   }
HandlePad(HloInstructionPtr pad)198   Status HandlePad(HloInstructionPtr pad) override {
199     return DefaultAction(pad);
200   }
HandleDynamicReshape(HloInstructionPtr dynamic_reshape)201   Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override {
202     return DefaultAction(dynamic_reshape);
203   }
HandleReshape(HloInstructionPtr reshape)204   Status HandleReshape(HloInstructionPtr reshape) override {
205     return DefaultAction(reshape);
206   }
HandleTranspose(HloInstructionPtr transpose)207   Status HandleTranspose(HloInstructionPtr transpose) override {
208     return DefaultAction(transpose);
209   }
HandleWhile(HloInstructionPtr xla_while)210   Status HandleWhile(HloInstructionPtr xla_while) override {
211     return DefaultAction(xla_while);
212   }
HandleConditional(HloInstructionPtr conditional)213   Status HandleConditional(HloInstructionPtr conditional) override {
214     return DefaultAction(conditional);
215   }
HandleCopyStart(HloInstructionPtr copy_start)216   Status HandleCopyStart(HloInstructionPtr copy_start) override {
217     return DefaultAction(copy_start);
218   }
HandleCopyDone(HloInstructionPtr copy_done)219   Status HandleCopyDone(HloInstructionPtr copy_done) override {
220     return DefaultAction(copy_done);
221   }
HandleRecv(HloInstructionPtr recv)222   Status HandleRecv(HloInstructionPtr recv) override {
223     return DefaultAction(recv);
224   }
HandleRecvDone(HloInstructionPtr recv_done)225   Status HandleRecvDone(HloInstructionPtr recv_done) override {
226     return DefaultAction(recv_done);
227   }
HandleSend(HloInstructionPtr send)228   Status HandleSend(HloInstructionPtr send) override {
229     return DefaultAction(send);
230   }
HandleSendDone(HloInstructionPtr send_done)231   Status HandleSendDone(HloInstructionPtr send_done) override {
232     return DefaultAction(send_done);
233   }
HandleGather(HloInstructionPtr gather)234   Status HandleGather(HloInstructionPtr gather) override {
235     return DefaultAction(gather);
236   }
HandleScatter(HloInstructionPtr scatter)237   Status HandleScatter(HloInstructionPtr scatter) override {
238     return DefaultAction(scatter);
239   }
HandleAfterAll(HloInstructionPtr token)240   Status HandleAfterAll(HloInstructionPtr token) override {
241     return DefaultAction(token);
242   }
HandleGetDimensionSize(HloInstructionPtr get_size)243   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
244     return DefaultAction(get_size);
245   }
HandleSetDimensionSize(HloInstructionPtr get_size)246   Status HandleSetDimensionSize(HloInstructionPtr get_size) override {
247     return DefaultAction(get_size);
248   }
HandleAddDependency(HloInstructionPtr add_dependency)249   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
250     return DefaultAction(add_dependency);
251   }
252 
253   // Invoked to inform the visitor that the traversal has completed, and that
254   // the root was "root".
FinishVisit(HloInstructionPtr)255   Status FinishVisit(HloInstructionPtr /*root*/) override {
256     return Status::OK();
257   }
258 
259  private:
260   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
261 };
262 
263 // Users should use these type aliases which are only two valid instantiations.
264 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
265 using ConstDfsHloVisitorWithDefault =
266     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
267 
268 // A common base class for visitors performing rewriting operation.
269 //
270 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
271 // visiting.
272 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
273  public:
274   // Runs a visitor on the module and returns whether the module has changed.
RunOnModule(HloModule * module)275   StatusOr<bool> RunOnModule(HloModule* module) {
276     bool is_changed = false;
277     for (const auto& computation : module->computations()) {
278       TF_RETURN_IF_ERROR(computation->Accept(this));
279       is_changed |= changed();
280     }
281     return is_changed;
282   }
283 
284   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)285   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
286     return Status::OK();
287   }
288 
changed()289   bool changed() const { return changed_; }
290 
291  protected:
292   // Replaces the existing HLO instruction old_instruction, with
293   // new_instruction, and marks the optimizer status as changed.
294   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)295   Status ReplaceWithNewInstruction(
296       HloInstruction* old_instruction,
297       std::unique_ptr<HloInstruction> new_instruction) {
298     VLOG(3) << "Replacing instruction:";
299     VLOG(3) << "  old: " << old_instruction->ToString();
300     VLOG(3) << "  new: " << new_instruction->ToString();
301     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
302         old_instruction, std::move(new_instruction)));
303     changed_ = true;
304     return Status::OK();
305   }
306 
307   // Replaces the existing HLO instruction old_instruction, with
308   // new_instruction, and marks the optimizer status as changed.
309   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)310   Status ReplaceInstruction(HloInstruction* old_instruction,
311                             HloInstruction* new_instruction) {
312     VLOG(3) << "Replacing instruction:";
313     VLOG(3) << "  old: " << old_instruction->ToString();
314     VLOG(3) << "  new: " << new_instruction->ToString();
315     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction(
316         old_instruction, new_instruction));
317     changed_ = true;
318     return Status::OK();
319   }
320 
321   bool changed_ = false;
322 };
323 
324 // (Const)FunctionVisitor lets you transform an
325 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
326 //
327 // This is useful if you have code that needs to handle visitors in the form of
328 // both std::function and DfsHloVisitor.  You can wrap the function in a
329 // FunctionVisitor and then treat it like any other DfsHloVisitor.
330 template <typename HloInstructionPtr>
331 class FunctionVisitorBase
332     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
333  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)334   explicit FunctionVisitorBase(
335       std::function<Status(HloInstructionPtr)> visitor_func)
336       : visitor_func_(std::move(visitor_func)) {}
337 
DefaultAction(HloInstructionPtr hlo_instruction)338   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
339     return visitor_func_(hlo_instruction);
340   }
341 
342  private:
343   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
344 
345   std::function<Status(HloInstructionPtr)> visitor_func_;
346 };
347 
348 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
349 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
350 
351 }  // namespace xla
352 
353 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
354