• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <utility>
26 #include <vector>
27 
28 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 #include "llvm/CodeGen/TargetSubtargetInfo.h"
37 #include "llvm/IR/BasicBlock.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/Intrinsics.h"
42 #include "llvm/IR/IntrinsicsX86.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Value.h"
45 #include "tensorflow/compiler/xla/layout_util.h"
46 #include "tensorflow/compiler/xla/map_util.h"
47 #include "tensorflow/compiler/xla/primitive_util.h"
48 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
49 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
51 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
52 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
53 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
55 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
56 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
58 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
59 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
60 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
64 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
66 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
67 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/status_macros.h"
73 #include "tensorflow/compiler/xla/types.h"
74 #include "tensorflow/compiler/xla/util.h"
75 #include "tensorflow/compiler/xla/window_util.h"
76 #include "tensorflow/compiler/xla/xla_data.pb.h"
77 #include "tensorflow/core/lib/core/bits.h"
78 #include "tensorflow/core/lib/core/errors.h"
79 #include "tensorflow/core/lib/math/math_util.h"
80 #include "tensorflow/core/platform/logging.h"
81 
82 namespace xla {
83 
84 namespace {
85 using llvm_ir::IrName;
86 using llvm_ir::SetToFirstInsertPoint;
87 }  // namespace
88 
89 namespace cpu {
90 
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)91 IrEmitter::IrEmitter(
92     mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
93     const BufferAssignment& assignment, llvm::Module* llvm_module,
94     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
95     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
96     const TargetMachineFeatures* target_machine_features,
97     bool emit_code_for_msan)
98     : assignment_(assignment),
99       module_(llvm_module),
100       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
101       b_(llvm_module->getContext()),
102       mlir_context_(mlir_context),
103       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
104       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
105       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
106       hlo_module_config_(hlo_module.config()),
107       is_top_level_computation_(false),
108       target_machine_features_(*target_machine_features),
109       emit_code_for_msan_(emit_code_for_msan) {
110   b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
111   Status s = GatherComputationsByAllocationType(
112       &hlo_module, &thread_local_computations_, &global_computations_);
113   absl::c_sort(thread_local_computations_);
114   absl::c_sort(global_computations_);
115   TF_CHECK_OK(s) << "Should have failed buffer assignment.";
116 }
117 
EmitThreadLocalFunctionEpilogue(HloComputation * computation)118 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
119   llvm::Argument* out_parameter = compute_function_->result_arg();
120   llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
121   const Shape& return_shape = computation->root_instruction()->shape();
122 
123   if (ShapeUtil::IsScalar(return_shape)) {
124     llvm::Value* ret_value =
125         Load(root_value.GetBasePointer(), "load_ret_value");
126     Store(ret_value,
127           BitCast(out_parameter, root_value.GetBasePointer()->getType()));
128   } else {
129     CHECK(return_shape.IsTuple());
130 
131     llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
132     llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
133     llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
134 
135     for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
136       const Shape& element_shape = return_shape.tuple_shapes(i);
137       llvm::Value* destination = llvm_ir::EmitGetTupleElement(
138           element_shape,
139           /*index=*/i,
140           /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
141           &b_);
142 
143       llvm::Value* source = llvm_ir::EmitGetTupleElement(
144           element_shape,
145           /*index=*/i,
146           /*alignment=*/MinimumAlignmentForShape(element_shape),
147           root_value.GetBasePointer(), &b_);
148 
149       Store(Load(source), destination);
150     }
151   }
152 }
153 
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)154 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
155     HloComputation* computation, const string& function_name_prefix,
156     bool is_top_level_computation,
157     absl::Span<HloInstruction* const> instruction_order) {
158   string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
159   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
160   is_top_level_computation_ = is_top_level_computation;
161   num_dynamic_loop_bounds_ = 0;
162   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
163     num_dynamic_loop_bounds_ =
164         computation->root_instruction()->outer_dimension_partitions().size();
165   }
166 
167   if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
168     TF_ASSIGN_OR_RETURN(
169         computation_root_allocation_,
170         assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
171   }
172 
173   for (const HloInstruction* param : computation->parameter_instructions()) {
174     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
175                         assignment_.GetUniqueTopLevelSlice(param));
176     computation_parameter_allocations_[param_slice.allocation()->index()] =
177         param->parameter_number();
178   }
179 
180   InitializeIrFunction(function_name);
181   // The rdtscp instruction is x86 specific.  We will fallback to LLVM's generic
182   // readcyclecounter if it is unavailable.
183   bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
184                     arch_type_ == llvm::Triple::ArchType::x86_64;
185   profiling_state_ = ProfilingState(use_rdtscp);
186 
187   tracing_state_.set_enabled(
188       computation->parent()->config().cpu_traceme_enabled());
189 
190   TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
191   llvm::Function* ir_function = compute_function_->function();
192   InsertOrDie(&emitted_functions_, computation, ir_function);
193   // Delete 'compute_function', finalizing 'ir_function' and restoring caller
194   // IR insert point.
195 
196   // Function epilogue: copying the value over to either the return register,
197   // or values pointing from the return register.
198   const BufferAllocation* root_allocation =
199       computation_root_allocation_.allocation();
200   if (root_allocation && root_allocation->is_thread_local()) {
201     EmitThreadLocalFunctionEpilogue(computation);
202   }
203 
204   // Destructor for compute_function_ emits the "ret void" instruction.
205   compute_function_.reset();
206   computation_root_allocation_ = BufferAllocation::Slice();
207   computation_parameter_allocations_.clear();
208   return ir_function;
209 }
210 
InitializeIrFunction(const string & function_name)211 void IrEmitter::InitializeIrFunction(const string& function_name) {
212   // Functions with local linkage get an inlining bonus.  Because we know
213   // a-priori that embedded functions (non-entry functions) will not have its
214   // name resolved, give it local linkage.
215   llvm::Function::LinkageTypes linkage =
216       is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
217                                 : llvm::GlobalValue::InternalLinkage;
218   // Create and initialize new IrFunction.
219   compute_function_.reset(new IrFunction(function_name, linkage,
220                                          hlo_module_config_, module_, &b_,
221                                          num_dynamic_loop_bounds_));
222 }
223 
~IrEmitter()224 IrEmitter::~IrEmitter() {}
225 
HandleBitcast(HloInstruction * bitcast)226 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
227   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
228   emitted_value_[bitcast] =
229       BitCast(GetEmittedValueFor(bitcast->operand(0)),
230               IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
231   return Status::OK();
232 }
233 
EmitGlobalForLiteral(const Literal & literal)234 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
235   llvm::Constant* initializer =
236       llvm_ir::ConvertLiteralToIrConstant(literal, module_);
237   llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
238       /*Module=*/*module_,
239       /*Type=*/initializer->getType(),
240       /*isConstant=*/true,
241       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
242       /*Initializer=*/initializer,
243       /*Name=*/"");
244   result_global->setAlignment(
245       llvm::Align(MinimumAlignmentForShape(literal.shape())));
246   result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
247   return llvm::ConstantExpr::getBitCast(
248       result_global, IrShapeType(literal.shape())->getPointerTo());
249 }
250 
EmitConstantGlobals()251 Status IrEmitter::EmitConstantGlobals() {
252   for (const BufferAllocation& allocation : assignment_.Allocations()) {
253     if (!allocation.is_constant()) {
254       continue;
255     }
256 
257     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
258     llvm::Constant* global_for_const;
259     auto it = emitted_literals_.find(&literal);
260     if (it != emitted_literals_.end()) {
261       global_for_const = it->second;
262     } else {
263       global_for_const = EmitGlobalForLiteral(literal);
264       InsertOrDie(&emitted_literals_, &literal, global_for_const);
265     }
266 
267     InsertOrDie(&constant_buffer_to_global_, allocation.index(),
268                 global_for_const);
269   }
270 
271   return Status::OK();
272 }
273 
HandleConstant(HloInstruction * constant)274 Status IrEmitter::HandleConstant(HloInstruction* constant) {
275   VLOG(2) << "HandleConstant: " << constant->ToString();
276   // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
277   // of the constant.
278   return EmitTargetAddressForOp(constant);
279 }
280 
HandleCopy(HloInstruction * copy)281 Status IrEmitter::HandleCopy(HloInstruction* copy) {
282   if (copy->shape().IsTuple() ||
283       (copy->shape().IsArray() &&
284        LayoutUtil::Equal(copy->operand(0)->shape().layout(),
285                          copy->shape().layout()))) {
286     // If the layouts are equal this is just a memcpy. kCopy shallow copies a
287     // tuple so just memcpy the top-level buffer for tuples.
288     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
289     return EmitMemcpy(*(copy->operand(0)), *copy);
290   } else if (copy->shape().IsArray()) {
291     // Use the elemental emitter for array shapes.
292     return DefaultAction(copy);
293   }
294   return Unimplemented("unsupported operand type %s for copy instruction",
295                        PrimitiveType_Name(copy->shape().element_type()));
296 }
297 
298 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)299 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
300   int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
301   DCHECK_GE(byte_size, 0);
302   // Largest scalar is a complex128 so we don't need to worry about the
303   // int64->int truncation here.
304   DCHECK_LE(byte_size, 16);
305 
306   // Allocations may be 8-byte aligned if part of a small block.
307   return std::min(int64{8}, byte_size);
308 }
309 
ByteSizeOf(const Shape & shape) const310 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
311   return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
312 }
313 
314 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)315 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
316   if (ShapeUtil::IsScalar(shape)) {
317     return MinimumAlignmentForPrimitiveType(shape.element_type());
318   }
319 
320   int64 buffer_size = ByteSizeOf(shape);
321   DCHECK_GE(buffer_size, 0);
322   DCHECK_LE(buffer_size, SIZE_MAX);
323 
324   return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
325 }
326 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)327 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
328                                                const Shape& shape) {
329   int alignment = MinimumAlignmentForShape(shape);
330   if (alignment > 1) {
331     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
332   }
333 }
334 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)335 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
336                                                int64 buffer_size) {
337   int alignment =
338       target_machine_features_.minimum_alignment_for_allocation(buffer_size);
339   if (alignment > 1) {
340     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
341   }
342 }
343 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)344 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
345                                                      const Shape& shape) {
346   AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
347 }
348 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)349 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
350                                                      int64 buffer_size) {
351   if (buffer_size > 0) {
352     llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
353   }
354 }
355 
HandleGetTupleElement(HloInstruction * get_tuple_element)356 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
357   // A tuple is an array of pointers, one for each operand. Each pointer points
358   // to the output buffer of its corresponding operand. A GetTupleElement
359   // instruction forwards a pointer to the tuple element buffer at the given
360   // index.
361   const HloInstruction* operand = get_tuple_element->operand(0);
362   const Shape& shape = get_tuple_element->shape();
363   emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
364       shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
365       GetEmittedValueFor(operand), &b_);
366   return Status::OK();
367 }
368 
HandleSelect(HloInstruction * select)369 Status IrEmitter::HandleSelect(HloInstruction* select) {
370   auto pred = select->operand(0);
371   TF_RET_CHECK(pred->shape().element_type() == PRED);
372   return DefaultAction(select);
373 }
374 
HandleTupleSelect(HloInstruction * tuple_select)375 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
376   auto pred = tuple_select->operand(0);
377   auto on_true = tuple_select->operand(1);
378   auto on_false = tuple_select->operand(2);
379   TF_RET_CHECK(pred->shape().element_type() == PRED);
380   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
381   TF_RET_CHECK(tuple_select->shape().IsTuple());
382   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
383   llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
384                            GetEmittedValueFor(on_true),
385                            GetEmittedValueFor(on_false), &b_);
386   return Status::OK();
387 }
388 
HandleInfeed(HloInstruction * instruction)389 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
390   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
391   VLOG(2) << "HandleInfeed: " << infeed->ToString();
392 
393   // The infeed operation produces a two-element tuple containing data and a
394   // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
395   const Shape& data_shape = infeed->infeed_shape();
396   DCHECK(ShapeUtil::Equal(data_shape,
397                           ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
398   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
399 
400   // Write the tuple index table.
401   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
402                       assignment_.GetUniqueSlice(infeed, {0}));
403   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
404   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
405                       assignment_.GetUniqueSlice(infeed, {1}));
406   llvm::Value* token_address = EmitBufferPointer(
407       token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
408   llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
409 
410   if (data_shape.IsTuple()) {
411     TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
412 
413     // For a tuple, we first copy each of the internal elements to
414     // their corresponding target locations. We then construct the
415     // tuple outer buffer containing pointers to the internal
416     // elements.
417     std::vector<llvm::Value*> tuple_element_addresses;
418     for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
419       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
420                           assignment_.GetUniqueSlice(infeed, {0, i}));
421 
422       const Shape& tuple_element_shape =
423           ShapeUtil::GetTupleElementShape(data_shape, i);
424 
425       // Only the outer tuple buffer's target address is obtained from
426       // GetEmittedValueFor, to handle the case when Infeed is the root
427       // instruction. Target addresses for internal elements can be obtained
428       // from EmitBufferPointer.
429       llvm::Value* tuple_element_address =
430           EmitBufferPointer(buffer, tuple_element_shape);
431 
432       TF_RETURN_IF_ERROR(EmitXfeedTransfer(
433           XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
434 
435       tuple_element_addresses.push_back(tuple_element_address);
436     }
437 
438     llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
439                        tuple_element_addresses, &b_);
440   } else {
441     TF_RETURN_IF_ERROR(
442         EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
443   }
444 
445   return Status::OK();
446 }
447 
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)448 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
449                                     llvm::Value* program_buffer_address) {
450   int64 length = ByteSizeOf(shape);
451   if (length < 0 || length > std::numeric_limits<int32>::max()) {
452     return InvalidArgument(
453         "xfeed (infeed or outfeed) buffer length %d is outside the valid "
454         "size range",
455         length);
456   }
457   int32 length_32 = static_cast<int32>(length);
458 
459   int32 shape_length;
460   TF_ASSIGN_OR_RETURN(
461       llvm::Value * shape_ptr,
462       llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
463 
464   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
465 
466   const char* acquire_func_name =
467       kind == XfeedKind::kInfeed
468           ? runtime::kAcquireInfeedBufferForDequeueSymbolName
469           : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
470 
471   // Implementation note: this call informs the runtime that it wants a
472   // buffer of size exactly 'length_32', and the runtime is responsible for
473   // check-failing the process if there is a mismatch, versus passing us
474   // back a buffer that we might overrun.
475   llvm::Value* acquired_pointer =
476       EmitCallToFunc(acquire_func_name,
477                      {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
478                       shape_ptr, b_.getInt32(shape_length)},
479                      i8_ptr_type);
480   if (kind == XfeedKind::kInfeed) {
481     // Copy to the program buffer address from the acquired buffer.
482     MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
483            acquired_pointer,
484            /*SrcAlign=*/llvm::Align(1), length_32);
485   } else {
486     // Outfeed -- copy from the in-program address to the acquired buffer.
487     MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
488            program_buffer_address,
489            /*SrcAlign=*/llvm::Align(1), length_32);
490     if (emit_code_for_msan_) {
491       // Mark the outfed data as initialized for msan. The buffer gets read by
492       // the host code, which might be msan-instrumented.
493       // TODO(b/66051036): Run the msan instrumentation pass instead.
494       const llvm::DataLayout& dl = module_->getDataLayout();
495       llvm::Type* intptr_type = b_.getIntPtrTy(dl);
496       EmitCallToFunc(
497           "__msan_unpoison",
498           {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
499           b_.getVoidTy());
500     }
501   }
502 
503   const char* release_func_name =
504       kind == XfeedKind::kInfeed
505           ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
506           : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
507   EmitCallToFunc(release_func_name,
508                  {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
509                   acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
510                  b_.getVoidTy());
511 
512   return Status::OK();
513 }
514 
HandleOutfeed(HloInstruction * outfeed)515 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
516   // Outfeed produces no useful result, but it does return a token[] that can be
517   // threaded through to other side effecting operations to ensure ordering.  In
518   // the IR emitter we treat this token as a normal u8[] and thus need to insert
519   // an entry for it in emitted_value_.
520   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
521 
522   HloInstruction* operand = outfeed->operands()[0];
523   const Shape& operand_shape = operand->shape();
524 
525   llvm::Value* value = GetEmittedValueFor(operand);
526   if (!operand_shape.IsTuple()) {
527     return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
528   }
529 
530   TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
531 
532   for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
533     const Shape& tuple_element_shape =
534         ShapeUtil::GetTupleElementShape(operand_shape, i);
535     llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
536         tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
537         value, &b_);
538     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
539                                          tuple_element_shape, tuple_element));
540   }
541 
542   return Status::OK();
543 }
544 
HandleSort(HloInstruction * hlo)545 Status IrEmitter::HandleSort(HloInstruction* hlo) {
546   const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
547   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
548   Shape keys_shape = sort->keys()->shape();
549   PrimitiveType keys_type = keys_shape.element_type();
550   if (!primitive_util::IsArrayType(keys_type)) {
551     return Unimplemented("Element type %s not supported in the Sort op on CPU.",
552                          PrimitiveType_Name(keys_type));
553   }
554   std::vector<llvm::Value*> destination_addresses(sort->operand_count());
555   for (int64 i = 0; i < sort->operand_count(); ++i) {
556     ShapeIndex shape_index =
557         sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
558     const HloInstruction* operand = sort->operand(i);
559     // We assume that the layout of all involved operands and outputs is the
560     // same.
561     TF_RET_CHECK(
562         LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
563     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
564         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
565 
566     // The sort is implemented in-place, therefore we first copy the operand
567     // buffer to the output buffer if they are not the same.
568     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
569     destination_addresses[i] =
570         EmitBufferPointer(destination_buffer, operand->shape());
571     auto source_address = GetAllocationSlice(*operand);
572     if (destination_buffer != source_address) {
573       int64 primitive_type_size =
574           ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
575       auto source_buffer = GetEmittedValueFor(operand);
576       int64 size = ByteSizeOf(operand->shape());
577       MemCpy(destination_addresses[i],
578              /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
579              /*SrcAlign=*/llvm::Align(primitive_type_size), size);
580     }
581   }
582 
583   // Normalize the shape and the dimension to sort.
584   Shape normalized_keys_shape =
585       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
586   int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
587       keys_shape.layout())[sort->sort_dimension()];
588 
589   int64 sort_dimension_elements =
590       normalized_keys_shape.dimensions(physical_dimension_to_sort);
591   int64 higher_dimensions = 1;
592   for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
593     higher_dimensions *= normalized_keys_shape.dimensions(i);
594   }
595   int64 lower_dimensions = 1;
596   for (int64 i = normalized_keys_shape.rank() - 1;
597        i > physical_dimension_to_sort; --i) {
598     lower_dimensions *= normalized_keys_shape.dimensions(i);
599   }
600 
601   CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
602   llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
603       b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
604       &b_);
605   llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
606       b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
607       &b_);
608   for (int64 i = 0; i < sort->operand_count(); ++i) {
609     llvm::Value* value_as_i8ptr =
610         PointerCast(destination_addresses[i], b_.getInt8PtrTy());
611     llvm::Value* slot_in_values_alloca =
612         ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
613     Store(value_as_i8ptr, slot_in_values_alloca);
614     llvm::Value* slot_in_sizes_alloca =
615         ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
616     llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
617         sort->operand(i)->shape().element_type()));
618     Store(size, slot_in_sizes_alloca);
619   }
620 
621   auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
622   EmitCallToFunc(
623       runtime::kKeyValueSortSymbolName,
624       {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
625        b_.getInt64(lower_dimensions), values,
626        b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
627        GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
628        less_than_function},
629       b_.getVoidTy());
630 
631   if (sort->values_count() > 0) {
632     llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
633   }
634   return Status::OK();
635 }
636 
HandleTuple(HloInstruction * tuple)637 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
638   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
639   std::vector<llvm::Value*> base_ptrs;
640   for (auto operand : tuple->operands()) {
641     base_ptrs.push_back(GetEmittedValueFor(operand));
642   }
643   llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
644   return Status::OK();
645 }
646 
HandleReduceWindow(HloInstruction * reduce_window)647 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
648   // Pseudo code for reduce window:
649   //
650   //   for (coordinates O in the output)
651   //     value = init_value;
652   //     for (coordinates W in the window)
653   //       for each index i:
654   //         input coordinates I_i = O_i * stride_i + W_i - pad_low_i
655   //       if I within bounds of input:
656   //         value = function(value, input(I));
657   //     output(O) = value;
658   //
659   // This is completely un-optimized and just here to have something
660   // that works.
661   return DefaultAction(reduce_window);
662 }
663 
HandleSelectAndScatter(HloInstruction * select_and_scatter)664 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
665   CHECK_EQ(select_and_scatter->operand_count(), 3);
666   const auto operand = select_and_scatter->operand(0);
667   const auto source = select_and_scatter->operand(1);
668   const auto init_value = select_and_scatter->operand(2);
669   const Window& window = select_and_scatter->window();
670   PrimitiveType operand_element_type = operand->shape().element_type();
671   const int64 rank = operand->shape().rank();
672   CHECK_EQ(rank, source->shape().rank());
673   CHECK_EQ(rank, window.dimensions_size());
674 
675   // TODO(b/31410564): Implement dilation for select-and-scatter.
676   if (window_util::HasDilation(window)) {
677     return Unimplemented(
678         "Dilation for SelectAndScatter is not implemented on CPU. ");
679   }
680 
681   // Pseudo code for select-and-scatter:
682   //
683   // initialized_flag is initially off for every window, and is turned on after
684   // the first iteration is completed and the first operand value is selected.
685   //
686   // output(*) = init_value
687   // for (coordinates S in the source) {
688   //   initialized_flag = false
689   //   for (coordinates W in the window) {
690   //     I = S * stride + W - pad_low
691   //     if I within bounds of operand:
692   //       if !initialized_flag or select(selected_value, operand(I)) == false:
693   //         selected_value = operand(I)
694   //         selected_index = I
695   //         initialized_flag = true
696   //   }
697   //   output(selected_index) = scatter(output(selected_index), source(S))
698   // }
699   //
700 
701   // Initialize the output array with the given init_value.
702   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
703       select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
704       [this, init_value](const llvm_ir::IrArray::Index& target_index) {
705         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
706         return Load(init_value_addr);
707       }));
708 
709   // Create a loop to iterate over the source array to scatter to the output.
710   llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
711   const llvm_ir::IrArray::Index source_index =
712       source_loops.AddLoopsForShape(source->shape(), "source");
713   SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
714 
715   // Allocate space to keep the currently selected value, its index, and
716   // the boolean initialized_flag, which is initially set to false.
717   llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
718       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
719       "selected_value_address", &b_,
720       MinimumAlignmentForPrimitiveType(operand_element_type));
721   llvm::Value* selected_index_address =
722       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
723           b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
724   llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
725       b_.getInt1Ty(), "initialized_flag_address", &b_);
726   Store(b_.getInt1(false), initialized_flag_address);
727 
728   // Create the inner loop to iterate over the window.
729   llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
730   std::vector<int64> window_size;
731   for (const auto& dim : window.dimensions()) {
732     window_size.push_back(dim.size());
733   }
734   const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
735       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
736   SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
737 
738   // Compute the operand index to visit and evaluate the condition whether the
739   // operand index is within the bounds. The unsigned comparison includes
740   // checking whether the operand index >= 0.
741   std::vector<llvm::Value*> operand_multi_index(source_index.size());
742   llvm::Value* in_bounds_condition = b_.getTrue();
743   for (int64 i = 0; i < rank; ++i) {
744     llvm::Value* strided_index =
745         NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
746     operand_multi_index[i] =
747         NSWSub(NSWAdd(strided_index, window_index[i]),
748                b_.getInt64(window.dimensions(i).padding_low()));
749     llvm::Value* index_condition =
750         ICmpULT(operand_multi_index[i],
751                 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
752     in_bounds_condition = And(in_bounds_condition, index_condition);
753   }
754   CHECK(in_bounds_condition != nullptr);
755 
756   // Only need to do something if the operand index is within the bounds. First
757   // check if the initialized_flag is set.
758   llvm_ir::LlvmIfData if_in_bounds =
759       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
760   SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
761   llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
762       Load(initialized_flag_address), "initialized", &b_);
763 
764   // If the initialized_flag is false, initialize the selected value and index
765   // with the currently visiting operand.
766   SetToFirstInsertPoint(if_initialized.false_block, &b_);
767   const auto save_operand_index =
768       [&](const llvm_ir::IrArray::Index& operand_index) {
769         for (int64 i = 0; i < rank; ++i) {
770           llvm::Value* selected_index_address_slot =
771               InBoundsGEP(selected_index_address, {b_.getInt32(i)});
772           Store(operand_index[i], selected_index_address_slot);
773         }
774       };
775   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
776   llvm_ir::IrArray::Index operand_index(
777       operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
778   llvm::Value* operand_data =
779       operand_array.EmitReadArrayElement(operand_index, &b_);
780   Store(operand_data, selected_value_address);
781   save_operand_index(operand_index);
782   Store(b_.getInt1(true), initialized_flag_address);
783 
784   // If the initialized_flag is true, call the `select` function to potentially
785   // update the selected value and index with the currently visiting operand.
786   SetToFirstInsertPoint(if_initialized.true_block, &b_);
787   llvm::Value* operand_address =
788       operand_array.EmitArrayElementAddress(operand_index, &b_);
789   llvm::Value* operand_element = Load(operand_address);
790   llvm::Value* result = EmitScalarReturningThreadLocalCall(
791       *select_and_scatter->select(),
792       {Load(selected_value_address), operand_element}, "select_function");
793 
794   // If the 'select' function returns false, update the selected value and the
795   // index to the currently visiting operand.
796   llvm::Value* cond = ICmpNE(
797       result,
798       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
799       "boolean_predicate");
800   llvm_ir::LlvmIfData if_select_lhs =
801       llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
802   SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
803   Store(Load(operand_address), selected_value_address);
804   save_operand_index(operand_index);
805 
806   // After iterating over the window elements, scatter the source element to
807   // the selected index of the output. The value we store at the output
808   // location is computed by calling the `scatter` function with the source
809   // value and the current output value.
810   SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
811   std::vector<llvm::Value*> selected_multi_index;
812   for (int64 i = 0; i < rank; ++i) {
813     llvm::Value* selected_index_address_slot =
814         InBoundsGEP(selected_index_address, {b_.getInt32(i)});
815     selected_multi_index.push_back(Load(selected_index_address_slot));
816   }
817   llvm_ir::IrArray source_array(GetIrArrayFor(source));
818   llvm::Value* source_value =
819       source_array.EmitReadArrayElement(source_index, &b_);
820   llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
821   llvm_ir::IrArray::Index selected_index(
822       selected_multi_index, output_array.GetShape(), source_index.GetType());
823   llvm::Value* output_value =
824       output_array.EmitReadArrayElement(selected_index, &b_);
825   llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
826       *select_and_scatter->scatter(), {output_value, source_value},
827       "scatter_function");
828   output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
829 
830   SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
831   return Status::OK();
832 }
833 
HandleDot(HloInstruction * dot)834 Status IrEmitter::HandleDot(HloInstruction* dot) {
835   auto lhs = dot->operand(0);
836   auto rhs = dot->operand(1);
837   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
838       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
839       /*supported_types=*/
840       {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
841   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
842 
843   if (dnums.lhs_contracting_dimensions_size() != 1) {
844     // This is disallowed by ShapeInference today.
845     return Unimplemented(
846         "Dot with multiple contracting dimensions not implemented.");
847   }
848 
849   llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
850   llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
851 
852   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
853   llvm_ir::IrArray target_array = GetIrArrayFor(dot);
854 
855   VLOG(2) << "HandleDot: ";
856   VLOG(2) << "  lhs operand: "
857           << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
858   VLOG(2) << "  rhs operand: "
859           << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
860   VLOG(2) << "  target: "
861           << llvm_ir::DumpToString(*target_array.GetBasePointer());
862 
863   // Dot operation is complicated so we delegate to a helper class.
864   return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
865                           /*addend_array=*/nullptr,
866                           GetExecutableRunOptionsArgument(), &b_, mlir_context_,
867                           hlo_module_config_, target_machine_features_);
868 }
869 
HandleConvolution(HloInstruction * convolution)870 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
871   auto lhs = convolution->operand(0);
872   auto rhs = convolution->operand(1);
873   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
874       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
875       /*supported_types=*/{F16, F32, F64, C64, C128}));
876 
877   // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
878   // different data layouts.
879   if (PotentiallyImplementedAsEigenConvolution(*convolution,
880                                                target_machine_features_)) {
881     const Shape& lhs_shape = lhs->shape();
882     const Shape& rhs_shape = rhs->shape();
883     const Shape& convolution_shape = convolution->shape();
884     // The input, kernel and output agree with respect to layout.
885     if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
886         LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
887         LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
888       // We lower 1D convolutions into calls to the same Eigen function as 2D
889       // convolutions, except that we pretend that the 1D convolution is really
890       // a 2D convolution with the missing dimension set to 1.  We also adjust
891       // the padding, dilation parameters as needed.
892       bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
893       llvm::Value* lhs_address = GetEmittedValueFor(lhs);
894       llvm::Value* rhs_address = GetEmittedValueFor(rhs);
895       TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
896 
897       const ConvolutionDimensionNumbers& dnums =
898           convolution->convolution_dimension_numbers();
899 
900       // Input tensor.
901       const Shape& input_shape = convolution->operand(0)->shape();
902       int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
903       int64 input_rows =
904           input_shape.dimensions(dnums.input_spatial_dimensions(0));
905       int64 input_cols =
906           one_dim_convolution
907               ? 1
908               : input_shape.dimensions(dnums.input_spatial_dimensions(1));
909       int64 input_channels =
910           input_shape.dimensions(dnums.input_feature_dimension());
911 
912       // Kernel tensor.
913       const Shape& kernel_shape = convolution->operand(1)->shape();
914       int64 kernel_rows =
915           kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
916       int64 kernel_cols =
917           one_dim_convolution
918               ? 1
919               : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
920       int64 kernel_channels =
921           kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
922       int64 kernel_filters =
923           kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
924 
925       // Output tensor.
926       const Shape& convolution_shape = convolution->shape();
927       int64 output_rows =
928           convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
929       int64 output_cols = one_dim_convolution
930                               ? 1
931                               : convolution_shape.dimensions(
932                                     dnums.output_spatial_dimensions(1));
933 
934       // Extract the window stride for the convolution.
935       const Window& window = convolution->window();
936       int64 row_stride = window.dimensions(0).stride();
937       int64 col_stride =
938           one_dim_convolution ? 1 : window.dimensions(1).stride();
939 
940       int64 padding_top = window.dimensions(0).padding_low();
941       int64 padding_bottom = window.dimensions(0).padding_high();
942       int64 padding_left =
943           one_dim_convolution ? 0 : window.dimensions(1).padding_low();
944       int64 padding_right =
945           one_dim_convolution ? 0 : window.dimensions(1).padding_high();
946 
947       int64 lhs_row_dilation = window.dimensions(0).base_dilation();
948       int64 lhs_col_dilation =
949           one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
950       int64 rhs_row_dilation = window.dimensions(0).window_dilation();
951       int64 rhs_col_dilation =
952           one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
953 
954       PrimitiveType primitive_type = lhs->shape().element_type();
955       llvm::Type* ir_ptr_type = primitive_type == F16
956                                     ? b_.getHalfTy()->getPointerTo()
957                                     : b_.getFloatTy()->getPointerTo();
958       bool multi_threaded =
959           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
960       bool use_mkl_dnn =
961           hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
962 
963       // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
964       // potential race condition by setting the omp_num_threads.
965       const char* fn_name =
966           primitive_type == F16
967               ? (multi_threaded
968                      ? runtime::kEigenConvF16SymbolName
969                      : runtime::kEigenSingleThreadedConvF16SymbolName)
970               : (multi_threaded
971                      ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
972                                     : runtime::kEigenConvF32SymbolName)
973                      : runtime::kEigenSingleThreadedConvF32SymbolName);
974       if (!multi_threaded && use_mkl_dnn) {
975         LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
976                         "conv2d function.";
977       }
978       EmitCallToFunc(fn_name,
979                      {
980                          GetExecutableRunOptionsArgument(),
981                          BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
982                          BitCast(lhs_address, ir_ptr_type),
983                          BitCast(rhs_address, ir_ptr_type),
984                          b_.getInt64(input_batch),
985                          b_.getInt64(input_rows),
986                          b_.getInt64(input_cols),
987                          b_.getInt64(input_channels),
988                          b_.getInt64(kernel_rows),
989                          b_.getInt64(kernel_cols),
990                          b_.getInt64(kernel_channels),
991                          b_.getInt64(kernel_filters),
992                          b_.getInt64(output_rows),
993                          b_.getInt64(output_cols),
994                          b_.getInt64(row_stride),
995                          b_.getInt64(col_stride),
996                          b_.getInt64(padding_top),
997                          b_.getInt64(padding_bottom),
998                          b_.getInt64(padding_left),
999                          b_.getInt64(padding_right),
1000                          b_.getInt64(lhs_row_dilation),
1001                          b_.getInt64(lhs_col_dilation),
1002                          b_.getInt64(rhs_row_dilation),
1003                          b_.getInt64(rhs_col_dilation),
1004                      },
1005                      b_.getVoidTy(), /*does_not_throw=*/true,
1006                      /*only_accesses_arg_memory=*/true);
1007 
1008       return Status::OK();
1009     }
1010   }
1011 
1012   // This is a completely un-optimized version of convolution just to
1013   // have an early version that works. E.g. the input index and
1014   // padding calculation is not hoisted out of the inner loop.
1015   //
1016   // See the description of convolution in the XLA documentation for the pseudo
1017   // code for convolution.
1018   return DefaultAction(convolution);
1019 }
1020 
HandleFft(HloInstruction * fft)1021 Status IrEmitter::HandleFft(HloInstruction* fft) {
1022   auto operand = fft->operand(0);
1023   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1024       /*instruction=*/*fft, /*operands=*/{operand},
1025       /*supported_types=*/{F32, F64, C64, C128}));
1026   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1027   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1028   VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1029   VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1030 
1031   llvm::Value* operand_address = GetEmittedValueFor(operand);
1032   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1033 
1034   const std::vector<int64>& fft_length = fft->fft_length();
1035   int64 input_batch = 1;
1036   for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1037     input_batch *= fft->shape().dimensions(i);
1038   }
1039 
1040   // Args have been computed, make the call.
1041   llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1042   bool multi_threaded_eigen =
1043       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1044   const char* fn_name = multi_threaded_eigen
1045                             ? runtime::kEigenFftSymbolName
1046                             : runtime::kEigenSingleThreadedFftSymbolName;
1047   const int fft_rank = fft_length.size();
1048   EmitCallToFunc(
1049       fn_name,
1050       {GetExecutableRunOptionsArgument(),
1051        BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1052        BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1053        b_.getInt32(operand->shape().element_type() == F64 ||
1054                    operand->shape().element_type() == C128),
1055        b_.getInt32(fft_rank), b_.getInt64(input_batch),
1056        b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1057        b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1058        b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1059       b_.getVoidTy(), /*does_not_throw=*/true,
1060       /*only_accesses_arg_memory=*/false,
1061       /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1062 
1063   return Status::OK();
1064 }
1065 
HandleAllReduceSingleReplica(HloInstruction * crs)1066 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1067   // When there is a single replica, a cross replica sum is the identity
1068   // function, and the buffer assignment expects a copy.
1069   //
1070   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1071   // in algebraic-simplifier, but currently on some platforms
1072   // HloModuleConfig::num_replicas changes between when the module is compiled
1073   // and when it's run.
1074   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1075 
1076   // CRS with one operand and one replica is simply the identity function.
1077   if (crs->operand_count() == 1) {
1078     return EmitMemcpy(*crs->operand(0), *crs);
1079   }
1080 
1081   // CRS with multiple operands and one replica produces a (one-deep) tuple.
1082   std::vector<llvm::Value*> operand_ptrs;
1083   for (int64 i = 0; i < crs->operand_count(); ++i) {
1084     llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1085     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1086                         assignment_.GetUniqueSlice(crs, {i}));
1087 
1088     const Shape& operand_shape = crs->operand(i)->shape();
1089     CHECK(operand_shape.IsArray())
1090         << "Operands to all-reduce must be arrays: " << crs->ToString();
1091     operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1092 
1093     // TODO(b/63762267): Be more aggressive about specifying alignment.
1094     MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1095            /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1096   }
1097   llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1098   return Status::OK();
1099 }
1100 
HandleAllReduceMultipleReplica(HloInstruction * crs)1101 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1102   CHECK_GE(crs->operand_count(), 1);
1103   PrimitiveType datatype = crs->operand(0)->shape().element_type();
1104   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1105 
1106   bool is_datatype_supported = [&] {
1107     // TODO(cheshire): Fix duplication wrt. cpu_runtime
1108     switch (datatype) {
1109       case PRED:
1110       case S8:
1111       case U8:
1112       case S32:
1113       case U32:
1114       case S64:
1115       case U64:
1116       case F16:
1117       case F32:
1118       case F64:
1119         return true;
1120       default:
1121         return false;
1122     }
1123   }();
1124 
1125   if (!is_datatype_supported) {
1126     return Unimplemented("AllReduce for datatype '%s' is not supported",
1127                          primitive_util::LowercasePrimitiveTypeName(datatype));
1128   }
1129 
1130   if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1131     return Unimplemented("AllReduce for computation '%s' is not supported",
1132                          crs->to_apply()->ToString());
1133   }
1134 
1135   std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1136   int32 replica_groups_size = replica_groups.size();
1137   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1138 
1139   bool is_tuple = crs->operand_count() > 1;
1140   std::vector<llvm::Value*> input_buffer_ptrs;
1141   std::vector<llvm::Value*> output_buffer_ptrs;
1142   if (is_tuple) {
1143     CHECK(crs->shape().IsTuple());
1144 
1145     for (int64 i = 0; i < crs->operand_count(); i++) {
1146       const HloInstruction* op = crs->operand(i);
1147       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1148                           assignment_.GetUniqueSlice(crs, {i}));
1149       const Shape& operand_shape = crs->operand(i)->shape();
1150       CHECK(operand_shape.IsArray())
1151           << "Operands to all-reduce must be arrays: " << crs->ToString();
1152       output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1153       input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1154     }
1155   } else {
1156     Shape shape = crs->operand(0)->shape();
1157     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1158                         assignment_.GetUniqueSlice(crs->operand(0), {}));
1159     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1160                         assignment_.GetUniqueSlice(crs, {}));
1161     input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1162     output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1163   }
1164 
1165   llvm::Value* input_buffers =
1166       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1167   llvm::Value* output_buffers =
1168       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1169 
1170   int32 shape_length;
1171   TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1172                       llvm_ir::EncodeSelfDescribingShapeConstant(
1173                           crs->shape(), &shape_length, &b_));
1174 
1175   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1176   EmitCallToFunc(
1177       runtime::kAllReduceSymbolName,
1178       {/*run_options=*/GetExecutableRunOptionsArgument(),
1179        /*replica_groups=*/replica_groups_v,
1180        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1181 
1182        /*channel_id_present=*/
1183        b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1184        /*op_id=*/
1185        b_.getInt64(crs->channel_id().has_value()
1186                        ? *crs->channel_id()
1187                        : crs->GetModule()->unique_id()),
1188        /*reduction_kind=*/
1189        b_.getInt32(
1190            static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
1191        /*shape_ptr=*/shape_ptr,
1192        /*shape_length=*/b_.getInt32(shape_length),
1193        /*num_buffers=*/b_.getInt32(crs->operand_count()),
1194        /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1195        /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1196       b_.getVoidTy());
1197 
1198   return Status::OK();
1199 }
1200 
HandleAllReduce(HloInstruction * crs)1201 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1202   if (hlo_module_config_.replica_count() == 1) {
1203     return HandleAllReduceSingleReplica(crs);
1204   }
1205   return HandleAllReduceMultipleReplica(crs);
1206 }
1207 
HandleAllToAll(HloInstruction * instruction)1208 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1209   auto* instr = Cast<HloAllToAllInstruction>(instruction);
1210   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1211   CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1212       << "Only tuple AllToAll is supported";
1213 
1214   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1215   std::string replica_groups =
1216       ReplicaGroupsToString(instruction->replica_groups());
1217   int32 replica_groups_size = replica_groups.size();
1218   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1219 
1220   int64 buffer_size = -1;
1221   std::vector<llvm::Value*> input_buffer_ptrs;
1222   std::vector<llvm::Value*> output_buffer_ptrs;
1223 
1224   for (int64 i = 0; i < instruction->operand_count(); i++) {
1225     const HloInstruction* op = instruction->operand(i);
1226     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1227                         assignment_.GetUniqueSlice(instruction, {i}));
1228     const Shape& operand_shape = instruction->operand(i)->shape();
1229     CHECK(operand_shape.IsArray())
1230         << "Operands to all-to-all must be arrays: " << instruction->ToString();
1231     output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1232     input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1233     CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1234     buffer_size = out_slice.size();
1235   }
1236 
1237   llvm::Value* input_buffers =
1238       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1239   llvm::Value* output_buffers =
1240       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1241 
1242   EmitCallToFunc(
1243       runtime::kAllToAllSymbolName,
1244       {/*run_options=*/GetExecutableRunOptionsArgument(),
1245        /*channel_id_present=*/
1246        b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
1247        /*op_id=*/
1248        b_.getInt64(instruction->channel_id().has_value()
1249                        ? *instruction->channel_id()
1250                        : instruction->GetModule()->unique_id()),
1251        /*replica_groups=*/replica_groups_v,
1252        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1253        /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1254        /*buffer_size=*/b_.getInt64(buffer_size),
1255        /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1256        /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1257       b_.getVoidTy());
1258 
1259   llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1260   return Status::OK();
1261 }
1262 
HandleCollectivePermute(HloInstruction * crs)1263 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1264   auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1265   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1266   std::string source_target_pairs = absl::StrJoin(
1267       instr->source_target_pairs(), ",", absl::PairFormatter("="));
1268   llvm::Value* source_target_pairs_v =
1269       b_.CreateGlobalStringPtr(source_target_pairs);
1270 
1271   Shape shape = crs->operand(0)->shape();
1272 
1273   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1274                       assignment_.GetUniqueSlice(crs->operand(0), {}));
1275   llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1276 
1277   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1278                       assignment_.GetUniqueSlice(crs, {}));
1279   llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1280 
1281   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1282   EmitCallToFunc(
1283       runtime::kCollectivePermuteSymbolName,
1284       {/*run_options=*/GetExecutableRunOptionsArgument(),
1285        /*channel_id_present=*/
1286        b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1287        /*op_id=*/
1288        b_.getInt64(crs->channel_id().has_value()
1289                        ? *crs->channel_id()
1290                        : crs->GetModule()->unique_id()),
1291        /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1292        /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1293        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1294        /*source_target_pairs=*/source_target_pairs_v,
1295        /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1296       b_.getVoidTy());
1297 
1298   return Status::OK();
1299 }
1300 
HandleReplicaId(HloInstruction * hlo)1301 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1302   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1303   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1304                       assignment_.GetUniqueSlice(hlo, {}));
1305   llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1306   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1307   EmitCallToFunc(
1308       runtime::kReplicaIdSymbolName,
1309       {/*run_options=*/GetExecutableRunOptionsArgument(),
1310        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1311       b_.getVoidTy());
1312   return Status::OK();
1313 }
1314 
HandleParameter(HloInstruction * parameter)1315 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1316   VLOG(2) << "HandleParameter: " << parameter->ToString();
1317   return EmitTargetAddressForOp(parameter);
1318 }
1319 
1320 // Returns true if the relative order of the unreduced dimensions stays the same
1321 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1322 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1323   DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1324 
1325   // Maps dimensions that were not reduced from their dimension numbers in the
1326   // source shape to their dimensions numbers in the destination shape.
1327   //
1328   // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1329   // [0->0, 3->1].
1330   absl::flat_hash_map<int64, int64> unreduced_dim_map;
1331 
1332   absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1333                                           reduce.dimensions().end());
1334 
1335   const Shape& operand_shape = reduce.operand(0)->shape();
1336   const Shape& result_shape = reduce.shape();
1337 
1338   int64 delta = 0;
1339   for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
1340     if (reduced_dims.contains(i)) {
1341       delta++;
1342     } else {
1343       InsertOrDie(&unreduced_dim_map, i, i - delta);
1344     }
1345   }
1346 
1347   // Iterate dimensions minor to major and check that the corresponding
1348   // dimensions in the source and target shapes are equivalent.
1349   int64 result_dim_idx = 0;
1350   for (int64 operand_dim_idx = 0;
1351        operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1352     int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
1353     if (!reduced_dims.contains(operand_dim)) {
1354       if (FindOrDie(unreduced_dim_map, operand_dim) !=
1355           result_shape.layout().minor_to_major(result_dim_idx++)) {
1356         return false;
1357       }
1358     }
1359   }
1360 
1361   CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1362 
1363   return true;
1364 }
1365 
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1366 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1367     HloComputation* function, string* failure_reason) const {
1368   CHECK_EQ(function->num_parameters(), 2);
1369 
1370   auto root_instruction = function->root_instruction();
1371   CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1372 
1373   if (root_instruction->operand_count() != 2) {
1374     *failure_reason = "root instruction is not a binary operation";
1375     return nullptr;
1376   }
1377 
1378   const Shape& root_shape = root_instruction->shape();
1379   if (ShapeUtil::ElementIsComplex(root_shape)) {
1380     // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1381     // Complex multiply would be more challenging. We could perhaps use a
1382     // strided load to get all reals in a vector, all images in a vector, or use
1383     // CreateShuffleVector on a bitcast to float x [2N].
1384     *failure_reason = "complex values not supported";
1385     return nullptr;
1386   }
1387   bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1388   bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1389   bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1390 
1391   auto lhs = root_instruction->operand(0);
1392   auto rhs = root_instruction->operand(1);
1393 
1394   auto param_0 = function->parameter_instruction(0);
1395   auto param_1 = function->parameter_instruction(1);
1396   if (!(lhs == param_0 && rhs == param_1) &&
1397       !(rhs == param_0 && lhs == param_1)) {
1398     *failure_reason =
1399         "root instruction is not a binary operation on the incoming arguments";
1400     return nullptr;
1401   }
1402 
1403   CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1404 
1405   // This is visually similar to ElementalIrEmitter, though conceptually we're
1406   // doing something different here.  ElementalIrEmitter emits scalar operations
1407   // while these emit scalar or vector operations depending on the type of the
1408   // operands. See CreateShardedVectorType for the actual types in use here.
1409   switch (root_instruction->opcode()) {
1410     default:
1411       *failure_reason = "did not recognize root instruction opcode";
1412       return nullptr;
1413 
1414     case HloOpcode::kAdd:
1415       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1416                                 llvm::Value* rhs) {
1417         return root_is_integral ? b->CreateAdd(lhs, rhs)
1418                                 : b->CreateFAdd(lhs, rhs);
1419       };
1420 
1421     case HloOpcode::kMultiply:
1422       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1423                                 llvm::Value* rhs) {
1424         return root_is_integral ? b->CreateMul(lhs, rhs)
1425                                 : b->CreateFMul(lhs, rhs);
1426       };
1427 
1428     case HloOpcode::kAnd:
1429       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1430         return b->CreateAnd(lhs, rhs);
1431       };
1432 
1433     case HloOpcode::kOr:
1434       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1435         return b->CreateOr(lhs, rhs);
1436       };
1437 
1438     case HloOpcode::kXor:
1439       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1440         return b->CreateXor(lhs, rhs);
1441       };
1442 
1443     case HloOpcode::kMaximum:
1444       return [root_is_floating_point, root_is_signed](
1445                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1446                  llvm::Value* rhs) -> llvm::Value* {
1447         if (root_is_floating_point) {
1448           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
1449                                               {lhs, rhs}, {lhs->getType()}, b);
1450         }
1451 
1452         return b->CreateSelect(
1453             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1454                                          : llvm::ICmpInst::ICMP_UGE,
1455                           lhs, rhs),
1456             lhs, rhs);
1457       };
1458 
1459     case HloOpcode::kMinimum:
1460       return [root_is_floating_point, root_is_signed](
1461                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1462                  llvm::Value* rhs) -> llvm::Value* {
1463         if (root_is_floating_point) {
1464           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
1465                                               {lhs, rhs}, {lhs->getType()}, b);
1466         }
1467 
1468         return b->CreateSelect(
1469             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1470                                          : llvm::ICmpInst::ICMP_ULE,
1471                           lhs, rhs),
1472             lhs, rhs);
1473       };
1474   }
1475 }
1476 
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1477 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1478     PrimitiveType element_type, unsigned element_count) {
1479   int vector_register_size_in_elements =
1480       target_machine_features_.vector_register_byte_size(
1481           *compute_function_->function()) /
1482       ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1483 
1484   ShardedVectorType sharded_vector_type;
1485   llvm::Type* element_ir_type =
1486       llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1487 
1488   for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1489     // For every power of two present in element_count, we generate one or more
1490     // vector or scalar types.
1491     const unsigned current_size_fragment = 1u << i;
1492     if (!(element_count & current_size_fragment)) {
1493       // Power of two not present in element_count.
1494       continue;
1495     }
1496 
1497     if (current_size_fragment == 1) {
1498       // Single element, use a scalar type.
1499       sharded_vector_type.push_back(element_ir_type);
1500       continue;
1501     }
1502 
1503     // Lower "current_size_fragment" number of elements using (as few as
1504     // possible) vector registers.
1505 
1506     if (current_size_fragment >= vector_register_size_in_elements) {
1507       auto vector_type = llvm::VectorType::get(
1508           element_ir_type, vector_register_size_in_elements, false);
1509       sharded_vector_type.insert(
1510           sharded_vector_type.end(),
1511           current_size_fragment / vector_register_size_in_elements,
1512           vector_type);
1513 
1514       // Both current_size_fragment and vector_register_size_in_elements are
1515       // powers of two.
1516       CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1517       continue;
1518     }
1519 
1520     // For now we assume that vector_register_size_in_elements and lower powers
1521     // of two are all legal vector sizes (or at least can be lowered easily by
1522     // LLVM).
1523     sharded_vector_type.push_back(
1524         llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1525   }
1526   return sharded_vector_type;
1527 }
1528 
1529 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,llvm::Align element_alignment)1530 IrEmitter::EmitInnerLoopForVectorizedReduction(
1531     const ReductionGenerator& reduction_generator,
1532     const llvm_ir::IrArray::Index& output_index,
1533     const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1534     HloInstruction* arg, absl::Span<const int64> dimensions,
1535     llvm::Align element_alignment) {
1536   ShardedVector accumulator;
1537   accumulator.reserve(accumulator_type.size());
1538   for (auto accumulator_shard_type : accumulator_type) {
1539     accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1540         accumulator_shard_type, "accumulator", &b_, 0));
1541   }
1542 
1543   llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1544 
1545   for (llvm::Value* accumulator_shard : accumulator) {
1546     llvm::Value* initial_value;
1547     auto shard_type = accumulator_shard->getType()->getPointerElementType();
1548     if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1549       initial_value =
1550           VectorSplat(vector_type->getElementCount(), init_value_ssa);
1551     } else {
1552       initial_value = init_value_ssa;
1553     }
1554 
1555     AlignedStore(initial_value, accumulator_shard, element_alignment);
1556   }
1557 
1558   llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1559                                            &b_);
1560   std::vector<llvm::Value*> input_multi_index =
1561       reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1562                                                        "reduction_dim");
1563 
1564   SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1565 
1566   llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1567   llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1568 
1569   for (auto& i : input_multi_index) {
1570     if (i == nullptr) {
1571       i = *it++;
1572     }
1573   }
1574   CHECK(output_index.end() == it);
1575   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1576                                       b_.getInt64Ty());
1577 
1578   llvm::Value* input_address = BitCast(
1579       arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1580 
1581   for (int i = 0; i < accumulator.size(); i++) {
1582     auto input_address_typed =
1583         BitCast(input_address, accumulator[i]->getType());
1584     auto current_accumulator_value =
1585         AlignedLoad(accumulator[i], element_alignment);
1586     auto addend = AlignedLoad(input_address_typed, element_alignment);
1587     arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1588 
1589     auto reduced_result =
1590         reduction_generator(&b_, current_accumulator_value, addend);
1591     AlignedStore(reduced_result, accumulator[i], element_alignment);
1592 
1593     if (i != (accumulator.size() - 1)) {
1594       input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1595                                            input_address_typed, 1);
1596     }
1597   }
1598 
1599   SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1600 
1601   ShardedVector result_ssa;
1602   result_ssa.reserve(accumulator.size());
1603   for (auto accumulator_shard : accumulator) {
1604     result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1605   }
1606   return result_ssa;
1607 }
1608 
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1609 void IrEmitter::EmitShardedVectorStore(
1610     llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1611     llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1612   for (int i = 0; i < value_to_store.size(); i++) {
1613     auto store_address_typed =
1614         BitCast(store_address,
1615                 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1616 
1617     auto store_instruction =
1618         AlignedStore(value_to_store[i], store_address_typed, alignment);
1619     containing_array.AnnotateLoadStoreInstructionWithMetadata(
1620         store_instruction);
1621 
1622     if (i != (value_to_store.size() - 1)) {
1623       store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1624                                            store_address_typed, 1);
1625     }
1626   }
1627 }
1628 
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1629 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1630     HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1631     absl::Span<const int64> dimensions, HloComputation* function,
1632     string* failure_reason) {
1633   if (!reduce->shape().IsArray()) {
1634     *failure_reason = "vectorization of variadic reduce not implemented";
1635     return false;
1636   }
1637 
1638   if (!ReductionPreservesLayout(*reduce)) {
1639     return false;
1640   }
1641 
1642   ReductionGenerator reduction_generator =
1643       MatchReductionGenerator(function, failure_reason);
1644   if (!reduction_generator) {
1645     return false;
1646   }
1647 
1648   int vector_register_size_in_elements =
1649       target_machine_features_.vector_register_byte_size(
1650           *compute_function_->function()) /
1651       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1652   if (vector_register_size_in_elements == 0) {
1653     // Either we don't know the vector register width for the target or the
1654     // vector register is smaller than the size of the primitive type.
1655     return false;
1656   }
1657 
1658   int vectorization_factor_in_bytes =
1659       target_machine_features_.vectorization_factor_in_bytes();
1660 
1661   // We try to process vectorization_factor elements at the same time.
1662   const int vectorization_factor =
1663       vectorization_factor_in_bytes /
1664       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1665 
1666   bool is_reduction_over_minor_dimension = absl::c_linear_search(
1667       dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1668 
1669   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1670       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1671       MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1672 
1673   if (is_reduction_over_minor_dimension) {
1674     // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1675     *failure_reason = "reduction over minor dimension not implemented";
1676     return false;
1677   }
1678 
1679   CHECK(!reduce->shape().IsTuple());
1680   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1681 
1682   // We know we're not reducing over the most minor dimension, which means we
1683   // can lower the reduction loop as:
1684   //
1685   //  1. We're reducing over dimensions R0, R1.
1686   //  2. D0 is the most minor dimension.
1687   //  3. VS is the vectorization stride (we want to reduce this many elements at
1688   //     once)
1689   //
1690   //  for (d1 in D1) {
1691   //    for (d0 in D0 with stride VS) {
1692   //      vector_acc = init
1693   //      for (r1 in R1) {
1694   //        for (r0 in R0) {
1695   //          vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1696   //        }
1697   //      }
1698   //      output[d1, d0] = vector_acc
1699   //    }
1700   //  }
1701 
1702   llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1703   std::vector<llvm::Value*> array_multi_index(
1704       reduce->shape().dimensions_size());
1705   for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1706        --i) {
1707     int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1708     int64 start_index = 0;
1709     int64 end_index = reduce->shape().dimensions(dimension);
1710     std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1711         start_index, end_index, absl::StrFormat("dim.%d", dimension));
1712     array_multi_index[dimension] = loop->GetIndVarValue();
1713   }
1714 
1715   int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1716   int64 innermost_dimension_size =
1717       reduce->shape().dimensions(innermost_dimension);
1718 
1719   if (llvm::BasicBlock* innermost_body_bb =
1720           loop_nest.GetInnerLoopBodyBasicBlock()) {
1721     SetToFirstInsertPoint(innermost_body_bb, &b_);
1722   }
1723 
1724   auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1725 
1726   if (innermost_dimension_size >= vectorization_factor) {
1727     int64 start_index = 0;
1728     int64 end_index = (innermost_dimension_size / vectorization_factor) *
1729                       vectorization_factor;
1730     std::unique_ptr<llvm_ir::ForLoop> loop =
1731         loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1732                           absl::StrFormat("dim.%d", innermost_dimension));
1733     array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1734 
1735     SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1736 
1737     ShardedVectorType vector_type = CreateShardedVectorType(
1738         reduce->shape().element_type(), vectorization_factor);
1739     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1740                                         b_.getInt64Ty());
1741     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1742                         EmitInnerLoopForVectorizedReduction(
1743                             reduction_generator, array_index, vector_type,
1744                             init_value, arg, dimensions, element_alignment));
1745 
1746     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1747     llvm::Value* output_address =
1748         target_array.EmitArrayElementAddress(array_index, &b_);
1749     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1750                            target_array);
1751 
1752     if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1753       CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1754       b_.SetInsertPoint(exit_terminator);
1755     } else {
1756       CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1757       b_.SetInsertPoint(loop->GetExitBasicBlock());
1758     }
1759   }
1760 
1761   // Since we increment the stride for the inner dimension by more than 1, we
1762   // may need to peel out an "epilogue" iteration to get the remaining elements
1763   // in the following case:
1764   if (innermost_dimension_size % vectorization_factor) {
1765     // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1766     array_multi_index[innermost_dimension] =
1767         b_.getInt64(innermost_dimension_size -
1768                     (innermost_dimension_size % vectorization_factor));
1769 
1770     ShardedVectorType vector_type = CreateShardedVectorType(
1771         reduce->shape().element_type(),
1772         innermost_dimension_size % vectorization_factor);
1773     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1774                                         b_.getInt64Ty());
1775     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1776                         EmitInnerLoopForVectorizedReduction(
1777                             reduction_generator, array_index, vector_type,
1778                             init_value, arg, dimensions, element_alignment));
1779 
1780     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1781     llvm::Value* output_address =
1782         target_array.EmitArrayElementAddress(array_index, &b_);
1783     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1784                            target_array);
1785   }
1786 
1787   if (outermost_loop_exit_block) {
1788     b_.SetInsertPoint(outermost_loop_exit_block);
1789   }
1790 
1791   return true;
1792 }
1793 
HandleReduce(HloInstruction * reduce)1794 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1795   auto arg = reduce->mutable_operand(0);
1796   auto init_value = reduce->mutable_operand(1);
1797   absl::Span<const int64> dimensions(reduce->dimensions());
1798   HloComputation* function = reduce->to_apply();
1799   if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1800     string vectorization_failure_reason;
1801     TF_ASSIGN_OR_RETURN(
1802         bool vectorization_successful,
1803         EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1804                              &vectorization_failure_reason));
1805     if (vectorization_successful) {
1806       VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1807               << "\n";
1808       return Status::OK();
1809     } else {
1810       VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1811               << vectorization_failure_reason;
1812     }
1813   }
1814 
1815   return DefaultAction(reduce);
1816 }
1817 
HandleSend(HloInstruction * send)1818 Status IrEmitter::HandleSend(HloInstruction* send) {
1819   // TODO(b/33942983): Support Send/Recv on CPU.
1820   return Unimplemented("Send is not implemented on CPU.");
1821 }
1822 
HandleSendDone(HloInstruction * send_done)1823 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1824   // TODO(b/33942983): Support Send/Recv on CPU.
1825   return Unimplemented("Send-done is not implemented on CPU.");
1826 }
1827 
HandleScatter(HloInstruction *)1828 Status IrEmitter::HandleScatter(HloInstruction*) {
1829   return Unimplemented("Scatter is not implemented on CPUs.");
1830 }
1831 
HandleSlice(HloInstruction * slice)1832 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1833   VLOG(2) << "HandleSlice: " << slice->ToString();
1834   auto operand = slice->operand(0);
1835   // The code below emits a sequential loop nest. For the parallel backend, use
1836   // ParallelLoopEmitter which respects dynamic loop bounds.
1837   if (ShouldEmitParallelLoopFor(*slice)) {
1838     return DefaultAction(slice);
1839   }
1840 
1841   // The code below assumes the layouts are equal.
1842   if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1843     return DefaultAction(slice);
1844   }
1845 
1846   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1847 
1848   if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1849     return Status::OK();
1850   }
1851 
1852   const Layout& layout = operand->shape().layout();
1853   const int64 num_dims = operand->shape().dimensions_size();
1854 
1855   // The slice lowering finds maximal contiguous blocks of memory that can be
1856   // copied from the source to the target. This is done by looking at the
1857   // source/target layout in minor to major order and do the following:
1858   //
1859   // * Find an initial segment of dimensions along which the slice uses the
1860   //   whole dimension. These are the "inner" dimensions and can be folded into
1861   //   the memcpy.
1862   //
1863   // * Of the remaining dimensions decide which ones require loops.
1864   //
1865   // * Implement the memcpy within the innermost loop.
1866 
1867   absl::flat_hash_set<int64> inner_dims;
1868   for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
1869     if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1870       break;
1871     }
1872     inner_dims.insert(dim);
1873   }
1874 
1875   const bool is_trivial_copy = (inner_dims.size() == num_dims);
1876   if (is_trivial_copy) {
1877     if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1878       return DefaultAction(slice);
1879     } else {
1880       return EmitMemcpy(*slice, *operand);
1881     }
1882   }
1883 
1884   // The memcpy will copy elements that are logically this shape (allowed to be
1885   // scalar).
1886   const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1887       [&inner_dims](int64 dim) { return inner_dims.contains(dim); },
1888       operand->shape());
1889 
1890   const int64 primitive_elements_per_logical_element =
1891       ShapeUtil::ElementsIn(logical_element_shape);
1892 
1893   // memcpy_dim is the innermost (in terms of layout) dimension for which the
1894   // slice does *not* just copy all the elements along the dimension.
1895   const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1896 
1897   const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1898   // The number of logical elements that can be copied in a single call
1899   // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1900   // stride.
1901   const int64 memcpy_logical_elements =
1902       memcpy_is_contiguous
1903           ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1904           : 1;
1905 
1906   // Determine the dimensions that get lowered as loops.
1907   std::vector<int64> outer_dims;
1908   for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1909     outer_dims.push_back(LayoutUtil::Major(layout, i));
1910   }
1911 
1912   // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1913   // needs to be wrapped around a loop as well.
1914   if (!memcpy_is_contiguous) {
1915     outer_dims.push_back(memcpy_dim);
1916   }
1917 
1918   llvm_ir::IrArray target_array = GetIrArrayFor(slice);
1919 
1920   const int64 num_outer_loops = outer_dims.size();
1921   llvm_ir::ForLoopNest loops(IrName(slice), &b_);
1922   std::vector<llvm::Value*> target_multi_index =
1923       loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
1924 
1925   // Only the indices for the outer dimensions have been initialized in
1926   // target_index. The rest of the indices should get initialized to 0, since
1927   // for the rest of the dimensions the copy writes to the full dimension.
1928   std::replace(target_multi_index.begin(), target_multi_index.end(),
1929                static_cast<llvm::Value*>(nullptr),
1930                static_cast<llvm::Value*>(b_.getInt64(0)));
1931   llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
1932                                        b_.getInt64Ty());
1933 
1934   if (num_outer_loops > 0) {
1935     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1936   }
1937 
1938   llvm_ir::IrArray source_array = GetIrArrayFor(operand);
1939   const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
1940       /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
1941       /*strides=*/slice->slice_strides(), /*builder=*/&b_);
1942 
1943   llvm::Value* memcpy_dest =
1944       target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
1945   llvm::Value* memcpy_source =
1946       source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
1947 
1948   const int64 memcpy_elements =
1949       primitive_elements_per_logical_element * memcpy_logical_elements;
1950 
1951   EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
1952                        slice->shape().element_type(), target_array,
1953                        source_array);
1954 
1955   if (VLOG_IS_ON(2)) {
1956     const int64 memcpy_bytes =
1957         ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
1958     VLOG(2) << "  emitted copy of " << memcpy_bytes << " bytes inside "
1959             << num_outer_loops << " loops";
1960   }
1961 
1962   if (num_outer_loops > 0) {
1963     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1964   }
1965 
1966   return Status::OK();
1967 }
1968 
HandleDynamicSlice(HloInstruction * dynamic_slice)1969 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1970   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
1971     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
1972     return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
1973   }
1974   return DefaultAction(dynamic_slice);
1975 }
1976 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1977 Status IrEmitter::HandleDynamicUpdateSlice(
1978     HloInstruction* dynamic_update_slice) {
1979   auto update = dynamic_update_slice->operand(1);
1980   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
1981     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1982     return EmitMemcpy(*update, *dynamic_update_slice);
1983   } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
1984                                                    assignment_)) {
1985     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1986     auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
1987     return llvm_ir::EmitDynamicUpdateSliceInPlace(
1988         operands, GetIrArrayFor(dynamic_update_slice),
1989         IrName(dynamic_update_slice, "in_place"), &b_);
1990   }
1991   return DefaultAction(dynamic_update_slice);
1992 }
1993 
HandleRecv(HloInstruction * recv)1994 Status IrEmitter::HandleRecv(HloInstruction* recv) {
1995   // TODO(b/33942983): Support Send/Recv on CPU.
1996   return Unimplemented("Recv is not implemented on CPU.");
1997 }
1998 
HandleRecvDone(HloInstruction * recv_done)1999 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2000   // TODO(b/33942983): Support Send/Recv on CPU.
2001   return Unimplemented("Recv-done is not implemented on CPU.");
2002 }
2003 
HandlePad(HloInstruction * pad)2004 Status IrEmitter::HandlePad(HloInstruction* pad) {
2005   // CPU backend does not properly handle negative padding but this is ok
2006   // because negative padding should be removed by the algebraic simplifier.
2007   for (auto& padding_dimension : pad->padding_config().dimensions()) {
2008     if (padding_dimension.edge_padding_low() < 0 ||
2009         padding_dimension.edge_padding_high() < 0) {
2010       return InternalErrorStrCat(
2011           "Encountered negative padding in IrEmitter on CPU. "
2012           "This should have been eliminated at the HLO level. ",
2013           pad->ToString());
2014     }
2015   }
2016 
2017   // First, fill in the padding value to all output elements.
2018   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2019       pad, "initialize",
2020       [this, pad](const llvm_ir::IrArray::Index& target_index) {
2021         const HloInstruction* padding_value = pad->operand(1);
2022         llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2023         return Load(padding_value_addr);
2024       }));
2025 
2026   // Create a loop to iterate over the operand elements and update the output
2027   // locations where the operand elements should be stored.
2028   llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2029   const HloInstruction* operand = pad->operand(0);
2030   const llvm_ir::IrArray::Index operand_index =
2031       loops.AddLoopsForShape(operand->shape(), "operand");
2032 
2033   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2034 
2035   // Load an element from the operand.
2036   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2037   llvm::Value* operand_data =
2038       operand_array.EmitReadArrayElement(operand_index, &b_);
2039 
2040   // Compute the output index the operand element should be assigned to.
2041   // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2042   const PaddingConfig& padding_config = pad->padding_config();
2043   std::vector<llvm::Value*> output_multi_index;
2044   for (size_t i = 0; i < operand_index.size(); ++i) {
2045     llvm::Value* offset =
2046         Mul(operand_index[i],
2047             b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2048     llvm::Value* index = Add(
2049         offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2050     output_multi_index.push_back(index);
2051   }
2052 
2053   // Store the operand element to the computed output location.
2054   llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2055   llvm_ir::IrArray::Index output_index(
2056       output_multi_index, output_array.GetShape(), operand_index.GetType());
2057   output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2058 
2059   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2060   return Status::OK();
2061 }
2062 
HandleFusion(HloInstruction * fusion)2063 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2064   auto* root = fusion->fused_expression_root();
2065   if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2066     VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2067     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2068     FusedIrEmitter fused_emitter(&elemental_emitter);
2069     BindFusionArguments(fusion, &fused_emitter);
2070 
2071     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2072     // Delegate to common implementation of fused in-place dynamic-update-slice.
2073     return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2074         fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2075   } else if (fusion->IsLoopFusion()) {
2076     VLOG(3) << "HandleFusion kLoop";
2077     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2078     FusedIrEmitter fused_emitter(&elemental_emitter);
2079     BindFusionArguments(fusion, &fused_emitter);
2080     TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2081                                             fusion->fused_expression_root()));
2082     return EmitTargetElementLoop(fusion, generator);
2083   } else if (fusion->IsOutputFusion()) {
2084     VLOG(3) << "HandleFusion kOutput";
2085     int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2086     const HloInstruction* dot = root->operand(dot_op_index);
2087     CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2088         << dot->ToString() << "  "
2089         << fusion->fused_instructions_computation()->ToString();
2090 
2091     int64 dot_lhs_param_number = dot->operand(0)->parameter_number();
2092     int64 dot_rhs_param_number = dot->operand(1)->parameter_number();
2093     int64 addend_param_number =
2094         root->operand(1 - dot_op_index)->parameter_number();
2095 
2096     Shape target_shape = fusion->shape();
2097     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2098     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2099 
2100     llvm_ir::IrArray lhs_array(
2101         GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2102     llvm_ir::IrArray rhs_array(
2103         GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2104     llvm_ir::IrArray addend_array(
2105         GetIrArrayFor(fusion->operand(addend_param_number)));
2106 
2107     TF_RETURN_IF_ERROR(EmitDotOperation(
2108         *dot, target_array, lhs_array, rhs_array, &addend_array,
2109         GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2110         hlo_module_config_, target_machine_features_));
2111     return Status::OK();
2112   } else {
2113     return Unimplemented("Fusion kind not implemented on CPU");
2114   }
2115 }
2116 
HandleCall(HloInstruction * call)2117 Status IrEmitter::HandleCall(HloInstruction* call) {
2118   HloComputation* computation = call->to_apply();
2119   llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2120 
2121   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2122 
2123   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2124     // ParallelTaskAssignment assigned partitions, emit call to
2125     // ParallelForkJoin.
2126     std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2127         {}, &b_, computation->name(),
2128         /*return_value_buffer=*/emitted_value_[call],
2129         /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2130         /*buffer_table_arg=*/GetBufferTableArgument(),
2131         /*profile_counters_arg=*/GetProfileCountersArgument());
2132 
2133     HloInstruction* root = computation->root_instruction();
2134     TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2135         call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2136         call_ir_function, computation->name()));
2137   } else {
2138     EmitGlobalCall(*computation, computation->name());
2139   }
2140 
2141   return Status::OK();
2142 }
2143 
HandleSliceToDynamic(HloInstruction * hlo)2144 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2145   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2146   std::vector<llvm::Value*> dynamic_dims;
2147   int32 raw_data_size =
2148       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2149   llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2150   llvm::Value* raw_buffer =
2151       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2152   for (int64 i = 1; i < hlo->operand_count(); ++i) {
2153     const int64 dim_index = i - 1;
2154     llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2155     llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
2156 
2157     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2158         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2159     b_.CreateStore(dyn_dim_size,
2160                    b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2161     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2162                                             /*isSigned=*/true,
2163                                             "i64_dyn_dim_size"));
2164   }
2165 
2166   llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2167   // Pseudo code for sliceToDynamic:
2168   //
2169   //   for (index i in dynamic_dim)
2170   //     dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2171   //     dest[dest_index] = source[i]
2172   auto loop_body_emitter =
2173       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2174     llvm::Value* source_element =
2175         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2176     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2177     // Delinearize the index based on the static shape.
2178     llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2179                                        &b_);
2180     data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2181     return Status::OK();
2182   };
2183   return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2184                               dynamic_dims, &b_)
2185       .EmitLoop(IrName(hlo));
2186 }
2187 
HandlePadToStatic(HloInstruction * hlo)2188 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2189   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2190 
2191   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2192                       assignment_.GetUniqueSlice(hlo, {0}));
2193   std::vector<llvm::Value*> dynamic_dims;
2194   std::vector<llvm::Value*> tuple_operand_ptrs;
2195   const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2196   const Shape& input_shape = hlo->operand(0)->shape();
2197   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2198   llvm_ir::IrArray data_array(data_address, data_shape);
2199   llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2200   llvm::Value* raw_buffer =
2201       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2202   int64 raw_data_size =
2203       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2204 
2205   // Put a placeholder for the data array's pointer
2206   tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2207   // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2208   // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2209   // Dynamic dimension sizes starts from output index 1.
2210   for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2211     // Read from the metadata section of the dynamic input (operand 0).
2212     const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2213     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2214     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2215                         assignment_.GetUniqueSlice(hlo, {i}));
2216     llvm::Value* dest_dim_size_address =
2217         EmitBufferPointer(dim_size_slice, data_shape);
2218     const int64 dim_index = i - 1;
2219     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2220         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2221     llvm::Value* dyn_dim_size = b_.CreateLoad(
2222         b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2223         "dyn_dim_size");
2224     b_.CreateStore(dyn_dim_size,
2225                    b_.CreateBitCast(dest_dim_size_address,
2226                                     b_.getInt32Ty()->getPointerTo()));
2227     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2228                                             /*isSigned=*/true,
2229                                             "i64_dyn_dim_size"));
2230     tuple_operand_ptrs.push_back(dest_dim_size_address);
2231   }
2232 
2233   // Pseudo code for padToStatic:
2234   //
2235   //   for (index i in dynamic_dim)
2236   //     source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2237   //     dest[i] = source[source_index]
2238   auto loop_body_emitter =
2239       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2240     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2241     llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2242     llvm::Value* source_element =
2243         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2244     data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2245     return Status::OK();
2246   };
2247   TF_RETURN_IF_ERROR(
2248       llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2249           .EmitLoop(IrName(hlo)));
2250 
2251   // Emit static tensor and dynamic sizes as one tuple.
2252   llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2253   return Status::OK();
2254 }
2255 
HandleTopK(HloInstruction * hlo)2256 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2257   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2258   const HloInstruction* input = hlo->operand(0);
2259   const int64 k = hlo->shape().tuple_shapes(0).dimensions().back();
2260   const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2261   TF_RET_CHECK(input->shape().element_type() == F32);
2262   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2263       hlo->shape().tuple_shapes(0).layout()));
2264   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2265       hlo->shape().tuple_shapes(1).layout()));
2266   TF_RET_CHECK(
2267       LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2268 
2269   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2270                       assignment_.GetUniqueSlice(hlo->operand(0), {}));
2271   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2272                       assignment_.GetUniqueSlice(hlo, {0}));
2273   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2274                       assignment_.GetUniqueSlice(hlo, {1}));
2275   llvm::Value* values_ptr =
2276       EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2277   llvm::Value* out_values_ptr =
2278       EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2279   llvm::Value* out_indices_ptr =
2280       EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2281   EmitCallToFunc(
2282       runtime::kTopKF32SymbolName,
2283       {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2284        b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2285        BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2286        BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2287        BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2288       b_.getVoidTy());
2289 
2290   llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2291                      &b_);
2292   return Status::OK();
2293 }
2294 
HandleCustomCall(HloInstruction * custom_call)2295 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2296   if (custom_call->custom_call_target() == "PadToStatic") {
2297     return HandlePadToStatic(custom_call);
2298   }
2299   if (custom_call->custom_call_target() == "SliceToDynamic") {
2300     return HandleSliceToDynamic(custom_call);
2301   }
2302   if (custom_call->custom_call_target() == "TopK") {
2303     return HandleTopK(custom_call);
2304   }
2305   absl::Span<HloInstruction* const> operands(custom_call->operands());
2306   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2307   llvm::AllocaInst* operands_alloca =
2308       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2309           i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2310   for (size_t i = 0; i < operands.size(); ++i) {
2311     const HloInstruction* operand = operands[i];
2312     llvm::Value* operand_as_i8ptr =
2313         PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2314     llvm::Value* slot_in_operands_alloca =
2315         InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2316     Store(operand_as_i8ptr, slot_in_operands_alloca);
2317   }
2318   if (emit_code_for_msan_) {
2319     // Mark the alloca as initialized for msan. The buffer gets read by the
2320     // custom callee, which might be msan-instrumented.
2321     // TODO(b/66051036): Run the msan instrumentation pass instead.
2322     const llvm::DataLayout& dl = module_->getDataLayout();
2323     llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2324     EmitCallToFunc(
2325         "__msan_unpoison",
2326         {PointerCast(operands_alloca, i8_ptr_type),
2327          llvm::ConstantInt::get(
2328              intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2329         b_.getVoidTy());
2330   }
2331 
2332   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2333   // Write the tuple table if the output is a tuple.
2334   if (custom_call->shape().IsTuple()) {
2335     std::vector<llvm::Value*> base_ptrs;
2336     for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2337          ++i) {
2338       const Shape& elem_shape =
2339           ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2340       TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2341       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2342                           assignment_.GetUniqueSlice(custom_call, {i}));
2343       llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2344       base_ptrs.push_back(addr);
2345     }
2346     llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2347   }
2348   auto* output_address_arg =
2349       PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2350 
2351   EmitCallToFunc(custom_call->custom_call_target(),
2352                  {output_address_arg, operands_alloca}, b_.getVoidTy());
2353 
2354   return Status::OK();
2355 }
2356 
HandleWhile(HloInstruction * xla_while)2357 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2358   // Precondition: Condition computation must return a scalar bool.
2359   HloComputation* condition = xla_while->while_condition();
2360   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2361                condition->root_instruction()->shape().element_type() == PRED)
2362       << "While condition computation must return bool; got: "
2363       << ShapeUtil::HumanString(condition->root_instruction()->shape());
2364   // Check that all while-related buffers share an allocation slice.
2365   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2366       xla_while->shape(),
2367       [this, &xla_while](const Shape& /*subshape*/,
2368                          const ShapeIndex& index) -> Status {
2369         auto check = [this](const HloInstruction* a, const HloInstruction* b,
2370                             const ShapeIndex& index) {
2371           const BufferAllocation::Slice slice_a =
2372               assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2373           const BufferAllocation::Slice slice_b =
2374               assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2375           if (slice_a != slice_b) {
2376             return InternalError(
2377                 "instruction %s %s does not share slice with "
2378                 "instruction %s %s",
2379                 a->ToString(), slice_a.ToString(), b->ToString(),
2380                 slice_b.ToString());
2381           }
2382           return Status::OK();
2383         };
2384         TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2385         TF_RETURN_IF_ERROR(check(
2386             xla_while, xla_while->while_condition()->parameter_instruction(0),
2387             index));
2388         TF_RETURN_IF_ERROR(
2389             check(xla_while, xla_while->while_body()->parameter_instruction(0),
2390                   index));
2391         TF_RETURN_IF_ERROR(check(
2392             xla_while, xla_while->while_body()->root_instruction(), index));
2393         return Status::OK();
2394       }));
2395 
2396   // Set emitted value to that of 'init' with which it shares an allocation.
2397   const HloInstruction* init = xla_while->operand(0);
2398   emitted_value_[xla_while] = GetEmittedValueFor(init);
2399 
2400   // Generating:
2401   //   while (Condition(while_result)) {
2402   //     // CopyInsertion pass inserts copies which enable 'while_result' to
2403   //     // be passed back in as 'Body' parameter.
2404   //     while_result = Body(while_result);  // Insert
2405   //   }
2406 
2407   // Terminates the current block with a branch to a while header.
2408   llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2409       module_->getContext(), IrName(xla_while, "header"),
2410       compute_function_->function());
2411   Br(header_bb);
2412   b_.SetInsertPoint(header_bb);
2413 
2414   // Calls the condition function to determine whether to proceed with the
2415   // body.  It must return a bool, so use the scalar call form.
2416   EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2417   llvm::Value* while_predicate = ICmpNE(
2418       Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2419       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2420 
2421   // Branches to the body or to the while exit depending on the condition.
2422   llvm::BasicBlock* body_bb =
2423       llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2424                                compute_function_->function());
2425   llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2426       module_->getContext(), IrName(xla_while, "exit"));
2427   CondBr(while_predicate, body_bb, exit_bb);
2428 
2429   // Calls the body function from the body block.
2430   b_.SetInsertPoint(body_bb);
2431 
2432   // Calls the body function.
2433   EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2434 
2435   // Finishes with a branch back to the header.
2436   Br(header_bb);
2437 
2438   // Adds the exit block to the function and sets the insert point there.
2439   compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2440   b_.SetInsertPoint(exit_bb);
2441 
2442   return Status::OK();
2443 }
2444 
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2445 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2446     HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2447     string* failure_reason) {
2448   if (ShouldEmitParallelLoopFor(*concatenate)) {
2449     *failure_reason =
2450         "cannot generate memcpy-based concat for the parallel CPU backend";
2451     return false;
2452   }
2453 
2454   const Shape& output_shape = concatenate->shape();
2455   for (auto* op : operands) {
2456     if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2457       *failure_reason = "operand has mismatching layouts";
2458       return false;
2459     }
2460   }
2461 
2462   // We split the dimensions into three categories: the dimension over which we
2463   // are concatenating (concat_dim), the dimensions that are minor to it
2464   // (inner_dims) and the dimensions that are major to it (outer_dims).
2465 
2466   int64 concat_dim = concatenate->dimensions(0);
2467   const Layout& output_layout = output_shape.layout();
2468   auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2469   auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2470 
2471   std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2472   std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2473                                 output_min2maj.end());
2474 
2475   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2476 
2477   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2478   llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2479 
2480   llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2481   std::vector<llvm::Value*> target_multi_index =
2482       loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2483   std::replace(target_multi_index.begin(), target_multi_index.end(),
2484                static_cast<llvm::Value*>(nullptr),
2485                static_cast<llvm::Value*>(b_.getInt64(0)));
2486   llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2487                                        b_.getInt64Ty());
2488 
2489   if (!outer_dims.empty()) {
2490     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2491   }
2492 
2493   PrimitiveType primitive_type = output_shape.element_type();
2494   unsigned primitive_type_size =
2495       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2496 
2497   // Contiguous subregions from each operand to the concatenate contribute to a
2498   // contiguous subregion in the target buffer starting at target_region_begin.
2499   llvm::Value* target_region_begin = BitCast(
2500       target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2501       i8_ptr_type);
2502   int64 byte_offset_into_target_region = 0;
2503 
2504   int64 inner_dims_product =
2505       std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2506                       [&](int64 product, int64 inner_dim) {
2507                         return product * output_shape.dimensions(inner_dim);
2508                       });
2509 
2510   // For each operand, emit a memcpy from the operand to the target of size
2511   // equal to the product of inner dimensions.
2512   for (HloInstruction* operand : operands) {
2513     const Shape& input_shape = operand->shape();
2514     llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2515     llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2516                                          b_.getInt64Ty());
2517     llvm::Value* copy_source_address = BitCast(
2518         source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2519         i8_ptr_type);
2520 
2521     llvm::Value* copy_target_address =
2522         GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2523 
2524     EmitTransferElements(
2525         copy_target_address, copy_source_address,
2526         inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2527         target_array, source_array);
2528 
2529     byte_offset_into_target_region += inner_dims_product *
2530                                       input_shape.dimensions(concat_dim) *
2531                                       primitive_type_size;
2532   }
2533 
2534   if (!outer_dims.empty()) {
2535     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2536   }
2537 
2538   return true;
2539 }
2540 
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2541 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2542                                    absl::Span<llvm::Value* const> arguments) {
2543   llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2544   std::vector<llvm::Value*> call_args;
2545   call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2546   absl::c_copy(arguments, std::back_inserter(call_args));
2547   return b_.CreateCall(
2548       b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2549           "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2550                                             /*isVarArg=*/true)),
2551       call_args);
2552 }
2553 
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2554 llvm::Value* IrEmitter::EmitCallToFunc(
2555     std::string func_name, const std::vector<llvm::Value*>& arguments,
2556     llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2557     bool only_accesses_inaccessible_mem_or_arg_mem) {
2558   std::vector<llvm::Type*> types;
2559   types.reserve(arguments.size());
2560   absl::c_transform(arguments, std::back_inserter(types),
2561                     [&](llvm::Value* val) { return val->getType(); });
2562   llvm::FunctionType* func_type =
2563       llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2564   auto func = llvm::dyn_cast<llvm::Function>(
2565       module_->getOrInsertFunction(func_name, func_type).getCallee());
2566   func->setCallingConv(llvm::CallingConv::C);
2567   if (does_not_throw) {
2568     func->setDoesNotThrow();
2569   }
2570   if (only_accesses_arg_memory) {
2571     func->setOnlyAccessesArgMemory();
2572   }
2573   if (only_accesses_inaccessible_mem_or_arg_mem) {
2574     func->setOnlyAccessesInaccessibleMemOrArgMem();
2575   }
2576   return b_.CreateCall(func, arguments);
2577 }
2578 
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64 element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2579 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2580                                      int64 element_count,
2581                                      PrimitiveType primitive_type,
2582                                      const llvm_ir::IrArray& target_array,
2583                                      const llvm_ir::IrArray& source_array) {
2584   unsigned primitive_type_size =
2585       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2586   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2587       primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2588   llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2589       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2590 
2591   if (element_count == 1) {
2592     auto* load_instruction =
2593         AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2594     source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2595     auto* store_instruction =
2596         AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2597                      element_alignment);
2598     target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2599   } else {
2600     auto* memcpy_instruction = b_.CreateMemCpy(
2601         target, /*DstAlign=*/llvm::Align(element_alignment), source,
2602         /*SrcAlign=*/llvm::Align(element_alignment),
2603         element_count * primitive_type_size);
2604 
2605     // The memcpy does the load and the store internally.  The aliasing related
2606     // metadata has to reflect that.
2607     std::map<int, llvm::MDNode*> merged_metadata =
2608         llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2609                                target_array.metadata());
2610     for (const auto& kind_md_pair : merged_metadata) {
2611       memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2612     }
2613   }
2614 }
2615 
HandleConcatenate(HloInstruction * concatenate)2616 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2617   absl::Span<HloInstruction* const> operands(concatenate->operands());
2618   string failure_reason;
2619   TF_ASSIGN_OR_RETURN(
2620       bool successful,
2621       EmitFastConcatenate(concatenate, operands, &failure_reason));
2622   if (successful) {
2623     VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2624     return Status::OK();
2625   }
2626 
2627   VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2628           << ": " << failure_reason;
2629 
2630   return DefaultAction(concatenate);
2631 }
2632 
HandleConditional(HloInstruction * conditional)2633 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2634   auto branch_index = conditional->operand(0);
2635   int num_branches = conditional->branch_count();
2636   TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2637                (branch_index->shape().element_type() == PRED ||
2638                 branch_index->shape().element_type() == S32))
2639       << "Branch index on a conditional must be scalar bool or int32; got: "
2640       << ShapeUtil::HumanString(branch_index->shape());
2641 
2642   for (int b = 0; b < num_branches; ++b) {
2643     HloComputation* br_computation = conditional->branch_computation(b);
2644     TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2645                                   br_computation->root_instruction()->shape()))
2646         << "Shape of conditional should be same as the shape of the " << b
2647         << "th branch computation; got: "
2648         << ShapeUtil::HumanString(conditional->shape()) << " and "
2649         << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2650   }
2651 
2652   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2653 
2654   if (branch_index->shape().element_type() == PRED) {
2655     // Emit an if-else to LLVM:
2656     //   if (pred)
2657     //     cond_result = true_computation(true_operand)
2658     //   else
2659     //     cond_result = false_computation(false_operand)
2660     llvm::LoadInst* pred_value = Load(
2661         GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2662     llvm::Value* pred_cond =
2663         ICmpNE(pred_value,
2664                llvm::ConstantInt::get(
2665                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2666                "boolean_predicate");
2667     llvm_ir::LlvmIfData if_data =
2668         llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2669 
2670     SetToFirstInsertPoint(if_data.true_block, &b_);
2671     EmitGlobalCall(*conditional->branch_computation(0),
2672                    IrName(conditional, "_true"));
2673 
2674     SetToFirstInsertPoint(if_data.false_block, &b_);
2675     EmitGlobalCall(*conditional->branch_computation(1),
2676                    IrName(conditional, "_false"));
2677 
2678     SetToFirstInsertPoint(if_data.after_block, &b_);
2679     return Status::OK();
2680   }
2681   // We emit a switch statement to LLVM:
2682   // switch (branch_index) {
2683   //   default:
2684   //     result = branch_computations[num_branches-1](operands[num_branches-1]);
2685   //     break;
2686   //   case 0:
2687   //     result = branch_computations[0](operands[0]); break;
2688   //   case 1:
2689   //     result = branch_computations[1](operands[1]); break;
2690   //   ...
2691   //   case [[num_branches-2]]:
2692   //     result = branch_computations[num_branches-2](operands[num_branches-2]);
2693   //     break;
2694   // }
2695   llvm::LoadInst* branch_index_value = Load(
2696       GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2697 
2698   auto case_block = b_.GetInsertBlock();
2699   llvm::BasicBlock* after_block;
2700   // Add a terminator to the case block, if necessary.
2701   if (case_block->getTerminator() == nullptr) {
2702     after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2703     b_.SetInsertPoint(case_block);
2704     b_.CreateBr(after_block);
2705   } else {
2706     after_block =
2707         case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2708   }
2709   // Our basic block should now end with an unconditional branch.  Remove it;
2710   // we're going to replace it with a switch based branch.
2711   case_block->getTerminator()->eraseFromParent();
2712 
2713   // Lower the default branch computation.
2714   auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2715   b_.SetInsertPoint(default_block);
2716   EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2717                  IrName(conditional, "_default"));
2718   b_.CreateBr(after_block);
2719 
2720   // Prepare the switch (branch_index) { ... } instruction.
2721   b_.SetInsertPoint(case_block);
2722   llvm::SwitchInst* case_inst =
2723       b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2724   // Lower each branch's computation.
2725   for (int b = 0; b < num_branches - 1; ++b) {  // last branch is default
2726     // Lower the case b: { ... ; break; } computation.
2727     auto branch_block =
2728         llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2729     b_.SetInsertPoint(branch_block);
2730     EmitGlobalCall(*conditional->branch_computation(b),
2731                    IrName(conditional, absl::StrCat("_branch", b)));
2732     b_.CreateBr(after_block);
2733     case_inst->addCase(b_.getInt32(b), branch_block);
2734   }
2735 
2736   SetToFirstInsertPoint(after_block, &b_);
2737   return Status::OK();
2738 }
2739 
HandleAfterAll(HloInstruction * after_all)2740 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2741   TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2742   // No code to generate, but we need to emit an address for book-keeping.
2743   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2744   return Status::OK();
2745 }
2746 
HandleAddDependency(HloInstruction * add_dependency)2747 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2748   // AddDedendency just forwards its zero-th operand.
2749   emitted_value_[add_dependency] =
2750       GetEmittedValueFor(add_dependency->operand(0));
2751   return Status::OK();
2752 }
2753 
HandleRng(HloInstruction * rng)2754 Status IrEmitter::HandleRng(HloInstruction* rng) {
2755   return Unimplemented("Rng should be expanded for CPU.");
2756 }
2757 
HandleRngGetAndUpdateState(HloInstruction * rng_state)2758 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2759   VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2760   llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2761       Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2762       &b_);
2763 
2764   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2765   llvm::Value* address = GetEmittedValueFor(rng_state);
2766 
2767   // The buffer has an array type while the value has a i128. Cast the
2768   // buffer to i128 type to store the value.
2769   address = BitCast(address, llvm::PointerType::get(
2770                                  old_state->getType()->getScalarType(),
2771                                  address->getType()->getPointerAddressSpace()));
2772   llvm::StoreInst* store = Store(old_state, address);
2773   store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2774       rng_state->shape().element_type())));
2775 
2776   return Status::OK();
2777 }
2778 
FinishVisit(HloInstruction * root)2779 Status IrEmitter::FinishVisit(HloInstruction* root) {
2780   // When this method is called, we should have already emitted an IR value for
2781   // the root (return) op. The IR value holds the address of the buffer holding
2782   // the value. If the root is a constant or parameter, we perform a memcpy from
2783   // this buffer to the retval buffer of the computation. Otherwise, there's
2784   // nothing to do since the result was already written directly into the output
2785   // buffer.
2786   VLOG(2) << "FinishVisit root: " << root->ToString();
2787   if (root->opcode() == HloOpcode::kOutfeed) {
2788     VLOG(2) << "  outfeed with value: "
2789             << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2790   } else {
2791     VLOG(2) << "  value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2792   }
2793 
2794   auto record_complete_computation = [&](llvm::Value* prof_counter) {
2795     if (prof_counter) {
2796       profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2797     }
2798   };
2799 
2800   // For the entry computation this increment is cumulative of embedded
2801   // computations since it includes cycles spent in computations invoked by
2802   // While, Call etc.
2803   record_complete_computation(GetProfileCounterFor(*root->parent()));
2804   return Status::OK();
2805 }
2806 
2807 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2808 llvm::Value* IrEmitter::GetProfileCounterCommon(
2809     const T& hlo,
2810     const std::unordered_map<const T*, int64>& profile_index_map) {
2811   auto it = profile_index_map.find(&hlo);
2812   if (it == profile_index_map.end()) {
2813     return nullptr;
2814   }
2815 
2816   int64 prof_counter_idx = it->second;
2817   string counter_name = IrName("prof_counter", hlo.name());
2818   return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2819              counter_name);
2820 }
2821 
GetProfileCounterFor(const HloInstruction & instruction)2822 llvm::Value* IrEmitter::GetProfileCounterFor(
2823     const HloInstruction& instruction) {
2824   return GetProfileCounterCommon<HloInstruction>(instruction,
2825                                                  instruction_to_profile_idx_);
2826 }
2827 
GetProfileCounterFor(const HloComputation & computation)2828 llvm::Value* IrEmitter::GetProfileCounterFor(
2829     const HloComputation& computation) {
2830   return GetProfileCounterCommon<HloComputation>(computation,
2831                                                  computation_to_profile_idx_);
2832 }
2833 
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2834 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2835                                                      llvm::Value* prof_counter,
2836                                                      llvm::Value* cycle_end,
2837                                                      llvm::Value* cycle_start) {
2838   auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2839   llvm::LoadInst* old_cycle_count =
2840       b->CreateLoad(prof_counter, "old_cycle_count");
2841   auto* new_cycle_count =
2842       b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2843   b->CreateStore(new_cycle_count, prof_counter);
2844 }
2845 
ReadCycleCounter(llvm::IRBuilder<> * b)2846 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2847   llvm::Module* module = b->GetInsertBlock()->getModule();
2848   if (!use_rdtscp_) {
2849     llvm::Function* func_llvm_readcyclecounter =
2850         llvm::Intrinsic::getDeclaration(module,
2851                                         llvm::Intrinsic::readcyclecounter);
2852     return b->CreateCall(func_llvm_readcyclecounter);
2853   }
2854   llvm::Function* func_llvm_x86_rdtscp =
2855       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2856   llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
2857   return b->CreateExtractValue(rdtscp_call, {0});
2858 }
2859 
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2860 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2861                                                  HloInstruction* hlo) {
2862   auto* cycle_start = ReadCycleCounter(b);
2863   cycle_start->setName(IrName(hlo, "cycle_start"));
2864   cycle_starts_[hlo] = cycle_start;
2865   if (first_read_cycle_start_ == nullptr) {
2866     first_read_cycle_start_ = cycle_start;
2867   }
2868 }
2869 
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2870 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2871                                                  HloInstruction* hlo,
2872                                                  llvm::Value* prof_counter) {
2873   auto* cycle_end = ReadCycleCounter(b);
2874   cycle_end->setName(IrName(hlo, "cycle_end"));
2875   auto* cycle_start = cycle_starts_[hlo];
2876   UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2877   last_read_cycle_end_ = cycle_end;
2878 }
2879 
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2880 void IrEmitter::ProfilingState::RecordCompleteComputation(
2881     llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2882   if (last_read_cycle_end_ && first_read_cycle_start_) {
2883     UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2884                          first_read_cycle_start_);
2885   }
2886 }
2887 
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2888 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
2889                                                HloInstruction* hlo,
2890                                                llvm::Value* run_options) {
2891   if (!enabled_) {
2892     return;
2893   }
2894 
2895   llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
2896   llvm::Type* void_ptr_type =
2897       int8_ptr_type;  // LLVM does not have a void*, we use an int8* instead.
2898   llvm::FunctionType* fn_type =
2899       llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
2900                               /*isVarArg=*/false);
2901 
2902   llvm::Function* function = b->GetInsertBlock()->getParent();
2903   llvm::Module* module = function->getParent();
2904   const char* fn_name = runtime::kTracingStartSymbolName;
2905   llvm::FunctionCallee trace_func =
2906       module->getOrInsertFunction(fn_name, fn_type);
2907   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2908     fn->setCallingConv(llvm::CallingConv::C);
2909     fn->setDoesNotThrow();
2910     fn->setOnlyAccessesArgMemory();
2911   }
2912   auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
2913   auto* activity_id =
2914       b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
2915                                  b->CreateBitCast(hlo_name, int8_ptr_type)});
2916   activity_id->setName(IrName(hlo, "activity_id"));
2917   activity_ids_[hlo] = activity_id;
2918 }
2919 
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2920 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
2921                                              HloInstruction* hlo,
2922                                              llvm::Value* run_options) {
2923   if (!enabled_) {
2924     return;
2925   }
2926 
2927   llvm::Type* void_ptr_type =
2928       b->getInt8Ty()->getPointerTo();  // LLVM does not have a void*, we use an
2929                                        // int8* instead.
2930   llvm::FunctionType* fn_type =
2931       llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
2932                               /*isVarArg=*/false);
2933 
2934   llvm::Function* function = b->GetInsertBlock()->getParent();
2935   llvm::Module* module = function->getParent();
2936   const char* fn_name = runtime::kTracingEndSymbolName;
2937   llvm::FunctionCallee trace_func =
2938       module->getOrInsertFunction(fn_name, fn_type);
2939   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2940     fn->setCallingConv(llvm::CallingConv::C);
2941     fn->setDoesNotThrow();
2942     fn->setOnlyAccessesArgMemory();
2943   }
2944   auto* activity_id = activity_ids_.at(hlo);
2945   b->CreateCall(trace_func,
2946                 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
2947 }
2948 
2949 namespace {
IsHloVeryCheap(const HloInstruction * hlo)2950 bool IsHloVeryCheap(const HloInstruction* hlo) {
2951   return hlo->opcode() == HloOpcode::kBitcast ||
2952          hlo->opcode() == HloOpcode::kTuple ||
2953          hlo->opcode() == HloOpcode::kGetTupleElement ||
2954          hlo->opcode() == HloOpcode::kParameter ||
2955          hlo->opcode() == HloOpcode::kConstant;
2956 }
2957 }  // namespace
2958 
Preprocess(HloInstruction * hlo)2959 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2960   VLOG(3) << "Visiting: " << hlo->ToString();
2961   // When profiling is enabled, trace the same HLOs that the profiler does.
2962   if (instruction_to_profile_idx_.count(hlo) ||
2963       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
2964     tracing_state_.EmitTracingStart(&b_, hlo,
2965                                     GetExecutableRunOptionsArgument());
2966     profiling_state_.RecordCycleStart(&b_, hlo);
2967   }
2968   return Status::OK();
2969 }
2970 
Postprocess(HloInstruction * hlo)2971 Status IrEmitter::Postprocess(HloInstruction* hlo) {
2972   if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
2973     profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
2974   }
2975   // When profiling is enabled, trace the same HLOs that the profiler does.
2976   if (instruction_to_profile_idx_.count(hlo) ||
2977       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
2978     tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
2979   }
2980   return Status::OK();
2981 }
2982 
GetIrArrayFor(const HloInstruction * hlo)2983 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
2984   llvm::Value* value_for_op = GetEmittedValueFor(hlo);
2985 
2986   llvm_ir::IrArray array(value_for_op, hlo->shape());
2987   AddAliasingInformationToIrArray(*hlo, &array);
2988   return array;
2989 }
2990 
GetIrArraysForOperandsOf(const HloInstruction * hlo)2991 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
2992     const HloInstruction* hlo) {
2993   std::vector<llvm_ir::IrArray> arrays;
2994   std::transform(
2995       hlo->operands().begin(), hlo->operands().end(),
2996       std::back_inserter(arrays),
2997       [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
2998   return arrays;
2999 }
3000 
GetEmittedValueFor(const HloInstruction * hlo)3001 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3002   auto it = emitted_value_.find(hlo);
3003   if (it == emitted_value_.end()) {
3004     LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3005   }
3006   return it->second;
3007 }
3008 
IrShapeType(const Shape & shape)3009 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3010   return llvm_ir::ShapeToIrType(shape, module_);
3011 }
3012 
GetProfileCountersArgument()3013 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3014   return compute_function_->profile_counters_arg();
3015 }
3016 
GetBufferTableArgument()3017 llvm::Value* IrEmitter::GetBufferTableArgument() {
3018   return compute_function_->buffer_table_arg();
3019 }
3020 
GetExecutableRunOptionsArgument()3021 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3022   return compute_function_->exec_run_options_arg();
3023 }
3024 
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3025 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3026     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3027   const BufferAllocation& allocation = *slice.allocation();
3028   llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3029     auto param_it =
3030         computation_parameter_allocations_.find(slice.allocation()->index());
3031     if (param_it != computation_parameter_allocations_.end()) {
3032       int64 param_number = param_it->second;
3033       // We have to access the parameter at offset param_number in the params
3034       // array. The code generated here is equivalent to this C code:
3035       //
3036       //   i8* param_address_untyped = params[param_number];
3037       //   Param* param_address_typed = (Param*)param_address_untyped;
3038       //
3039       // Where Param is the actual element type of the underlying buffer (for
3040       // example, float for an XLA F32 element type).
3041       llvm::Value* params = compute_function_->parameters_arg();
3042       llvm::Value* param_address_offset =
3043           llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
3044       llvm::LoadInst* param_address_untyped = Load(param_address_offset);
3045 
3046       if (!target_shape.IsOpaque()) {
3047         AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3048         AttachDereferenceableMetadataForLoad(param_address_untyped,
3049                                              target_shape);
3050       }
3051       return param_address_untyped;
3052     }
3053 
3054     // Thread-local allocations should only be assigned a single buffer.
3055     const auto& assigned_buffers = allocation.assigned_buffers();
3056     CHECK_EQ(1, assigned_buffers.size());
3057     const Shape& shape = assigned_buffers.begin()->first->shape();
3058 
3059     std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3060         compute_function_->function(), slice};
3061     auto buf_it = thread_local_buffers_.find(key);
3062     if (buf_it == thread_local_buffers_.end()) {
3063       llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3064           IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3065           &b_, MinimumAlignmentForShape(target_shape));
3066       auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3067       CHECK(it_inserted_pair.second);
3068       buf_it = it_inserted_pair.first;
3069     }
3070     return buf_it->second;
3071   }();
3072   return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3073 }
3074 
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3075 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3076     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3077   const BufferAllocation& allocation = *slice.allocation();
3078   llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3079       GetBufferTableArgument(), slice.index(), &b_);
3080   llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
3081   if (hlo_module_config_.debug_options()
3082           .xla_llvm_enable_invariant_load_metadata()) {
3083     tempbuf_address_base->setMetadata(
3084         llvm::LLVMContext::MD_invariant_load,
3085         llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3086   }
3087   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3088   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3089 
3090   llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3091   if (slice.offset() > 0) {
3092     // Adjust the address to account for the slice offset.
3093     tempbuf_address_untyped =
3094         InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
3095   }
3096   return BitCast(tempbuf_address_untyped,
3097                  IrShapeType(target_shape)->getPointerTo());
3098 }
3099 
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3100 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3101                                           const Shape& target_shape) {
3102   if (slice.allocation()->is_thread_local()) {
3103     return EmitThreadLocalBufferPointer(slice, target_shape);
3104   } else if (slice.allocation()->is_constant()) {
3105     return BitCast(
3106         FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3107         IrShapeType(target_shape)->getPointerTo());
3108   } else {
3109     return EmitGlobalBufferPointer(slice, target_shape);
3110   }
3111 }
3112 
EmitTargetAddressForOp(const HloInstruction * op)3113 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3114   const Shape& target_shape = op->shape();
3115   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3116                       assignment_.GetUniqueTopLevelSlice(op));
3117   llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3118   addr->setName(IrName(op));
3119   emitted_value_[op] = addr;
3120   return Status::OK();
3121 }
3122 
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3123 Status IrEmitter::EmitTargetElementLoop(
3124     HloInstruction* target_op,
3125     const llvm_ir::ElementGenerator& element_generator) {
3126   return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3127 }
3128 
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3129 Status IrEmitter::EmitTargetElementLoop(
3130     HloInstruction* target_op, absl::string_view desc,
3131     const llvm_ir::ElementGenerator& element_generator) {
3132   VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3133 
3134   const Shape& target_shape = target_op->shape();
3135   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3136   llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3137 
3138   if (target_shape.IsTuple() &&
3139       (target_op->opcode() == HloOpcode::kFusion ||
3140        target_op->opcode() == HloOpcode::kReduce ||
3141        target_op->opcode() == HloOpcode::kReduceWindow)) {
3142     // For multiple outputs fusion, we need to emit each operand and the root.
3143     TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3144     std::vector<llvm_ir::IrArray> output_arrays;
3145     for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3146       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3147                           assignment_.GetUniqueSlice(target_op, {i}));
3148       const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3149       llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3150       output_arrays.push_back(
3151           llvm_ir::IrArray(op_target_address, element_shape));
3152     }
3153     TF_RETURN_IF_ERROR(
3154         llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3155             .EmitLoop(IrName(target_op)));
3156 
3157     std::vector<llvm::Value*> tuple_operand_ptrs;
3158     for (int64 i = 0; i < output_arrays.size(); ++i) {
3159       tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3160     }
3161     llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3162 
3163   } else {
3164     if (ShouldEmitParallelLoopFor(*target_op)) {
3165       // Emit code to read dynamic loop bounds from compute function argument.
3166       std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3167           compute_function_->GetDynamicLoopBounds();
3168       // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3169       TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3170                                              &dynamic_loop_bounds, &b_)
3171                              .EmitLoop(IrName(target_op)));
3172     } else {
3173       TF_RETURN_IF_ERROR(
3174           llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3175               .EmitLoop(IrName(target_op)));
3176     }
3177   }
3178   return Status::OK();
3179 }
3180 
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3181 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3182                              const HloInstruction& destination) {
3183   llvm::Value* source_value = GetEmittedValueFor(&source);
3184   llvm::Value* destination_value = GetEmittedValueFor(&destination);
3185   int64 source_size = ByteSizeOf(source.shape());
3186   // TODO(b/63762267): Be more aggressive about specifying alignment.
3187   MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3188          /*SrcAlign=*/llvm::Align(1), source_size);
3189   return Status::OK();
3190 }
3191 
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3192 Status IrEmitter::ElementTypesSameAndSupported(
3193     const HloInstruction& instruction,
3194     absl::Span<const HloInstruction* const> operands,
3195     absl::Span<const PrimitiveType> supported_types) {
3196   for (auto operand : operands) {
3197     TF_RET_CHECK(
3198         ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3199   }
3200 
3201   TF_RET_CHECK(!operands.empty());
3202   PrimitiveType primitive_type = operands[0]->shape().element_type();
3203   if (!absl::c_linear_search(supported_types, primitive_type)) {
3204     return Unimplemented("unsupported operand type %s in op %s",
3205                          PrimitiveType_Name(primitive_type),
3206                          HloOpcodeString(instruction.opcode()));
3207   }
3208   return Status::OK();
3209 }
3210 
DefaultAction(HloInstruction * hlo)3211 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3212   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3213   for (const HloInstruction* operand : hlo->operands()) {
3214     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3215       return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3216     };
3217   }
3218   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3219   return EmitTargetElementLoop(
3220       hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3221 }
3222 
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3223 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3224     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3225     absl::string_view name) {
3226   std::vector<llvm::Value*> return_value =
3227       EmitThreadLocalCall(callee, parameters, name);
3228   CHECK_EQ(return_value.size(), 1);
3229   return return_value[0];
3230 }
3231 
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3232 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3233     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3234     absl::string_view name) {
3235   CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3236   const Shape& return_shape = callee.root_instruction()->shape();
3237   bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3238   bool is_tuple_of_scalars_return =
3239       return_shape.IsTuple() &&
3240       absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3241         return ShapeUtil::IsScalar(shape);
3242       });
3243   CHECK(is_scalar_return || is_tuple_of_scalars_return);
3244 
3245   std::vector<llvm::Value*> parameter_addrs;
3246   for (llvm::Value* parameter : parameters) {
3247     CHECK(!parameter->getType()->isPointerTy());
3248     llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3249         parameter->getType(), "arg_addr", &b_);
3250     Store(parameter, parameter_addr);
3251     parameter_addrs.push_back(parameter_addr);
3252   }
3253 
3254   llvm::Type* return_value_buffer_type =
3255       llvm_ir::ShapeToIrType(return_shape, module_);
3256   std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3257   int retval_alignment =
3258       is_scalar_return
3259           ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3260           : 0;
3261   llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3262       return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3263 
3264   std::vector<llvm::Value*> allocas_for_returned_scalars;
3265   if (is_scalar_return) {
3266     allocas_for_returned_scalars.push_back(return_value_buffer);
3267   } else {
3268     constexpr int max_tuple_size = 1000;
3269     CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3270         << "Multivalue function can not return more than 1000 elements to avoid"
3271         << " stack smashing";
3272     allocas_for_returned_scalars =
3273         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3274     llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
3275 
3276     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3277   }
3278 
3279   Call(FindOrDie(emitted_functions_, &callee),
3280        GetArrayFunctionCallArguments(
3281            parameter_addrs, &b_, name,
3282            /*return_value_buffer=*/return_value_buffer,
3283            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3284            /*buffer_table_arg=*/
3285            llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3286            /*profile_counters_arg=*/GetProfileCountersArgument()));
3287 
3288   std::vector<llvm::Value*> returned_scalars;
3289   returned_scalars.reserve(allocas_for_returned_scalars.size());
3290   for (llvm::Value* addr : allocas_for_returned_scalars) {
3291     returned_scalars.push_back(Load(addr));
3292   }
3293   return returned_scalars;
3294 }
3295 
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3296 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3297                                absl::string_view name) {
3298   CHECK(absl::c_binary_search(global_computations_, &callee));
3299 
3300   Call(FindOrDie(emitted_functions_, &callee),
3301        GetArrayFunctionCallArguments(
3302            /*parameter_addresses=*/{}, &b_, name,
3303            /*return_value_buffer=*/
3304            llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3305            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3306            /*buffer_table_arg=*/GetBufferTableArgument(),
3307            /*profile_counters_arg=*/GetProfileCountersArgument()));
3308 }
3309 
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3310 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3311     const HloComputation& callee) {
3312   const HloInstruction* root_inst = callee.root_instruction();
3313   if (root_inst->opcode() == HloOpcode::kOutfeed) {
3314     return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3315   }
3316 
3317   const BufferAllocation::Slice root_buffer =
3318       assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3319   return EmitBufferPointer(root_buffer, root_inst->shape());
3320 }
3321 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3322 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3323                                     FusedIrEmitter* fused_emitter) {
3324   for (int i = 0; i < fusion->operand_count(); i++) {
3325     const HloInstruction* operand = fusion->operand(i);
3326     fused_emitter->BindGenerator(
3327         fusion->fused_parameter(i),
3328         [this, operand](llvm_ir::IrArray::Index index) {
3329           return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3330         });
3331   }
3332 }
3333 
3334 }  // namespace cpu
3335 }  // namespace xla
3336