1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 18 19 #include "absl/strings/string_view.h" 20 #include "absl/types/span.h" 21 #include "tensorflow/compiler/xla/literal.h" 22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // DfsHloVisitor with default action based on the HloInstruction being visited. 37 // Users should not use this class directly, but use the type aliases 38 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. 39 // 40 // Do *not* add an override to this class if the opcode is covered by 41 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to 42 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler 43 // here will break passes which rely on the HandleElementwiseUnary/Binary 44 // handling these opcodes. 45 template <typename HloInstructionPtr> 46 class DfsHloVisitorWithDefaultBase 47 : public DfsHloVisitorBase<HloInstructionPtr> { 48 public: DfsHloVisitorWithDefaultBase()49 DfsHloVisitorWithDefaultBase() {} ~DfsHloVisitorWithDefaultBase()50 ~DfsHloVisitorWithDefaultBase() override {} 51 52 // Default action performed on HloInstruction. 53 virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0; 54 HandleElementwiseUnary(HloInstructionPtr hlo)55 Status HandleElementwiseUnary(HloInstructionPtr hlo) override { 56 return DefaultAction(hlo); 57 } HandleElementwiseBinary(HloInstructionPtr hlo)58 Status HandleElementwiseBinary(HloInstructionPtr hlo) override { 59 return DefaultAction(hlo); 60 } 61 HandleBatchNormTraining(HloInstructionPtr hlo)62 Status HandleBatchNormTraining(HloInstructionPtr hlo) override { 63 return DefaultAction(hlo); 64 } 65 HandleBatchNormInference(HloInstructionPtr hlo)66 Status HandleBatchNormInference(HloInstructionPtr hlo) override { 67 return DefaultAction(hlo); 68 } 69 HandleBatchNormGrad(HloInstructionPtr hlo)70 Status HandleBatchNormGrad(HloInstructionPtr hlo) override { 71 return DefaultAction(hlo); 72 } 73 HandleClamp(HloInstructionPtr clamp)74 Status HandleClamp(HloInstructionPtr clamp) override { 75 return DefaultAction(clamp); 76 } HandleConcatenate(HloInstructionPtr concatenate)77 Status HandleConcatenate(HloInstructionPtr concatenate) override { 78 return DefaultAction(concatenate); 79 } HandleSelect(HloInstructionPtr select)80 Status HandleSelect(HloInstructionPtr select) override { 81 return DefaultAction(select); 82 } HandleTupleSelect(HloInstructionPtr tuple_select)83 Status HandleTupleSelect(HloInstructionPtr tuple_select) override { 84 return DefaultAction(tuple_select); 85 } HandleDot(HloInstructionPtr dot)86 Status HandleDot(HloInstructionPtr dot) override { 87 return DefaultAction(dot); 88 } HandleConvolution(HloInstructionPtr convolution)89 Status HandleConvolution(HloInstructionPtr convolution) override { 90 return DefaultAction(convolution); 91 } HandleFft(HloInstructionPtr fft)92 Status HandleFft(HloInstructionPtr fft) override { 93 return DefaultAction(fft); 94 } HandleTriangularSolve(HloInstructionPtr hlo)95 Status HandleTriangularSolve(HloInstructionPtr hlo) override { 96 return DefaultAction(hlo); 97 } HandleCholesky(HloInstructionPtr hlo)98 Status HandleCholesky(HloInstructionPtr hlo) override { 99 return DefaultAction(hlo); 100 } HandleAllGather(HloInstructionPtr crs)101 Status HandleAllGather(HloInstructionPtr crs) override { 102 return DefaultAction(crs); 103 } HandleAllReduce(HloInstructionPtr crs)104 Status HandleAllReduce(HloInstructionPtr crs) override { 105 return DefaultAction(crs); 106 } HandleAllToAll(HloInstructionPtr hlo)107 Status HandleAllToAll(HloInstructionPtr hlo) override { 108 return DefaultAction(hlo); 109 } HandleCollectivePermute(HloInstructionPtr hlo)110 Status HandleCollectivePermute(HloInstructionPtr hlo) override { 111 return DefaultAction(hlo); 112 } HandleCollectivePermuteStart(HloInstructionPtr hlo)113 Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { 114 return DefaultAction(hlo); 115 } HandleCollectivePermuteDone(HloInstructionPtr hlo)116 Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { 117 return DefaultAction(hlo); 118 } HandleReplicaId(HloInstructionPtr hlo)119 Status HandleReplicaId(HloInstructionPtr hlo) override { 120 return DefaultAction(hlo); 121 } HandlePartitionId(HloInstructionPtr hlo)122 Status HandlePartitionId(HloInstructionPtr hlo) override { 123 return DefaultAction(hlo); 124 } HandleRng(HloInstructionPtr random)125 Status HandleRng(HloInstructionPtr random) override { 126 return DefaultAction(random); 127 } HandleRngBitGenerator(HloInstructionPtr random)128 Status HandleRngBitGenerator(HloInstructionPtr random) override { 129 return DefaultAction(random); 130 } HandleRngGetAndUpdateState(HloInstructionPtr random)131 Status HandleRngGetAndUpdateState(HloInstructionPtr random) override { 132 return DefaultAction(random); 133 } HandleInfeed(HloInstructionPtr infeed)134 Status HandleInfeed(HloInstructionPtr infeed) override { 135 return DefaultAction(infeed); 136 } HandleOutfeed(HloInstructionPtr outfeed)137 Status HandleOutfeed(HloInstructionPtr outfeed) override { 138 return DefaultAction(outfeed); 139 } HandleReverse(HloInstructionPtr reverse)140 Status HandleReverse(HloInstructionPtr reverse) override { 141 return DefaultAction(reverse); 142 } HandleSort(HloInstructionPtr sort)143 Status HandleSort(HloInstructionPtr sort) override { 144 return DefaultAction(sort); 145 } HandleConstant(HloInstructionPtr constant)146 Status HandleConstant(HloInstructionPtr constant) override { 147 return DefaultAction(constant); 148 } HandleIota(HloInstructionPtr iota)149 Status HandleIota(HloInstructionPtr iota) override { 150 return DefaultAction(iota); 151 } HandleGetTupleElement(HloInstructionPtr get_tuple_element)152 Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { 153 return DefaultAction(get_tuple_element); 154 } HandleParameter(HloInstructionPtr parameter)155 Status HandleParameter(HloInstructionPtr parameter) override { 156 return DefaultAction(parameter); 157 } HandleFusion(HloInstructionPtr fusion)158 Status HandleFusion(HloInstructionPtr fusion) override { 159 return DefaultAction(fusion); 160 } HandleCall(HloInstructionPtr call)161 Status HandleCall(HloInstructionPtr call) override { 162 return DefaultAction(call); 163 } HandleCustomCall(HloInstructionPtr custom_call)164 Status HandleCustomCall(HloInstructionPtr custom_call) override { 165 return DefaultAction(custom_call); 166 } HandleSlice(HloInstructionPtr slice)167 Status HandleSlice(HloInstructionPtr slice) override { 168 return DefaultAction(slice); 169 } HandleDynamicSlice(HloInstructionPtr dynamic_slice)170 Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override { 171 return DefaultAction(dynamic_slice); 172 } HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)173 Status HandleDynamicUpdateSlice( 174 HloInstructionPtr dynamic_update_slice) override { 175 return DefaultAction(dynamic_update_slice); 176 } HandleTuple(HloInstructionPtr tuple)177 Status HandleTuple(HloInstructionPtr tuple) override { 178 return DefaultAction(tuple); 179 } HandleMap(HloInstructionPtr map)180 Status HandleMap(HloInstructionPtr map) override { 181 return DefaultAction(map); 182 } HandleReduce(HloInstructionPtr reduce)183 Status HandleReduce(HloInstructionPtr reduce) override { 184 return DefaultAction(reduce); 185 } HandleReduceWindow(HloInstructionPtr reduce_window)186 Status HandleReduceWindow(HloInstructionPtr reduce_window) override { 187 return DefaultAction(reduce_window); 188 } HandleSelectAndScatter(HloInstructionPtr select_and_scatter)189 Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override { 190 return DefaultAction(select_and_scatter); 191 } HandleBitcast(HloInstructionPtr bitcast)192 Status HandleBitcast(HloInstructionPtr bitcast) override { 193 return DefaultAction(bitcast); 194 } HandleBroadcast(HloInstructionPtr broadcast)195 Status HandleBroadcast(HloInstructionPtr broadcast) override { 196 return DefaultAction(broadcast); 197 } HandlePad(HloInstructionPtr pad)198 Status HandlePad(HloInstructionPtr pad) override { 199 return DefaultAction(pad); 200 } HandleDynamicReshape(HloInstructionPtr dynamic_reshape)201 Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override { 202 return DefaultAction(dynamic_reshape); 203 } HandleReshape(HloInstructionPtr reshape)204 Status HandleReshape(HloInstructionPtr reshape) override { 205 return DefaultAction(reshape); 206 } HandleTranspose(HloInstructionPtr transpose)207 Status HandleTranspose(HloInstructionPtr transpose) override { 208 return DefaultAction(transpose); 209 } HandleWhile(HloInstructionPtr xla_while)210 Status HandleWhile(HloInstructionPtr xla_while) override { 211 return DefaultAction(xla_while); 212 } HandleConditional(HloInstructionPtr conditional)213 Status HandleConditional(HloInstructionPtr conditional) override { 214 return DefaultAction(conditional); 215 } HandleCopyStart(HloInstructionPtr copy_start)216 Status HandleCopyStart(HloInstructionPtr copy_start) override { 217 return DefaultAction(copy_start); 218 } HandleCopyDone(HloInstructionPtr copy_done)219 Status HandleCopyDone(HloInstructionPtr copy_done) override { 220 return DefaultAction(copy_done); 221 } HandleRecv(HloInstructionPtr recv)222 Status HandleRecv(HloInstructionPtr recv) override { 223 return DefaultAction(recv); 224 } HandleRecvDone(HloInstructionPtr recv_done)225 Status HandleRecvDone(HloInstructionPtr recv_done) override { 226 return DefaultAction(recv_done); 227 } HandleSend(HloInstructionPtr send)228 Status HandleSend(HloInstructionPtr send) override { 229 return DefaultAction(send); 230 } HandleSendDone(HloInstructionPtr send_done)231 Status HandleSendDone(HloInstructionPtr send_done) override { 232 return DefaultAction(send_done); 233 } HandleGather(HloInstructionPtr gather)234 Status HandleGather(HloInstructionPtr gather) override { 235 return DefaultAction(gather); 236 } HandleScatter(HloInstructionPtr scatter)237 Status HandleScatter(HloInstructionPtr scatter) override { 238 return DefaultAction(scatter); 239 } HandleAfterAll(HloInstructionPtr token)240 Status HandleAfterAll(HloInstructionPtr token) override { 241 return DefaultAction(token); 242 } HandleGetDimensionSize(HloInstructionPtr get_size)243 Status HandleGetDimensionSize(HloInstructionPtr get_size) override { 244 return DefaultAction(get_size); 245 } HandleSetDimensionSize(HloInstructionPtr get_size)246 Status HandleSetDimensionSize(HloInstructionPtr get_size) override { 247 return DefaultAction(get_size); 248 } HandleAddDependency(HloInstructionPtr add_dependency)249 Status HandleAddDependency(HloInstructionPtr add_dependency) override { 250 return DefaultAction(add_dependency); 251 } 252 253 // Invoked to inform the visitor that the traversal has completed, and that 254 // the root was "root". FinishVisit(HloInstructionPtr)255 Status FinishVisit(HloInstructionPtr /*root*/) override { 256 return Status::OK(); 257 } 258 259 private: 260 TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase); 261 }; 262 263 // Users should use these type aliases which are only two valid instantiations. 264 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>; 265 using ConstDfsHloVisitorWithDefault = 266 DfsHloVisitorWithDefaultBase<const HloInstruction*>; 267 268 // A common base class for visitors performing rewriting operation. 269 // 270 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while 271 // visiting. 272 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { 273 public: 274 // Runs a visitor on the module and returns whether the module has changed. RunOnModule(HloModule * module)275 StatusOr<bool> RunOnModule(HloModule* module) { 276 bool is_changed = false; 277 for (const auto& computation : module->computations()) { 278 TF_RETURN_IF_ERROR(computation->Accept(this)); 279 is_changed |= changed(); 280 } 281 return is_changed; 282 } 283 284 // Default visitor action is to do nothing and return OK. DefaultAction(HloInstruction *)285 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 286 return Status::OK(); 287 } 288 changed()289 bool changed() const { return changed_; } 290 291 protected: 292 // Replaces the existing HLO instruction old_instruction, with 293 // new_instruction, and marks the optimizer status as changed. 294 // Returns the Status representing the result of the replace operation. ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)295 Status ReplaceWithNewInstruction( 296 HloInstruction* old_instruction, 297 std::unique_ptr<HloInstruction> new_instruction) { 298 VLOG(3) << "Replacing instruction:"; 299 VLOG(3) << " old: " << old_instruction->ToString(); 300 VLOG(3) << " new: " << new_instruction->ToString(); 301 TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( 302 old_instruction, std::move(new_instruction))); 303 changed_ = true; 304 return Status::OK(); 305 } 306 307 // Replaces the existing HLO instruction old_instruction, with 308 // new_instruction, and marks the optimizer status as changed. 309 // Returns the Status representing the result of the replace operation. ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)310 Status ReplaceInstruction(HloInstruction* old_instruction, 311 HloInstruction* new_instruction) { 312 VLOG(3) << "Replacing instruction:"; 313 VLOG(3) << " old: " << old_instruction->ToString(); 314 VLOG(3) << " new: " << new_instruction->ToString(); 315 TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction( 316 old_instruction, new_instruction)); 317 changed_ = true; 318 return Status::OK(); 319 } 320 321 bool changed_ = false; 322 }; 323 324 // (Const)FunctionVisitor lets you transform an 325 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor. 326 // 327 // This is useful if you have code that needs to handle visitors in the form of 328 // both std::function and DfsHloVisitor. You can wrap the function in a 329 // FunctionVisitor and then treat it like any other DfsHloVisitor. 330 template <typename HloInstructionPtr> 331 class FunctionVisitorBase 332 : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> { 333 public: FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)334 explicit FunctionVisitorBase( 335 std::function<Status(HloInstructionPtr)> visitor_func) 336 : visitor_func_(std::move(visitor_func)) {} 337 DefaultAction(HloInstructionPtr hlo_instruction)338 Status DefaultAction(HloInstructionPtr hlo_instruction) override { 339 return visitor_func_(hlo_instruction); 340 } 341 342 private: 343 TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase); 344 345 std::function<Status(HloInstructionPtr)> visitor_func_; 346 }; 347 348 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>; 349 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>; 350 351 } // namespace xla 352 353 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 354