• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
18 
19 #include <stddef.h>
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/span.h"
30 #include "llvm/ADT/Triple.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
38 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
39 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
48 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
49 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
50 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
51 #include "tensorflow/compiler/xla/service/name_uniquer.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/types.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/platform/macros.h"
56 #include "tensorflow/core/platform/types.h"
57 
58 namespace xla {
59 namespace cpu {
60 // This class is the top-level API for the XLA HLO --> LLVM IR compiler.  It
61 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
62 // functions.
63 class IrEmitter : public DfsHloVisitorWithDefault,
64                   public IrBuilderMixin<IrEmitter> {
65   friend class CpuElementalIrEmitter;
66 
67  public:
68   using GeneratorForOperandIrArrays =
69       std::function<std::vector<llvm_ir::IrArray>()>;
70 
71   // Create a new LLVM IR emitter.
72   //
73   // hlo_module: the HLO module we are emitting IR for.
74   // assignment: a BufferAssignment from which we know which buffers are used by
75   //             the HLO nodes.
76   // mlir_context: the MLIR context used for IR emission.
77   // llvm_module: the LLVM module to emit IR into. It's built using the LLVM
78   //              context inside of mlir_context.
79   // instruction_to_profile_idx: the mapping from HLO instructions to their
80   //              index in the profiling array.
81   // computation_to_profile_idx: the mapping from HLO computations to their
82   //              index in the profiling array.
83   // emit_code_for_msan: whether emitted code should be compatible with msan.
84   IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
85             const BufferAssignment& assignment, llvm::Module* llvm_module,
86             std::unordered_map<const HloInstruction*, int64>
87                 instruction_to_profile_idx,
88             std::unordered_map<const HloComputation*, int64>
89                 computation_to_profile_idx,
90             const TargetMachineFeatures* target_machine,
91             bool emit_code_for_msan);
92   ~IrEmitter() override;
93 
94   // Emit and return the given HLO computation as an LLVM IR
95   // function.
96   //
97   // function_name_prefix is the desired name of the function. If the name is
98   // not unique among already emitted functions then a suffix is appended to
99   // make the name unique.
100   //
101   // 'is_top_level_computation' has the following meanings for each CPU backend:
102   // *) sequential: indicates that this is the entry computation of the HLO
103   //    module.
104   // *) parallel: indices that this is the callee of a kCall HLO in the entry
105   //    computation of the HLO module.
106   //
107   // If 'instruction_order' is not NULL, then the HLO instructions are emitted
108   // in the given order.  In this case, 'instruction_order' must be a
109   // topological sort of the set of nodes accessible from the root of the
110   // computation.
111   StatusOr<llvm::Function*> EmitComputation(
112       HloComputation* computation, const string& function_name_prefix,
113       bool is_top_level_computation,
114       absl::Span<HloInstruction* const> instruction_order);
115 
b()116   llvm::IRBuilder<>* b() { return &b_; }
117 
118   // builder() is for IrBuilderMixin.
builder()119   llvm::IRBuilder<>* builder() { return &b_; }
120 
121   // Emit an LLVM global variable for every constant buffer allocation.
122   Status EmitConstantGlobals();
123 
124  protected:
125   //
126   // The following methods implement the DfsHloVisitor interface.
127   //
128   // Default action which emits code for most operations. Operations which are
129   // special in some way are handled explicitly in HandleFoo methods.
130   Status DefaultAction(HloInstruction* hlo) override;
131 
132   Status HandleAllToAll(HloInstruction* instruction) override;
133   Status HandleBitcast(HloInstruction* bitcast) override;
134   Status HandleConstant(HloInstruction* constant) override;
135   Status HandleCopy(HloInstruction* copy) override;
136   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
137   Status HandleSelect(HloInstruction* select) override;
138   Status HandleTupleSelect(HloInstruction* tuple_select) override;
139   Status HandleDot(HloInstruction* dot) override;
140   Status HandleConvolution(HloInstruction* convolution) override;
141   Status HandleFft(HloInstruction* fft) override;
142   Status HandleAllReduce(HloInstruction* crs) override;
143   Status HandleCollectivePermute(HloInstruction* crs) override;
144   Status HandleInfeed(HloInstruction* infeed) override;
145   Status HandleOutfeed(HloInstruction* outfeed) override;
146   Status HandleSort(HloInstruction* sort) override;
147   Status HandleParameter(HloInstruction* parameter) override;
148   Status HandleReduce(HloInstruction* reduce) override;
149   Status HandleReduceWindow(HloInstruction* reduce_window) override;
150   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
151   Status HandleSend(HloInstruction* send) override;
152   Status HandleSendDone(HloInstruction* send_done) override;
153   Status HandleSlice(HloInstruction* slice) override;
154   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
155   Status HandleDynamicUpdateSlice(
156       HloInstruction* dynamic_update_slice) override;
157   Status HandleRecv(HloInstruction* recv) override;
158   Status HandleRecvDone(HloInstruction* recv_done) override;
159   Status HandlePad(HloInstruction* pad) override;
160   Status HandleTuple(HloInstruction* tuple) override;
161   Status HandleFusion(HloInstruction* fusion) override;
162   Status HandleCall(HloInstruction* call) override;
163   Status HandleCustomCall(HloInstruction* custom_call) override;
164   Status HandleWhile(HloInstruction* xla_while) override;
165   Status HandleConcatenate(HloInstruction* concatenate) override;
166   Status HandleConditional(HloInstruction* conditional) override;
167   Status HandleScatter(HloInstruction* scatter) override;
168   Status HandleAfterAll(HloInstruction* after_all) override;
169   Status HandleAddDependency(HloInstruction* add_dependency) override;
170   Status HandleReplicaId(HloInstruction* hlo) override;
171   Status HandleRng(HloInstruction* rng) override;
172   Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
173   Status FinishVisit(HloInstruction* root) override;
174 
175   Status Preprocess(HloInstruction* hlo) override;
176   Status Postprocess(HloInstruction* hlo) override;
177 
178   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
179   BufferAllocation::Slice GetAllocationSlice(
180       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
181     return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
182   }
183 
184  private:
185   Status HandleSliceToDynamic(HloInstruction* hlo);
186   Status HandlePadToStatic(HloInstruction* hlo);
187   Status HandleTopK(HloInstruction* hlo);
188   Status HandleAllReduceSingleReplica(HloInstruction* crs);
189   Status HandleAllReduceMultipleReplica(HloInstruction* crs);
190 
191   // Private helper to initialize an IR function for the computation.
192   void InitializeIrFunction(const string& function_name);
193 
194   // Emits the copying epilogue for the function,
195   // where it copies the returned value to the reserved alloca.
196   // This is only necessary for thread-local functions.
197   // Note that since the call graph is flattened, if the same function is
198   // called in both thread-local and non-thread-local it would be codegen'd
199   // twice, and we would know whether it's thread-local at codegen time.
200   void EmitThreadLocalFunctionEpilogue(HloComputation* computation);
201 
202   // Convenience functions to generate a GEP into the profile counter parameter
203   // which would correspond to the index for a given HLO instruction or
204   // computation.
205   llvm::Value* GetProfileCounterFor(const HloInstruction& instruction);
206   llvm::Value* GetProfileCounterFor(const HloComputation& computation);
207 
208   // Helper function template for the implementation of the above two functions.
209   template <typename T>
210   llvm::Value* GetProfileCounterCommon(
211       const T& hlo,
212       const std::unordered_map<const T*, int64>& profile_index_map);
213 
214   // Gets the IR Value emitted previously for the given hlo.
215   //
216   // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
217   // because GetIrArrayFor annotates buffer's loads/stores with noalias
218   // metadata.
219   //
220   // Make sure to call this only when you're certain a value *was* emitted - if
221   // not found, this will log a fatal error.
222   llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
223 
224   // Gets an IrArray representing the given hlo.
225   llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
226 
227   // Gets a list of IrArrays, one for each of hlo's operands.
228   std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
229       const HloInstruction* hlo);
230 
231   // Bind all argument IrArrays of `fusion` to `fused_emitter`.
232   void BindFusionArguments(const HloInstruction* fusion,
233                            FusedIrEmitter* fused_emitter);
234 
235   // Augments IrArray with aliasing information.
AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)236   void AddAliasingInformationToIrArray(const HloInstruction& hlo,
237                                        llvm_ir::IrArray* array) {
238     alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
239   }
240 
241   // Convenience function to get the IR type matching the given shape.
242   llvm::Type* IrShapeType(const Shape& shape);
243 
244   // Get the llvm::Value* that represents the "prof_counters" argument of the
245   // computation function being emitted by this emitter.
246   llvm::Value* GetProfileCountersArgument();
247 
248   // Get the xla::ExecutableRunOptions that represents the "run_options"
249   // argument of the computation function being emitted by this emitter.
250   llvm::Value* GetExecutableRunOptionsArgument();
251 
252   // Get the llvm::Value* that represents the "buffer_table" argument of the
253   // computation function being emitted by this emitter.
254   llvm::Value* GetBufferTableArgument();
255 
256   // Helper for EmitBufferPointer.
257   llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
258                                        const Shape& target_shape);
259 
260   // Helper for EmitBufferPointer.
261   llvm::Value* EmitThreadLocalBufferPointer(
262       const BufferAllocation::Slice& slice, const Shape& target_shape);
263 
264   // Emits code that computes the address of the given buffer allocation slice.
265   llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
266                                  const Shape& target_shape);
267 
268   // Emits a call to a thread local function (e.g. to the computation nested
269   // within a reduce or a map).  Thread local callees (by definition) only write
270   // to and read from thread local allocations.
271   // Supports only functions returning scalars or tuples of scalars.
272   //
273   // `parameters` holds the *scalar values* that need to be passed to the
274   // callee.  The return value is the scalar returned by the callee.
275   std::vector<llvm::Value*> EmitThreadLocalCall(
276       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
277       absl::string_view name);
278 
279   // Similar to EmitThreadLocal, yet assumes that the function returns a scalar.
280   llvm::Value* EmitScalarReturningThreadLocalCall(
281       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
282       absl::string_view name);
283 
284   // Emits a call to a "global" function (e.g. to the computation nested within
285   // a kWhile or a kCall).  Buffer assignment unabiguously assigns buffers to
286   // the parameters and return values for these computations so there is no need
287   // to explicitly pass parameters or return results.
288   void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
289 
290   // Returns the buffer to which a global call to `callee` would have written
291   // its result.
292   llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
293 
294   // Verifies that the element types of all of the given operand instructions
295   // match and are of one of the given supported types.
296   Status ElementTypesSameAndSupported(
297       const HloInstruction& instruction,
298       absl::Span<const HloInstruction* const> operands,
299       absl::Span<const PrimitiveType> supported_types);
300 
301   // Emit IR to perform a computation for every element in the given target op.
302   // This produces a series of nested loops (one for each dimension of the op's
303   // shape). The body of the inner-most loop is provided by the body_emitter
304   // function.
305   //
306   // desc is an optional human-readable string that's added to the loop name in
307   // IR.  Regardless of whether desc is provided, target_op->name() is included
308   // in the loop name.
309   //
310   // TODO(jingyue): target_op should be a `const HloInstruction*`.
311   Status EmitTargetElementLoop(
312       HloInstruction* target_op,
313       const llvm_ir::ElementGenerator& element_generator);
314   Status EmitTargetElementLoop(
315       HloInstruction* target_op, absl::string_view desc,
316       const llvm_ir::ElementGenerator& element_generator);
317 
318   // Emits a memcpy from the source instruction's result value to the
319   // destination's.  Both source and destination must have an entry in the
320   // emitted_value_ table.
321   Status EmitMemcpy(const HloInstruction& source,
322                     const HloInstruction& destination);
323 
324   // Emits IR to compute the target address of the buffer for the given op.
325   // After calling this function, you can get a pointer to this buffer by
326   // calling GetIrArrayForOp or GetEmittedValueFor.
327   Status EmitTargetAddressForOp(const HloInstruction* op);
328 
329   // Structurizes "array_elements" into an MD array that represents "shape".
330   // This is a recursive function, and "dimension_index" indicates the index of
331   // the current dimension that the function is considering (0 means the
332   // most-minor dimension).
333   llvm::Constant* CreateInitializerForConstantArray(
334       const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
335       int64 dimension_index);
336 
337   // Tries to codegen a reduction operation using vectorized instructions.
338   // Returns true if successful, and false on failure.  On failure, sets
339   // "failure_reason" to a string describing why it could not vectorize the
340   // reduction.
341   //
342   // TODO(sanjoy): Some of the things we do here can be abstracted out into
343   // concepts that generalize over other vectorizable operations.  We should
344   // consider pulling out these abstractions into a VectorizingIrEmitter or
345   // something similar.
346   StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
347                                       HloInstruction* arg,
348                                       HloInstruction* init_value,
349                                       absl::Span<const int64> dimensions,
350                                       HloComputation* function,
351                                       string* failure_reason);
352 
353   // We'd like to keep one or two one cache-line's worth of data in registers
354   // without generating IR with illegal (e.g. excessively large or
355   // non-power-of-two) vector types.  We do this by introducing a layer of
356   // abstraction: we introduce a high level vector-like concept called a
357   // "sharded vector" that models data parallelism, and is mapped to a sequence
358   // scalar and vector llvm::Value s.
359   //
360   // For example, we can represent 29 f32 elements by a sharded vector mapped to
361   // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
362   // Note that the last element is scalar.
363   //
364   // There is no requirement on the ordering or the uniqueness of the elements
365   // mapped to sharded vectors -- we allow repeated elements, and we allow
366   // elements to appear in any order.
367   using ShardedVector = std::vector<llvm::Value*>;
368 
369   // A sharded vector type is the element-wise llvm::Type's of some
370   // ShardedVector.
371   using ShardedVectorType = std::vector<llvm::Type*>;
372 
373   // Create a sharded vector type corresponding to a "element_count" long
374   // sequence of "element_type" values.
375   ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
376                                             unsigned element_count);
377 
378   // Emit LLVM IR to store the sharded vector "value_to_store" to
379   // "store_address".
380   void EmitShardedVectorStore(llvm::Value* store_address,
381                               const ShardedVector& value_to_store,
382                               llvm::Align alignment,
383                               const llvm_ir::IrArray& containing_array);
384 
385   using ReductionGenerator = std ::function<llvm::Value*(
386       llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
387 
388   // Tries to match the reduction function "function" to a known reduction
389   // pattern.  Returns a non-null ReductionGenerator on a successful match,
390   // which can be used to generate the LLVM IR corresponding to said reduction.
391   // On failure, this stores a reason string into "failure_reason".
392   ReductionGenerator MatchReductionGenerator(HloComputation* function,
393                                              string* failure_reason) const;
394 
395   // Emits the inner loop nest that runs the reduction.  Helper function for
396   // EmitVectorizedReduce.
397   StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
398       const ReductionGenerator& reduction_generator,
399       const llvm_ir::IrArray::Index& output_index,
400       const ShardedVectorType& accumulator_type, HloInstruction* init_value,
401       HloInstruction* arg, absl::Span<const int64> dimensions,
402       llvm::Align element_alignment);
403 
404   // Tries to emit a fast concatenate operation using memcpy.  Returns true if
405   // successful, and false on failure.  On failure, sets "failure_reason" to a
406   // string describing why it could not emit a fast concatenate.
407   StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
408                                      absl::Span<HloInstruction* const> operands,
409                                      string* failure_reason);
410 
411   // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
412   // from the address "source" to the address "target".
413   void EmitTransferElements(llvm::Value* target, llvm::Value* source,
414                             int64 element_count, PrimitiveType primitive_type,
415                             const llvm_ir::IrArray& target_array,
416                             const llvm_ir::IrArray& source_array);
417 
418   // Emits printing during the execution.
419   llvm::Value* EmitPrintf(absl::string_view fmt,
420                           absl::Span<llvm::Value* const> arguments);
421 
422   // Emits a call to a non-variadic function `func_name` with arguments
423   // `arguments` assuming C calling convention.
424   llvm::Value* EmitCallToFunc(
425       std::string func_name, const std::vector<llvm::Value*>& arguments,
426       llvm::Type* return_type, bool does_not_throw = true,
427       bool only_accesses_arg_memory = false,
428       bool only_accesses_inaccessible_mem_or_arg_mem = false);
429 
430   // Assignment of the buffers needed by the computation and their shape
431   // information.
432   const BufferAssignment& assignment_;
433 
434   // The LLVM module into which IR will be emitted.
435   llvm::Module* module_;
436 
437   // The target architecture.
438   llvm::Triple::ArchType arch_type_;
439 
440   // Used to produce unique names for generated functions.
441   NameUniquer name_uniquer_;
442 
443   // Map containing all previously emitted computations.
444   std::map<const HloComputation*, llvm::Function*> emitted_functions_;
445 
446   // Map containing all previously emitted thread-local temporary buffers.
447   std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
448       thread_local_buffers_;
449 
450   // The following fields track the IR emission state. According to LLVM memory
451   // management rules, their memory is owned by the module (Note that IrFunction
452   // creates the encapsulated llvm::Function s.t. it is added to the llvm
453   // module's function list).
454   std::unique_ptr<IrFunction> compute_function_;
455   llvm::IRBuilder<> b_;
456   mlir::MLIRContext* mlir_context_;
457 
458   // The buffer allocation slice for the root of the computation being compiled.
459   // Only relevant for thread local computations.
460   BufferAllocation::Slice computation_root_allocation_;
461 
462   // Maps the buffer allocation slices for the parameters to the computation
463   // being compiled to their parameter numbers.  Only relevant for thread local
464   // computations.
465   absl::flat_hash_map<BufferAllocation::Index, int64>
466       computation_parameter_allocations_;
467 
468   // Maps HLO instructions to their index into the profile counter array.
469   const std::unordered_map<const HloInstruction*, int64>
470       instruction_to_profile_idx_;
471 
472   // Maps HLO computations to their index into the profile counter array.
473   const std::unordered_map<const HloComputation*, int64>
474       computation_to_profile_idx_;
475 
476   // Maps HLOs to Values emitted for them.
477   absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
478 
479   llvm_ir::AliasAnalysis alias_analysis_;
480 
481   // The number of root instruction outer dimensions used in parallel loop
482   // emission (ParallelLoopEmitter).
483   int64 num_dynamic_loop_bounds_ = 0;
484 
485   // Returns whether the given instruction should be emitted as a parallel loop.
ShouldEmitParallelLoopFor(const HloInstruction & op)486   bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
487     // Emit parallel loop for root instruction if dynamic outer-dimension loop
488     // bounds were specified.
489     return num_dynamic_loop_bounds_ > 0 &&
490            op.parent()->root_instruction() == &op;
491   }
492 
493   // This struct contains all the state needed to emit instructions for
494   // profiling a computation.
495   class ProfilingState {
496    public:
ProfilingState()497     ProfilingState() : use_rdtscp_(false) {}
ProfilingState(bool use_rdtscp)498     explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {}
499 
500     // Record the cycle counter before an HLO executes.
501     void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo);
502     // Record the number of cycles it took for an HLO to execute.
503     void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo,
504                           llvm::Value* prof_counter);
505     // Record the number of cycles it took for the entire computation to
506     // execute.
507     void RecordCompleteComputation(llvm::IRBuilder<>* b,
508                                    llvm::Value* prof_counter);
509 
510     // Convenience function to generate a call to an intrinsic which reads the
511     // CPU cycle counter.
512     llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b);
513 
514     // Store the cycle counter delta to the per-HLO profile counter.
515     void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter,
516                               llvm::Value* cycle_end, llvm::Value* cycle_start);
517 
518    private:
519     // Should we use the x86-specific rdtscp or the generic readcyclecounter
520     // intrinsic?
521     bool use_rdtscp_;
522 
523     // The first read cycle counter in the program.
524     llvm::Value* first_read_cycle_start_ = nullptr;
525 
526     // The last read cycle counter in the program.
527     llvm::Value* last_read_cycle_end_ = nullptr;
528 
529     // Maps HLOs to the value the cycle counter contained right before the HLO
530     // began to execute.
531     std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
532   };
533 
534   ProfilingState profiling_state_;
535 
536   class TracingState {
537    public:
TracingState()538     TracingState() : enabled_(false) {}
set_enabled(bool value)539     void set_enabled(bool value) { enabled_ = value; }
540     void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo,
541                           llvm::Value* run_options);
542     void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo,
543                         llvm::Value* run_options);
544 
545    private:
546     bool enabled_;
547     // Maps from HLO to the activity id returned by xprof::TraceMe.
548     std::unordered_map<const HloInstruction*, llvm::Value*> activity_ids_;
549   };
550   TracingState tracing_state_;
551 
552   // Given a load instruction and a shape or buffer size, annotate the load's
553   // result with the alignment required by the shape or size.
554   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
555   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
556 
557   // Given a load instruction and a shape or buffer size, annotate the load's
558   // result with the dereferenceable bytes required by the shape / buffer size.
559   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
560                                             const Shape& shape);
561   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
562                                             int64 buffer_size);
563 
564   // Calculate the alignment of a buffer allocated for a given shape.
565   int MinimumAlignmentForShape(const Shape& shape);
566 
567   // Calculate the alignment of a buffer allocated for a given primitive type.
568   int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
569 
570   // Returns the number of bytes within the shape.
571   int64 ByteSizeOf(const Shape& shape) const;
572 
573   enum class XfeedKind {
574     kInfeed,
575     kOutfeed,
576   };
577 
578   // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
579   // address.
580   Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
581                            llvm::Value* program_buffer_address);
582 
583   // Returns a ConstExpr bitcast.
584   llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
585 
586   const HloModuleConfig& hlo_module_config_;
587 
588   bool is_top_level_computation_;
589 
590   const TargetMachineFeatures& target_machine_features_;
591 
592   struct LiteralPtrHashFunctor {
operatorLiteralPtrHashFunctor593     size_t operator()(const Literal* literal) const { return literal->Hash(); }
594   };
595 
596   struct LiteralPtrEqualityFunctor {
operatorLiteralPtrEqualityFunctor597     bool operator()(const Literal* lhs, const Literal* rhs) const {
598       return *lhs == *rhs;
599     }
600   };
601 
602   absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
603                       LiteralPtrEqualityFunctor>
604       emitted_literals_;
605 
606   absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
607       constant_buffer_to_global_;
608 
609   std::vector<const HloComputation*> thread_local_computations_;
610   std::vector<const HloComputation*> global_computations_;
611 
612   bool emit_code_for_msan_;
613 
614   TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
615 };
616 
617 }  // namespace cpu
618 }  // namespace xla
619 
620 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
621