1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 18 19 #include "absl/strings/string_view.h" 20 #include "absl/types/span.h" 21 #include "tensorflow/compiler/xla/literal.h" 22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // DfsHloVisitor with default action based on the HloInstruction being visited. 37 // Users should not use this class directly, but use the type aliases 38 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. 39 // 40 // Do *not* add an override to this class if the opcode is covered by 41 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to 42 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler 43 // here will break passes which rely on the HandleElementwiseUnary/Binary 44 // handling these opcodes. 45 template <typename HloInstructionPtr> 46 class DfsHloVisitorWithDefaultBase 47 : public DfsHloVisitorBase<HloInstructionPtr> { 48 public: DfsHloVisitorWithDefaultBase()49 DfsHloVisitorWithDefaultBase() {} ~DfsHloVisitorWithDefaultBase()50 ~DfsHloVisitorWithDefaultBase() override {} 51 52 // Default action performed on HloInstruction. 53 virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0; 54 HandleElementwiseUnary(HloInstructionPtr hlo)55 Status HandleElementwiseUnary(HloInstructionPtr hlo) override { 56 return DefaultAction(hlo); 57 } HandleElementwiseBinary(HloInstructionPtr hlo)58 Status HandleElementwiseBinary(HloInstructionPtr hlo) override { 59 return DefaultAction(hlo); 60 } 61 HandleBatchNormTraining(HloInstructionPtr hlo)62 Status HandleBatchNormTraining(HloInstructionPtr hlo) override { 63 return DefaultAction(hlo); 64 } 65 HandleBatchNormInference(HloInstructionPtr hlo)66 Status HandleBatchNormInference(HloInstructionPtr hlo) override { 67 return DefaultAction(hlo); 68 } 69 HandleBatchNormGrad(HloInstructionPtr hlo)70 Status HandleBatchNormGrad(HloInstructionPtr hlo) override { 71 return DefaultAction(hlo); 72 } 73 HandleClamp(HloInstructionPtr clamp)74 Status HandleClamp(HloInstructionPtr clamp) override { 75 return DefaultAction(clamp); 76 } HandleConcatenate(HloInstructionPtr concatenate)77 Status HandleConcatenate(HloInstructionPtr concatenate) override { 78 return DefaultAction(concatenate); 79 } HandleSelect(HloInstructionPtr select)80 Status HandleSelect(HloInstructionPtr select) override { 81 return DefaultAction(select); 82 } HandleTupleSelect(HloInstructionPtr tuple_select)83 Status HandleTupleSelect(HloInstructionPtr tuple_select) override { 84 return DefaultAction(tuple_select); 85 } HandleDot(HloInstructionPtr dot)86 Status HandleDot(HloInstructionPtr dot) override { 87 return DefaultAction(dot); 88 } HandleConvolution(HloInstructionPtr convolution)89 Status HandleConvolution(HloInstructionPtr convolution) override { 90 return DefaultAction(convolution); 91 } HandleFft(HloInstructionPtr fft)92 Status HandleFft(HloInstructionPtr fft) override { 93 return DefaultAction(fft); 94 } HandleTriangularSolve(HloInstructionPtr hlo)95 Status HandleTriangularSolve(HloInstructionPtr hlo) override { 96 return DefaultAction(hlo); 97 } HandleCholesky(HloInstructionPtr hlo)98 Status HandleCholesky(HloInstructionPtr hlo) override { 99 return DefaultAction(hlo); 100 } HandleAllGather(HloInstructionPtr crs)101 Status HandleAllGather(HloInstructionPtr crs) override { 102 return DefaultAction(crs); 103 } HandleAllGatherStart(HloInstructionPtr crs)104 Status HandleAllGatherStart(HloInstructionPtr crs) override { 105 return DefaultAction(crs); 106 } HandleAllGatherDone(HloInstructionPtr crs)107 Status HandleAllGatherDone(HloInstructionPtr crs) override { 108 return DefaultAction(crs); 109 } HandleAllReduce(HloInstructionPtr crs)110 Status HandleAllReduce(HloInstructionPtr crs) override { 111 return DefaultAction(crs); 112 } HandleReduceScatter(HloInstructionPtr hlo)113 Status HandleReduceScatter(HloInstructionPtr hlo) override { 114 return DefaultAction(hlo); 115 } HandleAllReduceStart(HloInstructionPtr hlo)116 Status HandleAllReduceStart(HloInstructionPtr hlo) override { 117 return DefaultAction(hlo); 118 } HandleAllReduceDone(HloInstructionPtr hlo)119 Status HandleAllReduceDone(HloInstructionPtr hlo) override { 120 return DefaultAction(hlo); 121 } HandleAllToAll(HloInstructionPtr hlo)122 Status HandleAllToAll(HloInstructionPtr hlo) override { 123 return DefaultAction(hlo); 124 } HandleCollectivePermute(HloInstructionPtr hlo)125 Status HandleCollectivePermute(HloInstructionPtr hlo) override { 126 return DefaultAction(hlo); 127 } HandleCollectivePermuteStart(HloInstructionPtr hlo)128 Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { 129 return DefaultAction(hlo); 130 } HandleCollectivePermuteDone(HloInstructionPtr hlo)131 Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { 132 return DefaultAction(hlo); 133 } HandleReplicaId(HloInstructionPtr hlo)134 Status HandleReplicaId(HloInstructionPtr hlo) override { 135 return DefaultAction(hlo); 136 } HandlePartitionId(HloInstructionPtr hlo)137 Status HandlePartitionId(HloInstructionPtr hlo) override { 138 return DefaultAction(hlo); 139 } HandleRng(HloInstructionPtr random)140 Status HandleRng(HloInstructionPtr random) override { 141 return DefaultAction(random); 142 } HandleRngBitGenerator(HloInstructionPtr random)143 Status HandleRngBitGenerator(HloInstructionPtr random) override { 144 return DefaultAction(random); 145 } HandleRngGetAndUpdateState(HloInstructionPtr random)146 Status HandleRngGetAndUpdateState(HloInstructionPtr random) override { 147 return DefaultAction(random); 148 } HandleInfeed(HloInstructionPtr infeed)149 Status HandleInfeed(HloInstructionPtr infeed) override { 150 return DefaultAction(infeed); 151 } HandleOutfeed(HloInstructionPtr outfeed)152 Status HandleOutfeed(HloInstructionPtr outfeed) override { 153 return DefaultAction(outfeed); 154 } HandleReverse(HloInstructionPtr reverse)155 Status HandleReverse(HloInstructionPtr reverse) override { 156 return DefaultAction(reverse); 157 } HandleSort(HloInstructionPtr sort)158 Status HandleSort(HloInstructionPtr sort) override { 159 return DefaultAction(sort); 160 } HandleConstant(HloInstructionPtr constant)161 Status HandleConstant(HloInstructionPtr constant) override { 162 return DefaultAction(constant); 163 } HandleIota(HloInstructionPtr iota)164 Status HandleIota(HloInstructionPtr iota) override { 165 return DefaultAction(iota); 166 } HandleGetTupleElement(HloInstructionPtr get_tuple_element)167 Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { 168 return DefaultAction(get_tuple_element); 169 } HandleParameter(HloInstructionPtr parameter)170 Status HandleParameter(HloInstructionPtr parameter) override { 171 return DefaultAction(parameter); 172 } HandleFusion(HloInstructionPtr fusion)173 Status HandleFusion(HloInstructionPtr fusion) override { 174 return DefaultAction(fusion); 175 } HandleCall(HloInstructionPtr call)176 Status HandleCall(HloInstructionPtr call) override { 177 return DefaultAction(call); 178 } HandleCustomCall(HloInstructionPtr custom_call)179 Status HandleCustomCall(HloInstructionPtr custom_call) override { 180 return DefaultAction(custom_call); 181 } HandleSlice(HloInstructionPtr slice)182 Status HandleSlice(HloInstructionPtr slice) override { 183 return DefaultAction(slice); 184 } HandleDynamicSlice(HloInstructionPtr dynamic_slice)185 Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override { 186 return DefaultAction(dynamic_slice); 187 } HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)188 Status HandleDynamicUpdateSlice( 189 HloInstructionPtr dynamic_update_slice) override { 190 return DefaultAction(dynamic_update_slice); 191 } HandleTuple(HloInstructionPtr tuple)192 Status HandleTuple(HloInstructionPtr tuple) override { 193 return DefaultAction(tuple); 194 } HandleMap(HloInstructionPtr map)195 Status HandleMap(HloInstructionPtr map) override { 196 return DefaultAction(map); 197 } HandleReduce(HloInstructionPtr reduce)198 Status HandleReduce(HloInstructionPtr reduce) override { 199 return DefaultAction(reduce); 200 } HandleReduceWindow(HloInstructionPtr reduce_window)201 Status HandleReduceWindow(HloInstructionPtr reduce_window) override { 202 return DefaultAction(reduce_window); 203 } HandleSelectAndScatter(HloInstructionPtr select_and_scatter)204 Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override { 205 return DefaultAction(select_and_scatter); 206 } HandleBitcast(HloInstructionPtr bitcast)207 Status HandleBitcast(HloInstructionPtr bitcast) override { 208 return DefaultAction(bitcast); 209 } HandleBroadcast(HloInstructionPtr broadcast)210 Status HandleBroadcast(HloInstructionPtr broadcast) override { 211 return DefaultAction(broadcast); 212 } HandlePad(HloInstructionPtr pad)213 Status HandlePad(HloInstructionPtr pad) override { 214 return DefaultAction(pad); 215 } HandleDynamicReshape(HloInstructionPtr dynamic_reshape)216 Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override { 217 return DefaultAction(dynamic_reshape); 218 } HandleReshape(HloInstructionPtr reshape)219 Status HandleReshape(HloInstructionPtr reshape) override { 220 return DefaultAction(reshape); 221 } HandleTranspose(HloInstructionPtr transpose)222 Status HandleTranspose(HloInstructionPtr transpose) override { 223 return DefaultAction(transpose); 224 } HandleWhile(HloInstructionPtr xla_while)225 Status HandleWhile(HloInstructionPtr xla_while) override { 226 return DefaultAction(xla_while); 227 } HandleConditional(HloInstructionPtr conditional)228 Status HandleConditional(HloInstructionPtr conditional) override { 229 return DefaultAction(conditional); 230 } HandleCopyStart(HloInstructionPtr copy_start)231 Status HandleCopyStart(HloInstructionPtr copy_start) override { 232 return DefaultAction(copy_start); 233 } HandleCopyDone(HloInstructionPtr copy_done)234 Status HandleCopyDone(HloInstructionPtr copy_done) override { 235 return DefaultAction(copy_done); 236 } HandleRecv(HloInstructionPtr recv)237 Status HandleRecv(HloInstructionPtr recv) override { 238 return DefaultAction(recv); 239 } HandleRecvDone(HloInstructionPtr recv_done)240 Status HandleRecvDone(HloInstructionPtr recv_done) override { 241 return DefaultAction(recv_done); 242 } HandleSend(HloInstructionPtr send)243 Status HandleSend(HloInstructionPtr send) override { 244 return DefaultAction(send); 245 } HandleSendDone(HloInstructionPtr send_done)246 Status HandleSendDone(HloInstructionPtr send_done) override { 247 return DefaultAction(send_done); 248 } HandleGather(HloInstructionPtr gather)249 Status HandleGather(HloInstructionPtr gather) override { 250 return DefaultAction(gather); 251 } HandleScatter(HloInstructionPtr scatter)252 Status HandleScatter(HloInstructionPtr scatter) override { 253 return DefaultAction(scatter); 254 } HandleAfterAll(HloInstructionPtr token)255 Status HandleAfterAll(HloInstructionPtr token) override { 256 return DefaultAction(token); 257 } HandleGetDimensionSize(HloInstructionPtr get_size)258 Status HandleGetDimensionSize(HloInstructionPtr get_size) override { 259 return DefaultAction(get_size); 260 } HandleSetDimensionSize(HloInstructionPtr get_size)261 Status HandleSetDimensionSize(HloInstructionPtr get_size) override { 262 return DefaultAction(get_size); 263 } HandleAddDependency(HloInstructionPtr add_dependency)264 Status HandleAddDependency(HloInstructionPtr add_dependency) override { 265 return DefaultAction(add_dependency); 266 } 267 268 // Invoked to inform the visitor that the traversal has completed, and that 269 // the root was "root". FinishVisit(HloInstructionPtr)270 Status FinishVisit(HloInstructionPtr /*root*/) override { 271 return Status::OK(); 272 } 273 274 private: 275 TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase); 276 }; 277 278 // Users should use these type aliases which are only two valid instantiations. 279 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>; 280 using ConstDfsHloVisitorWithDefault = 281 DfsHloVisitorWithDefaultBase<const HloInstruction*>; 282 283 // A common base class for visitors performing rewriting operation. 284 // 285 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while 286 // visiting. 287 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { 288 public: 289 // Runs a visitor on the module and returns whether the module has changed. RunOnModule(HloModule * module)290 StatusOr<bool> RunOnModule(HloModule* module) { 291 bool is_changed = false; 292 for (const auto& computation : module->computations()) { 293 TF_RETURN_IF_ERROR(computation->Accept(this)); 294 is_changed |= changed(); 295 } 296 return is_changed; 297 } 298 299 // Default visitor action is to do nothing and return OK. DefaultAction(HloInstruction *)300 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 301 return Status::OK(); 302 } 303 changed()304 bool changed() const { return changed_; } 305 306 protected: 307 // Replaces the existing HLO instruction old_instruction, with 308 // new_instruction, and marks the optimizer status as changed. 309 // Returns the Status representing the result of the replace operation. ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)310 Status ReplaceWithNewInstruction( 311 HloInstruction* old_instruction, 312 std::unique_ptr<HloInstruction> new_instruction) { 313 VLOG(3) << "Replacing instruction:"; 314 VLOG(3) << " old: " << old_instruction->ToString(); 315 VLOG(3) << " new: " << new_instruction->ToString(); 316 TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( 317 old_instruction, std::move(new_instruction))); 318 changed_ = true; 319 return Status::OK(); 320 } 321 322 // Replaces the existing HLO instruction old_instruction, with 323 // new_instruction, and marks the optimizer status as changed. 324 // Returns the Status representing the result of the replace operation. ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)325 Status ReplaceInstruction(HloInstruction* old_instruction, 326 HloInstruction* new_instruction) { 327 VLOG(3) << "Replacing instruction:"; 328 VLOG(3) << " old: " << old_instruction->ToString(); 329 VLOG(3) << " new: " << new_instruction->ToString(); 330 TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction( 331 old_instruction, new_instruction)); 332 changed_ = true; 333 return Status::OK(); 334 } 335 336 bool changed_ = false; 337 }; 338 339 // (Const)FunctionVisitor lets you transform an 340 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor. 341 // 342 // This is useful if you have code that needs to handle visitors in the form of 343 // both std::function and DfsHloVisitor. You can wrap the function in a 344 // FunctionVisitor and then treat it like any other DfsHloVisitor. 345 template <typename HloInstructionPtr> 346 class FunctionVisitorBase 347 : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> { 348 public: FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)349 explicit FunctionVisitorBase( 350 std::function<Status(HloInstructionPtr)> visitor_func) 351 : visitor_func_(std::move(visitor_func)) {} 352 DefaultAction(HloInstructionPtr hlo_instruction)353 Status DefaultAction(HloInstructionPtr hlo_instruction) override { 354 return visitor_func_(hlo_instruction); 355 } 356 357 private: 358 TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase); 359 360 std::function<Status(HloInstructionPtr)> visitor_func_; 361 }; 362 363 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>; 364 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>; 365 366 } // namespace xla 367 368 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 369