• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <utility>
26 #include <vector>
27 
28 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 #include "llvm/CodeGen/TargetSubtargetInfo.h"
37 #include "llvm/IR/BasicBlock.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/Intrinsics.h"
42 #include "llvm/IR/IntrinsicsX86.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Value.h"
45 #include "tensorflow/compiler/xla/layout_util.h"
46 #include "tensorflow/compiler/xla/map_util.h"
47 #include "tensorflow/compiler/xla/primitive_util.h"
48 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
49 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
51 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
52 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
53 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
55 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
56 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
58 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
59 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
60 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
64 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
66 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
67 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/status_macros.h"
73 #include "tensorflow/compiler/xla/types.h"
74 #include "tensorflow/compiler/xla/util.h"
75 #include "tensorflow/compiler/xla/window_util.h"
76 #include "tensorflow/compiler/xla/xla_data.pb.h"
77 #include "tensorflow/core/lib/core/bits.h"
78 #include "tensorflow/core/lib/core/errors.h"
79 #include "tensorflow/core/lib/math/math_util.h"
80 #include "tensorflow/core/platform/logging.h"
81 
82 namespace xla {
83 
84 namespace {
85 using llvm_ir::IrName;
86 using llvm_ir::SetToFirstInsertPoint;
87 }  // namespace
88 
89 namespace cpu {
90 
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)91 IrEmitter::IrEmitter(
92     mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
93     const BufferAssignment& assignment, llvm::Module* llvm_module,
94     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
95     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
96     const TargetMachineFeatures* target_machine_features,
97     bool emit_code_for_msan)
98     : assignment_(assignment),
99       module_(llvm_module),
100       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
101       b_(llvm_module->getContext()),
102       mlir_context_(mlir_context),
103       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
104       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
105       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
106       hlo_module_config_(hlo_module.config()),
107       is_top_level_computation_(false),
108       target_machine_features_(*target_machine_features),
109       emit_code_for_msan_(emit_code_for_msan) {
110   b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
111   Status s = GatherComputationsByAllocationType(
112       &hlo_module, &thread_local_computations_, &global_computations_);
113   absl::c_sort(thread_local_computations_);
114   absl::c_sort(global_computations_);
115   TF_CHECK_OK(s) << "Should have failed buffer assignment.";
116 }
117 
EmitThreadLocalFunctionEpilogue(HloComputation * computation)118 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
119   llvm::Argument* out_parameter = compute_function_->result_arg();
120   llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
121   const Shape& return_shape = computation->root_instruction()->shape();
122 
123   if (ShapeUtil::IsScalar(return_shape)) {
124     llvm::Value* ret_value =
125         Load(root_value.GetBasePointer(), "load_ret_value");
126     Store(ret_value,
127           BitCast(out_parameter, root_value.GetBasePointer()->getType()));
128   } else {
129     CHECK(return_shape.IsTuple());
130 
131     llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
132     llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
133     llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
134 
135     for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
136       const Shape& element_shape = return_shape.tuple_shapes(i);
137       llvm::Value* destination = llvm_ir::EmitGetTupleElement(
138           element_shape,
139           /*index=*/i,
140           /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
141           &b_);
142 
143       llvm::Value* source = llvm_ir::EmitGetTupleElement(
144           element_shape,
145           /*index=*/i,
146           /*alignment=*/MinimumAlignmentForShape(element_shape),
147           root_value.GetBasePointer(), &b_);
148 
149       Store(Load(source), destination);
150     }
151   }
152 }
153 
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)154 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
155     HloComputation* computation, const string& function_name_prefix,
156     bool is_top_level_computation,
157     absl::Span<HloInstruction* const> instruction_order) {
158   string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
159   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
160   is_top_level_computation_ = is_top_level_computation;
161   num_dynamic_loop_bounds_ = 0;
162   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
163     num_dynamic_loop_bounds_ =
164         computation->root_instruction()->outer_dimension_partitions().size();
165   }
166 
167   if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
168     TF_ASSIGN_OR_RETURN(
169         computation_root_allocation_,
170         assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
171   }
172 
173   for (const HloInstruction* param : computation->parameter_instructions()) {
174     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
175                         assignment_.GetUniqueTopLevelSlice(param));
176     computation_parameter_allocations_[param_slice.allocation()->index()] =
177         param->parameter_number();
178   }
179 
180   InitializeIrFunction(function_name);
181   // The rdtscp instruction is x86 specific.  We will fallback to LLVM's generic
182   // readcyclecounter if it is unavailable.
183   bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
184                     arch_type_ == llvm::Triple::ArchType::x86_64;
185   profiling_state_ = ProfilingState(use_rdtscp);
186 
187   tracing_state_.set_enabled(
188       computation->parent()->config().cpu_traceme_enabled());
189 
190   TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
191   llvm::Function* ir_function = compute_function_->function();
192   InsertOrDie(&emitted_functions_, computation, ir_function);
193   // Delete 'compute_function', finalizing 'ir_function' and restoring caller
194   // IR insert point.
195 
196   // Function epilogue: copying the value over to either the return register,
197   // or values pointing from the return register.
198   const BufferAllocation* root_allocation =
199       computation_root_allocation_.allocation();
200   if (root_allocation && root_allocation->is_thread_local()) {
201     EmitThreadLocalFunctionEpilogue(computation);
202   }
203 
204   // Destructor for compute_function_ emits the "ret void" instruction.
205   compute_function_.reset();
206   computation_root_allocation_ = BufferAllocation::Slice();
207   computation_parameter_allocations_.clear();
208   return ir_function;
209 }
210 
InitializeIrFunction(const string & function_name)211 void IrEmitter::InitializeIrFunction(const string& function_name) {
212   // Functions with local linkage get an inlining bonus.  Because we know
213   // a-priori that embedded functions (non-entry functions) will not have its
214   // name resolved, give it local linkage.
215   llvm::Function::LinkageTypes linkage =
216       is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
217                                 : llvm::GlobalValue::InternalLinkage;
218   // Create and initialize new IrFunction.
219   compute_function_.reset(new IrFunction(function_name, linkage,
220                                          hlo_module_config_, module_, &b_,
221                                          num_dynamic_loop_bounds_));
222 }
223 
~IrEmitter()224 IrEmitter::~IrEmitter() {}
225 
HandleBitcast(HloInstruction * bitcast)226 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
227   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
228   emitted_value_[bitcast] =
229       BitCast(GetEmittedValueFor(bitcast->operand(0)),
230               IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
231   return Status::OK();
232 }
233 
EmitGlobalForLiteral(const Literal & literal)234 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
235   llvm::Constant* initializer =
236       llvm_ir::ConvertLiteralToIrConstant(literal, module_);
237   llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
238       /*Module=*/*module_,
239       /*Type=*/initializer->getType(),
240       /*isConstant=*/true,
241       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
242       /*Initializer=*/initializer,
243       /*Name=*/"");
244   result_global->setAlignment(
245       llvm::Align(MinimumAlignmentForShape(literal.shape())));
246   result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
247   return llvm::ConstantExpr::getBitCast(
248       result_global, IrShapeType(literal.shape())->getPointerTo());
249 }
250 
EmitConstantGlobals()251 Status IrEmitter::EmitConstantGlobals() {
252   for (const BufferAllocation& allocation : assignment_.Allocations()) {
253     if (!allocation.is_constant()) {
254       continue;
255     }
256 
257     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
258     llvm::Constant* global_for_const;
259     auto it = emitted_literals_.find(&literal);
260     if (it != emitted_literals_.end()) {
261       global_for_const = it->second;
262     } else {
263       global_for_const = EmitGlobalForLiteral(literal);
264       InsertOrDie(&emitted_literals_, &literal, global_for_const);
265     }
266 
267     InsertOrDie(&constant_buffer_to_global_, allocation.index(),
268                 global_for_const);
269   }
270 
271   return Status::OK();
272 }
273 
HandleConstant(HloInstruction * constant)274 Status IrEmitter::HandleConstant(HloInstruction* constant) {
275   VLOG(2) << "HandleConstant: " << constant->ToString();
276   // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
277   // of the constant.
278   return EmitTargetAddressForOp(constant);
279 }
280 
HandleCopy(HloInstruction * copy)281 Status IrEmitter::HandleCopy(HloInstruction* copy) {
282   if (copy->shape().IsTuple() ||
283       (copy->shape().IsArray() &&
284        LayoutUtil::Equal(copy->operand(0)->shape().layout(),
285                          copy->shape().layout()))) {
286     // If the layouts are equal this is just a memcpy. kCopy shallow copies a
287     // tuple so just memcpy the top-level buffer for tuples.
288     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
289     return EmitMemcpy(*(copy->operand(0)), *copy);
290   } else if (copy->shape().IsArray()) {
291     // Use the elemental emitter for array shapes.
292     return DefaultAction(copy);
293   }
294   return Unimplemented("unsupported operand type %s for copy instruction",
295                        PrimitiveType_Name(copy->shape().element_type()));
296 }
297 
298 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)299 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
300   int64_t byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
301   DCHECK_GE(byte_size, 0);
302   // Largest scalar is a complex128 so we don't need to worry about the
303   // int64->int truncation here.
304   DCHECK_LE(byte_size, 16);
305 
306   // Allocations may be 8-byte aligned if part of a small block.
307   return std::min(int64{8}, byte_size);
308 }
309 
ByteSizeOf(const Shape & shape) const310 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
311   return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
312 }
313 
314 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)315 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
316   if (ShapeUtil::IsScalar(shape)) {
317     return MinimumAlignmentForPrimitiveType(shape.element_type());
318   }
319 
320   int64_t buffer_size = ByteSizeOf(shape);
321   DCHECK_GE(buffer_size, 0);
322   DCHECK_LE(buffer_size, SIZE_MAX);
323 
324   return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
325 }
326 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)327 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
328                                                const Shape& shape) {
329   int alignment = MinimumAlignmentForShape(shape);
330   if (alignment > 1) {
331     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
332   }
333 }
334 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)335 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
336                                                int64_t buffer_size) {
337   int alignment =
338       target_machine_features_.minimum_alignment_for_allocation(buffer_size);
339   if (alignment > 1) {
340     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
341   }
342 }
343 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)344 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
345                                                      const Shape& shape) {
346   AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
347 }
348 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64_t buffer_size)349 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
350                                                      int64_t buffer_size) {
351   if (buffer_size > 0) {
352     llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
353   }
354 }
355 
HandleGetTupleElement(HloInstruction * get_tuple_element)356 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
357   // A tuple is an array of pointers, one for each operand. Each pointer points
358   // to the output buffer of its corresponding operand. A GetTupleElement
359   // instruction forwards a pointer to the tuple element buffer at the given
360   // index.
361   const HloInstruction* operand = get_tuple_element->operand(0);
362   const Shape& shape = get_tuple_element->shape();
363   emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
364       shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
365       GetEmittedValueFor(operand), &b_);
366   return Status::OK();
367 }
368 
HandleSelect(HloInstruction * select)369 Status IrEmitter::HandleSelect(HloInstruction* select) {
370   auto pred = select->operand(0);
371   TF_RET_CHECK(pred->shape().element_type() == PRED);
372   return DefaultAction(select);
373 }
374 
HandleTupleSelect(HloInstruction * tuple_select)375 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
376   auto pred = tuple_select->operand(0);
377   auto on_true = tuple_select->operand(1);
378   auto on_false = tuple_select->operand(2);
379   TF_RET_CHECK(pred->shape().element_type() == PRED);
380   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
381   TF_RET_CHECK(tuple_select->shape().IsTuple());
382   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
383   llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
384                            GetEmittedValueFor(on_true),
385                            GetEmittedValueFor(on_false), &b_);
386   return Status::OK();
387 }
388 
HandleInfeed(HloInstruction * instruction)389 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
390   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
391   VLOG(2) << "HandleInfeed: " << infeed->ToString();
392 
393   // The infeed operation produces a two-element tuple containing data and a
394   // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
395   const Shape& data_shape = infeed->infeed_shape();
396   DCHECK(ShapeUtil::Equal(data_shape,
397                           ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
398   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
399 
400   // Write the tuple index table.
401   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
402                       assignment_.GetUniqueSlice(infeed, {0}));
403   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
404   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
405                       assignment_.GetUniqueSlice(infeed, {1}));
406   llvm::Value* token_address = EmitBufferPointer(
407       token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
408   llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
409 
410   if (data_shape.IsTuple()) {
411     TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
412 
413     // For a tuple, we first copy each of the internal elements to
414     // their corresponding target locations. We then construct the
415     // tuple outer buffer containing pointers to the internal
416     // elements.
417     std::vector<llvm::Value*> tuple_element_addresses;
418     for (int64_t i = 0; i < data_shape.tuple_shapes_size(); ++i) {
419       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
420                           assignment_.GetUniqueSlice(infeed, {0, i}));
421 
422       const Shape& tuple_element_shape =
423           ShapeUtil::GetTupleElementShape(data_shape, i);
424 
425       // Only the outer tuple buffer's target address is obtained from
426       // GetEmittedValueFor, to handle the case when Infeed is the root
427       // instruction. Target addresses for internal elements can be obtained
428       // from EmitBufferPointer.
429       llvm::Value* tuple_element_address =
430           EmitBufferPointer(buffer, tuple_element_shape);
431 
432       TF_RETURN_IF_ERROR(EmitXfeedTransfer(
433           XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
434 
435       tuple_element_addresses.push_back(tuple_element_address);
436     }
437 
438     llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
439                        tuple_element_addresses, &b_);
440   } else {
441     TF_RETURN_IF_ERROR(
442         EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
443   }
444 
445   return Status::OK();
446 }
447 
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)448 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
449                                     llvm::Value* program_buffer_address) {
450   int64_t length = ByteSizeOf(shape);
451   if (length < 0 || length > std::numeric_limits<int32>::max()) {
452     return InvalidArgument(
453         "xfeed (infeed or outfeed) buffer length %d is outside the valid "
454         "size range",
455         length);
456   }
457   int32_t length_32 = static_cast<int32>(length);
458 
459   int32_t shape_length;
460   TF_ASSIGN_OR_RETURN(
461       llvm::Value * shape_ptr,
462       llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
463 
464   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
465 
466   const char* acquire_func_name =
467       kind == XfeedKind::kInfeed
468           ? runtime::kAcquireInfeedBufferForDequeueSymbolName
469           : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
470 
471   // Implementation note: this call informs the runtime that it wants a
472   // buffer of size exactly 'length_32', and the runtime is responsible for
473   // check-failing the process if there is a mismatch, versus passing us
474   // back a buffer that we might overrun.
475   llvm::Value* acquired_pointer =
476       EmitCallToFunc(acquire_func_name,
477                      {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
478                       shape_ptr, b_.getInt32(shape_length)},
479                      i8_ptr_type);
480   if (kind == XfeedKind::kInfeed) {
481     // Copy to the program buffer address from the acquired buffer.
482     MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
483            acquired_pointer,
484            /*SrcAlign=*/llvm::Align(1), length_32);
485   } else {
486     // Outfeed -- copy from the in-program address to the acquired buffer.
487     MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
488            program_buffer_address,
489            /*SrcAlign=*/llvm::Align(1), length_32);
490     if (emit_code_for_msan_) {
491       // Mark the outfed data as initialized for msan. The buffer gets read by
492       // the host code, which might be msan-instrumented.
493       // TODO(b/66051036): Run the msan instrumentation pass instead.
494       const llvm::DataLayout& dl = module_->getDataLayout();
495       llvm::Type* intptr_type = b_.getIntPtrTy(dl);
496       EmitCallToFunc(
497           "__msan_unpoison",
498           {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
499           b_.getVoidTy());
500     }
501   }
502 
503   const char* release_func_name =
504       kind == XfeedKind::kInfeed
505           ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
506           : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
507   EmitCallToFunc(release_func_name,
508                  {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
509                   acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
510                  b_.getVoidTy());
511 
512   return Status::OK();
513 }
514 
HandleOutfeed(HloInstruction * outfeed)515 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
516   // Outfeed produces no useful result, but it does return a token[] that can be
517   // threaded through to other side effecting operations to ensure ordering.  In
518   // the IR emitter we treat this token as a normal u8[] and thus need to insert
519   // an entry for it in emitted_value_.
520   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
521 
522   HloInstruction* operand = outfeed->operands()[0];
523   const Shape& operand_shape = operand->shape();
524 
525   llvm::Value* value = GetEmittedValueFor(operand);
526   if (!operand_shape.IsTuple()) {
527     return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
528   }
529 
530   TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
531 
532   for (int64_t i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
533     const Shape& tuple_element_shape =
534         ShapeUtil::GetTupleElementShape(operand_shape, i);
535     llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
536         tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
537         value, &b_);
538     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
539                                          tuple_element_shape, tuple_element));
540   }
541 
542   return Status::OK();
543 }
544 
HandleSort(HloInstruction * hlo)545 Status IrEmitter::HandleSort(HloInstruction* hlo) {
546   const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
547   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
548   Shape keys_shape = sort->keys()->shape();
549   PrimitiveType keys_type = keys_shape.element_type();
550   if (!primitive_util::IsArrayType(keys_type)) {
551     return Unimplemented("Element type %s not supported in the Sort op on CPU.",
552                          PrimitiveType_Name(keys_type));
553   }
554   std::vector<llvm::Value*> destination_addresses(sort->operand_count());
555   for (int64_t i = 0; i < sort->operand_count(); ++i) {
556     ShapeIndex shape_index =
557         sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
558     const HloInstruction* operand = sort->operand(i);
559     // We assume that the layout of all involved operands and outputs is the
560     // same.
561     TF_RET_CHECK(
562         LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
563     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
564         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
565 
566     // The sort is implemented in-place, therefore we first copy the operand
567     // buffer to the output buffer if they are not the same.
568     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
569     destination_addresses[i] =
570         EmitBufferPointer(destination_buffer, operand->shape());
571     auto source_address = GetAllocationSlice(*operand);
572     if (destination_buffer != source_address) {
573       int64_t primitive_type_size =
574           ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
575       auto source_buffer = GetEmittedValueFor(operand);
576       int64_t size = ByteSizeOf(operand->shape());
577       MemCpy(destination_addresses[i],
578              /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
579              /*SrcAlign=*/llvm::Align(primitive_type_size), size);
580     }
581   }
582 
583   // Normalize the shape and the dimension to sort.
584   Shape normalized_keys_shape =
585       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
586   int64_t physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
587       keys_shape.layout())[sort->sort_dimension()];
588 
589   int64_t sort_dimension_elements =
590       normalized_keys_shape.dimensions(physical_dimension_to_sort);
591   int64_t higher_dimensions = 1;
592   for (int64_t i = 0; i < physical_dimension_to_sort; ++i) {
593     higher_dimensions *= normalized_keys_shape.dimensions(i);
594   }
595   int64_t lower_dimensions = 1;
596   for (int64_t i = normalized_keys_shape.rank() - 1;
597        i > physical_dimension_to_sort; --i) {
598     lower_dimensions *= normalized_keys_shape.dimensions(i);
599   }
600 
601   CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
602   llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
603       b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
604       &b_);
605   llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
606       b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
607       &b_);
608   for (int64_t i = 0; i < sort->operand_count(); ++i) {
609     llvm::Value* value_as_i8ptr =
610         PointerCast(destination_addresses[i], b_.getInt8PtrTy());
611     llvm::Value* slot_in_values_alloca =
612         ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
613     Store(value_as_i8ptr, slot_in_values_alloca);
614     llvm::Value* slot_in_sizes_alloca =
615         ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
616     llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
617         sort->operand(i)->shape().element_type()));
618     Store(size, slot_in_sizes_alloca);
619   }
620 
621   auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
622   EmitCallToFunc(
623       runtime::kKeyValueSortSymbolName,
624       {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
625        b_.getInt64(lower_dimensions), values,
626        b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
627        GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
628        less_than_function},
629       b_.getVoidTy());
630 
631   if (sort->values_count() > 0) {
632     llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
633   }
634   return Status::OK();
635 }
636 
HandleTuple(HloInstruction * tuple)637 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
638   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
639   std::vector<llvm::Value*> base_ptrs;
640   for (auto operand : tuple->operands()) {
641     base_ptrs.push_back(GetEmittedValueFor(operand));
642   }
643   llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
644   return Status::OK();
645 }
646 
HandleReduceWindow(HloInstruction * reduce_window)647 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
648   // Pseudo code for reduce window:
649   //
650   //   for (coordinates O in the output)
651   //     value = init_value;
652   //     for (coordinates W in the window)
653   //       for each index i:
654   //         input coordinates I_i = O_i * stride_i + W_i - pad_low_i
655   //       if I within bounds of input:
656   //         value = function(value, input(I));
657   //     output(O) = value;
658   //
659   // This is completely un-optimized and just here to have something
660   // that works.
661   return DefaultAction(reduce_window);
662 }
663 
HandleSelectAndScatter(HloInstruction * select_and_scatter)664 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
665   CHECK_EQ(select_and_scatter->operand_count(), 3);
666   const auto operand = select_and_scatter->operand(0);
667   const auto source = select_and_scatter->operand(1);
668   const auto init_value = select_and_scatter->operand(2);
669   const Window& window = select_and_scatter->window();
670   PrimitiveType operand_element_type = operand->shape().element_type();
671   const int64_t rank = operand->shape().rank();
672   CHECK_EQ(rank, source->shape().rank());
673   CHECK_EQ(rank, window.dimensions_size());
674 
675   // TODO(b/31410564): Implement dilation for select-and-scatter.
676   if (window_util::HasDilation(window)) {
677     return Unimplemented(
678         "Dilation for SelectAndScatter is not implemented on CPU. ");
679   }
680 
681   // Pseudo code for select-and-scatter:
682   //
683   // initialized_flag is initially off for every window, and is turned on after
684   // the first iteration is completed and the first operand value is selected.
685   //
686   // output(*) = init_value
687   // for (coordinates S in the source) {
688   //   initialized_flag = false
689   //   for (coordinates W in the window) {
690   //     I = S * stride + W - pad_low
691   //     if I within bounds of operand:
692   //       if !initialized_flag or select(selected_value, operand(I)) == false:
693   //         selected_value = operand(I)
694   //         selected_index = I
695   //         initialized_flag = true
696   //   }
697   //   output(selected_index) = scatter(output(selected_index), source(S))
698   // }
699   //
700 
701   // Initialize the output array with the given init_value.
702   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
703       select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
704       [this, init_value](const llvm_ir::IrArray::Index& target_index) {
705         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
706         return Load(init_value_addr);
707       }));
708 
709   // Create a loop to iterate over the source array to scatter to the output.
710   llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
711   const llvm_ir::IrArray::Index source_index =
712       source_loops.AddLoopsForShape(source->shape(), "source");
713   SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
714 
715   // Allocate space to keep the currently selected value, its index, and
716   // the boolean initialized_flag, which is initially set to false.
717   llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
718       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
719       "selected_value_address", &b_,
720       MinimumAlignmentForPrimitiveType(operand_element_type));
721   llvm::Value* selected_index_address =
722       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
723           b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
724   llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
725       b_.getInt1Ty(), "initialized_flag_address", &b_);
726   Store(b_.getInt1(false), initialized_flag_address);
727 
728   // Create the inner loop to iterate over the window.
729   llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
730   std::vector<int64> window_size;
731   for (const auto& dim : window.dimensions()) {
732     window_size.push_back(dim.size());
733   }
734   const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
735       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
736   SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
737 
738   // Compute the operand index to visit and evaluate the condition whether the
739   // operand index is within the bounds. The unsigned comparison includes
740   // checking whether the operand index >= 0.
741   std::vector<llvm::Value*> operand_multi_index(source_index.size());
742   llvm::Value* in_bounds_condition = b_.getTrue();
743   for (int64_t i = 0; i < rank; ++i) {
744     llvm::Value* strided_index =
745         NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
746     operand_multi_index[i] =
747         NSWSub(NSWAdd(strided_index, window_index[i]),
748                b_.getInt64(window.dimensions(i).padding_low()));
749     llvm::Value* index_condition =
750         ICmpULT(operand_multi_index[i],
751                 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
752     in_bounds_condition = And(in_bounds_condition, index_condition);
753   }
754   CHECK(in_bounds_condition != nullptr);
755 
756   // Only need to do something if the operand index is within the bounds. First
757   // check if the initialized_flag is set.
758   llvm_ir::LlvmIfData if_in_bounds =
759       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
760   SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
761   llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
762       Load(initialized_flag_address), "initialized", &b_);
763 
764   // If the initialized_flag is false, initialize the selected value and index
765   // with the currently visiting operand.
766   SetToFirstInsertPoint(if_initialized.false_block, &b_);
767   const auto save_operand_index =
768       [&](const llvm_ir::IrArray::Index& operand_index) {
769         for (int64_t i = 0; i < rank; ++i) {
770           llvm::Value* selected_index_address_slot =
771               InBoundsGEP(selected_index_address, {b_.getInt32(i)});
772           Store(operand_index[i], selected_index_address_slot);
773         }
774       };
775   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
776   llvm_ir::IrArray::Index operand_index(
777       operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
778   llvm::Value* operand_data =
779       operand_array.EmitReadArrayElement(operand_index, &b_);
780   Store(operand_data, selected_value_address);
781   save_operand_index(operand_index);
782   Store(b_.getInt1(true), initialized_flag_address);
783 
784   // If the initialized_flag is true, call the `select` function to potentially
785   // update the selected value and index with the currently visiting operand.
786   SetToFirstInsertPoint(if_initialized.true_block, &b_);
787   llvm::Value* operand_address =
788       operand_array.EmitArrayElementAddress(operand_index, &b_);
789   llvm::Value* operand_element = Load(operand_address);
790   llvm::Value* result = EmitScalarReturningThreadLocalCall(
791       *select_and_scatter->select(),
792       {Load(selected_value_address), operand_element}, "select_function");
793 
794   // If the 'select' function returns false, update the selected value and the
795   // index to the currently visiting operand.
796   llvm::Value* cond = ICmpNE(
797       result,
798       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
799       "boolean_predicate");
800   llvm_ir::LlvmIfData if_select_lhs =
801       llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
802   SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
803   Store(Load(operand_address), selected_value_address);
804   save_operand_index(operand_index);
805 
806   // After iterating over the window elements, scatter the source element to
807   // the selected index of the output. The value we store at the output
808   // location is computed by calling the `scatter` function with the source
809   // value and the current output value.
810   SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
811   std::vector<llvm::Value*> selected_multi_index;
812   for (int64_t i = 0; i < rank; ++i) {
813     llvm::Value* selected_index_address_slot =
814         InBoundsGEP(selected_index_address, {b_.getInt32(i)});
815     selected_multi_index.push_back(Load(selected_index_address_slot));
816   }
817   llvm_ir::IrArray source_array(GetIrArrayFor(source));
818   llvm::Value* source_value =
819       source_array.EmitReadArrayElement(source_index, &b_);
820   llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
821   llvm_ir::IrArray::Index selected_index(
822       selected_multi_index, output_array.GetShape(), source_index.GetType());
823   llvm::Value* output_value =
824       output_array.EmitReadArrayElement(selected_index, &b_);
825   llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
826       *select_and_scatter->scatter(), {output_value, source_value},
827       "scatter_function");
828   output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
829 
830   SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
831   return Status::OK();
832 }
833 
HandleDot(HloInstruction * dot)834 Status IrEmitter::HandleDot(HloInstruction* dot) {
835   auto lhs = dot->operand(0);
836   auto rhs = dot->operand(1);
837   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
838       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
839       /*supported_types=*/
840       {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
841   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
842 
843   if (dnums.lhs_contracting_dimensions_size() != 1) {
844     // This is disallowed by ShapeInference today.
845     return Unimplemented(
846         "Dot with multiple contracting dimensions not implemented.");
847   }
848 
849   llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
850   llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
851 
852   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
853   llvm_ir::IrArray target_array = GetIrArrayFor(dot);
854 
855   VLOG(2) << "HandleDot: ";
856   VLOG(2) << "  lhs operand: "
857           << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
858   VLOG(2) << "  rhs operand: "
859           << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
860   VLOG(2) << "  target: "
861           << llvm_ir::DumpToString(*target_array.GetBasePointer());
862 
863   // Dot operation is complicated so we delegate to a helper class.
864   return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
865                           /*addend_array=*/nullptr,
866                           GetExecutableRunOptionsArgument(), &b_, mlir_context_,
867                           hlo_module_config_, target_machine_features_);
868 }
869 
HandleConvolution(HloInstruction * convolution)870 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
871   auto lhs = convolution->operand(0);
872   auto rhs = convolution->operand(1);
873   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
874       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
875       /*supported_types=*/
876       {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
877 
878   // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
879   // different data layouts.
880   if (PotentiallyImplementedAsEigenConvolution(*convolution,
881                                                target_machine_features_)) {
882     const Shape& lhs_shape = lhs->shape();
883     const Shape& rhs_shape = rhs->shape();
884     const Shape& convolution_shape = convolution->shape();
885     // The input, kernel and output agree with respect to layout.
886     if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
887         LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
888         LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
889       // We lower 1D convolutions into calls to the same Eigen function as 2D
890       // convolutions, except that we pretend that the 1D convolution is really
891       // a 2D convolution with the missing dimension set to 1.  We also adjust
892       // the padding, dilation parameters as needed.
893       bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
894       llvm::Value* lhs_address = GetEmittedValueFor(lhs);
895       llvm::Value* rhs_address = GetEmittedValueFor(rhs);
896       TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
897 
898       const ConvolutionDimensionNumbers& dnums =
899           convolution->convolution_dimension_numbers();
900 
901       // Input tensor.
902       const Shape& input_shape = convolution->operand(0)->shape();
903       int64_t input_batch =
904           input_shape.dimensions(dnums.input_batch_dimension());
905       int64_t input_rows =
906           input_shape.dimensions(dnums.input_spatial_dimensions(0));
907       int64_t input_cols =
908           one_dim_convolution
909               ? 1
910               : input_shape.dimensions(dnums.input_spatial_dimensions(1));
911       int64_t input_channels =
912           input_shape.dimensions(dnums.input_feature_dimension());
913 
914       // Kernel tensor.
915       const Shape& kernel_shape = convolution->operand(1)->shape();
916       int64_t kernel_rows =
917           kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
918       int64_t kernel_cols =
919           one_dim_convolution
920               ? 1
921               : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
922       int64_t kernel_channels =
923           kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
924       int64_t kernel_filters =
925           kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
926 
927       // Output tensor.
928       const Shape& convolution_shape = convolution->shape();
929       int64_t output_rows =
930           convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
931       int64_t output_cols = one_dim_convolution
932                                 ? 1
933                                 : convolution_shape.dimensions(
934                                       dnums.output_spatial_dimensions(1));
935 
936       // Extract the window stride for the convolution.
937       const Window& window = convolution->window();
938       int64_t row_stride = window.dimensions(0).stride();
939       int64_t col_stride =
940           one_dim_convolution ? 1 : window.dimensions(1).stride();
941 
942       int64_t padding_top = window.dimensions(0).padding_low();
943       int64_t padding_bottom = window.dimensions(0).padding_high();
944       int64_t padding_left =
945           one_dim_convolution ? 0 : window.dimensions(1).padding_low();
946       int64_t padding_right =
947           one_dim_convolution ? 0 : window.dimensions(1).padding_high();
948 
949       int64_t lhs_row_dilation = window.dimensions(0).base_dilation();
950       int64_t lhs_col_dilation =
951           one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
952       int64_t rhs_row_dilation = window.dimensions(0).window_dilation();
953       int64_t rhs_col_dilation =
954           one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
955 
956       PrimitiveType primitive_type = lhs->shape().element_type();
957       llvm::Type* ir_ptr_type = primitive_type == F16
958                                     ? b_.getHalfTy()->getPointerTo()
959                                     : b_.getFloatTy()->getPointerTo();
960       bool multi_threaded =
961           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
962       bool use_mkl_dnn =
963           hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
964 
965       // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
966       // potential race condition by setting the omp_num_threads.
967       const char* fn_name =
968           primitive_type == F16
969               ? (multi_threaded
970                      ? runtime::kEigenConvF16SymbolName
971                      : runtime::kEigenSingleThreadedConvF16SymbolName)
972               : (multi_threaded
973                      ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
974                                     : runtime::kEigenConvF32SymbolName)
975                      : runtime::kEigenSingleThreadedConvF32SymbolName);
976       if (!multi_threaded && use_mkl_dnn) {
977         LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
978                         "conv2d function.";
979       }
980       EmitCallToFunc(fn_name,
981                      {
982                          GetExecutableRunOptionsArgument(),
983                          BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
984                          BitCast(lhs_address, ir_ptr_type),
985                          BitCast(rhs_address, ir_ptr_type),
986                          b_.getInt64(input_batch),
987                          b_.getInt64(input_rows),
988                          b_.getInt64(input_cols),
989                          b_.getInt64(input_channels),
990                          b_.getInt64(kernel_rows),
991                          b_.getInt64(kernel_cols),
992                          b_.getInt64(kernel_channels),
993                          b_.getInt64(kernel_filters),
994                          b_.getInt64(output_rows),
995                          b_.getInt64(output_cols),
996                          b_.getInt64(row_stride),
997                          b_.getInt64(col_stride),
998                          b_.getInt64(padding_top),
999                          b_.getInt64(padding_bottom),
1000                          b_.getInt64(padding_left),
1001                          b_.getInt64(padding_right),
1002                          b_.getInt64(lhs_row_dilation),
1003                          b_.getInt64(lhs_col_dilation),
1004                          b_.getInt64(rhs_row_dilation),
1005                          b_.getInt64(rhs_col_dilation),
1006                      },
1007                      b_.getVoidTy(), /*does_not_throw=*/true,
1008                      /*only_accesses_arg_memory=*/true);
1009 
1010       return Status::OK();
1011     }
1012   }
1013 
1014   // This is a completely un-optimized version of convolution just to
1015   // have an early version that works. E.g. the input index and
1016   // padding calculation is not hoisted out of the inner loop.
1017   //
1018   // See the description of convolution in the XLA documentation for the pseudo
1019   // code for convolution.
1020   return DefaultAction(convolution);
1021 }
1022 
HandleFft(HloInstruction * fft)1023 Status IrEmitter::HandleFft(HloInstruction* fft) {
1024   auto operand = fft->operand(0);
1025   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1026       /*instruction=*/*fft, /*operands=*/{operand},
1027       /*supported_types=*/{F32, F64, C64, C128}));
1028   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1029   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1030   VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1031   VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1032 
1033   llvm::Value* operand_address = GetEmittedValueFor(operand);
1034   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1035 
1036   const std::vector<int64>& fft_length = fft->fft_length();
1037   int64_t input_batch = 1;
1038   for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1039     input_batch *= fft->shape().dimensions(i);
1040   }
1041 
1042   // Args have been computed, make the call.
1043   llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1044   bool multi_threaded_eigen =
1045       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1046   const char* fn_name = multi_threaded_eigen
1047                             ? runtime::kEigenFftSymbolName
1048                             : runtime::kEigenSingleThreadedFftSymbolName;
1049   const int fft_rank = fft_length.size();
1050   EmitCallToFunc(
1051       fn_name,
1052       {GetExecutableRunOptionsArgument(),
1053        BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1054        BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1055        b_.getInt32(operand->shape().element_type() == F64 ||
1056                    operand->shape().element_type() == C128),
1057        b_.getInt32(fft_rank), b_.getInt64(input_batch),
1058        b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1059        b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1060        b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1061       b_.getVoidTy(), /*does_not_throw=*/true,
1062       /*only_accesses_arg_memory=*/false,
1063       /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1064 
1065   return Status::OK();
1066 }
1067 
HandleAllReduceSingleReplica(HloInstruction * crs)1068 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1069   // When there is a single replica, a cross replica sum is the identity
1070   // function, and the buffer assignment expects a copy.
1071   //
1072   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1073   // in algebraic-simplifier, but currently on some platforms
1074   // HloModuleConfig::num_replicas changes between when the module is compiled
1075   // and when it's run.
1076   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1077 
1078   // CRS with one operand and one replica is simply the identity function.
1079   if (crs->operand_count() == 1) {
1080     return EmitMemcpy(*crs->operand(0), *crs);
1081   }
1082 
1083   // CRS with multiple operands and one replica produces a (one-deep) tuple.
1084   std::vector<llvm::Value*> operand_ptrs;
1085   for (int64_t i = 0; i < crs->operand_count(); ++i) {
1086     llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1087     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1088                         assignment_.GetUniqueSlice(crs, {i}));
1089 
1090     const Shape& operand_shape = crs->operand(i)->shape();
1091     CHECK(operand_shape.IsArray())
1092         << "Operands to all-reduce must be arrays: " << crs->ToString();
1093     operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1094 
1095     // TODO(b/63762267): Be more aggressive about specifying alignment.
1096     MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1097            /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1098   }
1099   llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1100   return Status::OK();
1101 }
1102 
HandleAllReduceMultipleReplica(HloInstruction * crs)1103 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1104   CHECK_GE(crs->operand_count(), 1);
1105   PrimitiveType datatype = crs->operand(0)->shape().element_type();
1106   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1107 
1108   bool is_datatype_supported = [&] {
1109     // TODO(cheshire): Fix duplication wrt. cpu_runtime
1110     switch (datatype) {
1111       case PRED:
1112       case S8:
1113       case U8:
1114       case S32:
1115       case U32:
1116       case S64:
1117       case U64:
1118       case F16:
1119       case F32:
1120       case F64:
1121         return true;
1122       default:
1123         return false;
1124     }
1125   }();
1126 
1127   if (!is_datatype_supported) {
1128     return Unimplemented("AllReduce for datatype '%s' is not supported",
1129                          primitive_util::LowercasePrimitiveTypeName(datatype));
1130   }
1131 
1132   if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1133     return Unimplemented("AllReduce for computation '%s' is not supported",
1134                          crs->to_apply()->ToString());
1135   }
1136 
1137   std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1138   int32_t replica_groups_size = replica_groups.size();
1139   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1140 
1141   bool is_tuple = crs->operand_count() > 1;
1142   std::vector<llvm::Value*> input_buffer_ptrs;
1143   std::vector<llvm::Value*> output_buffer_ptrs;
1144   if (is_tuple) {
1145     CHECK(crs->shape().IsTuple());
1146 
1147     for (int64_t i = 0; i < crs->operand_count(); i++) {
1148       const HloInstruction* op = crs->operand(i);
1149       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1150                           assignment_.GetUniqueSlice(crs, {i}));
1151       const Shape& operand_shape = crs->operand(i)->shape();
1152       CHECK(operand_shape.IsArray())
1153           << "Operands to all-reduce must be arrays: " << crs->ToString();
1154       output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1155       input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1156     }
1157   } else {
1158     Shape shape = crs->operand(0)->shape();
1159     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1160                         assignment_.GetUniqueSlice(crs->operand(0), {}));
1161     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1162                         assignment_.GetUniqueSlice(crs, {}));
1163     input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1164     output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1165   }
1166 
1167   llvm::Value* input_buffers =
1168       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1169   llvm::Value* output_buffers =
1170       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1171 
1172   int32_t shape_length;
1173   TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1174                       llvm_ir::EncodeSelfDescribingShapeConstant(
1175                           crs->shape(), &shape_length, &b_));
1176 
1177   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1178   EmitCallToFunc(
1179       runtime::kAllReduceSymbolName,
1180       {/*run_options=*/GetExecutableRunOptionsArgument(),
1181        /*replica_groups=*/replica_groups_v,
1182        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1183 
1184        /*channel_id_present=*/
1185        b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1186        /*op_id=*/
1187        b_.getInt64(crs->channel_id().has_value()
1188                        ? *crs->channel_id()
1189                        : crs->GetModule()->unique_id()),
1190        /*reduction_kind=*/
1191        b_.getInt32(
1192            static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
1193        /*shape_ptr=*/shape_ptr,
1194        /*shape_length=*/b_.getInt32(shape_length),
1195        /*num_buffers=*/b_.getInt32(crs->operand_count()),
1196        /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1197        /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1198       b_.getVoidTy());
1199 
1200   return Status::OK();
1201 }
1202 
HandleAllReduce(HloInstruction * crs)1203 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1204   if (hlo_module_config_.replica_count() == 1) {
1205     return HandleAllReduceSingleReplica(crs);
1206   }
1207   return HandleAllReduceMultipleReplica(crs);
1208 }
1209 
HandleAllToAll(HloInstruction * instruction)1210 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1211   auto* instr = Cast<HloAllToAllInstruction>(instruction);
1212   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1213   CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1214       << "Only tuple AllToAll is supported";
1215 
1216   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1217   std::string replica_groups =
1218       ReplicaGroupsToString(instruction->replica_groups());
1219   int32_t replica_groups_size = replica_groups.size();
1220   llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1221 
1222   int64_t buffer_size = -1;
1223   std::vector<llvm::Value*> input_buffer_ptrs;
1224   std::vector<llvm::Value*> output_buffer_ptrs;
1225 
1226   for (int64_t i = 0; i < instruction->operand_count(); i++) {
1227     const HloInstruction* op = instruction->operand(i);
1228     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1229                         assignment_.GetUniqueSlice(instruction, {i}));
1230     const Shape& operand_shape = instruction->operand(i)->shape();
1231     CHECK(operand_shape.IsArray())
1232         << "Operands to all-to-all must be arrays: " << instruction->ToString();
1233     output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1234     input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1235     CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1236     buffer_size = out_slice.size();
1237   }
1238 
1239   llvm::Value* input_buffers =
1240       EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1241   llvm::Value* output_buffers =
1242       EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1243 
1244   EmitCallToFunc(
1245       runtime::kAllToAllSymbolName,
1246       {/*run_options=*/GetExecutableRunOptionsArgument(),
1247        /*channel_id_present=*/
1248        b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
1249        /*op_id=*/
1250        b_.getInt64(instruction->channel_id().has_value()
1251                        ? *instruction->channel_id()
1252                        : instruction->GetModule()->unique_id()),
1253        /*replica_groups=*/replica_groups_v,
1254        /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1255        /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1256        /*buffer_size=*/b_.getInt64(buffer_size),
1257        /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1258        /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1259       b_.getVoidTy());
1260 
1261   llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1262   return Status::OK();
1263 }
1264 
HandleCollectivePermute(HloInstruction * crs)1265 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1266   auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1267   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1268   std::string source_target_pairs = absl::StrJoin(
1269       instr->source_target_pairs(), ",", absl::PairFormatter("="));
1270   llvm::Value* source_target_pairs_v =
1271       b_.CreateGlobalStringPtr(source_target_pairs);
1272 
1273   Shape shape = crs->operand(0)->shape();
1274 
1275   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1276                       assignment_.GetUniqueSlice(crs->operand(0), {}));
1277   llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1278 
1279   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1280                       assignment_.GetUniqueSlice(crs, {}));
1281   llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1282 
1283   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1284   EmitCallToFunc(
1285       runtime::kCollectivePermuteSymbolName,
1286       {/*run_options=*/GetExecutableRunOptionsArgument(),
1287        /*channel_id_present=*/
1288        b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1289        /*op_id=*/
1290        b_.getInt64(crs->channel_id().has_value()
1291                        ? *crs->channel_id()
1292                        : crs->GetModule()->unique_id()),
1293        /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1294        /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1295        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1296        /*source_target_pairs=*/source_target_pairs_v,
1297        /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1298       b_.getVoidTy());
1299 
1300   return Status::OK();
1301 }
1302 
HandleReplicaId(HloInstruction * hlo)1303 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1304   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1305   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1306                       assignment_.GetUniqueSlice(hlo, {}));
1307   llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1308   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1309   EmitCallToFunc(
1310       runtime::kReplicaIdSymbolName,
1311       {/*run_options=*/GetExecutableRunOptionsArgument(),
1312        /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1313       b_.getVoidTy());
1314   return Status::OK();
1315 }
1316 
HandleParameter(HloInstruction * parameter)1317 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1318   VLOG(2) << "HandleParameter: " << parameter->ToString();
1319   return EmitTargetAddressForOp(parameter);
1320 }
1321 
1322 // Returns true if the relative order of the unreduced dimensions stays the same
1323 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1324 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1325   DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1326 
1327   // Maps dimensions that were not reduced from their dimension numbers in the
1328   // source shape to their dimensions numbers in the destination shape.
1329   //
1330   // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1331   // [0->0, 3->1].
1332   absl::flat_hash_map<int64, int64> unreduced_dim_map;
1333 
1334   absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1335                                           reduce.dimensions().end());
1336 
1337   const Shape& operand_shape = reduce.operand(0)->shape();
1338   const Shape& result_shape = reduce.shape();
1339 
1340   int64_t delta = 0;
1341   for (int64_t i = 0; i < operand_shape.dimensions_size(); i++) {
1342     if (reduced_dims.contains(i)) {
1343       delta++;
1344     } else {
1345       InsertOrDie(&unreduced_dim_map, i, i - delta);
1346     }
1347   }
1348 
1349   // Iterate dimensions minor to major and check that the corresponding
1350   // dimensions in the source and target shapes are equivalent.
1351   int64_t result_dim_idx = 0;
1352   for (int64_t operand_dim_idx = 0;
1353        operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1354     int64_t operand_dim =
1355         operand_shape.layout().minor_to_major(operand_dim_idx);
1356     if (!reduced_dims.contains(operand_dim)) {
1357       if (FindOrDie(unreduced_dim_map, operand_dim) !=
1358           result_shape.layout().minor_to_major(result_dim_idx++)) {
1359         return false;
1360       }
1361     }
1362   }
1363 
1364   CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1365 
1366   return true;
1367 }
1368 
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1369 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1370     HloComputation* function, string* failure_reason) const {
1371   CHECK_EQ(function->num_parameters(), 2);
1372 
1373   auto root_instruction = function->root_instruction();
1374   CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1375 
1376   if (root_instruction->operand_count() != 2) {
1377     *failure_reason = "root instruction is not a binary operation";
1378     return nullptr;
1379   }
1380 
1381   const Shape& root_shape = root_instruction->shape();
1382   if (ShapeUtil::ElementIsComplex(root_shape)) {
1383     // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1384     // Complex multiply would be more challenging. We could perhaps use a
1385     // strided load to get all reals in a vector, all images in a vector, or use
1386     // CreateShuffleVector on a bitcast to float x [2N].
1387     *failure_reason = "complex values not supported";
1388     return nullptr;
1389   }
1390   bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1391   bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1392   bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1393 
1394   auto lhs = root_instruction->operand(0);
1395   auto rhs = root_instruction->operand(1);
1396 
1397   auto param_0 = function->parameter_instruction(0);
1398   auto param_1 = function->parameter_instruction(1);
1399   if (!(lhs == param_0 && rhs == param_1) &&
1400       !(rhs == param_0 && lhs == param_1)) {
1401     *failure_reason =
1402         "root instruction is not a binary operation on the incoming arguments";
1403     return nullptr;
1404   }
1405 
1406   CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1407 
1408   // This is visually similar to ElementalIrEmitter, though conceptually we're
1409   // doing something different here.  ElementalIrEmitter emits scalar operations
1410   // while these emit scalar or vector operations depending on the type of the
1411   // operands. See CreateShardedVectorType for the actual types in use here.
1412   switch (root_instruction->opcode()) {
1413     default:
1414       *failure_reason = "did not recognize root instruction opcode";
1415       return nullptr;
1416 
1417     case HloOpcode::kAdd:
1418       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1419                                 llvm::Value* rhs) {
1420         return root_is_integral ? b->CreateAdd(lhs, rhs)
1421                                 : b->CreateFAdd(lhs, rhs);
1422       };
1423 
1424     case HloOpcode::kMultiply:
1425       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1426                                 llvm::Value* rhs) {
1427         return root_is_integral ? b->CreateMul(lhs, rhs)
1428                                 : b->CreateFMul(lhs, rhs);
1429       };
1430 
1431     case HloOpcode::kAnd:
1432       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1433         return b->CreateAnd(lhs, rhs);
1434       };
1435 
1436     case HloOpcode::kOr:
1437       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1438         return b->CreateOr(lhs, rhs);
1439       };
1440 
1441     case HloOpcode::kXor:
1442       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1443         return b->CreateXor(lhs, rhs);
1444       };
1445 
1446     case HloOpcode::kMaximum:
1447       return [root_is_floating_point, root_is_signed, this](
1448                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1449                  llvm::Value* rhs) -> llvm::Value* {
1450         if (root_is_floating_point) {
1451           return llvm_ir::EmitFloatMax(
1452               lhs, rhs, b,
1453               hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1454         }
1455 
1456         return b->CreateSelect(
1457             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1458                                          : llvm::ICmpInst::ICMP_UGE,
1459                           lhs, rhs),
1460             lhs, rhs);
1461       };
1462 
1463     case HloOpcode::kMinimum:
1464       return [root_is_floating_point, root_is_signed, this](
1465                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1466                  llvm::Value* rhs) -> llvm::Value* {
1467         if (root_is_floating_point) {
1468           return llvm_ir::EmitFloatMin(
1469               lhs, rhs, b,
1470               hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max());
1471         }
1472 
1473         return b->CreateSelect(
1474             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1475                                          : llvm::ICmpInst::ICMP_ULE,
1476                           lhs, rhs),
1477             lhs, rhs);
1478       };
1479   }
1480 }
1481 
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1482 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1483     PrimitiveType element_type, unsigned element_count) {
1484   int vector_register_size_in_elements =
1485       target_machine_features_.vector_register_byte_size(
1486           *compute_function_->function()) /
1487       ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1488 
1489   ShardedVectorType sharded_vector_type;
1490   llvm::Type* element_ir_type =
1491       llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1492 
1493   for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1494     // For every power of two present in element_count, we generate one or more
1495     // vector or scalar types.
1496     const unsigned current_size_fragment = 1u << i;
1497     if (!(element_count & current_size_fragment)) {
1498       // Power of two not present in element_count.
1499       continue;
1500     }
1501 
1502     if (current_size_fragment == 1) {
1503       // Single element, use a scalar type.
1504       sharded_vector_type.push_back(element_ir_type);
1505       continue;
1506     }
1507 
1508     // Lower "current_size_fragment" number of elements using (as few as
1509     // possible) vector registers.
1510 
1511     if (current_size_fragment >= vector_register_size_in_elements) {
1512       auto vector_type = llvm::VectorType::get(
1513           element_ir_type, vector_register_size_in_elements, false);
1514       sharded_vector_type.insert(
1515           sharded_vector_type.end(),
1516           current_size_fragment / vector_register_size_in_elements,
1517           vector_type);
1518 
1519       // Both current_size_fragment and vector_register_size_in_elements are
1520       // powers of two.
1521       CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1522       continue;
1523     }
1524 
1525     // For now we assume that vector_register_size_in_elements and lower powers
1526     // of two are all legal vector sizes (or at least can be lowered easily by
1527     // LLVM).
1528     sharded_vector_type.push_back(
1529         llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1530   }
1531   return sharded_vector_type;
1532 }
1533 
1534 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,llvm::Align element_alignment)1535 IrEmitter::EmitInnerLoopForVectorizedReduction(
1536     const ReductionGenerator& reduction_generator,
1537     const llvm_ir::IrArray::Index& output_index,
1538     const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1539     HloInstruction* arg, absl::Span<const int64> dimensions,
1540     llvm::Align element_alignment) {
1541   ShardedVector accumulator;
1542   accumulator.reserve(accumulator_type.size());
1543   for (auto accumulator_shard_type : accumulator_type) {
1544     accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1545         accumulator_shard_type, "accumulator", &b_, 0));
1546   }
1547 
1548   llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1549 
1550   for (llvm::Value* accumulator_shard : accumulator) {
1551     llvm::Value* initial_value;
1552     auto shard_type = accumulator_shard->getType()->getPointerElementType();
1553     if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1554       initial_value =
1555           VectorSplat(vector_type->getElementCount(), init_value_ssa);
1556     } else {
1557       initial_value = init_value_ssa;
1558     }
1559 
1560     AlignedStore(initial_value, accumulator_shard, element_alignment);
1561   }
1562 
1563   llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1564                                            &b_);
1565   std::vector<llvm::Value*> input_multi_index =
1566       reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1567                                                        "reduction_dim");
1568 
1569   SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1570 
1571   llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1572   llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1573 
1574   for (auto& i : input_multi_index) {
1575     if (i == nullptr) {
1576       i = *it++;
1577     }
1578   }
1579   CHECK(output_index.end() == it);
1580   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1581                                       b_.getInt64Ty());
1582 
1583   llvm::Value* input_address = BitCast(
1584       arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1585 
1586   for (int i = 0; i < accumulator.size(); i++) {
1587     auto input_address_typed =
1588         BitCast(input_address, accumulator[i]->getType());
1589     auto current_accumulator_value =
1590         AlignedLoad(accumulator[i], element_alignment);
1591     auto addend = AlignedLoad(input_address_typed, element_alignment);
1592     arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1593 
1594     auto reduced_result =
1595         reduction_generator(&b_, current_accumulator_value, addend);
1596     AlignedStore(reduced_result, accumulator[i], element_alignment);
1597 
1598     if (i != (accumulator.size() - 1)) {
1599       input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1600                                            input_address_typed, 1);
1601     }
1602   }
1603 
1604   SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1605 
1606   ShardedVector result_ssa;
1607   result_ssa.reserve(accumulator.size());
1608   for (auto accumulator_shard : accumulator) {
1609     result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1610   }
1611   return result_ssa;
1612 }
1613 
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1614 void IrEmitter::EmitShardedVectorStore(
1615     llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1616     llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1617   for (int i = 0; i < value_to_store.size(); i++) {
1618     auto store_address_typed =
1619         BitCast(store_address,
1620                 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1621 
1622     auto store_instruction =
1623         AlignedStore(value_to_store[i], store_address_typed, alignment);
1624     containing_array.AnnotateLoadStoreInstructionWithMetadata(
1625         store_instruction);
1626 
1627     if (i != (value_to_store.size() - 1)) {
1628       store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1629                                            store_address_typed, 1);
1630     }
1631   }
1632 }
1633 
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1634 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1635     HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1636     absl::Span<const int64> dimensions, HloComputation* function,
1637     string* failure_reason) {
1638   if (!reduce->shape().IsArray()) {
1639     *failure_reason = "vectorization of variadic reduce not implemented";
1640     return false;
1641   }
1642 
1643   if (!ReductionPreservesLayout(*reduce)) {
1644     return false;
1645   }
1646 
1647   ReductionGenerator reduction_generator =
1648       MatchReductionGenerator(function, failure_reason);
1649   if (!reduction_generator) {
1650     return false;
1651   }
1652 
1653   int vector_register_size_in_elements =
1654       target_machine_features_.vector_register_byte_size(
1655           *compute_function_->function()) /
1656       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1657   if (vector_register_size_in_elements == 0) {
1658     // Either we don't know the vector register width for the target or the
1659     // vector register is smaller than the size of the primitive type.
1660     return false;
1661   }
1662 
1663   int vectorization_factor_in_bytes =
1664       target_machine_features_.vectorization_factor_in_bytes();
1665 
1666   // We try to process vectorization_factor elements at the same time.
1667   const int vectorization_factor =
1668       vectorization_factor_in_bytes /
1669       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1670 
1671   bool is_reduction_over_minor_dimension = absl::c_linear_search(
1672       dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1673 
1674   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1675       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1676       MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1677 
1678   if (is_reduction_over_minor_dimension) {
1679     // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1680     *failure_reason = "reduction over minor dimension not implemented";
1681     return false;
1682   }
1683 
1684   CHECK(!reduce->shape().IsTuple());
1685   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1686 
1687   // We know we're not reducing over the most minor dimension, which means we
1688   // can lower the reduction loop as:
1689   //
1690   //  1. We're reducing over dimensions R0, R1.
1691   //  2. D0 is the most minor dimension.
1692   //  3. VS is the vectorization stride (we want to reduce this many elements at
1693   //     once)
1694   //
1695   //  for (d1 in D1) {
1696   //    for (d0 in D0 with stride VS) {
1697   //      vector_acc = init
1698   //      for (r1 in R1) {
1699   //        for (r0 in R0) {
1700   //          vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1701   //        }
1702   //      }
1703   //      output[d1, d0] = vector_acc
1704   //    }
1705   //  }
1706 
1707   llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1708   std::vector<llvm::Value*> array_multi_index(
1709       reduce->shape().dimensions_size());
1710   for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1711        --i) {
1712     int64_t dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1713     int64_t start_index = 0;
1714     int64_t end_index = reduce->shape().dimensions(dimension);
1715     std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1716         start_index, end_index, absl::StrFormat("dim.%d", dimension));
1717     array_multi_index[dimension] = loop->GetIndVarValue();
1718   }
1719 
1720   int64_t innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1721   int64_t innermost_dimension_size =
1722       reduce->shape().dimensions(innermost_dimension);
1723 
1724   if (llvm::BasicBlock* innermost_body_bb =
1725           loop_nest.GetInnerLoopBodyBasicBlock()) {
1726     SetToFirstInsertPoint(innermost_body_bb, &b_);
1727   }
1728 
1729   auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1730 
1731   if (innermost_dimension_size >= vectorization_factor) {
1732     int64_t start_index = 0;
1733     int64_t end_index = (innermost_dimension_size / vectorization_factor) *
1734                         vectorization_factor;
1735     std::unique_ptr<llvm_ir::ForLoop> loop =
1736         loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1737                           absl::StrFormat("dim.%d", innermost_dimension));
1738     array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1739 
1740     SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1741 
1742     ShardedVectorType vector_type = CreateShardedVectorType(
1743         reduce->shape().element_type(), vectorization_factor);
1744     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1745                                         b_.getInt64Ty());
1746     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1747                         EmitInnerLoopForVectorizedReduction(
1748                             reduction_generator, array_index, vector_type,
1749                             init_value, arg, dimensions, element_alignment));
1750 
1751     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1752     llvm::Value* output_address =
1753         target_array.EmitArrayElementAddress(array_index, &b_);
1754     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1755                            target_array);
1756 
1757     if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1758       CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1759       b_.SetInsertPoint(exit_terminator);
1760     } else {
1761       CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1762       b_.SetInsertPoint(loop->GetExitBasicBlock());
1763     }
1764   }
1765 
1766   // Since we increment the stride for the inner dimension by more than 1, we
1767   // may need to peel out an "epilogue" iteration to get the remaining elements
1768   // in the following case:
1769   if (innermost_dimension_size % vectorization_factor) {
1770     // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1771     array_multi_index[innermost_dimension] =
1772         b_.getInt64(innermost_dimension_size -
1773                     (innermost_dimension_size % vectorization_factor));
1774 
1775     ShardedVectorType vector_type = CreateShardedVectorType(
1776         reduce->shape().element_type(),
1777         innermost_dimension_size % vectorization_factor);
1778     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1779                                         b_.getInt64Ty());
1780     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1781                         EmitInnerLoopForVectorizedReduction(
1782                             reduction_generator, array_index, vector_type,
1783                             init_value, arg, dimensions, element_alignment));
1784 
1785     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1786     llvm::Value* output_address =
1787         target_array.EmitArrayElementAddress(array_index, &b_);
1788     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1789                            target_array);
1790   }
1791 
1792   if (outermost_loop_exit_block) {
1793     b_.SetInsertPoint(outermost_loop_exit_block);
1794   }
1795 
1796   return true;
1797 }
1798 
HandleReduce(HloInstruction * reduce)1799 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1800   auto arg = reduce->mutable_operand(0);
1801   auto init_value = reduce->mutable_operand(1);
1802   absl::Span<const int64> dimensions(reduce->dimensions());
1803   HloComputation* function = reduce->to_apply();
1804   if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1805     string vectorization_failure_reason;
1806     TF_ASSIGN_OR_RETURN(
1807         bool vectorization_successful,
1808         EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1809                              &vectorization_failure_reason));
1810     if (vectorization_successful) {
1811       VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1812               << "\n";
1813       return Status::OK();
1814     } else {
1815       VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1816               << vectorization_failure_reason;
1817     }
1818   }
1819 
1820   return DefaultAction(reduce);
1821 }
1822 
HandleSend(HloInstruction * send)1823 Status IrEmitter::HandleSend(HloInstruction* send) {
1824   // TODO(b/33942983): Support Send/Recv on CPU.
1825   return Unimplemented("Send is not implemented on CPU.");
1826 }
1827 
HandleSendDone(HloInstruction * send_done)1828 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1829   // TODO(b/33942983): Support Send/Recv on CPU.
1830   return Unimplemented("Send-done is not implemented on CPU.");
1831 }
1832 
HandleScatter(HloInstruction *)1833 Status IrEmitter::HandleScatter(HloInstruction*) {
1834   return Unimplemented("Scatter is not implemented on CPUs.");
1835 }
1836 
HandleSlice(HloInstruction * slice)1837 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1838   VLOG(2) << "HandleSlice: " << slice->ToString();
1839   auto operand = slice->operand(0);
1840   // The code below emits a sequential loop nest. For the parallel backend, use
1841   // ParallelLoopEmitter which respects dynamic loop bounds.
1842   if (ShouldEmitParallelLoopFor(*slice)) {
1843     return DefaultAction(slice);
1844   }
1845 
1846   // The code below assumes the layouts are equal.
1847   if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1848     return DefaultAction(slice);
1849   }
1850 
1851   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1852 
1853   if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1854     return Status::OK();
1855   }
1856 
1857   const Layout& layout = operand->shape().layout();
1858   const int64_t num_dims = operand->shape().dimensions_size();
1859 
1860   // The slice lowering finds maximal contiguous blocks of memory that can be
1861   // copied from the source to the target. This is done by looking at the
1862   // source/target layout in minor to major order and do the following:
1863   //
1864   // * Find an initial segment of dimensions along which the slice uses the
1865   //   whole dimension. These are the "inner" dimensions and can be folded into
1866   //   the memcpy.
1867   //
1868   // * Of the remaining dimensions decide which ones require loops.
1869   //
1870   // * Implement the memcpy within the innermost loop.
1871 
1872   absl::flat_hash_set<int64> inner_dims;
1873   for (int64_t dim : LayoutUtil::MinorToMajor(layout)) {
1874     if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1875       break;
1876     }
1877     inner_dims.insert(dim);
1878   }
1879 
1880   const bool is_trivial_copy = (inner_dims.size() == num_dims);
1881   if (is_trivial_copy) {
1882     if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1883       return DefaultAction(slice);
1884     } else {
1885       return EmitMemcpy(*slice, *operand);
1886     }
1887   }
1888 
1889   // The memcpy will copy elements that are logically this shape (allowed to be
1890   // scalar).
1891   const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1892       [&inner_dims](int64_t dim) { return inner_dims.contains(dim); },
1893       operand->shape());
1894 
1895   const int64_t primitive_elements_per_logical_element =
1896       ShapeUtil::ElementsIn(logical_element_shape);
1897 
1898   // memcpy_dim is the innermost (in terms of layout) dimension for which the
1899   // slice does *not* just copy all the elements along the dimension.
1900   const int64_t memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1901 
1902   const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1903   // The number of logical elements that can be copied in a single call
1904   // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1905   // stride.
1906   const int64_t memcpy_logical_elements =
1907       memcpy_is_contiguous
1908           ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1909           : 1;
1910 
1911   // Determine the dimensions that get lowered as loops.
1912   std::vector<int64> outer_dims;
1913   for (int64_t i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1914     outer_dims.push_back(LayoutUtil::Major(layout, i));
1915   }
1916 
1917   // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1918   // needs to be wrapped around a loop as well.
1919   if (!memcpy_is_contiguous) {
1920     outer_dims.push_back(memcpy_dim);
1921   }
1922 
1923   llvm_ir::IrArray target_array = GetIrArrayFor(slice);
1924 
1925   const int64_t num_outer_loops = outer_dims.size();
1926   llvm_ir::ForLoopNest loops(IrName(slice), &b_);
1927   std::vector<llvm::Value*> target_multi_index =
1928       loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
1929 
1930   // Only the indices for the outer dimensions have been initialized in
1931   // target_index. The rest of the indices should get initialized to 0, since
1932   // for the rest of the dimensions the copy writes to the full dimension.
1933   std::replace(target_multi_index.begin(), target_multi_index.end(),
1934                static_cast<llvm::Value*>(nullptr),
1935                static_cast<llvm::Value*>(b_.getInt64(0)));
1936   llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
1937                                        b_.getInt64Ty());
1938 
1939   if (num_outer_loops > 0) {
1940     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1941   }
1942 
1943   llvm_ir::IrArray source_array = GetIrArrayFor(operand);
1944   const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
1945       /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
1946       /*strides=*/slice->slice_strides(), /*builder=*/&b_);
1947 
1948   llvm::Value* memcpy_dest =
1949       target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
1950   llvm::Value* memcpy_source =
1951       source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
1952 
1953   const int64_t memcpy_elements =
1954       primitive_elements_per_logical_element * memcpy_logical_elements;
1955 
1956   EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
1957                        slice->shape().element_type(), target_array,
1958                        source_array);
1959 
1960   if (VLOG_IS_ON(2)) {
1961     const int64_t memcpy_bytes =
1962         ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
1963     VLOG(2) << "  emitted copy of " << memcpy_bytes << " bytes inside "
1964             << num_outer_loops << " loops";
1965   }
1966 
1967   if (num_outer_loops > 0) {
1968     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1969   }
1970 
1971   return Status::OK();
1972 }
1973 
HandleDynamicSlice(HloInstruction * dynamic_slice)1974 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1975   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
1976     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
1977     return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
1978   }
1979   return DefaultAction(dynamic_slice);
1980 }
1981 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1982 Status IrEmitter::HandleDynamicUpdateSlice(
1983     HloInstruction* dynamic_update_slice) {
1984   auto update = dynamic_update_slice->operand(1);
1985   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
1986     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1987     return EmitMemcpy(*update, *dynamic_update_slice);
1988   } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
1989                                                    assignment_)) {
1990     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1991     auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
1992     return llvm_ir::EmitDynamicUpdateSliceInPlace(
1993         operands, GetIrArrayFor(dynamic_update_slice),
1994         IrName(dynamic_update_slice, "in_place"), &b_);
1995   }
1996   return DefaultAction(dynamic_update_slice);
1997 }
1998 
HandleRecv(HloInstruction * recv)1999 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2000   // TODO(b/33942983): Support Send/Recv on CPU.
2001   return Unimplemented("Recv is not implemented on CPU.");
2002 }
2003 
HandleRecvDone(HloInstruction * recv_done)2004 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2005   // TODO(b/33942983): Support Send/Recv on CPU.
2006   return Unimplemented("Recv-done is not implemented on CPU.");
2007 }
2008 
HandlePad(HloInstruction * pad)2009 Status IrEmitter::HandlePad(HloInstruction* pad) {
2010   // CPU backend does not properly handle negative padding but this is ok
2011   // because negative padding should be removed by the algebraic simplifier.
2012   for (auto& padding_dimension : pad->padding_config().dimensions()) {
2013     if (padding_dimension.edge_padding_low() < 0 ||
2014         padding_dimension.edge_padding_high() < 0) {
2015       return InternalErrorStrCat(
2016           "Encountered negative padding in IrEmitter on CPU. "
2017           "This should have been eliminated at the HLO level. ",
2018           pad->ToString());
2019     }
2020   }
2021 
2022   // First, fill in the padding value to all output elements.
2023   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2024       pad, "initialize",
2025       [this, pad](const llvm_ir::IrArray::Index& target_index) {
2026         const HloInstruction* padding_value = pad->operand(1);
2027         llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2028         return Load(padding_value_addr);
2029       }));
2030 
2031   // Create a loop to iterate over the operand elements and update the output
2032   // locations where the operand elements should be stored.
2033   llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2034   const HloInstruction* operand = pad->operand(0);
2035   const llvm_ir::IrArray::Index operand_index =
2036       loops.AddLoopsForShape(operand->shape(), "operand");
2037 
2038   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2039 
2040   // Load an element from the operand.
2041   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2042   llvm::Value* operand_data =
2043       operand_array.EmitReadArrayElement(operand_index, &b_);
2044 
2045   // Compute the output index the operand element should be assigned to.
2046   // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2047   const PaddingConfig& padding_config = pad->padding_config();
2048   std::vector<llvm::Value*> output_multi_index;
2049   for (size_t i = 0; i < operand_index.size(); ++i) {
2050     llvm::Value* offset =
2051         Mul(operand_index[i],
2052             b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2053     llvm::Value* index = Add(
2054         offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2055     output_multi_index.push_back(index);
2056   }
2057 
2058   // Store the operand element to the computed output location.
2059   llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2060   llvm_ir::IrArray::Index output_index(
2061       output_multi_index, output_array.GetShape(), operand_index.GetType());
2062   output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2063 
2064   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2065   return Status::OK();
2066 }
2067 
HandleFusion(HloInstruction * fusion)2068 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2069   auto* root = fusion->fused_expression_root();
2070   if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2071     VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2072     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2073     FusedIrEmitter fused_emitter(&elemental_emitter);
2074     BindFusionArguments(fusion, &fused_emitter);
2075 
2076     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2077     // Delegate to common implementation of fused in-place dynamic-update-slice.
2078     return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2079         fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2080   } else if (fusion->IsLoopFusion()) {
2081     VLOG(3) << "HandleFusion kLoop";
2082     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2083     FusedIrEmitter fused_emitter(&elemental_emitter);
2084     BindFusionArguments(fusion, &fused_emitter);
2085     TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2086                                             fusion->fused_expression_root()));
2087     return EmitTargetElementLoop(fusion, generator);
2088   } else if (fusion->IsOutputFusion()) {
2089     VLOG(3) << "HandleFusion kOutput";
2090     int64_t dot_op_index =
2091         root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2092     const HloInstruction* dot = root->operand(dot_op_index);
2093     CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2094         << dot->ToString() << "  "
2095         << fusion->fused_instructions_computation()->ToString();
2096 
2097     int64_t dot_lhs_param_number = dot->operand(0)->parameter_number();
2098     int64_t dot_rhs_param_number = dot->operand(1)->parameter_number();
2099     int64_t addend_param_number =
2100         root->operand(1 - dot_op_index)->parameter_number();
2101 
2102     Shape target_shape = fusion->shape();
2103     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2104     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2105 
2106     llvm_ir::IrArray lhs_array(
2107         GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2108     llvm_ir::IrArray rhs_array(
2109         GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2110     llvm_ir::IrArray addend_array(
2111         GetIrArrayFor(fusion->operand(addend_param_number)));
2112 
2113     TF_RETURN_IF_ERROR(EmitDotOperation(
2114         *dot, target_array, lhs_array, rhs_array, &addend_array,
2115         GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2116         hlo_module_config_, target_machine_features_));
2117     return Status::OK();
2118   } else {
2119     return Unimplemented("Fusion kind not implemented on CPU");
2120   }
2121 }
2122 
HandleCall(HloInstruction * call)2123 Status IrEmitter::HandleCall(HloInstruction* call) {
2124   HloComputation* computation = call->to_apply();
2125   llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2126 
2127   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2128 
2129   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2130     // ParallelTaskAssignment assigned partitions, emit call to
2131     // ParallelForkJoin.
2132     std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2133         {}, &b_, computation->name(),
2134         /*return_value_buffer=*/emitted_value_[call],
2135         /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2136         /*buffer_table_arg=*/GetBufferTableArgument(),
2137         /*profile_counters_arg=*/GetProfileCountersArgument());
2138 
2139     HloInstruction* root = computation->root_instruction();
2140     TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2141         call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2142         call_ir_function, computation->name()));
2143   } else {
2144     EmitGlobalCall(*computation, computation->name());
2145   }
2146 
2147   return Status::OK();
2148 }
2149 
HandleSliceToDynamic(HloInstruction * hlo)2150 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2151   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2152   std::vector<llvm::Value*> dynamic_dims;
2153   int32_t raw_data_size =
2154       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2155   llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2156   llvm::Value* raw_buffer =
2157       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2158   for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2159     const int64_t dim_index = i - 1;
2160     llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2161     llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
2162 
2163     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2164         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2165     b_.CreateStore(dyn_dim_size,
2166                    b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2167     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2168                                             /*isSigned=*/true,
2169                                             "i64_dyn_dim_size"));
2170   }
2171 
2172   llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2173   // Pseudo code for sliceToDynamic:
2174   //
2175   //   for (index i in dynamic_dim)
2176   //     dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2177   //     dest[dest_index] = source[i]
2178   auto loop_body_emitter =
2179       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2180     llvm::Value* source_element =
2181         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2182     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2183     // Delinearize the index based on the static shape.
2184     llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2185                                        &b_);
2186     data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2187     return Status::OK();
2188   };
2189   return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2190                               dynamic_dims, &b_)
2191       .EmitLoop(IrName(hlo));
2192 }
2193 
HandlePadToStatic(HloInstruction * hlo)2194 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2195   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2196 
2197   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2198                       assignment_.GetUniqueSlice(hlo, {0}));
2199   std::vector<llvm::Value*> dynamic_dims;
2200   std::vector<llvm::Value*> tuple_operand_ptrs;
2201   const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2202   const Shape& input_shape = hlo->operand(0)->shape();
2203   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2204   llvm_ir::IrArray data_array(data_address, data_shape);
2205   llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2206   llvm::Value* raw_buffer =
2207       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2208   int64_t raw_data_size =
2209       ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2210 
2211   // Put a placeholder for the data array's pointer
2212   tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2213   // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2214   // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2215   // Dynamic dimension sizes starts from output index 1.
2216   for (int64_t i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2217     // Read from the metadata section of the dynamic input (operand 0).
2218     const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2219     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2220     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2221                         assignment_.GetUniqueSlice(hlo, {i}));
2222     llvm::Value* dest_dim_size_address =
2223         EmitBufferPointer(dim_size_slice, data_shape);
2224     const int64_t dim_index = i - 1;
2225     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2226         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2227     llvm::Value* dyn_dim_size = b_.CreateLoad(
2228         b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2229         "dyn_dim_size");
2230     b_.CreateStore(dyn_dim_size,
2231                    b_.CreateBitCast(dest_dim_size_address,
2232                                     b_.getInt32Ty()->getPointerTo()));
2233     dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2234                                             /*isSigned=*/true,
2235                                             "i64_dyn_dim_size"));
2236     tuple_operand_ptrs.push_back(dest_dim_size_address);
2237   }
2238 
2239   // Pseudo code for padToStatic:
2240   //
2241   //   for (index i in dynamic_dim)
2242   //     source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2243   //     dest[i] = source[source_index]
2244   auto loop_body_emitter =
2245       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2246     llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2247     llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2248     llvm::Value* source_element =
2249         GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2250     data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2251     return Status::OK();
2252   };
2253   TF_RETURN_IF_ERROR(
2254       llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2255           .EmitLoop(IrName(hlo)));
2256 
2257   // Emit static tensor and dynamic sizes as one tuple.
2258   llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2259   return Status::OK();
2260 }
2261 
HandleTopK(HloInstruction * hlo)2262 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2263   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2264   const HloInstruction* input = hlo->operand(0);
2265   const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
2266   const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2267   TF_RET_CHECK(input->shape().element_type() == F32);
2268   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2269       hlo->shape().tuple_shapes(0).layout()));
2270   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2271       hlo->shape().tuple_shapes(1).layout()));
2272   TF_RET_CHECK(
2273       LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2274 
2275   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2276                       assignment_.GetUniqueSlice(hlo->operand(0), {}));
2277   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2278                       assignment_.GetUniqueSlice(hlo, {0}));
2279   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2280                       assignment_.GetUniqueSlice(hlo, {1}));
2281   llvm::Value* values_ptr =
2282       EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2283   llvm::Value* out_values_ptr =
2284       EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2285   llvm::Value* out_indices_ptr =
2286       EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2287   EmitCallToFunc(
2288       runtime::kTopKF32SymbolName,
2289       {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2290        b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2291        BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2292        BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2293        BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2294       b_.getVoidTy());
2295 
2296   llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2297                      &b_);
2298   return Status::OK();
2299 }
2300 
HandleCustomCall(HloInstruction * custom_call)2301 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2302   if (custom_call->custom_call_target() == "PadToStatic") {
2303     return HandlePadToStatic(custom_call);
2304   }
2305   if (custom_call->custom_call_target() == "SliceToDynamic") {
2306     return HandleSliceToDynamic(custom_call);
2307   }
2308   if (custom_call->custom_call_target() == "TopK") {
2309     return HandleTopK(custom_call);
2310   }
2311 
2312   auto typed_custom_call = Cast<HloCustomCallInstruction>(custom_call);
2313   switch (typed_custom_call->api_version()) {
2314     case CustomCallApiVersion::API_VERSION_ORIGINAL:
2315       break;
2316     case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
2317       // TODO(b/194529780): Support status-returning custom calls on CPU.
2318       return Unimplemented(
2319           "XLA CPU does not support custom calls that return a success/failure "
2320           "status");
2321     default:
2322       return InternalError(
2323           "Unknown custom-call API version enum value: %d (%s)",
2324           typed_custom_call->api_version(),
2325           CustomCallApiVersion_Name(typed_custom_call->api_version()));
2326   }
2327 
2328   absl::Span<HloInstruction* const> operands(custom_call->operands());
2329   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2330   llvm::AllocaInst* operands_alloca =
2331       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2332           i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2333   for (size_t i = 0; i < operands.size(); ++i) {
2334     const HloInstruction* operand = operands[i];
2335     llvm::Value* operand_as_i8ptr =
2336         PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2337     llvm::Value* slot_in_operands_alloca =
2338         InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2339     Store(operand_as_i8ptr, slot_in_operands_alloca);
2340   }
2341   if (emit_code_for_msan_) {
2342     // Mark the alloca as initialized for msan. The buffer gets read by the
2343     // custom callee, which might be msan-instrumented.
2344     // TODO(b/66051036): Run the msan instrumentation pass instead.
2345     const llvm::DataLayout& dl = module_->getDataLayout();
2346     llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2347     EmitCallToFunc(
2348         "__msan_unpoison",
2349         {PointerCast(operands_alloca, i8_ptr_type),
2350          llvm::ConstantInt::get(
2351              intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2352         b_.getVoidTy());
2353   }
2354 
2355   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2356   // Write the tuple table if the output is a tuple.
2357   if (custom_call->shape().IsTuple()) {
2358     std::vector<llvm::Value*> base_ptrs;
2359     for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2360          ++i) {
2361       const Shape& elem_shape =
2362           ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2363       TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2364       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2365                           assignment_.GetUniqueSlice(custom_call, {i}));
2366       llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2367       base_ptrs.push_back(addr);
2368     }
2369     llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2370   }
2371   auto* output_address_arg =
2372       PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2373 
2374   EmitCallToFunc(custom_call->custom_call_target(),
2375                  {output_address_arg, operands_alloca}, b_.getVoidTy());
2376 
2377   return Status::OK();
2378 }
2379 
HandleWhile(HloInstruction * xla_while)2380 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2381   // Precondition: Condition computation must return a scalar bool.
2382   HloComputation* condition = xla_while->while_condition();
2383   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2384                condition->root_instruction()->shape().element_type() == PRED)
2385       << "While condition computation must return bool; got: "
2386       << ShapeUtil::HumanString(condition->root_instruction()->shape());
2387   // Check that all while-related buffers share an allocation slice.
2388   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2389       xla_while->shape(),
2390       [this, &xla_while](const Shape& /*subshape*/,
2391                          const ShapeIndex& index) -> Status {
2392         auto check = [this](const HloInstruction* a, const HloInstruction* b,
2393                             const ShapeIndex& index) {
2394           const BufferAllocation::Slice slice_a =
2395               assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2396           const BufferAllocation::Slice slice_b =
2397               assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2398           if (slice_a != slice_b) {
2399             return InternalError(
2400                 "instruction %s %s does not share slice with "
2401                 "instruction %s %s",
2402                 a->ToString(), slice_a.ToString(), b->ToString(),
2403                 slice_b.ToString());
2404           }
2405           return Status::OK();
2406         };
2407         TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2408         TF_RETURN_IF_ERROR(check(
2409             xla_while, xla_while->while_condition()->parameter_instruction(0),
2410             index));
2411         TF_RETURN_IF_ERROR(
2412             check(xla_while, xla_while->while_body()->parameter_instruction(0),
2413                   index));
2414         TF_RETURN_IF_ERROR(check(
2415             xla_while, xla_while->while_body()->root_instruction(), index));
2416         return Status::OK();
2417       }));
2418 
2419   // Set emitted value to that of 'init' with which it shares an allocation.
2420   const HloInstruction* init = xla_while->operand(0);
2421   emitted_value_[xla_while] = GetEmittedValueFor(init);
2422 
2423   // Generating:
2424   //   while (Condition(while_result)) {
2425   //     // CopyInsertion pass inserts copies which enable 'while_result' to
2426   //     // be passed back in as 'Body' parameter.
2427   //     while_result = Body(while_result);  // Insert
2428   //   }
2429 
2430   // Terminates the current block with a branch to a while header.
2431   llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2432       module_->getContext(), IrName(xla_while, "header"),
2433       compute_function_->function());
2434   Br(header_bb);
2435   b_.SetInsertPoint(header_bb);
2436 
2437   // Calls the condition function to determine whether to proceed with the
2438   // body.  It must return a bool, so use the scalar call form.
2439   EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2440   llvm::Value* while_predicate = ICmpNE(
2441       Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2442       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2443 
2444   // Branches to the body or to the while exit depending on the condition.
2445   llvm::BasicBlock* body_bb =
2446       llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2447                                compute_function_->function());
2448   llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2449       module_->getContext(), IrName(xla_while, "exit"));
2450   CondBr(while_predicate, body_bb, exit_bb);
2451 
2452   // Calls the body function from the body block.
2453   b_.SetInsertPoint(body_bb);
2454 
2455   // Calls the body function.
2456   EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2457 
2458   // Finishes with a branch back to the header.
2459   Br(header_bb);
2460 
2461   // Adds the exit block to the function and sets the insert point there.
2462   compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2463   b_.SetInsertPoint(exit_bb);
2464 
2465   return Status::OK();
2466 }
2467 
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2468 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2469     HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2470     string* failure_reason) {
2471   if (ShouldEmitParallelLoopFor(*concatenate)) {
2472     *failure_reason =
2473         "cannot generate memcpy-based concat for the parallel CPU backend";
2474     return false;
2475   }
2476 
2477   const Shape& output_shape = concatenate->shape();
2478   for (auto* op : operands) {
2479     if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2480       *failure_reason = "operand has mismatching layouts";
2481       return false;
2482     }
2483   }
2484 
2485   // We split the dimensions into three categories: the dimension over which we
2486   // are concatenating (concat_dim), the dimensions that are minor to it
2487   // (inner_dims) and the dimensions that are major to it (outer_dims).
2488 
2489   int64_t concat_dim = concatenate->dimensions(0);
2490   const Layout& output_layout = output_shape.layout();
2491   auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2492   auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2493 
2494   std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2495   std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2496                                 output_min2maj.end());
2497 
2498   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2499 
2500   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2501   llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2502 
2503   llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2504   std::vector<llvm::Value*> target_multi_index =
2505       loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2506   std::replace(target_multi_index.begin(), target_multi_index.end(),
2507                static_cast<llvm::Value*>(nullptr),
2508                static_cast<llvm::Value*>(b_.getInt64(0)));
2509   llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2510                                        b_.getInt64Ty());
2511 
2512   if (!outer_dims.empty()) {
2513     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2514   }
2515 
2516   PrimitiveType primitive_type = output_shape.element_type();
2517   unsigned primitive_type_size =
2518       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2519 
2520   // Contiguous subregions from each operand to the concatenate contribute to a
2521   // contiguous subregion in the target buffer starting at target_region_begin.
2522   llvm::Value* target_region_begin = BitCast(
2523       target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2524       i8_ptr_type);
2525   int64_t byte_offset_into_target_region = 0;
2526 
2527   int64_t inner_dims_product =
2528       std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2529                       [&](int64_t product, int64_t inner_dim) {
2530                         return product * output_shape.dimensions(inner_dim);
2531                       });
2532 
2533   // For each operand, emit a memcpy from the operand to the target of size
2534   // equal to the product of inner dimensions.
2535   for (HloInstruction* operand : operands) {
2536     const Shape& input_shape = operand->shape();
2537     llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2538     llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2539                                          b_.getInt64Ty());
2540     llvm::Value* copy_source_address = BitCast(
2541         source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2542         i8_ptr_type);
2543 
2544     llvm::Value* copy_target_address =
2545         GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2546 
2547     EmitTransferElements(
2548         copy_target_address, copy_source_address,
2549         inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2550         target_array, source_array);
2551 
2552     byte_offset_into_target_region += inner_dims_product *
2553                                       input_shape.dimensions(concat_dim) *
2554                                       primitive_type_size;
2555   }
2556 
2557   if (!outer_dims.empty()) {
2558     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2559   }
2560 
2561   return true;
2562 }
2563 
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2564 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2565                                    absl::Span<llvm::Value* const> arguments) {
2566   llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2567   std::vector<llvm::Value*> call_args;
2568   call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2569   absl::c_copy(arguments, std::back_inserter(call_args));
2570   return b_.CreateCall(
2571       b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2572           "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2573                                             /*isVarArg=*/true)),
2574       call_args);
2575 }
2576 
EmitPrintfToStderr(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2577 llvm::Value* IrEmitter::EmitPrintfToStderr(
2578     absl::string_view fmt, absl::Span<llvm::Value* const> arguments) {
2579   llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2580   std::vector<llvm::Value*> call_args;
2581   call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2582   absl::c_copy(arguments, std::back_inserter(call_args));
2583   return b_.CreateCall(
2584       b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2585           runtime::kPrintfToStderrSymbolName,
2586           llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2587                                   /*isVarArg=*/true)),
2588       call_args);
2589 }
2590 
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2591 llvm::Value* IrEmitter::EmitCallToFunc(
2592     std::string func_name, const std::vector<llvm::Value*>& arguments,
2593     llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2594     bool only_accesses_inaccessible_mem_or_arg_mem) {
2595   std::vector<llvm::Type*> types;
2596   types.reserve(arguments.size());
2597   absl::c_transform(arguments, std::back_inserter(types),
2598                     [&](llvm::Value* val) { return val->getType(); });
2599   llvm::FunctionType* func_type =
2600       llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2601   auto func = llvm::dyn_cast<llvm::Function>(
2602       module_->getOrInsertFunction(func_name, func_type).getCallee());
2603   func->setCallingConv(llvm::CallingConv::C);
2604   if (does_not_throw) {
2605     func->setDoesNotThrow();
2606   }
2607   if (only_accesses_arg_memory) {
2608     func->setOnlyAccessesArgMemory();
2609   }
2610   if (only_accesses_inaccessible_mem_or_arg_mem) {
2611     func->setOnlyAccessesInaccessibleMemOrArgMem();
2612   }
2613   return b_.CreateCall(func, arguments);
2614 }
2615 
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64_t element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2616 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2617                                      int64_t element_count,
2618                                      PrimitiveType primitive_type,
2619                                      const llvm_ir::IrArray& target_array,
2620                                      const llvm_ir::IrArray& source_array) {
2621   unsigned primitive_type_size =
2622       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2623   llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2624       primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2625   llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2626       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2627 
2628   if (element_count == 1) {
2629     auto* load_instruction =
2630         AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2631     source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2632     auto* store_instruction =
2633         AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2634                      element_alignment);
2635     target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2636   } else {
2637     auto* memcpy_instruction = b_.CreateMemCpy(
2638         target, /*DstAlign=*/llvm::Align(element_alignment), source,
2639         /*SrcAlign=*/llvm::Align(element_alignment),
2640         element_count * primitive_type_size);
2641 
2642     // The memcpy does the load and the store internally.  The aliasing related
2643     // metadata has to reflect that.
2644     std::map<int, llvm::MDNode*> merged_metadata =
2645         llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2646                                target_array.metadata());
2647     for (const auto& kind_md_pair : merged_metadata) {
2648       memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2649     }
2650   }
2651 }
2652 
HandleConcatenate(HloInstruction * concatenate)2653 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2654   absl::Span<HloInstruction* const> operands(concatenate->operands());
2655   string failure_reason;
2656   TF_ASSIGN_OR_RETURN(
2657       bool successful,
2658       EmitFastConcatenate(concatenate, operands, &failure_reason));
2659   if (successful) {
2660     VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2661     return Status::OK();
2662   }
2663 
2664   VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2665           << ": " << failure_reason;
2666 
2667   return DefaultAction(concatenate);
2668 }
2669 
HandleConditional(HloInstruction * conditional)2670 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2671   auto branch_index = conditional->operand(0);
2672   int num_branches = conditional->branch_count();
2673   TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2674                (branch_index->shape().element_type() == PRED ||
2675                 branch_index->shape().element_type() == S32))
2676       << "Branch index on a conditional must be scalar bool or int32; got: "
2677       << ShapeUtil::HumanString(branch_index->shape());
2678 
2679   for (int b = 0; b < num_branches; ++b) {
2680     HloComputation* br_computation = conditional->branch_computation(b);
2681     TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2682                                   br_computation->root_instruction()->shape()))
2683         << "Shape of conditional should be same as the shape of the " << b
2684         << "th branch computation; got: "
2685         << ShapeUtil::HumanString(conditional->shape()) << " and "
2686         << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2687   }
2688 
2689   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2690 
2691   if (branch_index->shape().element_type() == PRED) {
2692     // Emit an if-else to LLVM:
2693     //   if (pred)
2694     //     cond_result = true_computation(true_operand)
2695     //   else
2696     //     cond_result = false_computation(false_operand)
2697     llvm::LoadInst* pred_value = Load(
2698         GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2699     llvm::Value* pred_cond =
2700         ICmpNE(pred_value,
2701                llvm::ConstantInt::get(
2702                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2703                "boolean_predicate");
2704     llvm_ir::LlvmIfData if_data =
2705         llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2706 
2707     SetToFirstInsertPoint(if_data.true_block, &b_);
2708     EmitGlobalCall(*conditional->branch_computation(0),
2709                    IrName(conditional, "_true"));
2710 
2711     SetToFirstInsertPoint(if_data.false_block, &b_);
2712     EmitGlobalCall(*conditional->branch_computation(1),
2713                    IrName(conditional, "_false"));
2714 
2715     SetToFirstInsertPoint(if_data.after_block, &b_);
2716     return Status::OK();
2717   }
2718   // We emit a switch statement to LLVM:
2719   // switch (branch_index) {
2720   //   default:
2721   //     result = branch_computations[num_branches-1](operands[num_branches-1]);
2722   //     break;
2723   //   case 0:
2724   //     result = branch_computations[0](operands[0]); break;
2725   //   case 1:
2726   //     result = branch_computations[1](operands[1]); break;
2727   //   ...
2728   //   case [[num_branches-2]]:
2729   //     result = branch_computations[num_branches-2](operands[num_branches-2]);
2730   //     break;
2731   // }
2732   llvm::LoadInst* branch_index_value = Load(
2733       GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2734 
2735   auto case_block = b_.GetInsertBlock();
2736   llvm::BasicBlock* after_block;
2737   // Add a terminator to the case block, if necessary.
2738   if (case_block->getTerminator() == nullptr) {
2739     after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2740     b_.SetInsertPoint(case_block);
2741     b_.CreateBr(after_block);
2742   } else {
2743     after_block =
2744         case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2745   }
2746   // Our basic block should now end with an unconditional branch.  Remove it;
2747   // we're going to replace it with a switch based branch.
2748   case_block->getTerminator()->eraseFromParent();
2749 
2750   // Lower the default branch computation.
2751   auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2752   b_.SetInsertPoint(default_block);
2753   EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2754                  IrName(conditional, "_default"));
2755   b_.CreateBr(after_block);
2756 
2757   // Prepare the switch (branch_index) { ... } instruction.
2758   b_.SetInsertPoint(case_block);
2759   llvm::SwitchInst* case_inst =
2760       b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2761   // Lower each branch's computation.
2762   for (int b = 0; b < num_branches - 1; ++b) {  // last branch is default
2763     // Lower the case b: { ... ; break; } computation.
2764     auto branch_block =
2765         llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2766     b_.SetInsertPoint(branch_block);
2767     EmitGlobalCall(*conditional->branch_computation(b),
2768                    IrName(conditional, absl::StrCat("_branch", b)));
2769     b_.CreateBr(after_block);
2770     case_inst->addCase(b_.getInt32(b), branch_block);
2771   }
2772 
2773   SetToFirstInsertPoint(after_block, &b_);
2774   return Status::OK();
2775 }
2776 
HandleAfterAll(HloInstruction * after_all)2777 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2778   TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2779   // No code to generate, but we need to emit an address for book-keeping.
2780   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2781   return Status::OK();
2782 }
2783 
HandleAddDependency(HloInstruction * add_dependency)2784 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2785   // AddDedendency just forwards its zero-th operand.
2786   emitted_value_[add_dependency] =
2787       GetEmittedValueFor(add_dependency->operand(0));
2788   return Status::OK();
2789 }
2790 
HandleRng(HloInstruction * rng)2791 Status IrEmitter::HandleRng(HloInstruction* rng) {
2792   return Unimplemented("Rng should be expanded for CPU.");
2793 }
2794 
HandleRngGetAndUpdateState(HloInstruction * rng_state)2795 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2796   VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2797   llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2798       Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2799       &b_);
2800 
2801   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2802   llvm::Value* address = GetEmittedValueFor(rng_state);
2803 
2804   // The buffer has an array type while the value has a i128. Cast the
2805   // buffer to i128 type to store the value.
2806   address = BitCast(address, llvm::PointerType::get(
2807                                  old_state->getType()->getScalarType(),
2808                                  address->getType()->getPointerAddressSpace()));
2809   llvm::StoreInst* store = Store(old_state, address);
2810   store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2811       rng_state->shape().element_type())));
2812 
2813   return Status::OK();
2814 }
2815 
FinishVisit(HloInstruction * root)2816 Status IrEmitter::FinishVisit(HloInstruction* root) {
2817   // When this method is called, we should have already emitted an IR value for
2818   // the root (return) op. The IR value holds the address of the buffer holding
2819   // the value. If the root is a constant or parameter, we perform a memcpy from
2820   // this buffer to the retval buffer of the computation. Otherwise, there's
2821   // nothing to do since the result was already written directly into the output
2822   // buffer.
2823   VLOG(2) << "FinishVisit root: " << root->ToString();
2824   if (root->opcode() == HloOpcode::kOutfeed) {
2825     VLOG(2) << "  outfeed with value: "
2826             << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2827   } else {
2828     VLOG(2) << "  value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2829   }
2830 
2831   auto record_complete_computation = [&](llvm::Value* prof_counter) {
2832     if (prof_counter) {
2833       profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2834     }
2835   };
2836 
2837   // For the entry computation this increment is cumulative of embedded
2838   // computations since it includes cycles spent in computations invoked by
2839   // While, Call etc.
2840   record_complete_computation(GetProfileCounterFor(*root->parent()));
2841   return Status::OK();
2842 }
2843 
2844 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2845 llvm::Value* IrEmitter::GetProfileCounterCommon(
2846     const T& hlo,
2847     const std::unordered_map<const T*, int64>& profile_index_map) {
2848   auto it = profile_index_map.find(&hlo);
2849   if (it == profile_index_map.end()) {
2850     return nullptr;
2851   }
2852 
2853   int64_t prof_counter_idx = it->second;
2854   string counter_name = IrName("prof_counter", hlo.name());
2855   return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2856              counter_name);
2857 }
2858 
GetProfileCounterFor(const HloInstruction & instruction)2859 llvm::Value* IrEmitter::GetProfileCounterFor(
2860     const HloInstruction& instruction) {
2861   return GetProfileCounterCommon<HloInstruction>(instruction,
2862                                                  instruction_to_profile_idx_);
2863 }
2864 
GetProfileCounterFor(const HloComputation & computation)2865 llvm::Value* IrEmitter::GetProfileCounterFor(
2866     const HloComputation& computation) {
2867   return GetProfileCounterCommon<HloComputation>(computation,
2868                                                  computation_to_profile_idx_);
2869 }
2870 
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2871 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2872                                                      llvm::Value* prof_counter,
2873                                                      llvm::Value* cycle_end,
2874                                                      llvm::Value* cycle_start) {
2875   auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2876   llvm::LoadInst* old_cycle_count =
2877       b->CreateLoad(prof_counter, "old_cycle_count");
2878   auto* new_cycle_count =
2879       b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2880   b->CreateStore(new_cycle_count, prof_counter);
2881 }
2882 
ReadCycleCounter(llvm::IRBuilder<> * b)2883 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2884   llvm::Module* module = b->GetInsertBlock()->getModule();
2885   if (!use_rdtscp_) {
2886     llvm::Function* func_llvm_readcyclecounter =
2887         llvm::Intrinsic::getDeclaration(module,
2888                                         llvm::Intrinsic::readcyclecounter);
2889     return b->CreateCall(func_llvm_readcyclecounter);
2890   }
2891   llvm::Function* func_llvm_x86_rdtscp =
2892       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2893   llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
2894   return b->CreateExtractValue(rdtscp_call, {0});
2895 }
2896 
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2897 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2898                                                  HloInstruction* hlo) {
2899   auto* cycle_start = ReadCycleCounter(b);
2900   cycle_start->setName(IrName(hlo, "cycle_start"));
2901   cycle_starts_[hlo] = cycle_start;
2902   if (first_read_cycle_start_ == nullptr) {
2903     first_read_cycle_start_ = cycle_start;
2904   }
2905 }
2906 
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2907 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2908                                                  HloInstruction* hlo,
2909                                                  llvm::Value* prof_counter) {
2910   auto* cycle_end = ReadCycleCounter(b);
2911   cycle_end->setName(IrName(hlo, "cycle_end"));
2912   auto* cycle_start = cycle_starts_[hlo];
2913   UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2914   last_read_cycle_end_ = cycle_end;
2915 }
2916 
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2917 void IrEmitter::ProfilingState::RecordCompleteComputation(
2918     llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2919   if (last_read_cycle_end_ && first_read_cycle_start_) {
2920     UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2921                          first_read_cycle_start_);
2922   }
2923 }
2924 
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2925 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
2926                                                HloInstruction* hlo,
2927                                                llvm::Value* run_options) {
2928   if (!enabled_) {
2929     return;
2930   }
2931 
2932   llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
2933   llvm::Type* void_ptr_type =
2934       int8_ptr_type;  // LLVM does not have a void*, we use an int8* instead.
2935   llvm::FunctionType* fn_type =
2936       llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
2937                               /*isVarArg=*/false);
2938 
2939   llvm::Function* function = b->GetInsertBlock()->getParent();
2940   llvm::Module* module = function->getParent();
2941   const char* fn_name = runtime::kTracingStartSymbolName;
2942   llvm::FunctionCallee trace_func =
2943       module->getOrInsertFunction(fn_name, fn_type);
2944   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2945     fn->setCallingConv(llvm::CallingConv::C);
2946     fn->setDoesNotThrow();
2947     fn->setOnlyAccessesArgMemory();
2948   }
2949   auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
2950   auto* activity_id =
2951       b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
2952                                  b->CreateBitCast(hlo_name, int8_ptr_type)});
2953   activity_id->setName(IrName(hlo, "activity_id"));
2954   activity_ids_[hlo] = activity_id;
2955 }
2956 
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2957 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
2958                                              HloInstruction* hlo,
2959                                              llvm::Value* run_options) {
2960   if (!enabled_) {
2961     return;
2962   }
2963 
2964   llvm::Type* void_ptr_type =
2965       b->getInt8Ty()->getPointerTo();  // LLVM does not have a void*, we use an
2966                                        // int8* instead.
2967   llvm::FunctionType* fn_type =
2968       llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
2969                               /*isVarArg=*/false);
2970 
2971   llvm::Function* function = b->GetInsertBlock()->getParent();
2972   llvm::Module* module = function->getParent();
2973   const char* fn_name = runtime::kTracingEndSymbolName;
2974   llvm::FunctionCallee trace_func =
2975       module->getOrInsertFunction(fn_name, fn_type);
2976   if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2977     fn->setCallingConv(llvm::CallingConv::C);
2978     fn->setDoesNotThrow();
2979     fn->setOnlyAccessesArgMemory();
2980   }
2981   auto* activity_id = activity_ids_.at(hlo);
2982   b->CreateCall(trace_func,
2983                 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
2984 }
2985 
2986 namespace {
IsHloVeryCheap(const HloInstruction * hlo)2987 bool IsHloVeryCheap(const HloInstruction* hlo) {
2988   return hlo->opcode() == HloOpcode::kBitcast ||
2989          hlo->opcode() == HloOpcode::kTuple ||
2990          hlo->opcode() == HloOpcode::kGetTupleElement ||
2991          hlo->opcode() == HloOpcode::kParameter ||
2992          hlo->opcode() == HloOpcode::kConstant;
2993 }
2994 }  // namespace
2995 
Preprocess(HloInstruction * hlo)2996 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2997   VLOG(3) << "Visiting: " << hlo->ToString();
2998   // When profiling is enabled, trace the same HLOs that the profiler does.
2999   if (instruction_to_profile_idx_.count(hlo) ||
3000       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3001     tracing_state_.EmitTracingStart(&b_, hlo,
3002                                     GetExecutableRunOptionsArgument());
3003     profiling_state_.RecordCycleStart(&b_, hlo);
3004   }
3005   return Status::OK();
3006 }
3007 
Postprocess(HloInstruction * hlo)3008 Status IrEmitter::Postprocess(HloInstruction* hlo) {
3009   if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
3010     profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
3011   }
3012   // When profiling is enabled, trace the same HLOs that the profiler does.
3013   if (instruction_to_profile_idx_.count(hlo) ||
3014       (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
3015     tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
3016   }
3017   return Status::OK();
3018 }
3019 
GetIrArrayFor(const HloInstruction * hlo)3020 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
3021   llvm::Value* value_for_op = GetEmittedValueFor(hlo);
3022 
3023   llvm_ir::IrArray array(value_for_op, hlo->shape());
3024   AddAliasingInformationToIrArray(*hlo, &array);
3025   return array;
3026 }
3027 
GetIrArraysForOperandsOf(const HloInstruction * hlo)3028 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
3029     const HloInstruction* hlo) {
3030   std::vector<llvm_ir::IrArray> arrays;
3031   std::transform(
3032       hlo->operands().begin(), hlo->operands().end(),
3033       std::back_inserter(arrays),
3034       [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
3035   return arrays;
3036 }
3037 
GetEmittedValueFor(const HloInstruction * hlo)3038 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3039   auto it = emitted_value_.find(hlo);
3040   if (it == emitted_value_.end()) {
3041     LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3042   }
3043   return it->second;
3044 }
3045 
IrShapeType(const Shape & shape)3046 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3047   return llvm_ir::ShapeToIrType(shape, module_);
3048 }
3049 
GetProfileCountersArgument()3050 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3051   return compute_function_->profile_counters_arg();
3052 }
3053 
GetBufferTableArgument()3054 llvm::Value* IrEmitter::GetBufferTableArgument() {
3055   return compute_function_->buffer_table_arg();
3056 }
3057 
GetExecutableRunOptionsArgument()3058 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3059   return compute_function_->exec_run_options_arg();
3060 }
3061 
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3062 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3063     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3064   const BufferAllocation& allocation = *slice.allocation();
3065   llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3066     auto param_it =
3067         computation_parameter_allocations_.find(slice.allocation()->index());
3068     if (param_it != computation_parameter_allocations_.end()) {
3069       int64_t param_number = param_it->second;
3070       // We have to access the parameter at offset param_number in the params
3071       // array. The code generated here is equivalent to this C code:
3072       //
3073       //   i8* param_address_untyped = params[param_number];
3074       //   Param* param_address_typed = (Param*)param_address_untyped;
3075       //
3076       // Where Param is the actual element type of the underlying buffer (for
3077       // example, float for an XLA F32 element type).
3078       llvm::Value* params = compute_function_->parameters_arg();
3079       llvm::Value* param_address_offset =
3080           llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
3081       llvm::LoadInst* param_address_untyped = Load(param_address_offset);
3082 
3083       if (!target_shape.IsOpaque()) {
3084         AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3085         AttachDereferenceableMetadataForLoad(param_address_untyped,
3086                                              target_shape);
3087       }
3088       return param_address_untyped;
3089     }
3090 
3091     // Thread-local allocations should only be assigned a single buffer.
3092     const auto& assigned_buffers = allocation.assigned_buffers();
3093     CHECK_EQ(1, assigned_buffers.size());
3094     const Shape& shape = assigned_buffers.begin()->first->shape();
3095 
3096     std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3097         compute_function_->function(), slice};
3098     auto buf_it = thread_local_buffers_.find(key);
3099     if (buf_it == thread_local_buffers_.end()) {
3100       llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3101           IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3102           &b_, MinimumAlignmentForShape(target_shape));
3103       auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3104       CHECK(it_inserted_pair.second);
3105       buf_it = it_inserted_pair.first;
3106     }
3107     return buf_it->second;
3108   }();
3109   return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3110 }
3111 
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3112 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3113     const BufferAllocation::Slice& slice, const Shape& target_shape) {
3114   const BufferAllocation& allocation = *slice.allocation();
3115   llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3116       GetBufferTableArgument(), slice.index(), &b_);
3117   llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
3118   if (hlo_module_config_.debug_options()
3119           .xla_llvm_enable_invariant_load_metadata()) {
3120     tempbuf_address_base->setMetadata(
3121         llvm::LLVMContext::MD_invariant_load,
3122         llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3123   }
3124   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3125   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3126 
3127   llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3128   if (slice.offset() > 0) {
3129     // Adjust the address to account for the slice offset.
3130     tempbuf_address_untyped =
3131         InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
3132   }
3133   return BitCast(tempbuf_address_untyped,
3134                  IrShapeType(target_shape)->getPointerTo());
3135 }
3136 
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3137 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3138                                           const Shape& target_shape) {
3139   if (slice.allocation()->is_thread_local()) {
3140     return EmitThreadLocalBufferPointer(slice, target_shape);
3141   } else if (slice.allocation()->is_constant()) {
3142     return BitCast(
3143         FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3144         IrShapeType(target_shape)->getPointerTo());
3145   } else {
3146     return EmitGlobalBufferPointer(slice, target_shape);
3147   }
3148 }
3149 
EmitTargetAddressForOp(const HloInstruction * op)3150 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3151   const Shape& target_shape = op->shape();
3152   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3153                       assignment_.GetUniqueTopLevelSlice(op));
3154   llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3155   addr->setName(IrName(op));
3156   emitted_value_[op] = addr;
3157   return Status::OK();
3158 }
3159 
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3160 Status IrEmitter::EmitTargetElementLoop(
3161     HloInstruction* target_op,
3162     const llvm_ir::ElementGenerator& element_generator) {
3163   return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3164 }
3165 
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3166 Status IrEmitter::EmitTargetElementLoop(
3167     HloInstruction* target_op, absl::string_view desc,
3168     const llvm_ir::ElementGenerator& element_generator) {
3169   VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3170 
3171   const Shape& target_shape = target_op->shape();
3172   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3173   llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3174 
3175   if (target_shape.IsTuple() &&
3176       (target_op->opcode() == HloOpcode::kFusion ||
3177        target_op->opcode() == HloOpcode::kReduce ||
3178        target_op->opcode() == HloOpcode::kReduceWindow)) {
3179     // For multiple outputs fusion, we need to emit each operand and the root.
3180     TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3181     std::vector<llvm_ir::IrArray> output_arrays;
3182     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3183       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3184                           assignment_.GetUniqueSlice(target_op, {i}));
3185       const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3186       llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3187       output_arrays.push_back(
3188           llvm_ir::IrArray(op_target_address, element_shape));
3189     }
3190     TF_RETURN_IF_ERROR(
3191         llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3192             .EmitLoop(IrName(target_op)));
3193 
3194     std::vector<llvm::Value*> tuple_operand_ptrs;
3195     for (int64_t i = 0; i < output_arrays.size(); ++i) {
3196       tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3197     }
3198     llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3199 
3200   } else {
3201     if (ShouldEmitParallelLoopFor(*target_op)) {
3202       // Emit code to read dynamic loop bounds from compute function argument.
3203       std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3204           compute_function_->GetDynamicLoopBounds();
3205       // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3206       TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3207                                              &dynamic_loop_bounds, &b_)
3208                              .EmitLoop(IrName(target_op)));
3209     } else {
3210       TF_RETURN_IF_ERROR(
3211           llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3212               .EmitLoop(IrName(target_op)));
3213     }
3214   }
3215   return Status::OK();
3216 }
3217 
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3218 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3219                              const HloInstruction& destination) {
3220   llvm::Value* source_value = GetEmittedValueFor(&source);
3221   llvm::Value* destination_value = GetEmittedValueFor(&destination);
3222   int64_t source_size = ByteSizeOf(source.shape());
3223   // TODO(b/63762267): Be more aggressive about specifying alignment.
3224   MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3225          /*SrcAlign=*/llvm::Align(1), source_size);
3226   return Status::OK();
3227 }
3228 
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3229 Status IrEmitter::ElementTypesSameAndSupported(
3230     const HloInstruction& instruction,
3231     absl::Span<const HloInstruction* const> operands,
3232     absl::Span<const PrimitiveType> supported_types) {
3233   for (auto operand : operands) {
3234     TF_RET_CHECK(
3235         ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3236   }
3237 
3238   TF_RET_CHECK(!operands.empty());
3239   PrimitiveType primitive_type = operands[0]->shape().element_type();
3240   if (!absl::c_linear_search(supported_types, primitive_type)) {
3241     return Unimplemented("unsupported operand type %s in op %s",
3242                          PrimitiveType_Name(primitive_type),
3243                          HloOpcodeString(instruction.opcode()));
3244   }
3245   return Status::OK();
3246 }
3247 
DefaultAction(HloInstruction * hlo)3248 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3249   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3250   for (const HloInstruction* operand : hlo->operands()) {
3251     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3252       return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3253     };
3254   }
3255   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3256   return EmitTargetElementLoop(
3257       hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3258 }
3259 
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3260 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3261     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3262     absl::string_view name) {
3263   std::vector<llvm::Value*> return_value =
3264       EmitThreadLocalCall(callee, parameters, name);
3265   CHECK_EQ(return_value.size(), 1);
3266   return return_value[0];
3267 }
3268 
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3269 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3270     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3271     absl::string_view name) {
3272   CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3273   const Shape& return_shape = callee.root_instruction()->shape();
3274   bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3275   bool is_tuple_of_scalars_return =
3276       return_shape.IsTuple() &&
3277       absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3278         return ShapeUtil::IsScalar(shape);
3279       });
3280   CHECK(is_scalar_return || is_tuple_of_scalars_return);
3281 
3282   std::vector<llvm::Value*> parameter_addrs;
3283   for (llvm::Value* parameter : parameters) {
3284     CHECK(!parameter->getType()->isPointerTy());
3285     llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3286         parameter->getType(), "arg_addr", &b_);
3287     Store(parameter, parameter_addr);
3288     parameter_addrs.push_back(parameter_addr);
3289   }
3290 
3291   llvm::Type* return_value_buffer_type =
3292       llvm_ir::ShapeToIrType(return_shape, module_);
3293   std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3294   int retval_alignment =
3295       is_scalar_return
3296           ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3297           : 0;
3298   llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3299       return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3300 
3301   std::vector<llvm::Value*> allocas_for_returned_scalars;
3302   if (is_scalar_return) {
3303     allocas_for_returned_scalars.push_back(return_value_buffer);
3304   } else {
3305     constexpr int max_tuple_size = 1000;
3306     CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3307         << "Multivalue function can not return more than 1000 elements to avoid"
3308         << " stack smashing";
3309     allocas_for_returned_scalars =
3310         llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3311     llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
3312 
3313     EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3314   }
3315 
3316   Call(FindOrDie(emitted_functions_, &callee),
3317        GetArrayFunctionCallArguments(
3318            parameter_addrs, &b_, name,
3319            /*return_value_buffer=*/return_value_buffer,
3320            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3321            /*buffer_table_arg=*/
3322            llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3323            /*profile_counters_arg=*/GetProfileCountersArgument()));
3324 
3325   std::vector<llvm::Value*> returned_scalars;
3326   returned_scalars.reserve(allocas_for_returned_scalars.size());
3327   for (llvm::Value* addr : allocas_for_returned_scalars) {
3328     returned_scalars.push_back(Load(addr));
3329   }
3330   return returned_scalars;
3331 }
3332 
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3333 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3334                                absl::string_view name) {
3335   CHECK(absl::c_binary_search(global_computations_, &callee));
3336 
3337   Call(FindOrDie(emitted_functions_, &callee),
3338        GetArrayFunctionCallArguments(
3339            /*parameter_addresses=*/{}, &b_, name,
3340            /*return_value_buffer=*/
3341            llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3342            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3343            /*buffer_table_arg=*/GetBufferTableArgument(),
3344            /*profile_counters_arg=*/GetProfileCountersArgument()));
3345 }
3346 
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3347 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3348     const HloComputation& callee) {
3349   const HloInstruction* root_inst = callee.root_instruction();
3350   if (root_inst->opcode() == HloOpcode::kOutfeed) {
3351     return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3352   }
3353 
3354   const BufferAllocation::Slice root_buffer =
3355       assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3356   return EmitBufferPointer(root_buffer, root_inst->shape());
3357 }
3358 
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3359 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3360                                     FusedIrEmitter* fused_emitter) {
3361   for (int i = 0; i < fusion->operand_count(); i++) {
3362     const HloInstruction* operand = fusion->operand(i);
3363     fused_emitter->BindGenerator(
3364         fusion->fused_parameter(i),
3365         [this, operand](llvm_ir::IrArray::Index index) {
3366           return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3367         });
3368   }
3369 }
3370 
3371 }  // namespace cpu
3372 }  // namespace xla
3373