1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 18 19 #include <stddef.h> 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/strings/string_view.h" 28 #include "absl/types/span.h" 29 #include "llvm/ADT/Triple.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/IRBuilder.h" 32 #include "llvm/IR/Module.h" 33 #include "llvm/IR/Value.h" 34 #include "llvm/Target/TargetMachine.h" 35 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 36 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 37 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 39 #include "tensorflow/compiler/xla/service/hlo_computation.h" 40 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 41 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 42 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 43 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" 44 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 45 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 46 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 47 #include "tensorflow/compiler/xla/service/name_uniquer.h" 48 #include "tensorflow/compiler/xla/statusor.h" 49 #include "tensorflow/compiler/xla/types.h" 50 #include "tensorflow/compiler/xla/xla_data.pb.h" 51 #include "tensorflow/core/platform/macros.h" 52 #include "tensorflow/core/platform/types.h" 53 54 namespace xla { 55 namespace cpu { 56 // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It 57 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR 58 // functions. 59 class IrEmitter : public DfsHloVisitorWithDefault, 60 public IrBuilderMixin<IrEmitter> { 61 public: 62 using GeneratorForOperandIrArrays = 63 std::function<std::vector<llvm_ir::IrArray>()>; 64 65 // Create a new LLVM IR emitter. 66 // 67 // hlo_module: the HLO module we are emitting IR for. 68 // assignment: a BufferAssignment from which we know which buffers are used by 69 // the HLO nodes. 70 // llvm_module: the LLVM module to emit IR into. 71 // instruction_to_profile_idx: the mapping from HLO instructions to their 72 // index in the profiling array. 73 // computation_to_profile_idx: the mapping from HLO computations to their 74 // index in the profiling array. 75 // emit_code_for_msan: whether emitted code should be compatible with msan. 76 IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, 77 llvm::Module* llvm_module, 78 std::unordered_map<const HloInstruction*, int64> 79 instruction_to_profile_idx, 80 std::unordered_map<const HloComputation*, int64> 81 computation_to_profile_idx, 82 const TargetMachineFeatures* target_machine, 83 bool emit_code_for_msan); 84 ~IrEmitter() override; 85 86 // Emit and return the given HLO computation as an LLVM IR 87 // function. 88 // 89 // function_name_prefix is the desired name of the function. If the name is 90 // not unique among already emitted functions then a suffix is appended to 91 // make the name unique. 92 // 93 // 'is_top_level_computation' has the following meanings for each CPU backend: 94 // *) sequential: indicates that this is the entry computation of the HLO 95 // module. 96 // *) parallel: indices that this is the callee of a kCall HLO in the entry 97 // computation of the HLO module. 98 // 99 // If 'instruction_order' is not NULL, then the HLO instructions are emitted 100 // in the given order. In this case, 'instruction_order' must be a 101 // topological sort of the set of nodes accessible from the root of the 102 // computation. 103 StatusOr<llvm::Function*> EmitComputation( 104 HloComputation* computation, const string& function_name_prefix, 105 bool is_top_level_computation, 106 absl::Span<HloInstruction* const> instruction_order); 107 b()108 llvm::IRBuilder<>* b() { return &b_; } 109 110 // builder() is for IrBuilderMixin. builder()111 llvm::IRBuilder<>* builder() { return &b_; } 112 113 // Emit an LLVM global variable for every constant buffer allocation. 114 Status EmitConstantGlobals(); 115 116 // Emit code to map one element according to `map_instr`. 117 llvm::Value* EmitElementalMap( 118 const HloMapInstruction& map_instr, 119 absl::Span<llvm::Value* const> elemental_operands, 120 absl::string_view name); 121 // Emit code to emit the element at `index` for a reduce window instruction. 122 StatusOr<llvm::Value*> EmitElementalReduceWindow( 123 const HloReduceWindowInstruction* reduce_window, 124 const llvm_ir::ElementGenerator& input_generator, 125 const llvm_ir::IrArray::Index& index); 126 // Emit code to emit the element at `index` for a convolution instruction. 127 StatusOr<llvm::Value*> EmitElementalConvolution( 128 const HloConvolutionInstruction* convolution, 129 const llvm_ir::ElementGenerator& input_generator, 130 const llvm_ir::ElementGenerator& kernel_generator, 131 const llvm_ir::IrArray::Index& index); 132 // Emit code to emit the element at `index` for a reduce instruction. 133 StatusOr<llvm::Value*> EmitElementalReduce( 134 const HloReduceInstruction* reduce, 135 const llvm_ir::ElementGenerator& input_generator, 136 const llvm_ir::ElementGenerator& initial_value_generator, 137 const llvm_ir::IrArray::Index& index); 138 139 protected: 140 // 141 // The following methods implement the DfsHloVisitor interface. 142 // 143 // Default action which emits code for most operations. Operations which are 144 // special in some way are handled explicitly in HandleFoo methods. 145 Status DefaultAction(HloInstruction* hlo) override; 146 147 Status HandleAllToAll(HloInstruction* instruction) override; 148 Status HandleBitcast(HloInstruction* bitcast) override; 149 Status HandleConstant(HloInstruction* constant) override; 150 Status HandleCopy(HloInstruction* copy) override; 151 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 152 Status HandleSelect(HloInstruction* select) override; 153 Status HandleTupleSelect(HloInstruction* tuple_select) override; 154 Status HandleDot(HloInstruction* dot) override; 155 Status HandleConvolution(HloInstruction* convolution) override; 156 Status HandleFft(HloInstruction* fft) override; 157 Status HandleAllReduce(HloInstruction* crs) override; 158 Status HandleInfeed(HloInstruction* infeed) override; 159 Status HandleOutfeed(HloInstruction* outfeed) override; 160 Status HandleSort(HloInstruction* sort) override; 161 Status HandleParameter(HloInstruction* parameter) override; 162 Status HandleReduce(HloInstruction* reduce) override; 163 Status HandleReduceWindow(HloInstruction* reduce_window) override; 164 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; 165 Status HandleSend(HloInstruction* send) override; 166 Status HandleSendDone(HloInstruction* send_done) override; 167 Status HandleSlice(HloInstruction* slice) override; 168 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 169 Status HandleDynamicUpdateSlice( 170 HloInstruction* dynamic_update_slice) override; 171 Status HandleRecv(HloInstruction* recv) override; 172 Status HandleRecvDone(HloInstruction* recv_done) override; 173 Status HandlePad(HloInstruction* pad) override; 174 Status HandleTuple(HloInstruction* tuple) override; 175 Status HandleFusion(HloInstruction* fusion) override; 176 Status HandleCall(HloInstruction* call) override; 177 Status HandleCustomCall(HloInstruction* custom_call) override; 178 Status HandleWhile(HloInstruction* xla_while) override; 179 Status HandleConcatenate(HloInstruction* concatenate) override; 180 Status HandleConditional(HloInstruction* conditional) override; 181 Status HandleScatter(HloInstruction* scatter) override; 182 Status HandleAfterAll(HloInstruction* after_all) override; 183 Status HandleAddDependency(HloInstruction* add_dependency) override; 184 Status HandleRng(HloInstruction* rng) override; 185 Status FinishVisit(HloInstruction* root) override; 186 187 Status Preprocess(HloInstruction* hlo) override; 188 Status Postprocess(HloInstruction* hlo) override; 189 190 // A convenient helper for calling BufferAssignment::GetUniqueSlice. 191 BufferAllocation::Slice GetAllocationSlice( 192 const HloInstruction& hlo, const ShapeIndex& index = {}) const { 193 return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); 194 } 195 196 private: 197 // Private helper to initialize an IR function for the computation. 198 void InitializeIrFunction(const string& function_name); 199 200 template <typename T> 201 llvm::Value* GetProfileCounterCommon( 202 const T& hlo, 203 const std::unordered_map<const T*, int64>& profile_index_map); 204 205 // Convenience functions to generate a GEP into the profile counter parameter 206 // which would correspond to the index for a given HLO instruction or 207 // computation. GetProfileCounterFor(const HloInstruction & instruction)208 llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) { 209 return GetProfileCounterCommon<HloInstruction>(instruction, 210 instruction_to_profile_idx_); 211 } 212 GetProfileCounterFor(const HloComputation & computation)213 llvm::Value* GetProfileCounterFor(const HloComputation& computation) { 214 return GetProfileCounterCommon<HloComputation>(computation, 215 computation_to_profile_idx_); 216 } 217 218 // Gets the IR Value emitted previously for the given hlo. 219 // 220 // Prefer calling GetIrArrayFor if the value you're reading is a buffer, 221 // because GetIrArrayFor annotates buffer's loads/stores with noalias 222 // metadata. 223 // 224 // Make sure to call this only when you're certain a value *was* emitted - if 225 // not found, this will log a fatal error. 226 llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); 227 228 // Gets an IrArray representing the given hlo. 229 llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); 230 231 // Gets a list of IrArrays, one for each of hlo's operands. 232 std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf( 233 const HloInstruction* hlo); 234 GetGeneratorForOperandIrArrays(HloInstruction * unnested_hlo)235 GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( 236 HloInstruction* unnested_hlo) { 237 return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); }; 238 } 239 240 // Augments IrArray with aliasing information. AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)241 void AddAliasingInformationToIrArray(const HloInstruction& hlo, 242 llvm_ir::IrArray* array) { 243 alias_analysis_.AddAliasingInformationToIrArray(hlo, array); 244 } 245 246 // Convenience function to get the IR type matching the given shape. 247 llvm::Type* IrShapeType(const Shape& shape); 248 249 // Get the llvm::Value* that represents the "prof_counters" argument of the 250 // computation function being emitted by this emitter. 251 llvm::Value* GetProfileCountersArgument(); 252 253 // Get the xla::ExecutableRunOptions that represents the "run_options" 254 // argument of the computation function being emitted by this emitter. 255 llvm::Value* GetExecutableRunOptionsArgument(); 256 257 // Get the llvm::Value* that represents the "buffer_table" argument of the 258 // computation function being emitted by this emitter. 259 llvm::Value* GetBufferTableArgument(); 260 261 // Helper for EmitBufferPointer. 262 llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, 263 const Shape& target_shape); 264 265 // Helper for EmitBufferPointer. 266 llvm::Value* EmitThreadLocalBufferPointer( 267 const BufferAllocation::Slice& slice, const Shape& target_shape); 268 269 // Emits code that computes the address of the given buffer allocation slice. 270 llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, 271 const Shape& target_shape); 272 273 // Emits a call to a thread local function (e.g. to the computation nested 274 // within a reduce or a map). Thread local callees (by definition) only write 275 // to and read from thread local allocations. 276 // 277 // `parameters` holds the *scalar values* that need to be passed to the 278 // callee. The return value is the scalar returned by the callee. 279 llvm::Value* EmitThreadLocalCall(const HloComputation& callee, 280 absl::Span<llvm::Value* const> parameters, 281 absl::string_view name); 282 283 // Emits a call to a "global" function (e.g. to the computation nested within 284 // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to 285 // the parameters and return values for these computations so there is no need 286 // to explicitly pass parameters or return results. 287 void EmitGlobalCall(const HloComputation& callee, absl::string_view name); 288 289 // Returns the buffer to which a global call to `callee` would have written 290 // its result. 291 llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); 292 293 // Verifies that the element types of all of the given operand instructions 294 // match and are of one of the given supported types. 295 Status ElementTypesSameAndSupported( 296 const HloInstruction& instruction, 297 absl::Span<const HloInstruction* const> operands, 298 absl::Span<const PrimitiveType> supported_types); 299 300 // Emit IR to perform a computation for every element in the given target op. 301 // This produces a series of nested loops (one for each dimension of the op's 302 // shape). The body of the inner-most loop is provided by the body_emitter 303 // function. 304 // 305 // desc is an optional human-readable string that's added to the loop name in 306 // IR. Regardless of whether desc is provided, target_op->name() is included 307 // in the loop name. 308 // 309 // TODO(jingyue): target_op should be a `const HloInstruction*`. 310 Status EmitTargetElementLoop( 311 HloInstruction* target_op, 312 const llvm_ir::ElementGenerator& element_generator); 313 Status EmitTargetElementLoop( 314 HloInstruction* target_op, absl::string_view desc, 315 const llvm_ir::ElementGenerator& element_generator); 316 317 // Emits a memcpy from the source instruction's result value to the 318 // destination's. Both source and destination must have an entry in the 319 // emitted_value_ table. 320 Status EmitMemcpy(const HloInstruction& source, 321 const HloInstruction& destination); 322 323 // Emits IR to compute the target address of the buffer for the given op. 324 // After calling this function, you can get a pointer to this buffer by 325 // calling GetIrArrayForOp or GetEmittedValueFor. 326 Status EmitTargetAddressForOp(const HloInstruction* op); 327 328 // Structurizes "array_elements" into an MD array that represents "shape". 329 // This is a recursive function, and "dimension_index" indicates the index of 330 // the current dimension that the function is considering (0 means the 331 // most-minor dimension). 332 llvm::Constant* CreateInitializerForConstantArray( 333 const std::vector<llvm::Constant*>& array_elements, const Shape& shape, 334 int64 dimension_index); 335 336 // Tries to codegen a reduction operation using vectorized instructions. 337 // Returns true if successful, and false on failure. On failure, sets 338 // "failure_reason" to a string describing why it could not vectorize the 339 // reduction. 340 // 341 // TODO(sanjoy): Some of the things we do here can be abstracted out into 342 // concepts that generalize over other vectorizable operations. We should 343 // consider pulling out these abstractions into a VectorizingIrEmitter or 344 // something similar. 345 StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce, 346 HloInstruction* arg, 347 HloInstruction* init_value, 348 absl::Span<const int64> dimensions, 349 HloComputation* function, 350 string* failure_reason); 351 352 // We'd like to keep one or two one cache-line's worth of data in registers 353 // without generating IR with illegal (e.g. excessively large or 354 // non-power-of-two) vector types. We do this by introducing a layer of 355 // abstraction: we introduce a high level vector-like concept called a 356 // "sharded vector" that models data paralleism, and is mapped to a sequence 357 // scalar and vector llvm::Value s. 358 // 359 // For example, we can represent 29 f32 elements by a sharded vector mapped to 360 // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32]. 361 // Note that the last element is scalar. 362 // 363 // There is no requirement on the ordering or the uniqueness of the elements 364 // mapped to sharded vectors -- we allow repeated elements, and we allow 365 // elements to appear in any order. 366 using ShardedVector = std::vector<llvm::Value*>; 367 368 // A sharded vector type is the element-wise llvm::Type's of some 369 // ShardedVector. 370 using ShardedVectorType = std::vector<llvm::Type*>; 371 372 // Create a sharded vector type corresponding to a "element_count" long 373 // sequence of "element_type" values. 374 ShardedVectorType CreateShardedVectorType(PrimitiveType element_type, 375 unsigned element_count); 376 377 // Emit LLVM IR to store the sharded vector "value_to_store" to 378 // "store_address". 379 void EmitShardedVectorStore(llvm::Value* store_address, 380 const ShardedVector& value_to_store, 381 const int alignment, 382 const llvm_ir::IrArray& containing_array); 383 384 using ReductionGenerator = std ::function<llvm::Value*( 385 llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>; 386 387 // Tries to match the reduction function "function" to a known reduction 388 // pattern. Returns a non-null ReductionGenerator on a successful match, 389 // which can be used to generate the LLVM IR corresponding to said reduction. 390 // On failure, this stores a reason string into "failure_reason". 391 ReductionGenerator MatchReductionGenerator(HloComputation* function, 392 string* failure_reason) const; 393 394 // Emits the inner loop nest that runs the reduction. Helper function for 395 // EmitVectorizedReduce. 396 StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction( 397 const ReductionGenerator& reduction_generator, 398 const llvm_ir::IrArray::Index& output_index, 399 const ShardedVectorType& accumulator_type, HloInstruction* init_value, 400 HloInstruction* arg, absl::Span<const int64> dimensions, 401 unsigned element_alignment); 402 403 // Tries to emit a fast concatenate operation using memcpy. Returns true if 404 // successful, and false on failure. On failure, sets "failure_reason" to a 405 // string describing why it could not emit a fast concatenate. 406 StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate, 407 absl::Span<HloInstruction* const> operands, 408 string* failure_reason); 409 410 // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" 411 // from the address "source" to the address "target". 412 void EmitTransferElements(llvm::Value* target, llvm::Value* source, 413 int64 element_count, PrimitiveType primitive_type, 414 const llvm_ir::IrArray& target_array, 415 const llvm_ir::IrArray& source_array); 416 417 // Assignment of the buffers needed by the computation and their shape 418 // information. 419 const BufferAssignment& assignment_; 420 421 // The LLVM module into which IR will be emitted. 422 llvm::Module* module_; 423 424 // The target architecture. 425 llvm::Triple::ArchType arch_type_; 426 427 // Used to produce unique names for generated functions. 428 NameUniquer name_uniquer_; 429 430 // Map containing all previously emitted computations. 431 std::map<const HloComputation*, llvm::Function*> emitted_functions_; 432 433 // Map containing all previously emitted thread-local temporary buffers. 434 std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> 435 thread_local_buffers_; 436 437 // The following fields track the IR emission state. According to LLVM memory 438 // management rules, their memory is owned by the module (Note that IrFunction 439 // creates the encapsulated llvm::Function s.t. it is added to the llvm 440 // module's function list). 441 std::unique_ptr<IrFunction> compute_function_; 442 llvm::IRBuilder<> b_; 443 444 // The buffer allocation slice for the root of the computation being compiled. 445 // Only relevant for thread local computations. 446 BufferAllocation::Slice computation_root_allocation_; 447 448 // Maps the buffer allocation slices for the parameters to the computation 449 // being compiled to their parameter numbers. Only relevant for thread local 450 // computations. 451 absl::flat_hash_map<BufferAllocation::Index, int64> 452 computation_parameter_allocations_; 453 454 // Maps HLO instructions to their index into the profile counter array. 455 const std::unordered_map<const HloInstruction*, int64> 456 instruction_to_profile_idx_; 457 458 // Maps HLO computations to their index into the profile counter array. 459 const std::unordered_map<const HloComputation*, int64> 460 computation_to_profile_idx_; 461 462 // Maps HLOs to Values emitted for them. 463 absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_; 464 465 llvm_ir::AliasAnalysis alias_analysis_; 466 467 // The number of root instruction outer dimensions used in parallel loop 468 // emission (ParallelLoopEmitter). 469 int64 num_dynamic_loop_bounds_ = 0; 470 471 // Returns whether the given instruction should be emitted as a parallel loop. ShouldEmitParallelLoopFor(const HloInstruction & op)472 bool ShouldEmitParallelLoopFor(const HloInstruction& op) const { 473 // Emit parallel loop for root instruction if dynamic outer-dimension loop 474 // bounds were specified. 475 return num_dynamic_loop_bounds_ > 0 && 476 op.parent()->root_instruction() == &op; 477 } 478 479 // This struct contains all the state needed to emit instructions for 480 // profiling a computation. 481 class ProfilingState { 482 public: ProfilingState()483 ProfilingState() : use_rdtscp_(false) {} ProfilingState(bool use_rdtscp)484 explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} 485 486 // Record the cycle counter before an HLO executes. 487 void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); 488 // Record the number of cycles it took for an HLO to execute. 489 void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, 490 llvm::Value* prof_counter); 491 // Record the number of cycles it took for the entire computation to 492 // execute. 493 void RecordCompleteComputation(llvm::IRBuilder<>* b, 494 llvm::Value* prof_counter); 495 496 // Convenience function to generate a call to an intrinsic which reads the 497 // CPU cycle counter. 498 llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); 499 500 // Store the cycle counter delta to the per-HLO profile counter. 501 void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, 502 llvm::Value* cycle_end, llvm::Value* cycle_start); 503 504 private: 505 // Should we use the x86-specific rdtscp or the generic readcyclecounter 506 // intrinsic? 507 bool use_rdtscp_; 508 509 // The first read cycle counter in the program. 510 llvm::Value* first_read_cycle_start_ = nullptr; 511 512 // The last read cycle counter in the program. 513 llvm::Value* last_read_cycle_end_ = nullptr; 514 515 // An alloca used to hold the output of the aux value returned by the rdtscp 516 // intrinsic. 517 llvm::Value* aux_i8ptr_ = nullptr; 518 519 // Maps HLOs to the value the cycle counter contained right before the HLO 520 // began to execute. 521 std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_; 522 }; 523 524 ProfilingState profiling_state_; 525 526 // Given a load instruction and a shape or buffer size, annotate the load's 527 // result with the alignment required by the shape or size. 528 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape); 529 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size); 530 531 // Given a load instruction and a shape or buffer size, annotate the load's 532 // result with the dereferenceable bytes required by the shape / buffer size. 533 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 534 const Shape& shape); 535 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, 536 int64 buffer_size); 537 538 // Calculate the alignment of a buffer allocated for a given shape. 539 int MinimumAlignmentForShape(const Shape& shape); 540 541 // Calculate the alignment of a buffer allocated for a given primitive type. 542 int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); 543 544 // Returns the number of bytes within the shape. 545 int64 ByteSizeOf(const Shape& shape) const; 546 547 enum class XfeedKind { 548 kInfeed, 549 kOutfeed, 550 }; 551 552 // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program 553 // address. 554 Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, 555 llvm::Value* program_buffer_address); 556 557 // Returns a ConstExpr bitcast. 558 llvm::Constant* EmitGlobalForLiteral(const Literal& literal); 559 560 const HloModuleConfig& hlo_module_config_; 561 562 bool is_top_level_computation_; 563 564 const TargetMachineFeatures& target_machine_features_; 565 566 struct LiteralPtrHashFunctor { operatorLiteralPtrHashFunctor567 size_t operator()(const Literal* literal) const { return literal->Hash(); } 568 }; 569 570 struct LiteralPtrEqualityFunctor { operatorLiteralPtrEqualityFunctor571 bool operator()(const Literal* lhs, const Literal* rhs) const { 572 return *lhs == *rhs; 573 } 574 }; 575 576 absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor, 577 LiteralPtrEqualityFunctor> 578 emitted_literals_; 579 580 absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*> 581 constant_buffer_to_global_; 582 583 std::vector<const HloComputation*> thread_local_computations_; 584 std::vector<const HloComputation*> global_computations_; 585 586 bool emit_code_for_msan_; 587 588 TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); 589 }; 590 591 } // namespace cpu 592 } // namespace xla 593 594 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 595