• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // DfsHloVisitor with default action based on the HloInstruction being visited.
37 // Users should not use this class directly, but use the type aliases
38 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
39 //
40 // Do *not* add an override to this class if the opcode is covered by
41 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
42 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
43 // here will break passes which rely on the HandleElementwiseUnary/Binary
44 // handling these opcodes.
45 template <typename HloInstructionPtr>
46 class DfsHloVisitorWithDefaultBase
47     : public DfsHloVisitorBase<HloInstructionPtr> {
48  public:
DfsHloVisitorWithDefaultBase()49   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()50   ~DfsHloVisitorWithDefaultBase() override {}
51 
52   // Default action performed on HloInstruction.
53   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
54 
HandleElementwiseUnary(HloInstructionPtr hlo)55   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
56     return DefaultAction(hlo);
57   }
HandleElementwiseBinary(HloInstructionPtr hlo)58   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
59     return DefaultAction(hlo);
60   }
61 
HandleBatchNormTraining(HloInstructionPtr hlo)62   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
63     return DefaultAction(hlo);
64   }
65 
HandleBatchNormInference(HloInstructionPtr hlo)66   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
67     return DefaultAction(hlo);
68   }
69 
HandleBatchNormGrad(HloInstructionPtr hlo)70   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
71     return DefaultAction(hlo);
72   }
73 
HandleClamp(HloInstructionPtr clamp)74   Status HandleClamp(HloInstructionPtr clamp) override {
75     return DefaultAction(clamp);
76   }
HandleConcatenate(HloInstructionPtr concatenate)77   Status HandleConcatenate(HloInstructionPtr concatenate) override {
78     return DefaultAction(concatenate);
79   }
HandleSelect(HloInstructionPtr select)80   Status HandleSelect(HloInstructionPtr select) override {
81     return DefaultAction(select);
82   }
HandleTupleSelect(HloInstructionPtr tuple_select)83   Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
84     return DefaultAction(tuple_select);
85   }
HandleDot(HloInstructionPtr dot)86   Status HandleDot(HloInstructionPtr dot) override {
87     return DefaultAction(dot);
88   }
HandleConvolution(HloInstructionPtr convolution)89   Status HandleConvolution(HloInstructionPtr convolution) override {
90     return DefaultAction(convolution);
91   }
HandleFft(HloInstructionPtr fft)92   Status HandleFft(HloInstructionPtr fft) override {
93     return DefaultAction(fft);
94   }
HandleTriangularSolve(HloInstructionPtr hlo)95   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
96     return DefaultAction(hlo);
97   }
HandleCholesky(HloInstructionPtr hlo)98   Status HandleCholesky(HloInstructionPtr hlo) override {
99     return DefaultAction(hlo);
100   }
HandleAllReduce(HloInstructionPtr crs)101   Status HandleAllReduce(HloInstructionPtr crs) override {
102     return DefaultAction(crs);
103   }
HandleAllToAll(HloInstructionPtr hlo)104   Status HandleAllToAll(HloInstructionPtr hlo) override {
105     return DefaultAction(hlo);
106   }
HandleCollectivePermute(HloInstructionPtr hlo)107   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
108     return DefaultAction(hlo);
109   }
HandleReplicaId(HloInstructionPtr hlo)110   Status HandleReplicaId(HloInstructionPtr hlo) override {
111     return DefaultAction(hlo);
112   }
HandlePartitionId(HloInstructionPtr hlo)113   Status HandlePartitionId(HloInstructionPtr hlo) override {
114     return DefaultAction(hlo);
115   }
HandleRng(HloInstructionPtr random)116   Status HandleRng(HloInstructionPtr random) override {
117     return DefaultAction(random);
118   }
HandleRngGetAndUpdateState(HloInstructionPtr random)119   Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
120     return DefaultAction(random);
121   }
HandleInfeed(HloInstructionPtr infeed)122   Status HandleInfeed(HloInstructionPtr infeed) override {
123     return DefaultAction(infeed);
124   }
HandleOutfeed(HloInstructionPtr outfeed)125   Status HandleOutfeed(HloInstructionPtr outfeed) override {
126     return DefaultAction(outfeed);
127   }
HandleReverse(HloInstructionPtr reverse)128   Status HandleReverse(HloInstructionPtr reverse) override {
129     return DefaultAction(reverse);
130   }
HandleSort(HloInstructionPtr sort)131   Status HandleSort(HloInstructionPtr sort) override {
132     return DefaultAction(sort);
133   }
HandleConstant(HloInstructionPtr constant)134   Status HandleConstant(HloInstructionPtr constant) override {
135     return DefaultAction(constant);
136   }
HandleIota(HloInstructionPtr iota)137   Status HandleIota(HloInstructionPtr iota) override {
138     return DefaultAction(iota);
139   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)140   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
141     return DefaultAction(get_tuple_element);
142   }
HandleParameter(HloInstructionPtr parameter)143   Status HandleParameter(HloInstructionPtr parameter) override {
144     return DefaultAction(parameter);
145   }
HandleFusion(HloInstructionPtr fusion)146   Status HandleFusion(HloInstructionPtr fusion) override {
147     return DefaultAction(fusion);
148   }
HandleCall(HloInstructionPtr call)149   Status HandleCall(HloInstructionPtr call) override {
150     return DefaultAction(call);
151   }
HandleCustomCall(HloInstructionPtr custom_call)152   Status HandleCustomCall(HloInstructionPtr custom_call) override {
153     return DefaultAction(custom_call);
154   }
HandleSlice(HloInstructionPtr slice)155   Status HandleSlice(HloInstructionPtr slice) override {
156     return DefaultAction(slice);
157   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)158   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
159     return DefaultAction(dynamic_slice);
160   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)161   Status HandleDynamicUpdateSlice(
162       HloInstructionPtr dynamic_update_slice) override {
163     return DefaultAction(dynamic_update_slice);
164   }
HandleTuple(HloInstructionPtr tuple)165   Status HandleTuple(HloInstructionPtr tuple) override {
166     return DefaultAction(tuple);
167   }
HandleMap(HloInstructionPtr map)168   Status HandleMap(HloInstructionPtr map) override {
169     return DefaultAction(map);
170   }
HandleReduce(HloInstructionPtr reduce)171   Status HandleReduce(HloInstructionPtr reduce) override {
172     return DefaultAction(reduce);
173   }
HandleReduceWindow(HloInstructionPtr reduce_window)174   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
175     return DefaultAction(reduce_window);
176   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)177   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
178     return DefaultAction(select_and_scatter);
179   }
HandleBitcast(HloInstructionPtr bitcast)180   Status HandleBitcast(HloInstructionPtr bitcast) override {
181     return DefaultAction(bitcast);
182   }
HandleBroadcast(HloInstructionPtr broadcast)183   Status HandleBroadcast(HloInstructionPtr broadcast) override {
184     return DefaultAction(broadcast);
185   }
HandlePad(HloInstructionPtr pad)186   Status HandlePad(HloInstructionPtr pad) override {
187     return DefaultAction(pad);
188   }
HandleReshape(HloInstructionPtr reshape)189   Status HandleReshape(HloInstructionPtr reshape) override {
190     return DefaultAction(reshape);
191   }
HandleTranspose(HloInstructionPtr transpose)192   Status HandleTranspose(HloInstructionPtr transpose) override {
193     return DefaultAction(transpose);
194   }
HandleWhile(HloInstructionPtr xla_while)195   Status HandleWhile(HloInstructionPtr xla_while) override {
196     return DefaultAction(xla_while);
197   }
HandleConditional(HloInstructionPtr conditional)198   Status HandleConditional(HloInstructionPtr conditional) override {
199     return DefaultAction(conditional);
200   }
HandleCopyStart(HloInstructionPtr copy_start)201   Status HandleCopyStart(HloInstructionPtr copy_start) override {
202     return DefaultAction(copy_start);
203   }
HandleCopyDone(HloInstructionPtr copy_done)204   Status HandleCopyDone(HloInstructionPtr copy_done) override {
205     return DefaultAction(copy_done);
206   }
HandleRecv(HloInstructionPtr recv)207   Status HandleRecv(HloInstructionPtr recv) override {
208     return DefaultAction(recv);
209   }
HandleRecvDone(HloInstructionPtr recv_done)210   Status HandleRecvDone(HloInstructionPtr recv_done) override {
211     return DefaultAction(recv_done);
212   }
HandleSend(HloInstructionPtr send)213   Status HandleSend(HloInstructionPtr send) override {
214     return DefaultAction(send);
215   }
HandleSendDone(HloInstructionPtr send_done)216   Status HandleSendDone(HloInstructionPtr send_done) override {
217     return DefaultAction(send_done);
218   }
HandleGather(HloInstructionPtr gather)219   Status HandleGather(HloInstructionPtr gather) override {
220     return DefaultAction(gather);
221   }
HandleScatter(HloInstructionPtr scatter)222   Status HandleScatter(HloInstructionPtr scatter) override {
223     return DefaultAction(scatter);
224   }
HandleAfterAll(HloInstructionPtr token)225   Status HandleAfterAll(HloInstructionPtr token) override {
226     return DefaultAction(token);
227   }
HandleGetDimensionSize(HloInstructionPtr get_size)228   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
229     return DefaultAction(get_size);
230   }
HandleSetDimensionSize(HloInstructionPtr get_size)231   Status HandleSetDimensionSize(HloInstructionPtr get_size) override {
232     return DefaultAction(get_size);
233   }
HandleAddDependency(HloInstructionPtr add_dependency)234   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
235     return DefaultAction(add_dependency);
236   }
237 
238   // Invoked to inform the visitor that the traversal has completed, and that
239   // the root was "root".
FinishVisit(HloInstructionPtr)240   Status FinishVisit(HloInstructionPtr /*root*/) override {
241     return Status::OK();
242   }
243 
244  private:
245   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
246 };
247 
248 // Users should use these type aliases which are only two valid instantiations.
249 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
250 using ConstDfsHloVisitorWithDefault =
251     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
252 
253 // A common base class for visitors performing rewriting operation.
254 //
255 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
256 // visiting.
257 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
258  public:
259   // Runs a visitor on the module and returns whether the module has changed.
RunOnModule(HloModule * module)260   StatusOr<bool> RunOnModule(HloModule* module) {
261     bool is_changed = false;
262     for (const auto& computation : module->computations()) {
263       TF_RETURN_IF_ERROR(computation->Accept(this));
264       is_changed |= changed();
265     }
266     return is_changed;
267   }
268 
269   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)270   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
271     return Status::OK();
272   }
273 
changed()274   bool changed() const { return changed_; }
275 
276  protected:
277   // Replaces the existing HLO instruction old_instruction, with
278   // new_instruction, and marks the optimizer status as changed.
279   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)280   Status ReplaceWithNewInstruction(
281       HloInstruction* old_instruction,
282       std::unique_ptr<HloInstruction> new_instruction) {
283     VLOG(3) << "Replacing instruction:";
284     VLOG(3) << "  old: " << old_instruction->ToString();
285     VLOG(3) << "  new: " << new_instruction->ToString();
286     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
287         old_instruction, std::move(new_instruction)));
288     changed_ = true;
289     return Status::OK();
290   }
291 
292   // Replaces the existing HLO instruction old_instruction, with
293   // new_instruction, and marks the optimizer status as changed.
294   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)295   Status ReplaceInstruction(HloInstruction* old_instruction,
296                             HloInstruction* new_instruction) {
297     VLOG(3) << "Replacing instruction:";
298     VLOG(3) << "  old: " << old_instruction->ToString();
299     VLOG(3) << "  new: " << new_instruction->ToString();
300     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction(
301         old_instruction, new_instruction));
302     changed_ = true;
303     return Status::OK();
304   }
305 
306   bool changed_ = false;
307 };
308 
309 // (Const)FunctionVisitor lets you transform an
310 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
311 //
312 // This is useful if you have code that needs to handle visitors in the form of
313 // both std::function and DfsHloVisitor.  You can wrap the function in a
314 // FunctionVisitor and then treat it like any other DfsHloVisitor.
315 template <typename HloInstructionPtr>
316 class FunctionVisitorBase
317     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
318  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)319   explicit FunctionVisitorBase(
320       std::function<Status(HloInstructionPtr)> visitor_func)
321       : visitor_func_(std::move(visitor_func)) {}
322 
DefaultAction(HloInstructionPtr hlo_instruction)323   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
324     return visitor_func_(hlo_instruction);
325   }
326 
327  private:
328   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
329 
330   std::function<Status(HloInstructionPtr)> visitor_func_;
331 };
332 
333 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
334 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
335 
336 }  // namespace xla
337 
338 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
339