1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 18 19 #include <stddef.h> 20 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/strings/string_view.h" 29 #include "absl/types/span.h" 30 #include "llvm/ADT/Triple.h" 31 #include "llvm/IR/Function.h" 32 #include "llvm/IR/IRBuilder.h" 33 #include "llvm/IR/Module.h" 34 #include "llvm/IR/Value.h" 35 #include "llvm/Target/TargetMachine.h" 36 #include "mlir/IR/MLIRContext.h" // from @llvm-project 37 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 38 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 39 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 41 #include "tensorflow/compiler/xla/service/hlo_computation.h" 42 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 43 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 44 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 45 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" 46 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 47 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 48 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 49 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 50 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 51 #include "tensorflow/compiler/xla/service/name_uniquer.h" 52 #include "tensorflow/compiler/xla/statusor.h" 53 #include "tensorflow/compiler/xla/types.h" 54 #include "tensorflow/compiler/xla/xla_data.pb.h" 55 #include "tensorflow/core/platform/macros.h" 56 #include "tensorflow/core/platform/types.h" 57 58 namespace xla { 59 namespace cpu { 60 // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It 61 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR 62 // functions. 63 class IrEmitter : public DfsHloVisitorWithDefault, 64 public IrBuilderMixin<IrEmitter> { 65 friend class CpuElementalIrEmitter; 66 67 public: 68 using GeneratorForOperandIrArrays = 69 std::function<std::vector<llvm_ir::IrArray>()>; 70 71 // Create a new LLVM IR emitter. 72 // 73 // hlo_module: the HLO module we are emitting IR for. 74 // assignment: a BufferAssignment from which we know which buffers are used by 75 // the HLO nodes. 76 // mlir_context: the MLIR context used for IR emission. 77 // llvm_module: the LLVM module to emit IR into. It's built using the LLVM 78 // context inside of mlir_context. 79 // instruction_to_profile_idx: the mapping from HLO instructions to their 80 // index in the profiling array. 81 // computation_to_profile_idx: the mapping from HLO computations to their 82 // index in the profiling array. 83 // emit_code_for_msan: whether emitted code should be compatible with msan. 84 IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, 85 const BufferAssignment& assignment, llvm::Module* llvm_module, 86 std::unordered_map<const HloInstruction*, int64> 87 instruction_to_profile_idx, 88 std::unordered_map<const HloComputation*, int64> 89 computation_to_profile_idx, 90 const TargetMachineFeatures* target_machine, 91 bool emit_code_for_msan); 92 ~IrEmitter() override; 93 94 // Emit and return the given HLO computation as an LLVM IR 95 // function. 96 // 97 // function_name_prefix is the desired name of the function. If the name is 98 // not unique among already emitted functions then a suffix is appended to 99 // make the name unique. 100 // 101 // 'is_top_level_computation' has the following meanings for each CPU backend: 102 // *) sequential: indicates that this is the entry computation of the HLO 103 // module. 104 // *) parallel: indices that this is the callee of a kCall HLO in the entry 105 // computation of the HLO module. 106 // 107 // If 'instruction_order' is not NULL, then the HLO instructions are emitted 108 // in the given order. In this case, 'instruction_order' must be a 109 // topological sort of the set of nodes accessible from the root of the 110 // computation. 111 StatusOr<llvm::Function*> EmitComputation( 112 HloComputation* computation, const string& function_name_prefix, 113 bool is_top_level_computation, 114 absl::Span<HloInstruction* const> instruction_order); 115 b()116 llvm::IRBuilder<>* b() { return &b_; } 117 118 // builder() is for IrBuilderMixin. builder()119 llvm::IRBuilder<>* builder() { return &b_; } 120 121 // Emit an LLVM global variable for every constant buffer allocation. 122 Status EmitConstantGlobals(); 123 124 protected: 125 // 126 // The following methods implement the DfsHloVisitor interface. 127 // 128 // Default action which emits code for most operations. Operations which are 129 // special in some way are handled explicitly in HandleFoo methods. 130 Status DefaultAction(HloInstruction* hlo) override; 131 132 Status HandleAllToAll(HloInstruction* instruction) override; 133 Status HandleBitcast(HloInstruction* bitcast) override; 134 Status HandleConstant(HloInstruction* constant) override; 135 Status HandleCopy(HloInstruction* copy) override; 136 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 137 Status HandleSelect(HloInstruction* select) override; 138 Status HandleTupleSelect(HloInstruction* tuple_select) override; 139 Status HandleDot(HloInstruction* dot) override; 140 Status HandleConvolution(HloInstruction* convolution) override; 141 Status HandleFft(HloInstruction* fft) override; 142 Status HandleAllReduce(HloInstruction* crs) override; 143 Status HandleCollectivePermute(HloInstruction* crs) override; 144 Status HandleInfeed(HloInstruction* infeed) override; 145 Status HandleOutfeed(HloInstruction* outfeed) override; 146 Status HandleSort(HloInstruction* sort) override; 147 Status HandleParameter(HloInstruction* parameter) override; 148 Status HandleReduce(HloInstruction* reduce) override; 149 Status HandleReduceWindow(HloInstruction* reduce_window) override; 150 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; 151 Status HandleSend(HloInstruction* send) override; 152 Status HandleSendDone(HloInstruction* send_done) override; 153 Status HandleSlice(HloInstruction* slice) override; 154 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 155 Status HandleDynamicUpdateSlice( 156 HloInstruction* dynamic_update_slice) override; 157 Status HandleRecv(HloInstruction* recv) override; 158 Status HandleRecvDone(HloInstruction* recv_done) override; 159 Status HandlePad(HloInstruction* pad) override; 160 Status HandleTuple(HloInstruction* tuple) override; 161 Status HandleFusion(HloInstruction* fusion) override; 162 Status HandleCall(HloInstruction* call) override; 163 Status HandleCustomCall(HloInstruction* custom_call) override; 164 Status HandleWhile(HloInstruction* xla_while) override; 165 Status HandleConcatenate(HloInstruction* concatenate) override; 166 Status HandleConditional(HloInstruction* conditional) override; 167 Status HandleScatter(HloInstruction* scatter) override; 168 Status HandleAfterAll(HloInstruction* after_all) override; 169 Status HandleAddDependency(HloInstruction* add_dependency) override; 170 Status HandleReplicaId(HloInstruction* hlo) override; 171 Status HandleRng(HloInstruction* rng) override; 172 Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; 173 Status FinishVisit(HloInstruction* root) override; 174 175 Status Preprocess(HloInstruction* hlo) override; 176 Status Postprocess(HloInstruction* hlo) override; 177 178 // A convenient helper for calling BufferAssignment::GetUniqueSlice. 179 BufferAllocation::Slice GetAllocationSlice( 180 const HloInstruction& hlo, const ShapeIndex& index = {}) const { 181 return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); 182 } 183 184 private: 185 Status HandleSliceToDynamic(HloInstruction* hlo); 186 Status HandlePadToStatic(HloInstruction* hlo); 187 Status HandleTopK(HloInstruction* hlo); 188 Status HandleAllReduceSingleReplica(HloInstruction* crs); 189 Status HandleAllReduceMultipleReplica(HloInstruction* crs); 190 191 // Private helper to initialize an IR function for the computation. 192 void InitializeIrFunction(const string& function_name); 193 194 // Emits the copying epilogue for the function, 195 // where it copies the returned value to the reserved alloca. 196 // This is only necessary for thread-local functions. 197 // Note that since the call graph is flattened, if the same function is 198 // called in both thread-local and non-thread-local it would be codegen'd 199 // twice, and we would know whether it's thread-local at codegen time. 200 void EmitThreadLocalFunctionEpilogue(HloComputation* computation); 201 202 // Convenience functions to generate a GEP into the profile counter parameter 203 // which would correspond to the index for a given HLO instruction or 204 // computation. 205 llvm::Value* GetProfileCounterFor(const HloInstruction& instruction); 206 llvm::Value* GetProfileCounterFor(const HloComputation& computation); 207 208 // Helper function template for the implementation of the above two functions. 209 template <typename T> 210 llvm::Value* GetProfileCounterCommon( 211 const T& hlo, 212 const std::unordered_map<const T*, int64>& profile_index_map); 213 214 // Gets the IR Value emitted previously for the given hlo. 215 // 216 // Prefer calling GetIrArrayFor if the value you're reading is a buffer, 217 // because GetIrArrayFor annotates buffer's loads/stores with noalias 218 // metadata. 219 // 220 // Make sure to call this only when you're certain a value *was* emitted - if 221 // not found, this will log a fatal error. 222 llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); 223 224 // Gets an IrArray representing the given hlo. 225 llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); 226 227 // Gets a list of IrArrays, one for each of hlo's operands. 228 std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf( 229 const HloInstruction* hlo); 230 231 // Bind all argument IrArrays of `fusion` to `fused_emitter`. 232 void BindFusionArguments(const HloInstruction* fusion, 233 FusedIrEmitter* fused_emitter); 234 235 // Augments IrArray with aliasing information. AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)236 void AddAliasingInformationToIrArray(const HloInstruction& hlo, 237 llvm_ir::IrArray* array) { 238 alias_analysis_.AddAliasingInformationToIrArray(hlo, array); 239 } 240 241 // Convenience function to get the IR type matching the given shape. 242 llvm::Type* IrShapeType(const Shape& shape); 243 244 // Get the llvm::Value* that represents the "prof_counters" argument of the 245 // computation function being emitted by this emitter. 246 llvm::Value* GetProfileCountersArgument(); 247 248 // Get the xla::ExecutableRunOptions that represents the "run_options" 249 // argument of the computation function being emitted by this emitter. 250 llvm::Value* GetExecutableRunOptionsArgument(); 251 252 // Get the llvm::Value* that represents the "buffer_table" argument of the 253 // computation function being emitted by this emitter. 254 llvm::Value* GetBufferTableArgument(); 255 256 // Helper for EmitBufferPointer. 257 llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, 258 const Shape& target_shape); 259 260 // Helper for EmitBufferPointer. 261 llvm::Value* EmitThreadLocalBufferPointer( 262 const BufferAllocation::Slice& slice, const Shape& target_shape); 263 264 // Emits code that computes the address of the given buffer allocation slice. 265 llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, 266 const Shape& target_shape); 267 268 // Emits a call to a thread local function (e.g. to the computation nested 269 // within a reduce or a map). Thread local callees (by definition) only write 270 // to and read from thread local allocations. 271 // Supports only functions returning scalars or tuples of scalars. 272 // 273 // `parameters` holds the *scalar values* that need to be passed to the 274 // callee. The return value is the scalar returned by the callee. 275 std::vector<llvm::Value*> EmitThreadLocalCall( 276 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 277 absl::string_view name); 278 279 // Similar to EmitThreadLocal, yet assumes that the function returns a scalar. 280 llvm::Value* EmitScalarReturningThreadLocalCall( 281 const HloComputation& callee, absl::Span<llvm::Value* const> parameters, 282 absl::string_view name); 283 284 // Emits a call to a "global" function (e.g. to the computation nested within 285 // a kWhile or a kCall). Buffer assignment unabiguously assigns buffers to 286 // the parameters and return values for these computations so there is no need 287 // to explicitly pass parameters or return results. 288 void EmitGlobalCall(const HloComputation& callee, absl::string_view name); 289 290 // Returns the buffer to which a global call to `callee` would have written 291 // its result. 292 llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); 293 294 // Verifies that the element types of all of the given operand instructions 295 // match and are of one of the given supported types. 296 Status ElementTypesSameAndSupported( 297 const HloInstruction& instruction, 298 absl::Span<const HloInstruction* const> operands, 299 absl::Span<const PrimitiveType> supported_types); 300 301 // Emit IR to perform a computation for every element in the given target op. 302 // This produces a series of nested loops (one for each dimension of the op's 303 // shape). The body of the inner-most loop is provided by the body_emitter 304 // function. 305 // 306 // desc is an optional human-readable string that's added to the loop name in 307 // IR. Regardless of whether desc is provided, target_op->name() is included 308 // in the loop name. 309 // 310 // TODO(jingyue): target_op should be a `const HloInstruction*`. 311 Status EmitTargetElementLoop( 312 HloInstruction* target_op, 313 const llvm_ir::ElementGenerator& element_generator); 314 Status EmitTargetElementLoop( 315 HloInstruction* target_op, absl::string_view desc, 316 const llvm_ir::ElementGenerator& element_generator); 317 318 // Emits a memcpy from the source instruction's result value to the 319 // destination's. Both source and destination must have an entry in the 320 // emitted_value_ table. 321 Status EmitMemcpy(const HloInstruction& source, 322 const HloInstruction& destination); 323 324 // Emits IR to compute the target address of the buffer for the given op. 325 // After calling this function, you can get a pointer to this buffer by 326 // calling GetIrArrayForOp or GetEmittedValueFor. 327 Status EmitTargetAddressForOp(const HloInstruction* op); 328 329 // Structurizes "array_elements" into an MD array that represents "shape". 330 // This is a recursive function, and "dimension_index" indicates the index of 331 // the current dimension that the function is considering (0 means the 332 // most-minor dimension). 333 llvm::Constant* CreateInitializerForConstantArray( 334 const std::vector<llvm::Constant*>& array_elements, const Shape& shape, 335 int64_t dimension_index); 336 337 // Tries to codegen a reduction operation using vectorized instructions. 338 // Returns true if successful, and false on failure. On failure, sets 339 // "failure_reason" to a string describing why it could not vectorize the 340 // reduction. 341 // 342 // TODO(sanjoy): Some of the things we do here can be abstracted out into 343 // concepts that generalize over other vectorizable operations. We should 344 // consider pulling out these abstractions into a VectorizingIrEmitter or 345 // something similar. 346 StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce, 347 HloInstruction* arg, 348 HloInstruction* init_value, 349 absl::Span<const int64> dimensions, 350 HloComputation* function, 351 string* failure_reason); 352 353 // We'd like to keep one or two one cache-line's worth of data in registers 354 // without generating IR with illegal (e.g. excessively large or 355 // non-power-of-two) vector types. We do this by introducing a layer of 356 // abstraction: we introduce a high level vector-like concept called a 357 // "sharded vector" that models data parallelism, and is mapped to a sequence 358 // scalar and vector llvm::Value s. 359 // 360 // For example, we can represent 29 f32 elements by a sharded vector mapped to 361 // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32]. 362 // Note that the last element is scalar. 363 // 364 // There is no requirement on the ordering or the uniqueness of the elements 365 // mapped to sharded vectors -- we allow repeated elements, and we allow 366 // elements to appear in any order. 367 using ShardedVector = std::vector<llvm::Value*>; 368 369 // A sharded vector type is the element-wise llvm::Type's of some 370 // ShardedVector. 371 using ShardedVectorType = std::vector<llvm::Type*>; 372 373 // Create a sharded vector type corresponding to a "element_count" long 374 // sequence of "element_type" values. 375 ShardedVectorType CreateShardedVectorType(PrimitiveType element_type, 376 unsigned element_count); 377 378 // Emit LLVM IR to store the sharded vector "value_to_store" to 379 // "store_address". 380 void EmitShardedVectorStore(llvm::Value* store_address, 381 const ShardedVector& value_to_store, 382 llvm::Align alignment, 383 const llvm_ir::IrArray& containing_array); 384 385 using ReductionGenerator = std ::function<llvm::Value*( 386 llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>; 387 388 // Tries to match the reduction function "function" to a known reduction 389 // pattern. Returns a non-null ReductionGenerator on a successful match, 390 // which can be used to generate the LLVM IR corresponding to said reduction. 391 // On failure, this stores a reason string into "failure_reason". 392 ReductionGenerator MatchReductionGenerator(HloComputation* function, 393 string* failure_reason) const; 394 395 // Emits the inner loop nest that runs the reduction. Helper function for 396 // EmitVectorizedReduce. 397 StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction( 398 const ReductionGenerator& reduction_generator, 399 const llvm_ir::IrArray::Index& output_index, 400 const ShardedVectorType& accumulator_type, HloInstruction* init_value, 401 HloInstruction* arg, absl::Span<const int64> dimensions, 402 llvm::Align element_alignment); 403 404 // Tries to emit a fast concatenate operation using memcpy. Returns true if 405 // successful, and false on failure. On failure, sets "failure_reason" to a 406 // string describing why it could not emit a fast concatenate. 407 StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate, 408 absl::Span<HloInstruction* const> operands, 409 string* failure_reason); 410 411 // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" 412 // from the address "source" to the address "target". 413 void EmitTransferElements(llvm::Value* target, llvm::Value* source, 414 int64_t element_count, PrimitiveType primitive_type, 415 const llvm_ir::IrArray& target_array, 416 const llvm_ir::IrArray& source_array); 417 418 // Emits printing during the execution. 419 llvm::Value* EmitPrintf(absl::string_view fmt, 420 absl::Span<llvm::Value* const> arguments); 421 llvm::Value* EmitPrintfToStderr(absl::string_view fmt, 422 absl::Span<llvm::Value* const> arguments); 423 424 // Emits a call to a non-variadic function `func_name` with arguments 425 // `arguments` assuming C calling convention. 426 llvm::Value* EmitCallToFunc( 427 std::string func_name, const std::vector<llvm::Value*>& arguments, 428 llvm::Type* return_type, bool does_not_throw = true, 429 bool only_accesses_arg_memory = false, 430 bool only_accesses_inaccessible_mem_or_arg_mem = false); 431 432 // Assignment of the buffers needed by the computation and their shape 433 // information. 434 const BufferAssignment& assignment_; 435 436 // The LLVM module into which IR will be emitted. 437 llvm::Module* module_; 438 439 // The target architecture. 440 llvm::Triple::ArchType arch_type_; 441 442 // Used to produce unique names for generated functions. 443 NameUniquer name_uniquer_; 444 445 // Map containing all previously emitted computations. 446 std::map<const HloComputation*, llvm::Function*> emitted_functions_; 447 448 // Map containing all previously emitted thread-local temporary buffers. 449 std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> 450 thread_local_buffers_; 451 452 // The following fields track the IR emission state. According to LLVM memory 453 // management rules, their memory is owned by the module (Note that IrFunction 454 // creates the encapsulated llvm::Function s.t. it is added to the llvm 455 // module's function list). 456 std::unique_ptr<IrFunction> compute_function_; 457 llvm::IRBuilder<> b_; 458 mlir::MLIRContext* mlir_context_; 459 460 // The buffer allocation slice for the root of the computation being compiled. 461 // Only relevant for thread local computations. 462 BufferAllocation::Slice computation_root_allocation_; 463 464 // Maps the buffer allocation slices for the parameters to the computation 465 // being compiled to their parameter numbers. Only relevant for thread local 466 // computations. 467 absl::flat_hash_map<BufferAllocation::Index, int64> 468 computation_parameter_allocations_; 469 470 // Maps HLO instructions to their index into the profile counter array. 471 const std::unordered_map<const HloInstruction*, int64> 472 instruction_to_profile_idx_; 473 474 // Maps HLO computations to their index into the profile counter array. 475 const std::unordered_map<const HloComputation*, int64> 476 computation_to_profile_idx_; 477 478 // Maps HLOs to Values emitted for them. 479 absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_; 480 481 llvm_ir::AliasAnalysis alias_analysis_; 482 483 // The number of root instruction outer dimensions used in parallel loop 484 // emission (ParallelLoopEmitter). 485 int64 num_dynamic_loop_bounds_ = 0; 486 487 // Returns whether the given instruction should be emitted as a parallel loop. ShouldEmitParallelLoopFor(const HloInstruction & op)488 bool ShouldEmitParallelLoopFor(const HloInstruction& op) const { 489 // Emit parallel loop for root instruction if dynamic outer-dimension loop 490 // bounds were specified. 491 return num_dynamic_loop_bounds_ > 0 && 492 op.parent()->root_instruction() == &op; 493 } 494 495 // This struct contains all the state needed to emit instructions for 496 // profiling a computation. 497 class ProfilingState { 498 public: ProfilingState()499 ProfilingState() : use_rdtscp_(false) {} ProfilingState(bool use_rdtscp)500 explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} 501 502 // Record the cycle counter before an HLO executes. 503 void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); 504 // Record the number of cycles it took for an HLO to execute. 505 void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, 506 llvm::Value* prof_counter); 507 // Record the number of cycles it took for the entire computation to 508 // execute. 509 void RecordCompleteComputation(llvm::IRBuilder<>* b, 510 llvm::Value* prof_counter); 511 512 // Convenience function to generate a call to an intrinsic which reads the 513 // CPU cycle counter. 514 llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); 515 516 // Store the cycle counter delta to the per-HLO profile counter. 517 void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, 518 llvm::Value* cycle_end, llvm::Value* cycle_start); 519 520 private: 521 // Should we use the x86-specific rdtscp or the generic readcyclecounter 522 // intrinsic? 523 bool use_rdtscp_; 524 525 // The first read cycle counter in the program. 526 llvm::Value* first_read_cycle_start_ = nullptr; 527 528 // The last read cycle counter in the program. 529 llvm::Value* last_read_cycle_end_ = nullptr; 530 531 // Maps HLOs to the value the cycle counter contained right before the HLO 532 // began to execute. 533 std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_; 534 }; 535 536 ProfilingState profiling_state_; 537 538 class TracingState { 539 public: TracingState()540 TracingState() : enabled_(false) {} set_enabled(bool value)541 void set_enabled(bool value) { enabled_ = value; } 542 void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo, 543 llvm::Value* run_options); 544 void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo, 545 llvm::Value* run_options); 546 547 private: 548 bool enabled_; 549 // Maps from HLO to the activity id returned by xprof::TraceMe. 550 std::unordered_map<const HloInstruction*, llvm::Value*> activity_ids_; 551 }; 552 TracingState tracing_state_; 553 554 // Given a load instruction and a shape or buffer size, annotate the load's 555 // result with the alignment required by the shape or size. 556 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape); 557 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, 558 int64_t buffer_size); 559 560 // Given a load instruction and a shape or buffer size, annotate the load's 561 // result with the dereferenceable bytes required by the shape / buffer size. 562 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 563 const Shape& shape); 564 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 565 int64_t buffer_size); 566 567 // Calculate the alignment of a buffer allocated for a given shape. 568 int MinimumAlignmentForShape(const Shape& shape); 569 570 // Calculate the alignment of a buffer allocated for a given primitive type. 571 int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); 572 573 // Returns the number of bytes within the shape. 574 int64 ByteSizeOf(const Shape& shape) const; 575 576 enum class XfeedKind { 577 kInfeed, 578 kOutfeed, 579 }; 580 581 // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program 582 // address. 583 Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, 584 llvm::Value* program_buffer_address); 585 586 // Returns a ConstExpr bitcast. 587 llvm::Constant* EmitGlobalForLiteral(const Literal& literal); 588 589 const HloModuleConfig& hlo_module_config_; 590 591 bool is_top_level_computation_; 592 593 const TargetMachineFeatures& target_machine_features_; 594 595 struct LiteralPtrHashFunctor { operatorLiteralPtrHashFunctor596 size_t operator()(const Literal* literal) const { return literal->Hash(); } 597 }; 598 599 struct LiteralPtrEqualityFunctor { operatorLiteralPtrEqualityFunctor600 bool operator()(const Literal* lhs, const Literal* rhs) const { 601 return *lhs == *rhs; 602 } 603 }; 604 605 absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor, 606 LiteralPtrEqualityFunctor> 607 emitted_literals_; 608 609 absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*> 610 constant_buffer_to_global_; 611 612 std::vector<const HloComputation*> thread_local_computations_; 613 std::vector<const HloComputation*> global_computations_; 614 615 bool emit_code_for_msan_; 616 617 TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); 618 }; 619 620 } // namespace cpu 621 } // namespace xla 622 623 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 624