1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 18 19 #include <type_traits> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/status.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 class HloComputation; 37 class HloInstruction; 38 39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an 40 // instruction, all its operands were already visited. User code can subclass 41 // this to iterate over an HloInstruction DAG. The Handle* routines have 42 // operands / data unpacked for ease of use in the visitor subclass. 43 // 44 // No instruction will ever be visited twice; however, the root instruction will 45 // be reported again when the traversal is done via a call to FinishVisit. 46 // 47 // A subclass must override at least 48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and 49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)). 50 // The default Handle methods for (unary, binary) ops call 51 // (HandleElementwiseUnary, HandleElementwiseBinary). 52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an 53 // "unimplemented" error status. 54 // 55 // Note: this may change to an iterator in the future for flexibility purposes. 56 // 57 // Users should not use this class directly, but use the type-aliases 58 // DfsHloVisitor/ConstDfsHloVisitor instead. 59 template <typename HloInstructionPtr> 60 class DfsHloVisitorBase { 61 static_assert( 62 std::is_same<HloInstruction*, HloInstructionPtr>::value || 63 std::is_same<const HloInstruction*, HloInstructionPtr>::value, 64 "Template argument expected to be HloInstruction* or const " 65 "HloInstruction*"); 66 67 public: DfsHloVisitorBase()68 DfsHloVisitorBase() {} ~DfsHloVisitorBase()69 virtual ~DfsHloVisitorBase() {} 70 71 // These routines are self-descriptive, see class comment for usage 72 // information. 73 74 virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); 75 virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); 76 77 virtual Status HandleClamp(HloInstructionPtr hlo) = 0; 78 virtual Status HandleSelect(HloInstructionPtr hlo) = 0; 79 virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0; HandleMaximum(HloInstructionPtr hlo)80 virtual Status HandleMaximum(HloInstructionPtr hlo) { 81 return HandleElementwiseBinary(hlo); 82 } HandleMinimum(HloInstructionPtr hlo)83 virtual Status HandleMinimum(HloInstructionPtr hlo) { 84 return HandleElementwiseBinary(hlo); 85 } 86 virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; HandleConvert(HloInstructionPtr hlo)87 virtual Status HandleConvert(HloInstructionPtr hlo) { 88 return HandleElementwiseUnary(hlo); 89 } HandleBitcastConvert(HloInstructionPtr hlo)90 virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { 91 return HandleElementwiseUnary(hlo); 92 } HandleCopy(HloInstructionPtr hlo)93 virtual Status HandleCopy(HloInstructionPtr hlo) { 94 return HandleElementwiseUnary(hlo); 95 } HandleComplex(HloInstructionPtr hlo)96 virtual Status HandleComplex(HloInstructionPtr hlo) { 97 return HandleElementwiseBinary(hlo); 98 } HandleMultiply(HloInstructionPtr hlo)99 virtual Status HandleMultiply(HloInstructionPtr hlo) { 100 return HandleElementwiseBinary(hlo); 101 } 102 virtual Status HandleDot(HloInstructionPtr hlo) = 0; HandlePower(HloInstructionPtr hlo)103 virtual Status HandlePower(HloInstructionPtr hlo) { 104 return HandleElementwiseBinary(hlo); 105 } HandleSqrt(HloInstructionPtr hlo)106 virtual Status HandleSqrt(HloInstructionPtr hlo) { 107 return HandleElementwiseUnary(hlo); 108 } HandleRsqrt(HloInstructionPtr hlo)109 virtual Status HandleRsqrt(HloInstructionPtr hlo) { 110 return HandleElementwiseUnary(hlo); 111 } 112 virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; 113 virtual Status HandleFft(HloInstructionPtr fft) = 0; 114 virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; 115 virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; 116 virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; 117 virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; 118 virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; 119 virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; 120 virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; HandleCompare(HloInstructionPtr hlo)121 virtual Status HandleCompare(HloInstructionPtr hlo) { 122 return HandleElementwiseBinary(hlo); 123 } HandleAdd(HloInstructionPtr hlo)124 virtual Status HandleAdd(HloInstructionPtr hlo) { 125 return HandleElementwiseBinary(hlo); 126 } HandleDivide(HloInstructionPtr hlo)127 virtual Status HandleDivide(HloInstructionPtr hlo) { 128 return HandleElementwiseBinary(hlo); 129 } HandleRemainder(HloInstructionPtr hlo)130 virtual Status HandleRemainder(HloInstructionPtr hlo) { 131 return HandleElementwiseBinary(hlo); 132 } HandleSubtract(HloInstructionPtr hlo)133 virtual Status HandleSubtract(HloInstructionPtr hlo) { 134 return HandleElementwiseBinary(hlo); 135 } HandleAbs(HloInstructionPtr hlo)136 virtual Status HandleAbs(HloInstructionPtr hlo) { 137 return HandleElementwiseUnary(hlo); 138 } HandleAtan2(HloInstructionPtr hlo)139 virtual Status HandleAtan2(HloInstructionPtr hlo) { 140 return HandleElementwiseBinary(hlo); 141 } HandleRound(HloInstructionPtr hlo)142 virtual Status HandleRound(HloInstructionPtr hlo) { 143 return HandleElementwiseUnary(hlo); 144 } HandleSign(HloInstructionPtr hlo)145 virtual Status HandleSign(HloInstructionPtr hlo) { 146 return HandleElementwiseUnary(hlo); 147 } HandleNegate(HloInstructionPtr hlo)148 virtual Status HandleNegate(HloInstructionPtr hlo) { 149 return HandleElementwiseUnary(hlo); 150 } HandleExp(HloInstructionPtr hlo)151 virtual Status HandleExp(HloInstructionPtr hlo) { 152 return HandleElementwiseUnary(hlo); 153 } HandleExpm1(HloInstructionPtr hlo)154 virtual Status HandleExpm1(HloInstructionPtr hlo) { 155 return HandleElementwiseUnary(hlo); 156 } HandleFloor(HloInstructionPtr hlo)157 virtual Status HandleFloor(HloInstructionPtr hlo) { 158 return HandleElementwiseUnary(hlo); 159 } HandleCeil(HloInstructionPtr hlo)160 virtual Status HandleCeil(HloInstructionPtr hlo) { 161 return HandleElementwiseUnary(hlo); 162 } HandleLog(HloInstructionPtr hlo)163 virtual Status HandleLog(HloInstructionPtr hlo) { 164 return HandleElementwiseUnary(hlo); 165 } HandleClz(HloInstructionPtr hlo)166 virtual Status HandleClz(HloInstructionPtr hlo) { 167 return HandleElementwiseUnary(hlo); 168 } HandleLog1p(HloInstructionPtr hlo)169 virtual Status HandleLog1p(HloInstructionPtr hlo) { 170 return HandleElementwiseUnary(hlo); 171 } HandleCos(HloInstructionPtr hlo)172 virtual Status HandleCos(HloInstructionPtr hlo) { 173 return HandleElementwiseUnary(hlo); 174 } HandleSin(HloInstructionPtr hlo)175 virtual Status HandleSin(HloInstructionPtr hlo) { 176 return HandleElementwiseUnary(hlo); 177 } HandleTanh(HloInstructionPtr hlo)178 virtual Status HandleTanh(HloInstructionPtr hlo) { 179 return HandleElementwiseUnary(hlo); 180 } HandleReal(HloInstructionPtr hlo)181 virtual Status HandleReal(HloInstructionPtr hlo) { 182 return HandleElementwiseUnary(hlo); 183 } HandleImag(HloInstructionPtr hlo)184 virtual Status HandleImag(HloInstructionPtr hlo) { 185 return HandleElementwiseUnary(hlo); 186 } HandleIsFinite(HloInstructionPtr hlo)187 virtual Status HandleIsFinite(HloInstructionPtr hlo) { 188 return HandleElementwiseUnary(hlo); 189 } HandleAnd(HloInstructionPtr hlo)190 virtual Status HandleAnd(HloInstructionPtr hlo) { 191 return HandleElementwiseBinary(hlo); 192 } HandleNot(HloInstructionPtr hlo)193 virtual Status HandleNot(HloInstructionPtr hlo) { 194 return HandleElementwiseUnary(hlo); 195 } HandleOr(HloInstructionPtr hlo)196 virtual Status HandleOr(HloInstructionPtr hlo) { 197 return HandleElementwiseBinary(hlo); 198 } HandleXor(HloInstructionPtr hlo)199 virtual Status HandleXor(HloInstructionPtr hlo) { 200 return HandleElementwiseBinary(hlo); 201 } HandleShiftLeft(HloInstructionPtr hlo)202 virtual Status HandleShiftLeft(HloInstructionPtr hlo) { 203 return HandleElementwiseBinary(hlo); 204 } HandleShiftRightArithmetic(HloInstructionPtr hlo)205 virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { 206 return HandleElementwiseBinary(hlo); 207 } HandleShiftRightLogical(HloInstructionPtr hlo)208 virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { 209 return HandleElementwiseBinary(hlo); 210 } 211 HandleReducePrecision(HloInstructionPtr hlo)212 virtual Status HandleReducePrecision(HloInstructionPtr hlo) { 213 return HandleElementwiseUnary(hlo); 214 } 215 HandleDomain(HloInstructionPtr hlo)216 virtual Status HandleDomain(HloInstructionPtr hlo) { 217 return HandleElementwiseUnary(hlo); 218 } 219 220 virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; 221 virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; 222 virtual Status HandleRng(HloInstructionPtr hlo) = 0; 223 virtual Status HandleReverse(HloInstructionPtr hlo) = 0; 224 virtual Status HandleSort(HloInstructionPtr hlo) = 0; 225 virtual Status HandleConstant(HloInstructionPtr hlo) = 0; 226 virtual Status HandleIota(HloInstructionPtr hlo) = 0; 227 virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; 228 virtual Status HandleReduce(HloInstructionPtr hlo) = 0; 229 virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; 230 virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; 231 virtual Status HandleReshape(HloInstructionPtr hlo) = 0; 232 virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; 233 virtual Status HandleParameter(HloInstructionPtr hlo) = 0; 234 virtual Status HandleFusion(HloInstructionPtr hlo) = 0; 235 virtual Status HandleCall(HloInstructionPtr hlo) = 0; 236 virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; 237 virtual Status HandleSlice(HloInstructionPtr hlo) = 0; 238 virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; 239 virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; 240 virtual Status HandleTuple(HloInstructionPtr hlo) = 0; 241 virtual Status HandleMap(HloInstructionPtr hlo) = 0; 242 virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; 243 virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; 244 virtual Status HandleWhile(HloInstructionPtr hlo) = 0; 245 virtual Status HandleConditional(HloInstructionPtr hlo) = 0; 246 virtual Status HandleGather(HloInstructionPtr hlo) = 0; 247 virtual Status HandleScatter(HloInstructionPtr hlo) = 0; 248 249 virtual Status HandlePad(HloInstructionPtr hlo) = 0; 250 251 virtual Status HandleSend(HloInstructionPtr send) = 0; 252 virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; 253 254 virtual Status HandleRecv(HloInstructionPtr recv) = 0; 255 virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; 256 257 virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; 258 259 virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; 260 261 virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; 262 263 virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; 264 virtual Status HandleAfterAll(HloInstructionPtr token) = 0; 265 266 // Invoked to inform the visitor that the traversal has completed, and that 267 // the root was "root". 268 virtual Status FinishVisit(HloInstructionPtr root) = 0; 269 270 // 3 possible visitation states of HLO instructions. Each instruction's 271 // state only flows one way: kNotVisited -> kVisiting -> kVisited. 272 enum VisitState { 273 kNotVisited = 0, 274 kVisiting = 1, 275 kVisited = 2, 276 }; 277 GetVisitState(int id)278 VisitState GetVisitState(int id) { 279 auto iter = visit_state_.find(id); 280 if (iter == visit_state_.end()) { 281 return VisitState::kNotVisited; 282 } 283 return iter->second; 284 } 285 VisitState GetVisitState(const HloInstruction& instruction); 286 287 // Resize internal state if necessary to hold state for ids <= num. 288 // This call is purely a performance hint and can be omitted without 289 // affecting correctness. ReserveVisitStates(int num)290 void ReserveVisitStates(int num) { visit_state_.reserve(num); } 291 292 // Useful when we want to visit the same computation more than once with the 293 // same visitor. ResetVisitStates()294 void ResetVisitStates() { visit_state_.clear(); } 295 SetVisitState(int id,VisitState state)296 void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } 297 298 // Sets the visitation state of the given instruction as kVisiting. 299 // 300 // Precondition: current state must be kNotVisited. 301 void SetVisiting(const HloInstruction& instruction); 302 303 // Sets the visitation state of the given instruction as kVisited. 304 // 305 // Precondition: current state must be either kNotVisited or kVisiting. 306 void SetVisited(const HloInstruction& instruction); 307 308 // Returns whether the state of the given instruction is kVisiting. IsVisiting(const HloInstruction & instruction)309 bool IsVisiting(const HloInstruction& instruction) { 310 return GetVisitState(instruction) == kVisiting; 311 } 312 313 // Returns whether the state of the given instruction is kVisited. DidVisit(const HloInstruction & instruction)314 bool DidVisit(const HloInstruction& instruction) { 315 return GetVisitState(instruction) == kVisited; 316 } 317 318 // Returns whether the state of the given instruction is kNotVisited. NotVisited(const HloInstruction & instruction)319 bool NotVisited(const HloInstruction& instruction) { 320 return GetVisitState(instruction) == kNotVisited; 321 } 322 323 // This method should be overridden by subclasses that wish to run some 324 // operation on an op before its Handle* visitor method is called. 325 // 326 // For any HLO op, the order of calls is: 327 // 328 // Preprocess(op); 329 // Handle/OpType/(op); 330 // Postprocess(op); 331 // 332 // Overriding methods should call DfsHloVisitor::Preprocess before doing their 333 // own preprocessing. 334 virtual Status Preprocess(HloInstructionPtr hlo); 335 336 // This method should be overridden by subclasses that wish to run some 337 // operation on an op after its Handle* visitor method is called. See 338 // Preprocess for more details. 339 // 340 // Overriding methods should call DfsHloVisitor::Postprocess after doing their 341 // own postprocessing. 342 virtual Status Postprocess(HloInstructionPtr hlo); 343 344 private: 345 absl::flat_hash_map<int, VisitState> visit_state_; 346 347 TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); 348 }; 349 350 // Users should use one of these two type aliases, which are the only two valid 351 // instantiations of DfsHloVisitorBase. 352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>; 353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>; 354 355 } // namespace xla 356 357 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 358