• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
18 
19 #include <stddef.h>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/span.h"
29 #include "llvm/ADT/Triple.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
36 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
37 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
47 #include "tensorflow/compiler/xla/service/name_uniquer.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/types.h"
53 
54 namespace xla {
55 namespace cpu {
56 // This class is the top-level API for the XLA HLO --> LLVM IR compiler.  It
57 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
58 // functions.
59 class IrEmitter : public DfsHloVisitorWithDefault,
60                   public IrBuilderMixin<IrEmitter> {
61  public:
62   using GeneratorForOperandIrArrays =
63       std::function<std::vector<llvm_ir::IrArray>()>;
64 
65   // Create a new LLVM IR emitter.
66   //
67   // hlo_module: the HLO module we are emitting IR for.
68   // assignment: a BufferAssignment from which we know which buffers are used by
69   //             the HLO nodes.
70   // llvm_module: the LLVM module to emit IR into.
71   // instruction_to_profile_idx: the mapping from HLO instructions to their
72   //              index in the profiling array.
73   // computation_to_profile_idx: the mapping from HLO computations to their
74   //              index in the profiling array.
75   // emit_code_for_msan: whether emitted code should be compatible with msan.
76   IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
77             llvm::Module* llvm_module,
78             std::unordered_map<const HloInstruction*, int64>
79                 instruction_to_profile_idx,
80             std::unordered_map<const HloComputation*, int64>
81                 computation_to_profile_idx,
82             const TargetMachineFeatures* target_machine,
83             bool emit_code_for_msan);
84   ~IrEmitter() override;
85 
86   // Emit and return the given HLO computation as an LLVM IR
87   // function.
88   //
89   // function_name_prefix is the desired name of the function. If the name is
90   // not unique among already emitted functions then a suffix is appended to
91   // make the name unique.
92   //
93   // 'is_top_level_computation' has the following meanings for each CPU backend:
94   // *) sequential: indicates that this is the entry computation of the HLO
95   //    module.
96   // *) parallel: indices that this is the callee of a kCall HLO in the entry
97   //    computation of the HLO module.
98   //
99   // If 'instruction_order' is not NULL, then the HLO instructions are emitted
100   // in the given order.  In this case, 'instruction_order' must be a
101   // topological sort of the set of nodes accessible from the root of the
102   // computation.
103   StatusOr<llvm::Function*> EmitComputation(
104       HloComputation* computation, const string& function_name_prefix,
105       bool is_top_level_computation,
106       absl::Span<HloInstruction* const> instruction_order);
107 
b()108   llvm::IRBuilder<>* b() { return &b_; }
109 
110   // builder() is for IrBuilderMixin.
builder()111   llvm::IRBuilder<>* builder() { return &b_; }
112 
113   // Emit an LLVM global variable for every constant buffer allocation.
114   Status EmitConstantGlobals();
115 
116   // Emit code to map one element according to `map_instr`.
117   llvm::Value* EmitElementalMap(
118       const HloMapInstruction& map_instr,
119       absl::Span<llvm::Value* const> elemental_operands,
120       absl::string_view name);
121   // Emit code to emit the element at `index` for a reduce window instruction.
122   StatusOr<llvm::Value*> EmitElementalReduceWindow(
123       const HloReduceWindowInstruction* reduce_window,
124       const llvm_ir::ElementGenerator& input_generator,
125       const llvm_ir::IrArray::Index& index);
126   // Emit code to emit the element at `index` for a convolution instruction.
127   StatusOr<llvm::Value*> EmitElementalConvolution(
128       const HloConvolutionInstruction* convolution,
129       const llvm_ir::ElementGenerator& input_generator,
130       const llvm_ir::ElementGenerator& kernel_generator,
131       const llvm_ir::IrArray::Index& index);
132   // Emit code to emit the element at `index` for a reduce instruction.
133   StatusOr<llvm::Value*> EmitElementalReduce(
134       const HloReduceInstruction* reduce,
135       const llvm_ir::ElementGenerator& input_generator,
136       const llvm_ir::ElementGenerator& initial_value_generator,
137       const llvm_ir::IrArray::Index& index);
138 
139  protected:
140   //
141   // The following methods implement the DfsHloVisitor interface.
142   //
143   // Default action which emits code for most operations. Operations which are
144   // special in some way are handled explicitly in HandleFoo methods.
145   Status DefaultAction(HloInstruction* hlo) override;
146 
147   Status HandleAllToAll(HloInstruction* instruction) override;
148   Status HandleBitcast(HloInstruction* bitcast) override;
149   Status HandleConstant(HloInstruction* constant) override;
150   Status HandleCopy(HloInstruction* copy) override;
151   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
152   Status HandleSelect(HloInstruction* select) override;
153   Status HandleTupleSelect(HloInstruction* tuple_select) override;
154   Status HandleDot(HloInstruction* dot) override;
155   Status HandleConvolution(HloInstruction* convolution) override;
156   Status HandleFft(HloInstruction* fft) override;
157   Status HandleAllReduce(HloInstruction* crs) override;
158   Status HandleInfeed(HloInstruction* infeed) override;
159   Status HandleOutfeed(HloInstruction* outfeed) override;
160   Status HandleSort(HloInstruction* sort) override;
161   Status HandleParameter(HloInstruction* parameter) override;
162   Status HandleReduce(HloInstruction* reduce) override;
163   Status HandleReduceWindow(HloInstruction* reduce_window) override;
164   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
165   Status HandleSend(HloInstruction* send) override;
166   Status HandleSendDone(HloInstruction* send_done) override;
167   Status HandleSlice(HloInstruction* slice) override;
168   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
169   Status HandleDynamicUpdateSlice(
170       HloInstruction* dynamic_update_slice) override;
171   Status HandleRecv(HloInstruction* recv) override;
172   Status HandleRecvDone(HloInstruction* recv_done) override;
173   Status HandlePad(HloInstruction* pad) override;
174   Status HandleTuple(HloInstruction* tuple) override;
175   Status HandleFusion(HloInstruction* fusion) override;
176   Status HandleCall(HloInstruction* call) override;
177   Status HandleCustomCall(HloInstruction* custom_call) override;
178   Status HandleWhile(HloInstruction* xla_while) override;
179   Status HandleConcatenate(HloInstruction* concatenate) override;
180   Status HandleConditional(HloInstruction* conditional) override;
181   Status HandleScatter(HloInstruction* scatter) override;
182   Status HandleAfterAll(HloInstruction* after_all) override;
183   Status HandleAddDependency(HloInstruction* add_dependency) override;
184   Status HandleRng(HloInstruction* rng) override;
185   Status FinishVisit(HloInstruction* root) override;
186 
187   Status Preprocess(HloInstruction* hlo) override;
188   Status Postprocess(HloInstruction* hlo) override;
189 
190   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
191   BufferAllocation::Slice GetAllocationSlice(
192       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
193     return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
194   }
195 
196  private:
197   // Private helper to initialize an IR function for the computation.
198   void InitializeIrFunction(const string& function_name);
199 
200   template <typename T>
201   llvm::Value* GetProfileCounterCommon(
202       const T& hlo,
203       const std::unordered_map<const T*, int64>& profile_index_map);
204 
205   // Convenience functions to generate a GEP into the profile counter parameter
206   // which would correspond to the index for a given HLO instruction or
207   // computation.
GetProfileCounterFor(const HloInstruction & instruction)208   llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) {
209     return GetProfileCounterCommon<HloInstruction>(instruction,
210                                                    instruction_to_profile_idx_);
211   }
212 
GetProfileCounterFor(const HloComputation & computation)213   llvm::Value* GetProfileCounterFor(const HloComputation& computation) {
214     return GetProfileCounterCommon<HloComputation>(computation,
215                                                    computation_to_profile_idx_);
216   }
217 
218   // Gets the IR Value emitted previously for the given hlo.
219   //
220   // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
221   // because GetIrArrayFor annotates buffer's loads/stores with noalias
222   // metadata.
223   //
224   // Make sure to call this only when you're certain a value *was* emitted - if
225   // not found, this will log a fatal error.
226   llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
227 
228   // Gets an IrArray representing the given hlo.
229   llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
230 
231   // Gets a list of IrArrays, one for each of hlo's operands.
232   std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
233       const HloInstruction* hlo);
234 
GetGeneratorForOperandIrArrays(HloInstruction * unnested_hlo)235   GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays(
236       HloInstruction* unnested_hlo) {
237     return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); };
238   }
239 
240   // Augments IrArray with aliasing information.
AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)241   void AddAliasingInformationToIrArray(const HloInstruction& hlo,
242                                        llvm_ir::IrArray* array) {
243     alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
244   }
245 
246   // Convenience function to get the IR type matching the given shape.
247   llvm::Type* IrShapeType(const Shape& shape);
248 
249   // Get the llvm::Value* that represents the "prof_counters" argument of the
250   // computation function being emitted by this emitter.
251   llvm::Value* GetProfileCountersArgument();
252 
253   // Get the xla::ExecutableRunOptions that represents the "run_options"
254   // argument of the computation function being emitted by this emitter.
255   llvm::Value* GetExecutableRunOptionsArgument();
256 
257   // Get the llvm::Value* that represents the "buffer_table" argument of the
258   // computation function being emitted by this emitter.
259   llvm::Value* GetBufferTableArgument();
260 
261   // Helper for EmitBufferPointer.
262   llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
263                                        const Shape& target_shape);
264 
265   // Helper for EmitBufferPointer.
266   llvm::Value* EmitThreadLocalBufferPointer(
267       const BufferAllocation::Slice& slice, const Shape& target_shape);
268 
269   // Emits code that computes the address of the given buffer allocation slice.
270   llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
271                                  const Shape& target_shape);
272 
273   // Emits a call to a thread local function (e.g. to the computation nested
274   // within a reduce or a map).  Thread local callees (by definition) only write
275   // to and read from thread local allocations.
276   //
277   // `parameters` holds the *scalar values* that need to be passed to the
278   // callee.  The return value is the scalar returned by the callee.
279   llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
280                                    absl::Span<llvm::Value* const> parameters,
281                                    absl::string_view name);
282 
283   // Emits a call to a "global" function (e.g. to the computation nested within
284   // a kWhile or a kCall).  Buffer assignment unabiguously assignes buffers to
285   // the parameters and return values for these computations so there is no need
286   // to explicitly pass parameters or return results.
287   void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
288 
289   // Returns the buffer to which a global call to `callee` would have written
290   // its result.
291   llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
292 
293   // Verifies that the element types of all of the given operand instructions
294   // match and are of one of the given supported types.
295   Status ElementTypesSameAndSupported(
296       const HloInstruction& instruction,
297       absl::Span<const HloInstruction* const> operands,
298       absl::Span<const PrimitiveType> supported_types);
299 
300   // Emit IR to perform a computation for every element in the given target op.
301   // This produces a series of nested loops (one for each dimension of the op's
302   // shape). The body of the inner-most loop is provided by the body_emitter
303   // function.
304   //
305   // desc is an optional human-readable string that's added to the loop name in
306   // IR.  Regardless of whether desc is provided, target_op->name() is included
307   // in the loop name.
308   //
309   // TODO(jingyue): target_op should be a `const HloInstruction*`.
310   Status EmitTargetElementLoop(
311       HloInstruction* target_op,
312       const llvm_ir::ElementGenerator& element_generator);
313   Status EmitTargetElementLoop(
314       HloInstruction* target_op, absl::string_view desc,
315       const llvm_ir::ElementGenerator& element_generator);
316 
317   // Emits a memcpy from the source instruction's result value to the
318   // destination's.  Both source and destination must have an entry in the
319   // emitted_value_ table.
320   Status EmitMemcpy(const HloInstruction& source,
321                     const HloInstruction& destination);
322 
323   // Emits IR to compute the target address of the buffer for the given op.
324   // After calling this function, you can get a pointer to this buffer by
325   // calling GetIrArrayForOp or GetEmittedValueFor.
326   Status EmitTargetAddressForOp(const HloInstruction* op);
327 
328   // Structurizes "array_elements" into an MD array that represents "shape".
329   // This is a recursive function, and "dimension_index" indicates the index of
330   // the current dimension that the function is considering (0 means the
331   // most-minor dimension).
332   llvm::Constant* CreateInitializerForConstantArray(
333       const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
334       int64 dimension_index);
335 
336   // Tries to codegen a reduction operation using vectorized instructions.
337   // Returns true if successful, and false on failure.  On failure, sets
338   // "failure_reason" to a string describing why it could not vectorize the
339   // reduction.
340   //
341   // TODO(sanjoy): Some of the things we do here can be abstracted out into
342   // concepts that generalize over other vectorizable operations.  We should
343   // consider pulling out these abstractions into a VectorizingIrEmitter or
344   // something similar.
345   StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
346                                       HloInstruction* arg,
347                                       HloInstruction* init_value,
348                                       absl::Span<const int64> dimensions,
349                                       HloComputation* function,
350                                       string* failure_reason);
351 
352   // We'd like to keep one or two one cache-line's worth of data in registers
353   // without generating IR with illegal (e.g. excessively large or
354   // non-power-of-two) vector types.  We do this by introducing a layer of
355   // abstraction: we introduce a high level vector-like concept called a
356   // "sharded vector" that models data paralleism, and is mapped to a sequence
357   // scalar and vector llvm::Value s.
358   //
359   // For example, we can represent 29 f32 elements by a sharded vector mapped to
360   // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
361   // Note that the last element is scalar.
362   //
363   // There is no requirement on the ordering or the uniqueness of the elements
364   // mapped to sharded vectors -- we allow repeated elements, and we allow
365   // elements to appear in any order.
366   using ShardedVector = std::vector<llvm::Value*>;
367 
368   // A sharded vector type is the element-wise llvm::Type's of some
369   // ShardedVector.
370   using ShardedVectorType = std::vector<llvm::Type*>;
371 
372   // Create a sharded vector type corresponding to a "element_count" long
373   // sequence of "element_type" values.
374   ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
375                                             unsigned element_count);
376 
377   // Emit LLVM IR to store the sharded vector "value_to_store" to
378   // "store_address".
379   void EmitShardedVectorStore(llvm::Value* store_address,
380                               const ShardedVector& value_to_store,
381                               const int alignment,
382                               const llvm_ir::IrArray& containing_array);
383 
384   using ReductionGenerator = std ::function<llvm::Value*(
385       llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
386 
387   // Tries to match the reduction function "function" to a known reduction
388   // pattern.  Returns a non-null ReductionGenerator on a successful match,
389   // which can be used to generate the LLVM IR corresponding to said reduction.
390   // On failure, this stores a reason string into "failure_reason".
391   ReductionGenerator MatchReductionGenerator(HloComputation* function,
392                                              string* failure_reason) const;
393 
394   // Emits the inner loop nest that runs the reduction.  Helper function for
395   // EmitVectorizedReduce.
396   StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
397       const ReductionGenerator& reduction_generator,
398       const llvm_ir::IrArray::Index& output_index,
399       const ShardedVectorType& accumulator_type, HloInstruction* init_value,
400       HloInstruction* arg, absl::Span<const int64> dimensions,
401       unsigned element_alignment);
402 
403   // Tries to emit a fast concatenate operation using memcpy.  Returns true if
404   // successful, and false on failure.  On failure, sets "failure_reason" to a
405   // string describing why it could not emit a fast concatenate.
406   StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
407                                      absl::Span<HloInstruction* const> operands,
408                                      string* failure_reason);
409 
410   // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
411   // from the address "source" to the address "target".
412   void EmitTransferElements(llvm::Value* target, llvm::Value* source,
413                             int64 element_count, PrimitiveType primitive_type,
414                             const llvm_ir::IrArray& target_array,
415                             const llvm_ir::IrArray& source_array);
416 
417   // Assignment of the buffers needed by the computation and their shape
418   // information.
419   const BufferAssignment& assignment_;
420 
421   // The LLVM module into which IR will be emitted.
422   llvm::Module* module_;
423 
424   // The target architecture.
425   llvm::Triple::ArchType arch_type_;
426 
427   // Used to produce unique names for generated functions.
428   NameUniquer name_uniquer_;
429 
430   // Map containing all previously emitted computations.
431   std::map<const HloComputation*, llvm::Function*> emitted_functions_;
432 
433   // Map containing all previously emitted thread-local temporary buffers.
434   std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
435       thread_local_buffers_;
436 
437   // The following fields track the IR emission state. According to LLVM memory
438   // management rules, their memory is owned by the module (Note that IrFunction
439   // creates the encapsulated llvm::Function s.t. it is added to the llvm
440   // module's function list).
441   std::unique_ptr<IrFunction> compute_function_;
442   llvm::IRBuilder<> b_;
443 
444   // The buffer allocation slice for the root of the computation being compiled.
445   // Only relevant for thread local computations.
446   BufferAllocation::Slice computation_root_allocation_;
447 
448   // Maps the buffer allocation slices for the parameters to the computation
449   // being compiled to their parameter numbers.  Only relevant for thread local
450   // computations.
451   absl::flat_hash_map<BufferAllocation::Index, int64>
452       computation_parameter_allocations_;
453 
454   // Maps HLO instructions to their index into the profile counter array.
455   const std::unordered_map<const HloInstruction*, int64>
456       instruction_to_profile_idx_;
457 
458   // Maps HLO computations to their index into the profile counter array.
459   const std::unordered_map<const HloComputation*, int64>
460       computation_to_profile_idx_;
461 
462   // Maps HLOs to Values emitted for them.
463   absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
464 
465   llvm_ir::AliasAnalysis alias_analysis_;
466 
467   // The number of root instruction outer dimensions used in parallel loop
468   // emission (ParallelLoopEmitter).
469   int64 num_dynamic_loop_bounds_ = 0;
470 
471   // Returns whether the given instruction should be emitted as a parallel loop.
ShouldEmitParallelLoopFor(const HloInstruction & op)472   bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
473     // Emit parallel loop for root instruction if dynamic outer-dimension loop
474     // bounds were specified.
475     return num_dynamic_loop_bounds_ > 0 &&
476            op.parent()->root_instruction() == &op;
477   }
478 
479   // This struct contains all the state needed to emit instructions for
480   // profiling a computation.
481   class ProfilingState {
482    public:
ProfilingState()483     ProfilingState() : use_rdtscp_(false) {}
ProfilingState(bool use_rdtscp)484     explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {}
485 
486     // Record the cycle counter before an HLO executes.
487     void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo);
488     // Record the number of cycles it took for an HLO to execute.
489     void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo,
490                           llvm::Value* prof_counter);
491     // Record the number of cycles it took for the entire computation to
492     // execute.
493     void RecordCompleteComputation(llvm::IRBuilder<>* b,
494                                    llvm::Value* prof_counter);
495 
496     // Convenience function to generate a call to an intrinsic which reads the
497     // CPU cycle counter.
498     llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b);
499 
500     // Store the cycle counter delta to the per-HLO profile counter.
501     void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter,
502                               llvm::Value* cycle_end, llvm::Value* cycle_start);
503 
504    private:
505     // Should we use the x86-specific rdtscp or the generic readcyclecounter
506     // intrinsic?
507     bool use_rdtscp_;
508 
509     // The first read cycle counter in the program.
510     llvm::Value* first_read_cycle_start_ = nullptr;
511 
512     // The last read cycle counter in the program.
513     llvm::Value* last_read_cycle_end_ = nullptr;
514 
515     // An alloca used to hold the output of the aux value returned by the rdtscp
516     // intrinsic.
517     llvm::Value* aux_i8ptr_ = nullptr;
518 
519     // Maps HLOs to the value the cycle counter contained right before the HLO
520     // began to execute.
521     std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
522   };
523 
524   ProfilingState profiling_state_;
525 
526   // Given a load instruction and a shape or buffer size, annotate the load's
527   // result with the alignment required by the shape or size.
528   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
529   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
530 
531   // Given a load instruction and a shape or buffer size, annotate the load's
532   // result with the dereferenceable bytes required by the shape / buffer size.
533   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
534                                             const Shape& shape);
535   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
536                                             int64 buffer_size);
537 
538   // Calculate the alignment of a buffer allocated for a given shape.
539   int MinimumAlignmentForShape(const Shape& shape);
540 
541   // Calculate the alignment of a buffer allocated for a given primitive type.
542   int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
543 
544   // Returns the number of bytes within the shape.
545   int64 ByteSizeOf(const Shape& shape) const;
546 
547   enum class XfeedKind {
548     kInfeed,
549     kOutfeed,
550   };
551 
552   // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
553   // address.
554   Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
555                            llvm::Value* program_buffer_address);
556 
557   // Returns a ConstExpr bitcast.
558   llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
559 
560   const HloModuleConfig& hlo_module_config_;
561 
562   bool is_top_level_computation_;
563 
564   const TargetMachineFeatures& target_machine_features_;
565 
566   struct LiteralPtrHashFunctor {
operatorLiteralPtrHashFunctor567     size_t operator()(const Literal* literal) const { return literal->Hash(); }
568   };
569 
570   struct LiteralPtrEqualityFunctor {
operatorLiteralPtrEqualityFunctor571     bool operator()(const Literal* lhs, const Literal* rhs) const {
572       return *lhs == *rhs;
573     }
574   };
575 
576   absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
577                       LiteralPtrEqualityFunctor>
578       emitted_literals_;
579 
580   absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
581       constant_buffer_to_global_;
582 
583   std::vector<const HloComputation*> thread_local_computations_;
584   std::vector<const HloComputation*> global_computations_;
585 
586   bool emit_code_for_msan_;
587 
588   TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
589 };
590 
591 }  // namespace cpu
592 }  // namespace xla
593 
594 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
595