• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 #include <algorithm>
21 #include <iterator>
22 #include <limits>
23 #include <memory>
24 #include <utility>
25 #include <vector>
26 
27 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/types/span.h"
33 #include "llvm/CodeGen/TargetRegisterInfo.h"
34 #include "llvm/CodeGen/TargetSubtargetInfo.h"
35 #include "llvm/IR/BasicBlock.h"
36 #include "llvm/IR/Constants.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/Intrinsics.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "tensorflow/compiler/xla/layout_util.h"
42 #include "tensorflow/compiler/xla/map_util.h"
43 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
44 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
45 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
46 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
47 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
48 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
49 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
50 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
51 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
52 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
53 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
55 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
56 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
57 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
58 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
59 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
60 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
61 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
62 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
63 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/status_macros.h"
66 #include "tensorflow/compiler/xla/types.h"
67 #include "tensorflow/compiler/xla/util.h"
68 #include "tensorflow/compiler/xla/window_util.h"
69 #include "tensorflow/core/lib/core/bits.h"
70 #include "tensorflow/core/lib/core/errors.h"
71 #include "tensorflow/core/lib/math/math_util.h"
72 #include "tensorflow/core/platform/logging.h"
73 
74 namespace xla {
75 
76 namespace {
77 using llvm_ir::IrName;
78 using llvm_ir::SetToFirstInsertPoint;
79 }  // namespace
80 
81 namespace cpu {
82 
IrEmitter(const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)83 IrEmitter::IrEmitter(
84     const HloModule& hlo_module, const BufferAssignment& assignment,
85     llvm::Module* llvm_module,
86     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
87     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
88     const TargetMachineFeatures* target_machine_features,
89     bool emit_code_for_msan)
90     : assignment_(assignment),
91       module_(llvm_module),
92       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
93       b_(llvm_module->getContext()),
94       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
95       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
96       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
97       hlo_module_config_(hlo_module.config()),
98       is_top_level_computation_(false),
99       target_machine_features_(*target_machine_features),
100       emit_code_for_msan_(emit_code_for_msan) {
101   b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
102   Status s = GatherComputationsByAllocationType(
103       &hlo_module, &thread_local_computations_, &global_computations_);
104   absl::c_sort(thread_local_computations_);
105   absl::c_sort(global_computations_);
106   TF_CHECK_OK(s) << "Should have failed buffer assignment.";
107 }
108 
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)109 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
110     HloComputation* computation, const string& function_name_prefix,
111     bool is_top_level_computation,
112     absl::Span<HloInstruction* const> instruction_order) {
113   string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
114   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
115   is_top_level_computation_ = is_top_level_computation;
116   num_dynamic_loop_bounds_ = 0;
117   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
118     num_dynamic_loop_bounds_ =
119         computation->root_instruction()->outer_dimension_partitions().size();
120   }
121 
122   if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
123     TF_ASSIGN_OR_RETURN(
124         computation_root_allocation_,
125         assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
126   }
127 
128   for (const HloInstruction* param : computation->parameter_instructions()) {
129     TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
130                         assignment_.GetUniqueTopLevelSlice(param));
131     computation_parameter_allocations_[param_slice.allocation()->index()] =
132         param->parameter_number();
133   }
134 
135   InitializeIrFunction(function_name);
136   // The rdtscp instruction is x86 specific.  We will fallback to LLVM's generic
137   // readcyclecounter if it is unavailable.
138   bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
139                     arch_type_ == llvm::Triple::ArchType::x86_64;
140   profiling_state_ = ProfilingState(use_rdtscp);
141   TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
142   llvm::Function* ir_function = compute_function_->function();
143   InsertOrDie(&emitted_functions_, computation, ir_function);
144   // Delete 'compute_function', finalizing 'ir_function' and restoring caller
145   // IR insert point.
146   compute_function_.reset();
147   computation_root_allocation_ = BufferAllocation::Slice();
148   computation_parameter_allocations_.clear();
149   return ir_function;
150 }
151 
InitializeIrFunction(const string & function_name)152 void IrEmitter::InitializeIrFunction(const string& function_name) {
153   // Functions with local linkage get an inlining bonus.  Because we know
154   // a-priori that embedded functions (non-entry functions) will not have its
155   // name resolved, give it local linkage.
156   llvm::Function::LinkageTypes linkage =
157       is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
158                                 : llvm::GlobalValue::InternalLinkage;
159   // Create and initialize new IrFunction.
160   compute_function_.reset(new IrFunction(function_name, linkage,
161                                          hlo_module_config_, module_, &b_,
162                                          num_dynamic_loop_bounds_));
163 }
164 
~IrEmitter()165 IrEmitter::~IrEmitter() {}
166 
HandleBitcast(HloInstruction * bitcast)167 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
168   VLOG(2) << "HandleBitcast: " << bitcast->ToString();
169   emitted_value_[bitcast] =
170       BitCast(GetEmittedValueFor(bitcast->operand(0)),
171               IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
172   return Status::OK();
173 }
174 
EmitGlobalForLiteral(const Literal & literal)175 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
176   llvm::Constant* initializer =
177       llvm_ir::ConvertLiteralToIrConstant(literal, module_);
178   llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
179       /*Module=*/*module_,
180       /*Type=*/initializer->getType(),
181       /*isConstant=*/true,
182       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
183       /*Initializer=*/initializer,
184       /*Name=*/"");
185   result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
186   result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
187   return llvm::ConstantExpr::getBitCast(
188       result_global, IrShapeType(literal.shape())->getPointerTo());
189 }
190 
EmitConstantGlobals()191 Status IrEmitter::EmitConstantGlobals() {
192   for (const BufferAllocation& allocation : assignment_.Allocations()) {
193     if (!allocation.is_constant()) {
194       continue;
195     }
196 
197     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
198     llvm::Constant* global_for_const;
199     auto it = emitted_literals_.find(&literal);
200     if (it != emitted_literals_.end()) {
201       global_for_const = it->second;
202     } else {
203       global_for_const = EmitGlobalForLiteral(literal);
204       InsertOrDie(&emitted_literals_, &literal, global_for_const);
205     }
206 
207     InsertOrDie(&constant_buffer_to_global_, allocation.index(),
208                 global_for_const);
209   }
210 
211   return Status::OK();
212 }
213 
HandleConstant(HloInstruction * constant)214 Status IrEmitter::HandleConstant(HloInstruction* constant) {
215   VLOG(2) << "HandleConstant: " << constant->ToString();
216   // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
217   // of the constant.
218   return EmitTargetAddressForOp(constant);
219 }
220 
HandleCopy(HloInstruction * copy)221 Status IrEmitter::HandleCopy(HloInstruction* copy) {
222   if (copy->shape().IsTuple()) {
223     // kCopy shallow copies a tuple so just memcpy the top-level buffer.
224     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
225     return EmitMemcpy(*(copy->operand(0)), *copy);
226   } else if (copy->shape().IsArray()) {
227     // Use the elemental emitter for array shapes.
228     return DefaultAction(copy);
229   }
230   return Unimplemented("unsupported operand type %s for copy instruction",
231                        PrimitiveType_Name(copy->shape().element_type()));
232 }
233 
234 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)235 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
236   int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
237   DCHECK_GE(byte_size, 0);
238   // Largest scalar is a complex128 so we don't need to worry about the
239   // int64->int truncation here.
240   DCHECK_LE(byte_size, 16);
241 
242   // Allocations may be 8-byte aligned if part of a small block.
243   return std::min(8LL, byte_size);
244 }
245 
ByteSizeOf(const Shape & shape) const246 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
247   return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
248 }
249 
250 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)251 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
252   if (ShapeUtil::IsScalar(shape)) {
253     return MinimumAlignmentForPrimitiveType(shape.element_type());
254   }
255 
256   int64 buffer_size = ByteSizeOf(shape);
257   DCHECK_GE(buffer_size, 0);
258   DCHECK_LE(buffer_size, SIZE_MAX);
259 
260   return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
261 }
262 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)263 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
264                                                const Shape& shape) {
265   int alignment = MinimumAlignmentForShape(shape);
266   if (alignment > 1) {
267     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
268   }
269 }
270 
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)271 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
272                                                int64 buffer_size) {
273   int alignment =
274       target_machine_features_.minimum_alignment_for_allocation(buffer_size);
275   if (alignment > 1) {
276     llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
277   }
278 }
279 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)280 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
281                                                      const Shape& shape) {
282   AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
283 }
284 
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)285 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
286                                                      int64 buffer_size) {
287   if (buffer_size > 0) {
288     llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
289   }
290 }
291 
HandleGetTupleElement(HloInstruction * get_tuple_element)292 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
293   // A tuple is an array of pointers, one for each operand. Each pointer points
294   // to the output buffer of its corresponding operand. A GetTupleElement
295   // instruction forwards a pointer to the tuple element buffer at the given
296   // index.
297   auto operand = get_tuple_element->operand(0);
298   const Shape& shape = get_tuple_element->shape();
299   emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
300       shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
301       GetEmittedValueFor(operand), &b_);
302   return Status::OK();
303 }
304 
HandleSelect(HloInstruction * select)305 Status IrEmitter::HandleSelect(HloInstruction* select) {
306   auto pred = select->operand(0);
307   TF_RET_CHECK(pred->shape().element_type() == PRED);
308   return DefaultAction(select);
309 }
310 
HandleTupleSelect(HloInstruction * tuple_select)311 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
312   auto pred = tuple_select->operand(0);
313   auto on_true = tuple_select->operand(1);
314   auto on_false = tuple_select->operand(2);
315   TF_RET_CHECK(pred->shape().element_type() == PRED);
316   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
317   TF_RET_CHECK(tuple_select->shape().IsTuple());
318   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
319   llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
320                            GetEmittedValueFor(on_true),
321                            GetEmittedValueFor(on_false), &b_);
322   return Status::OK();
323 }
324 
HandleInfeed(HloInstruction * instruction)325 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
326   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
327   VLOG(2) << "HandleInfeed: " << infeed->ToString();
328 
329   // The infeed operation produces a two-element tuple containing data and a
330   // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
331   const Shape& data_shape = infeed->infeed_shape();
332   DCHECK(ShapeUtil::Equal(data_shape,
333                           ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
334   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
335 
336   // Write the tuple index table.
337   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
338                       assignment_.GetUniqueSlice(infeed, {0}));
339   llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
340   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
341                       assignment_.GetUniqueSlice(infeed, {1}));
342   llvm::Value* token_address = EmitBufferPointer(
343       token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
344   llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
345 
346   if (data_shape.IsTuple()) {
347     TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
348 
349     // For a tuple, we first copy each of the internal elements to
350     // their corresponding target locations. We then construct the
351     // tuple outer buffer containing pointers to the internal
352     // elements.
353     std::vector<llvm::Value*> tuple_element_addresses;
354     for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
355       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
356                           assignment_.GetUniqueSlice(infeed, {0, i}));
357 
358       const Shape& tuple_element_shape =
359           ShapeUtil::GetTupleElementShape(data_shape, i);
360 
361       // Only the outer tuple buffer's target address is obtained from
362       // GetEmittedValueFor, to handle the case when Infeed is the root
363       // instruction. Target addresses for internal elements can be obtained
364       // from EmitBufferPointer.
365       llvm::Value* tuple_element_address =
366           EmitBufferPointer(buffer, tuple_element_shape);
367 
368       TF_RETURN_IF_ERROR(EmitXfeedTransfer(
369           XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
370 
371       tuple_element_addresses.push_back(tuple_element_address);
372     }
373 
374     llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
375                        tuple_element_addresses, &b_);
376   } else {
377     TF_RETURN_IF_ERROR(
378         EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
379   }
380 
381   return Status::OK();
382 }
383 
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)384 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
385                                     llvm::Value* program_buffer_address) {
386   int64 length = ByteSizeOf(shape);
387   if (length <= 0 || length > std::numeric_limits<int32>::max()) {
388     return InvalidArgument(
389         "xfeed (infeed or outfeed) buffer length %d is outside the valid "
390         "size range",
391         length);
392   }
393   int32 length_32 = static_cast<int32>(length);
394 
395   int32 shape_length;
396   TF_ASSIGN_OR_RETURN(
397       llvm::Value * shape_ptr,
398       llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
399 
400   llvm::Type* int32_type = b_.getInt32Ty();
401   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
402   llvm::FunctionType* acquire_type = llvm::FunctionType::get(
403       i8_ptr_type,
404       {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
405        /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
406       /*isVarArg=*/false);
407 
408   llvm::Function* acquire_func;
409   if (kind == XfeedKind::kInfeed) {
410     acquire_func = llvm::dyn_cast<llvm::Function>(
411         module_
412             ->getOrInsertFunction(
413                 runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)
414             .getCallee());
415   } else {
416     acquire_func = llvm::dyn_cast<llvm::Function>(
417         module_
418             ->getOrInsertFunction(
419                 runtime::kAcquireOutfeedBufferForPopulationSymbolName,
420                 acquire_type)
421             .getCallee());
422   }
423   acquire_func->setCallingConv(llvm::CallingConv::C);
424 
425   llvm::FunctionType* release_type = llvm::FunctionType::get(
426       b_.getVoidTy(),
427       {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
428        /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
429        /*shape_length*/ int32_type},
430       /*isVarArg=*/false);
431 
432   llvm::Function* release_func;
433   if (kind == XfeedKind::kInfeed) {
434     release_func = llvm::dyn_cast<llvm::Function>(
435         module_
436             ->getOrInsertFunction(
437                 runtime::kReleaseInfeedBufferAfterDequeueSymbolName,
438                 release_type)
439             .getCallee());
440   } else {
441     release_func = llvm::dyn_cast<llvm::Function>(
442         module_
443             ->getOrInsertFunction(
444                 runtime::kReleaseOutfeedBufferAfterPopulationSymbolName,
445                 release_type)
446             .getCallee());
447   }
448   release_func->setCallingConv(llvm::CallingConv::C);
449 
450   // Implementation note: this call informs the runtime that it wants a buffer
451   // of size exactly 'length_32', and the runtime is responsible for
452   // check-failing the process if there is a mismatch, versus passing us back a
453   // buffer that we might overrun.
454   llvm::Value* acquired_pointer = Call(
455       acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
456                      shape_ptr, b_.getInt32(shape_length)});
457 
458   if (kind == XfeedKind::kInfeed) {
459     // Copy to the program buffer address from the acquired buffer.
460     MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
461            /*SrcAlign=*/1, length_32);
462   } else {
463     // Outfeed -- copy from the in-program address to the acquired buffer.
464     MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
465            /*SrcAlign=*/1, length_32);
466   }
467 
468   Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
469                       acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
470 
471   return Status::OK();
472 }
473 
HandleOutfeed(HloInstruction * outfeed)474 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
475   // Outfeed produces no useful result, but it does return a token[] that can be
476   // threaded through to other side effecting operations to ensure ordering.  In
477   // the IR emitter we treat this token as a normal u8[] and thus need to insert
478   // an entry for it in emitted_value_.
479   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
480 
481   HloInstruction* operand = outfeed->operands()[0];
482   const Shape& operand_shape = operand->shape();
483 
484   llvm::Value* value = GetEmittedValueFor(operand);
485   if (!operand_shape.IsTuple()) {
486     return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
487   }
488 
489   TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
490 
491   for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
492     const Shape& tuple_element_shape =
493         ShapeUtil::GetTupleElementShape(operand_shape, i);
494     llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
495         tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
496         value, &b_);
497     TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
498                                          tuple_element_shape, tuple_element));
499   }
500 
501   return Status::OK();
502 }
503 
HandleSort(HloInstruction * hlo)504 Status IrEmitter::HandleSort(HloInstruction* hlo) {
505   const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
506   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
507   Shape keys_shape = sort->keys()->shape();
508   PrimitiveType keys_type = keys_shape.element_type();
509   switch (keys_type) {
510     case PRED:
511     case S8:
512     case U8:
513     case S16:
514     case U16:
515     case BF16:
516     case F16:
517     case S32:
518     case U32:
519     case F32:
520     case S64:
521     case U64:
522     case F64:
523       break;
524     default:
525       return Unimplemented(
526           "Element type %s not supported in the Sort op on CPU.",
527           PrimitiveType_Name(keys_type));
528   }
529   std::vector<llvm::Value*> destination_addresses(sort->operand_count());
530   for (int64 i = 0; i < sort->operand_count(); ++i) {
531     ShapeIndex shape_index =
532         sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
533     const HloInstruction* operand = sort->operand(i);
534     // We assume that the layout of all involved operands and outputs is the
535     // same.
536     TF_RET_CHECK(
537         LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
538     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
539         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
540 
541     // The sort is implemented in-place, therefore we first copy the operand
542     // buffer to the output buffer if they are not the same.
543     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
544     destination_addresses[i] =
545         EmitBufferPointer(destination_buffer, operand->shape());
546     auto source_address = GetAllocationSlice(*operand);
547     if (destination_buffer != source_address) {
548       int64 primitive_type_size =
549           ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
550       auto source_buffer = GetEmittedValueFor(operand);
551       int64 size = ByteSizeOf(operand->shape());
552       MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size,
553              source_buffer,
554              /*SrcAlign=*/primitive_type_size, size);
555     }
556   }
557 
558   // Normalize the shape and the dimension to sort.
559   Shape normalized_keys_shape =
560       ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
561   int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
562       keys_shape.layout())[sort->sort_dimension()];
563 
564   int64 sort_dimension_elements =
565       normalized_keys_shape.dimensions(physical_dimension_to_sort);
566   int64 higher_dimensions = 1;
567   for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
568     higher_dimensions *= normalized_keys_shape.dimensions(i);
569   }
570   int64 lower_dimensions = 1;
571   for (int64 i = normalized_keys_shape.rank() - 1;
572        i > physical_dimension_to_sort; --i) {
573     lower_dimensions *= normalized_keys_shape.dimensions(i);
574   }
575 
576   auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
577   CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
578   llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
579       b_.getVoidTy(),
580       {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
581        b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
582        b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(),
583        b_.getInt64Ty()->getPointerTo(), less_than_function->getType()},
584       /*isVarArg=*/false);
585   auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
586       module_
587           ->getOrInsertFunction(runtime::kKeyValueSortSymbolName,
588                                 key_value_sort_type)
589           .getCallee());
590   key_value_sort_func->setCallingConv(llvm::CallingConv::C);
591   key_value_sort_func->setDoesNotThrow();
592   llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
593       b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
594       &b_);
595   llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
596       b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
597       &b_);
598   for (int64 i = 0; i < sort->operand_count(); ++i) {
599     llvm::Value* value_as_i8ptr =
600         PointerCast(destination_addresses[i], b_.getInt8PtrTy());
601     llvm::Value* slot_in_values_alloca =
602         ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
603     Store(value_as_i8ptr, slot_in_values_alloca);
604     llvm::Value* slot_in_sizes_alloca =
605         ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
606     llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
607         sort->operand(i)->shape().element_type()));
608     Store(size, slot_in_sizes_alloca);
609   }
610 
611   Call(key_value_sort_func,
612        {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
613         b_.getInt64(lower_dimensions), values,
614         b_.getInt32(sort->operand_count()), sizes,
615         b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(),
616         GetProfileCountersArgument(), less_than_function});
617 
618   if (sort->values_count() > 0) {
619     llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
620   }
621   return Status::OK();
622 }
623 
HandleTuple(HloInstruction * tuple)624 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
625   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
626   std::vector<llvm::Value*> base_ptrs;
627   for (auto operand : tuple->operands()) {
628     base_ptrs.push_back(GetEmittedValueFor(operand));
629   }
630   llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
631   return Status::OK();
632 }
633 
EmitElementalMap(const HloMapInstruction & map_instr,absl::Span<llvm::Value * const> elemental_operands,absl::string_view name)634 llvm::Value* IrEmitter::EmitElementalMap(
635     const HloMapInstruction& map_instr,
636     absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
637   return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
638 }
639 
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::IrArray::Index & index)640 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
641     const HloReduceWindowInstruction* reduce_window,
642     const llvm_ir::ElementGenerator& input_generator,
643     const llvm_ir::IrArray::Index& index) {
644   const HloInstruction* operand = reduce_window->operand(0);
645   const Window& window = reduce_window->window();
646 
647   // We fold inputs into the accumulator and initialize it to
648   // the initial value on the reduce_window.
649   PrimitiveType operand_element_type = operand->shape().element_type();
650   llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
651       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
652       "reduce_window_accumulator_address", &b_,
653       MinimumAlignmentForPrimitiveType(operand_element_type));
654   Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
655         accumulator_address);
656 
657   llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
658   std::vector<int64> window_size;
659   for (const auto& dim : window.dimensions()) {
660     window_size.push_back(dim.size());
661   }
662   const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
663       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
664   CHECK_EQ(window_index.size(), index.size());
665 
666   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
667 
668   std::vector<llvm::Value*> input_multi_index(index.size());
669   llvm::Value* in_bounds_condition = nullptr;
670   for (size_t i = 0; i < index.size(); ++i) {
671     llvm::Value* strided_index =
672         NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
673     input_multi_index[i] = NSWSub(
674         NSWAdd(strided_index,
675                NSWMul(window_index[i],
676                       b_.getInt64(window.dimensions(i).window_dilation()))),
677         b_.getInt64(window.dimensions(i).padding_low()));
678 
679     // We need to verify that we are not in the dilated base area.
680     llvm::Value* dilation_condition =
681         ICmpEQ(SRem(input_multi_index[i],
682                     b_.getInt64(window.dimensions(i).base_dilation())),
683                b_.getInt64(0));
684     if (in_bounds_condition == nullptr) {
685       in_bounds_condition = dilation_condition;
686     } else {
687       in_bounds_condition = And(in_bounds_condition, dilation_condition);
688     }
689 
690     // Apply base dilation to the index.
691     input_multi_index[i] =
692         SDiv(input_multi_index[i],
693              b_.getInt64(window.dimensions(i).base_dilation()));
694 
695     // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we
696     // are in the padding so that we can skip the computation. That is
697     // equivalent to input_multi_index[i] < bound as an *unsigned* comparison,
698     // since a negative value will wrap to a large positive value.
699     llvm::Value* index_condition =
700         ICmpULT(input_multi_index[i],
701                 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
702     if (in_bounds_condition == nullptr) {
703       in_bounds_condition = index_condition;
704     } else {
705       in_bounds_condition = And(in_bounds_condition, index_condition);
706     }
707   }
708   CHECK(in_bounds_condition != nullptr);
709 
710   llvm_ir::LlvmIfData if_data =
711       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
712   SetToFirstInsertPoint(if_data.true_block, &b_);
713 
714   // We are not in the padding, so carry out the computation.
715   llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(),
716                                       b_.getInt64Ty());
717   TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
718                       input_generator(input_index));
719   llvm::Value* result = EmitThreadLocalCall(
720       *reduce_window->to_apply(), {Load(accumulator_address), input_value},
721       "reducer_function");
722   Store(result, accumulator_address);
723 
724   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
725   return Load(accumulator_address);
726 }
727 
HandleReduceWindow(HloInstruction * reduce_window)728 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
729   // Pseudo code for reduce window:
730   //
731   //   for (coordinates O in the output)
732   //     value = init_value;
733   //     for (coordinates W in the window)
734   //       for each index i:
735   //         input coordinates I_i = O_i * stride_i + W_i - pad_low_i
736   //       if I within bounds of input:
737   //         value = function(value, input(I));
738   //     output(O) = value;
739   //
740   // This is completely un-optimized and just here to have something
741   // that works.
742   return DefaultAction(reduce_window);
743 }
744 
HandleSelectAndScatter(HloInstruction * select_and_scatter)745 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
746   CHECK_EQ(select_and_scatter->operand_count(), 3);
747   const auto operand = select_and_scatter->operand(0);
748   const auto source = select_and_scatter->operand(1);
749   const auto init_value = select_and_scatter->operand(2);
750   const Window& window = select_and_scatter->window();
751   PrimitiveType operand_element_type = operand->shape().element_type();
752   const int64 rank = operand->shape().rank();
753   CHECK_EQ(rank, source->shape().rank());
754   CHECK_EQ(rank, window.dimensions_size());
755 
756   // TODO(b/31410564): Implement dilation for select-and-scatter.
757   if (window_util::HasDilation(window)) {
758     return Unimplemented(
759         "Dilation for SelectAndScatter is not implemented on CPU. ");
760   }
761 
762   // Pseudo code for select-and-scatter:
763   //
764   // initialized_flag is initially off for every window, and is turned on after
765   // the first iteration is completed and the first operand value is selected.
766   //
767   // output(*) = init_value
768   // for (coordinates S in the source) {
769   //   initialized_flag = false
770   //   for (coordinates W in the window) {
771   //     I = S * stride + W - pad_low
772   //     if I within bounds of operand:
773   //       if !initialized_flag or select(selected_value, operand(I)) == false:
774   //         selected_value = operand(I)
775   //         selected_index = I
776   //         initialized_flag = true
777   //   }
778   //   output(selected_index) = scatter(output(selected_index), source(S))
779   // }
780   //
781 
782   // Initialize the output array with the given init_value.
783   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
784       select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
785       [this, init_value](const llvm_ir::IrArray::Index& target_index) {
786         llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
787         return Load(init_value_addr);
788       }));
789 
790   // Create a loop to iterate over the source array to scatter to the output.
791   llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
792   const llvm_ir::IrArray::Index source_index =
793       source_loops.AddLoopsForShape(source->shape(), "source");
794   SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
795 
796   // Allocate space to keep the currently selected value, its index, and
797   // the boolean initialized_flag, which is initially set to false.
798   llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
799       llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
800       "selected_value_address", &b_,
801       MinimumAlignmentForPrimitiveType(operand_element_type));
802   llvm::Value* selected_index_address =
803       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
804           b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
805   llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
806       b_.getInt1Ty(), "initialized_flag_address", &b_);
807   Store(b_.getInt1(false), initialized_flag_address);
808 
809   // Create the inner loop to iterate over the window.
810   llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
811   std::vector<int64> window_size;
812   for (const auto& dim : window.dimensions()) {
813     window_size.push_back(dim.size());
814   }
815   const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
816       ShapeUtil::MakeShape(operand_element_type, window_size), "window");
817   SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
818 
819   // Compute the operand index to visit and evaluate the condition whether the
820   // operand index is within the bounds. The unsigned comparison includes
821   // checking whether the operand index >= 0.
822   std::vector<llvm::Value*> operand_multi_index(source_index.size());
823   llvm::Value* in_bounds_condition = b_.getTrue();
824   for (int64 i = 0; i < rank; ++i) {
825     llvm::Value* strided_index =
826         NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
827     operand_multi_index[i] =
828         NSWSub(NSWAdd(strided_index, window_index[i]),
829                b_.getInt64(window.dimensions(i).padding_low()));
830     llvm::Value* index_condition =
831         ICmpULT(operand_multi_index[i],
832                 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
833     in_bounds_condition = And(in_bounds_condition, index_condition);
834   }
835   CHECK(in_bounds_condition != nullptr);
836 
837   // Only need to do something if the operand index is within the bounds. First
838   // check if the initialized_flag is set.
839   llvm_ir::LlvmIfData if_in_bounds =
840       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
841   SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
842   llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
843       Load(initialized_flag_address), "initialized", &b_);
844 
845   // If the initialized_flag is false, initialize the selected value and index
846   // with the currently visiting operand.
847   SetToFirstInsertPoint(if_initialized.false_block, &b_);
848   const auto save_operand_index =
849       [&](const llvm_ir::IrArray::Index& operand_index) {
850         for (int64 i = 0; i < rank; ++i) {
851           llvm::Value* selected_index_address_slot =
852               InBoundsGEP(selected_index_address, {b_.getInt32(i)});
853           Store(operand_index[i], selected_index_address_slot);
854         }
855       };
856   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
857   llvm_ir::IrArray::Index operand_index(
858       operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
859   llvm::Value* operand_data =
860       operand_array.EmitReadArrayElement(operand_index, &b_);
861   Store(operand_data, selected_value_address);
862   save_operand_index(operand_index);
863   Store(b_.getInt1(true), initialized_flag_address);
864 
865   // If the initialized_flag is true, call the `select` function to potentially
866   // update the selected value and index with the currently visiting operand.
867   SetToFirstInsertPoint(if_initialized.true_block, &b_);
868   llvm::Value* operand_address =
869       operand_array.EmitArrayElementAddress(operand_index, &b_);
870   llvm::Value* operand_element = Load(operand_address);
871   llvm::Value* result = EmitThreadLocalCall(
872       *select_and_scatter->select(),
873       {Load(selected_value_address), operand_element}, "select_function");
874 
875   // If the 'select' function returns false, update the selected value and the
876   // index to the currently visiting operand.
877   llvm::Value* cond = ICmpNE(
878       result,
879       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
880       "boolean_predicate");
881   llvm_ir::LlvmIfData if_select_lhs =
882       llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
883   SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
884   Store(Load(operand_address), selected_value_address);
885   save_operand_index(operand_index);
886 
887   // After iterating over the window elements, scatter the source element to
888   // the selected index of the output. The value we store at the output
889   // location is computed by calling the `scatter` function with the source
890   // value and the current output value.
891   SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
892   std::vector<llvm::Value*> selected_multi_index;
893   for (int64 i = 0; i < rank; ++i) {
894     llvm::Value* selected_index_address_slot =
895         InBoundsGEP(selected_index_address, {b_.getInt32(i)});
896     selected_multi_index.push_back(Load(selected_index_address_slot));
897   }
898   llvm_ir::IrArray source_array(GetIrArrayFor(source));
899   llvm::Value* source_value =
900       source_array.EmitReadArrayElement(source_index, &b_);
901   llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
902   llvm_ir::IrArray::Index selected_index(
903       selected_multi_index, output_array.GetShape(), source_index.GetType());
904   llvm::Value* output_value =
905       output_array.EmitReadArrayElement(selected_index, &b_);
906   llvm::Value* scatter_value =
907       EmitThreadLocalCall(*select_and_scatter->scatter(),
908                           {output_value, source_value}, "scatter_function");
909   output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
910 
911   SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
912   return Status::OK();
913 }
914 
HandleDot(HloInstruction * dot)915 Status IrEmitter::HandleDot(HloInstruction* dot) {
916   auto lhs = dot->operand(0);
917   auto rhs = dot->operand(1);
918   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
919       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
920       /*supported_types=*/{F16, F32, F64, C64, C128}));
921   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
922 
923   if (dnums.lhs_contracting_dimensions_size() != 1) {
924     // This is disallowed by ShapeInference today.
925     return Unimplemented(
926         "Dot with multiple contracting dimensions not implemented.");
927   }
928 
929   llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
930   llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
931 
932   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
933   llvm_ir::IrArray target_array = GetIrArrayFor(dot);
934 
935   VLOG(2) << "HandleDot: ";
936   VLOG(2) << "  lhs operand: "
937           << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
938   VLOG(2) << "  rhs operand: "
939           << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
940   VLOG(2) << "  target: "
941           << llvm_ir::DumpToString(*target_array.GetBasePointer());
942 
943   // Dot operation is complicated so we delegate to a helper class.
944   return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
945                           /*addend_array=*/nullptr,
946                           GetExecutableRunOptionsArgument(), &b_,
947                           hlo_module_config_, target_machine_features_);
948 }
949 
EmitElementalConvolution(const HloConvolutionInstruction * convolution,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::ElementGenerator & kernel_generator,const llvm_ir::IrArray::Index & index)950 StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
951     const HloConvolutionInstruction* convolution,
952     const llvm_ir::ElementGenerator& input_generator,
953     const llvm_ir::ElementGenerator& kernel_generator,
954     const llvm_ir::IrArray::Index& index) {
955   const HloInstruction* lhs = convolution->operand(0);
956   const HloInstruction* rhs = convolution->operand(1);
957   const Window& window = convolution->window();
958 
959   const ConvolutionDimensionNumbers& dnums =
960       convolution->convolution_dimension_numbers();
961   int num_spatial_dims = dnums.output_spatial_dimensions_size();
962   std::vector<llvm::Value*> output_spatial(num_spatial_dims);
963   for (int i = 0; i < num_spatial_dims; ++i) {
964     output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
965   }
966   llvm::Value* output_feature = index[dnums.output_feature_dimension()];
967   llvm::Value* batch = index[dnums.output_batch_dimension()];
968 
969   // We will accumulate the products into this sum to calculate the output entry
970   // at the given index.
971   PrimitiveType lhs_element_type = lhs->shape().element_type();
972   llvm::Type* lhs_llvm_type =
973       llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
974   llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
975       lhs_llvm_type, "convolution_sum_address", &b_,
976       MinimumAlignmentForPrimitiveType(lhs_element_type));
977   llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
978   Store(constant_zero, sum_address);
979 
980   llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
981   std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
982   for (int i = 0; i < num_spatial_dims; ++i) {
983     kernel_spatial[i] =
984         loops
985             .AddLoop(
986                 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
987                 absl::StrCat("k", i))
988             ->GetIndVarValue();
989   }
990   llvm::Value* input_feature =
991       loops
992           .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
993                    "iz")
994           ->GetIndVarValue();
995 
996   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
997 
998   // Calculate the spatial index in the input array, taking striding, dilation
999   // and padding into account. An index in the padding will be out of the bounds
1000   // of the array.
1001   const auto calculate_input_index = [this](llvm::Value* output_index,
1002                                             llvm::Value* kernel_index,
1003                                             const WindowDimension& window_dim) {
1004     llvm::Value* strided_index =
1005         NSWMul(output_index, b_.getInt64(window_dim.stride()));
1006     llvm::Value* dilated_kernel_index =
1007         NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
1008     return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
1009                   b_.getInt64(window_dim.padding_low()));
1010   };
1011   std::vector<llvm::Value*> input_spatial(num_spatial_dims);
1012   for (int i = 0; i < num_spatial_dims; ++i) {
1013     input_spatial[i] = calculate_input_index(
1014         output_spatial[i], kernel_spatial[i], window.dimensions(i));
1015   }
1016 
1017   // We need to check if 0 <= input dim < bound, as otherwise we are in the
1018   // padding so that we can skip the computation. That is equivalent to input
1019   // dim < bound as an *unsigned* comparison, since a negative value will wrap
1020   // to a large positive value. The input dim is dilated, so we need to dilate
1021   // the bound as well to match.
1022 
1023   // Also need to check that the input coordinates are not in one of the
1024   // holes created by base dilation.
1025   const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
1026     llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
1027     return ICmpEQ(remainder, b_.getInt64(0));
1028   };
1029 
1030   llvm::Value* in_bounds_condition = b_.getInt1(true);
1031   for (int i = 0; i < num_spatial_dims; ++i) {
1032     llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
1033         lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
1034         window.dimensions(i).base_dilation()));
1035     llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
1036     llvm::Value* dim_not_in_hole =
1037         not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
1038     llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
1039     in_bounds_condition = And(in_bounds_condition, dim_ok);
1040   }
1041 
1042   // Now we need to map the dilated base coordinates back to the actual
1043   // data indices on the lhs.
1044   const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
1045     return SDiv(input_index, b_.getInt64(base_dilation));
1046   };
1047   for (int i = 0; i < num_spatial_dims; ++i) {
1048     input_spatial[i] =
1049         undilate(input_spatial[i], window.dimensions(i).base_dilation());
1050   }
1051 
1052   llvm_ir::LlvmIfData if_data =
1053       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
1054   SetToFirstInsertPoint(if_data.true_block, &b_);
1055 
1056   // We are not in the padding, so carry out the computation.
1057   int num_dims = num_spatial_dims + 2;
1058   std::vector<llvm::Value*> input_multi_index(num_dims);
1059   for (int i = 0; i < num_spatial_dims; ++i) {
1060     input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
1061   }
1062   input_multi_index[dnums.input_feature_dimension()] = input_feature;
1063   input_multi_index[dnums.input_batch_dimension()] = batch;
1064 
1065   std::vector<llvm::Value*> kernel_multi_index(num_dims);
1066   for (int i = 0; i < num_spatial_dims; ++i) {
1067     kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
1068         window.dimensions(i).window_reversal()
1069             ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
1070                      kernel_spatial[i])
1071             : kernel_spatial[i];
1072   }
1073 
1074   kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
1075   kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
1076 
1077   llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
1078                                       b_.getInt64Ty());
1079   TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
1080                       input_generator(input_index));
1081   llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
1082                                        b_.getInt64Ty());
1083   TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
1084                       kernel_generator(kernel_index));
1085   llvm::Value* product = FMul(input_value, kernel_value);
1086   llvm::Value* sum = FAdd(Load(sum_address), product);
1087   Store(sum, sum_address);
1088 
1089   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1090   return Load(sum_address);
1091 }
1092 
HandleConvolution(HloInstruction * convolution)1093 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
1094   auto lhs = convolution->operand(0);
1095   auto rhs = convolution->operand(1);
1096   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1097       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
1098       /*supported_types=*/{F16, F32, F64, C64, C128}));
1099 
1100   // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
1101   // different data layouts.
1102   if (PotentiallyImplementedAsEigenConvolution(*convolution,
1103                                                target_machine_features_)) {
1104     const Shape& lhs_shape = lhs->shape();
1105     const Shape& rhs_shape = rhs->shape();
1106     const Shape& convolution_shape = convolution->shape();
1107     // The input, kernel and output agree with respect to layout.
1108     if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
1109         LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
1110         LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
1111       // We lower 1D convolutions into calls to the same Eigen function as 2D
1112       // convolutions, except that we pretend that the 1D convolution is really
1113       // a 2D convolution with the missing dimension set to 1.  We also adjust
1114       // the padding, dilation parameters as needed.
1115       bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
1116       llvm::Value* lhs_address = GetEmittedValueFor(lhs);
1117       llvm::Value* rhs_address = GetEmittedValueFor(rhs);
1118       TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
1119 
1120       const ConvolutionDimensionNumbers& dnums =
1121           convolution->convolution_dimension_numbers();
1122 
1123       // Input tensor.
1124       const Shape& input_shape = convolution->operand(0)->shape();
1125       int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
1126       int64 input_rows =
1127           input_shape.dimensions(dnums.input_spatial_dimensions(0));
1128       int64 input_cols =
1129           one_dim_convolution
1130               ? 1
1131               : input_shape.dimensions(dnums.input_spatial_dimensions(1));
1132       int64 input_channels =
1133           input_shape.dimensions(dnums.input_feature_dimension());
1134 
1135       // Kernel tensor.
1136       const Shape& kernel_shape = convolution->operand(1)->shape();
1137       int64 kernel_rows =
1138           kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
1139       int64 kernel_cols =
1140           one_dim_convolution
1141               ? 1
1142               : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
1143       int64 kernel_channels =
1144           kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
1145       int64 kernel_filters =
1146           kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
1147 
1148       // Output tensor.
1149       const Shape& convolution_shape = convolution->shape();
1150       int64 output_rows =
1151           convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
1152       int64 output_cols = one_dim_convolution
1153                               ? 1
1154                               : convolution_shape.dimensions(
1155                                     dnums.output_spatial_dimensions(1));
1156 
1157       // Extract the window stride for the convolution.
1158       const Window& window = convolution->window();
1159       int64 row_stride = window.dimensions(0).stride();
1160       int64 col_stride =
1161           one_dim_convolution ? 1 : window.dimensions(1).stride();
1162 
1163       int64 padding_top = window.dimensions(0).padding_low();
1164       int64 padding_bottom = window.dimensions(0).padding_high();
1165       int64 padding_left =
1166           one_dim_convolution ? 0 : window.dimensions(1).padding_low();
1167       int64 padding_right =
1168           one_dim_convolution ? 0 : window.dimensions(1).padding_high();
1169 
1170       int64 lhs_row_dilation = window.dimensions(0).base_dilation();
1171       int64 lhs_col_dilation =
1172           one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
1173       int64 rhs_row_dilation = window.dimensions(0).window_dilation();
1174       int64 rhs_col_dilation =
1175           one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
1176 
1177       PrimitiveType primitive_type = lhs->shape().element_type();
1178       llvm::Type* ir_ptr_type = primitive_type == F16
1179                                     ? b_.getHalfTy()->getPointerTo()
1180                                     : b_.getFloatTy()->getPointerTo();
1181       llvm::Type* int64_type = b_.getInt64Ty();
1182       llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1183       llvm::FunctionType* conv_type = llvm::FunctionType::get(
1184           b_.getVoidTy(),
1185           {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
1186            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
1187            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
1188            int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
1189            int64_type,    int64_type,  int64_type,  int64_type},
1190           /*isVarArg=*/false);
1191       bool multi_threaded =
1192           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1193       bool use_mkl_dnn =
1194           hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
1195 
1196       // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
1197       // potential race condition by setting the omp_num_threads.
1198       const char* fn_name =
1199           primitive_type == F16
1200               ? (multi_threaded
1201                      ? runtime::kEigenConvF16SymbolName
1202                      : runtime::kEigenSingleThreadedConvF16SymbolName)
1203               : (multi_threaded
1204                      ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
1205                                     : runtime::kEigenConvF32SymbolName)
1206                      : runtime::kEigenSingleThreadedConvF32SymbolName);
1207       if (!multi_threaded && use_mkl_dnn) {
1208         LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
1209                         "conv2d function.";
1210       }
1211       llvm::Function* conv_func = llvm::dyn_cast<llvm::Function>(
1212           module_->getOrInsertFunction(fn_name, conv_type).getCallee());
1213       conv_func->setCallingConv(llvm::CallingConv::C);
1214       conv_func->setDoesNotThrow();
1215       conv_func->setOnlyAccessesArgMemory();
1216       Call(conv_func, {
1217                           GetExecutableRunOptionsArgument(),
1218                           BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
1219                           BitCast(lhs_address, ir_ptr_type),
1220                           BitCast(rhs_address, ir_ptr_type),
1221                           b_.getInt64(input_batch),
1222                           b_.getInt64(input_rows),
1223                           b_.getInt64(input_cols),
1224                           b_.getInt64(input_channels),
1225                           b_.getInt64(kernel_rows),
1226                           b_.getInt64(kernel_cols),
1227                           b_.getInt64(kernel_channels),
1228                           b_.getInt64(kernel_filters),
1229                           b_.getInt64(output_rows),
1230                           b_.getInt64(output_cols),
1231                           b_.getInt64(row_stride),
1232                           b_.getInt64(col_stride),
1233                           b_.getInt64(padding_top),
1234                           b_.getInt64(padding_bottom),
1235                           b_.getInt64(padding_left),
1236                           b_.getInt64(padding_right),
1237                           b_.getInt64(lhs_row_dilation),
1238                           b_.getInt64(lhs_col_dilation),
1239                           b_.getInt64(rhs_row_dilation),
1240                           b_.getInt64(rhs_col_dilation),
1241                       });
1242 
1243       return Status::OK();
1244     }
1245   }
1246 
1247   // This is a completely un-optimized version of convolution just to
1248   // have an early version that works. E.g. the input index and
1249   // padding calculation is not hoisted out of the inner loop.
1250   //
1251   // See the description of convolution in the XLA documentation for the pseudo
1252   // code for convolution.
1253   return DefaultAction(convolution);
1254 }
1255 
HandleFft(HloInstruction * fft)1256 Status IrEmitter::HandleFft(HloInstruction* fft) {
1257   auto operand = fft->operand(0);
1258   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1259       /*instruction=*/*fft, /*operands=*/{operand},
1260       /*supported_types=*/{F32, C64}));
1261   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1262   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1263   VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1264   VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1265 
1266   llvm::Value* operand_address = GetEmittedValueFor(operand);
1267   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1268 
1269   const std::vector<int64>& fft_length = fft->fft_length();
1270   int64 input_batch = 1;
1271   for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1272     input_batch *= fft->shape().dimensions(i);
1273   }
1274 
1275   // Args have been computed, make the call.
1276   llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1277   llvm::Type* int32_type = b_.getInt32Ty();
1278   llvm::Type* int64_type = b_.getInt64Ty();
1279   llvm::FunctionType* fft_type = llvm::FunctionType::get(
1280       b_.getVoidTy(),
1281       {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
1282        int64_type, int64_type, int64_type, int64_type},
1283       /*isVarArg=*/false);
1284 
1285   bool multi_threaded_eigen =
1286       hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1287   const char* fn_name = multi_threaded_eigen
1288                             ? runtime::kEigenFftSymbolName
1289                             : runtime::kEigenSingleThreadedFftSymbolName;
1290 
1291   llvm::Function* fft_func = llvm::dyn_cast<llvm::Function>(
1292       module_->getOrInsertFunction(fn_name, fft_type).getCallee());
1293   fft_func->setCallingConv(llvm::CallingConv::C);
1294   fft_func->setDoesNotThrow();
1295   fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
1296   const int fft_rank = fft_length.size();
1297   Call(fft_func,
1298        {GetExecutableRunOptionsArgument(),
1299         BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1300         BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1301         b_.getInt32(fft_rank), b_.getInt64(input_batch),
1302         b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1303         b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1304         b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
1305 
1306   return Status::OK();
1307 }
1308 
HandleAllReduce(HloInstruction * crs)1309 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1310   if (hlo_module_config_.replica_count() != 1) {
1311     // TODO(b/33011107): Support nontrivial cross replica sum on CPU.
1312     return Unimplemented(
1313         "AllReduce with >1 replica is not implemented on CPU.");
1314   }
1315 
1316   // When there is a single replica, a cross replica sum is the identity
1317   // function, and the buffer assignment expects a copy.
1318   //
1319   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1320   // in algebraic-simplifier, but currently on some platforms
1321   // HloModuleConfig::num_replicas changes between when the module is compiled
1322   // and when it's run.
1323   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1324 
1325   // CRS with one operand and one replica is simply the identity function.
1326   if (crs->operand_count() == 1) {
1327     return EmitMemcpy(*crs->operand(0), *crs);
1328   }
1329 
1330   // CRS with multiple operands and one replica produces a (one-deep) tuple.
1331   std::vector<llvm::Value*> operand_ptrs;
1332   for (int64 i = 0; i < crs->operand_count(); ++i) {
1333     llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1334     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1335                         assignment_.GetUniqueSlice(crs, {i}));
1336 
1337     const Shape& operand_shape = crs->operand(i)->shape();
1338     CHECK(operand_shape.IsArray())
1339         << "Operands to all-reduce must be arrays: " << crs->ToString();
1340     operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1341 
1342     // TODO(b/63762267): Be more aggressive about specifying alignment.
1343     MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
1344            /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
1345   }
1346   llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1347   return Status::OK();
1348 }
1349 
HandleParameter(HloInstruction * parameter)1350 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1351   VLOG(2) << "HandleParameter: " << parameter->ToString();
1352   return EmitTargetAddressForOp(parameter);
1353 }
1354 
1355 // Returns true if the relative order of the unreduced dimensions stays the same
1356 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1357 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1358   DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1359 
1360   // Maps dimensions that were not reduced from their dimension numbers in the
1361   // source shape to their dimensions numbers in the destination shape.
1362   //
1363   // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1364   // [0->0, 3->1].
1365   absl::flat_hash_map<int64, int64> unreduced_dim_map;
1366 
1367   absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1368                                           reduce.dimensions().end());
1369 
1370   const Shape& operand_shape = reduce.operand(0)->shape();
1371   const Shape& result_shape = reduce.shape();
1372 
1373   int64 delta = 0;
1374   for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
1375     if (reduced_dims.contains(i)) {
1376       delta++;
1377     } else {
1378       InsertOrDie(&unreduced_dim_map, i, i - delta);
1379     }
1380   }
1381 
1382   // Iterate dimensions minor to major and check that the corresponding
1383   // dimensions in the source and target shapes are equivalent.
1384   int64 result_dim_idx = 0;
1385   for (int64 operand_dim_idx = 0;
1386        operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1387     int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
1388     if (!reduced_dims.contains(operand_dim)) {
1389       if (FindOrDie(unreduced_dim_map, operand_dim) !=
1390           result_shape.layout().minor_to_major(result_dim_idx++)) {
1391         return false;
1392       }
1393     }
1394   }
1395 
1396   CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1397 
1398   return true;
1399 }
1400 
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1401 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1402     HloComputation* function, string* failure_reason) const {
1403   CHECK_EQ(function->num_parameters(), 2);
1404 
1405   auto root_instruction = function->root_instruction();
1406   CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1407 
1408   if (root_instruction->operand_count() != 2) {
1409     *failure_reason = "root instruction is not a binary operation";
1410     return nullptr;
1411   }
1412 
1413   const Shape& root_shape = root_instruction->shape();
1414   if (ShapeUtil::ElementIsComplex(root_shape)) {
1415     // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1416     // Complex multiply would be more challenging. We could perhaps use a
1417     // strided load to get all reals in a vector, all images in a vector, or use
1418     // CreateShuffleVector on a bitcast to float x [2N].
1419     *failure_reason = "complex values not supported";
1420     return nullptr;
1421   }
1422   bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1423   bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1424   bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1425 
1426   auto lhs = root_instruction->operand(0);
1427   auto rhs = root_instruction->operand(1);
1428 
1429   auto param_0 = function->parameter_instruction(0);
1430   auto param_1 = function->parameter_instruction(1);
1431   if (!(lhs == param_0 && rhs == param_1) &&
1432       !(rhs == param_0 && lhs == param_1)) {
1433     *failure_reason =
1434         "root instruction is not a binary operation on the incoming arguments";
1435     return nullptr;
1436   }
1437 
1438   CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1439 
1440   // This is visually similar to ElementalIrEmitter, though conceptually we're
1441   // doing something different here.  ElementalIrEmitter emits scalar operations
1442   // while these emit scalar or vector operations depending on the type of the
1443   // operands. See CreateShardedVectorType for the actual types in use here.
1444   switch (root_instruction->opcode()) {
1445     default:
1446       *failure_reason = "did not recognize root instruction opcode";
1447       return nullptr;
1448 
1449     case HloOpcode::kAdd:
1450       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1451                                 llvm::Value* rhs) {
1452         return root_is_integral ? b->CreateAdd(lhs, rhs)
1453                                 : b->CreateFAdd(lhs, rhs);
1454       };
1455 
1456     case HloOpcode::kMultiply:
1457       return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1458                                 llvm::Value* rhs) {
1459         return root_is_integral ? b->CreateMul(lhs, rhs)
1460                                 : b->CreateFMul(lhs, rhs);
1461       };
1462 
1463     case HloOpcode::kAnd:
1464       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1465         return b->CreateAnd(lhs, rhs);
1466       };
1467 
1468     case HloOpcode::kOr:
1469       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1470         return b->CreateOr(lhs, rhs);
1471       };
1472 
1473     case HloOpcode::kXor:
1474       return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1475         return b->CreateXor(lhs, rhs);
1476       };
1477 
1478     case HloOpcode::kMaximum:
1479       return [root_is_floating_point, root_is_signed](
1480                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1481                  llvm::Value* rhs) -> llvm::Value* {
1482         if (root_is_floating_point) {
1483           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
1484                                               {lhs, rhs}, {lhs->getType()}, b);
1485         }
1486 
1487         return b->CreateSelect(
1488             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1489                                          : llvm::ICmpInst::ICMP_UGE,
1490                           lhs, rhs),
1491             lhs, rhs);
1492       };
1493 
1494     case HloOpcode::kMinimum:
1495       return [root_is_floating_point, root_is_signed](
1496                  llvm::IRBuilder<>* b, llvm::Value* lhs,
1497                  llvm::Value* rhs) -> llvm::Value* {
1498         if (root_is_floating_point) {
1499           return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
1500                                               {lhs, rhs}, {lhs->getType()}, b);
1501         }
1502 
1503         return b->CreateSelect(
1504             b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1505                                          : llvm::ICmpInst::ICMP_ULE,
1506                           lhs, rhs),
1507             lhs, rhs);
1508       };
1509   }
1510 }
1511 
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1512 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1513     PrimitiveType element_type, unsigned element_count) {
1514   int vector_register_size_in_elements =
1515       target_machine_features_.vector_register_byte_size(
1516           *compute_function_->function()) /
1517       ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1518 
1519   ShardedVectorType sharded_vector_type;
1520   llvm::Type* element_ir_type =
1521       llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1522 
1523   for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1524     // For every power of two present in element_count, we generate one or more
1525     // vector or scalar types.
1526     const unsigned current_size_fragment = 1u << i;
1527     if (!(element_count & current_size_fragment)) {
1528       // Power of two not present in element_count.
1529       continue;
1530     }
1531 
1532     if (current_size_fragment == 1) {
1533       // Single element, use a scalar type.
1534       sharded_vector_type.push_back(element_ir_type);
1535       continue;
1536     }
1537 
1538     // Lower "current_size_fragment" number of elements using (as few as
1539     // possible) vector registers.
1540 
1541     if (current_size_fragment >= vector_register_size_in_elements) {
1542       auto vector_type = llvm::VectorType::get(
1543           element_ir_type, vector_register_size_in_elements);
1544       sharded_vector_type.insert(
1545           sharded_vector_type.end(),
1546           current_size_fragment / vector_register_size_in_elements,
1547           vector_type);
1548 
1549       // Both current_size_fragment and vector_register_size_in_elements are
1550       // powers of two.
1551       CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1552       continue;
1553     }
1554 
1555     // For now we assume that vector_register_size_in_elements and lower powers
1556     // of two are all legal vector sizes (or at least can be lowered easily by
1557     // LLVM).
1558     sharded_vector_type.push_back(
1559         llvm::VectorType::get(element_ir_type, current_size_fragment));
1560   }
1561   return sharded_vector_type;
1562 }
1563 
1564 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,unsigned element_alignment)1565 IrEmitter::EmitInnerLoopForVectorizedReduction(
1566     const ReductionGenerator& reduction_generator,
1567     const llvm_ir::IrArray::Index& output_index,
1568     const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1569     HloInstruction* arg, absl::Span<const int64> dimensions,
1570     unsigned element_alignment) {
1571   ShardedVector accumulator;
1572   accumulator.reserve(accumulator_type.size());
1573   for (auto accumulator_shard_type : accumulator_type) {
1574     accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1575         accumulator_shard_type, "accumulator", &b_, 0));
1576   }
1577 
1578   llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1579 
1580   for (llvm::Value* accumulator_shard : accumulator) {
1581     llvm::Value* initial_value;
1582     auto shard_type = accumulator_shard->getType()->getPointerElementType();
1583     if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1584       initial_value =
1585           VectorSplat(vector_type->getNumElements(), init_value_ssa);
1586     } else {
1587       initial_value = init_value_ssa;
1588     }
1589 
1590     AlignedStore(initial_value, accumulator_shard, element_alignment);
1591   }
1592 
1593   llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1594                                            &b_);
1595   std::vector<llvm::Value*> input_multi_index =
1596       reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1597                                                        "reduction_dim");
1598 
1599   SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1600 
1601   llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1602   llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1603 
1604   for (auto& i : input_multi_index) {
1605     if (i == nullptr) {
1606       i = *it++;
1607     }
1608   }
1609   CHECK(output_index.end() == it);
1610   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1611                                       b_.getInt64Ty());
1612 
1613   llvm::Value* input_address = BitCast(
1614       arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1615 
1616   for (int i = 0; i < accumulator.size(); i++) {
1617     auto input_address_typed =
1618         BitCast(input_address, accumulator[i]->getType());
1619     auto current_accumulator_value =
1620         AlignedLoad(accumulator[i], element_alignment);
1621     auto addend = AlignedLoad(input_address_typed, element_alignment);
1622     arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1623 
1624     auto reduced_result =
1625         reduction_generator(&b_, current_accumulator_value, addend);
1626     AlignedStore(reduced_result, accumulator[i], element_alignment);
1627 
1628     if (i != (accumulator.size() - 1)) {
1629       input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1630                                            input_address_typed, 1);
1631     }
1632   }
1633 
1634   SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1635 
1636   ShardedVector result_ssa;
1637   result_ssa.reserve(accumulator.size());
1638   for (auto accumulator_shard : accumulator) {
1639     result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1640   }
1641   return result_ssa;
1642 }
1643 
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,const int alignment,const llvm_ir::IrArray & containing_array)1644 void IrEmitter::EmitShardedVectorStore(
1645     llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1646     const int alignment, const llvm_ir::IrArray& containing_array) {
1647   for (int i = 0; i < value_to_store.size(); i++) {
1648     auto store_address_typed =
1649         BitCast(store_address,
1650                 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1651 
1652     auto store_instruction =
1653         AlignedStore(value_to_store[i], store_address_typed, alignment);
1654     containing_array.AnnotateLoadStoreInstructionWithMetadata(
1655         store_instruction);
1656 
1657     if (i != (value_to_store.size() - 1)) {
1658       store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1659                                            store_address_typed, 1);
1660     }
1661   }
1662 }
1663 
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1664 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1665     HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1666     absl::Span<const int64> dimensions, HloComputation* function,
1667     string* failure_reason) {
1668   if (!ReductionPreservesLayout(*reduce)) {
1669     return false;
1670   }
1671 
1672   ReductionGenerator reduction_generator =
1673       MatchReductionGenerator(function, failure_reason);
1674   if (!reduction_generator) {
1675     return false;
1676   }
1677 
1678   int vectorization_factor_in_bytes =
1679       target_machine_features_.vectorization_factor_in_bytes();
1680 
1681   // We try to process vectorization_factor elements at the same time.
1682   const int vectorization_factor =
1683       vectorization_factor_in_bytes /
1684       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1685 
1686   bool is_reduction_over_minor_dimension = absl::c_linear_search(
1687       dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1688 
1689   unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
1690       ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1691       MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
1692 
1693   if (is_reduction_over_minor_dimension) {
1694     // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1695     *failure_reason = "reduction over minor dimension not implemented";
1696     return false;
1697   }
1698 
1699   CHECK(!reduce->shape().IsTuple());
1700   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1701 
1702   // We know we're not reducing over the most minor dimension, which means we
1703   // can lower the reduction loop as:
1704   //
1705   //  1. We're reducing over dimensions R0, R1.
1706   //  2. D0 is the most minor dimension.
1707   //  3. VS is the vectorization stride (we want to reduce this many elements at
1708   //     once)
1709   //
1710   //  for (d1 in D1) {
1711   //    for (d0 in D0 with stride VS) {
1712   //      vector_acc = init
1713   //      for (r1 in R1) {
1714   //        for (r0 in R0) {
1715   //          vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1716   //        }
1717   //      }
1718   //      output[d1, d0] = vector_acc
1719   //    }
1720   //  }
1721 
1722   llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1723   std::vector<llvm::Value*> array_multi_index(
1724       reduce->shape().dimensions_size());
1725   for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1726        --i) {
1727     int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1728     int64 start_index = 0;
1729     int64 end_index = reduce->shape().dimensions(dimension);
1730     std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1731         start_index, end_index, absl::StrFormat("dim.%d", dimension));
1732     array_multi_index[dimension] = loop->GetIndVarValue();
1733   }
1734 
1735   int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1736   int64 innermost_dimension_size =
1737       reduce->shape().dimensions(innermost_dimension);
1738 
1739   if (llvm::BasicBlock* innermost_body_bb =
1740           loop_nest.GetInnerLoopBodyBasicBlock()) {
1741     SetToFirstInsertPoint(innermost_body_bb, &b_);
1742   }
1743 
1744   auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1745 
1746   if (innermost_dimension_size >= vectorization_factor) {
1747     int64 start_index = 0;
1748     int64 end_index = (innermost_dimension_size / vectorization_factor) *
1749                       vectorization_factor;
1750     std::unique_ptr<llvm_ir::ForLoop> loop =
1751         loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1752                           absl::StrFormat("dim.%d", innermost_dimension));
1753     array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1754 
1755     SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1756 
1757     ShardedVectorType vector_type = CreateShardedVectorType(
1758         reduce->shape().element_type(), vectorization_factor);
1759     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1760                                         b_.getInt64Ty());
1761     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1762                         EmitInnerLoopForVectorizedReduction(
1763                             reduction_generator, array_index, vector_type,
1764                             init_value, arg, dimensions, element_alignment));
1765 
1766     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1767     llvm::Value* output_address =
1768         target_array.EmitArrayElementAddress(array_index, &b_);
1769     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1770                            target_array);
1771 
1772     if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1773       CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1774       b_.SetInsertPoint(exit_terminator);
1775     } else {
1776       CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1777       b_.SetInsertPoint(loop->GetExitBasicBlock());
1778     }
1779   }
1780 
1781   // Since we increment the stride for the inner dimension by more than 1, we
1782   // may need to peel out an "epilogue" iteration to get the remaining elements
1783   // in the following case:
1784   if (innermost_dimension_size % vectorization_factor) {
1785     // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1786     array_multi_index[innermost_dimension] =
1787         b_.getInt64(innermost_dimension_size -
1788                     (innermost_dimension_size % vectorization_factor));
1789 
1790     ShardedVectorType vector_type = CreateShardedVectorType(
1791         reduce->shape().element_type(),
1792         innermost_dimension_size % vectorization_factor);
1793     llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1794                                         b_.getInt64Ty());
1795     TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1796                         EmitInnerLoopForVectorizedReduction(
1797                             reduction_generator, array_index, vector_type,
1798                             init_value, arg, dimensions, element_alignment));
1799 
1800     llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1801     llvm::Value* output_address =
1802         target_array.EmitArrayElementAddress(array_index, &b_);
1803     EmitShardedVectorStore(output_address, accumulator, element_alignment,
1804                            target_array);
1805   }
1806 
1807   if (outermost_loop_exit_block) {
1808     b_.SetInsertPoint(outermost_loop_exit_block);
1809   }
1810 
1811   return true;
1812 }
1813 
EmitElementalReduce(const HloReduceInstruction * reduce,const llvm_ir::ElementGenerator & input_generator,const llvm_ir::ElementGenerator & initial_value_generator,const llvm_ir::IrArray::Index & index)1814 StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
1815     const HloReduceInstruction* reduce,
1816     const llvm_ir::ElementGenerator& input_generator,
1817     const llvm_ir::ElementGenerator& initial_value_generator,
1818     const llvm_ir::IrArray::Index& index) {
1819   const HloInstruction* arg = reduce->operand(0);
1820   absl::Span<const int64> dimensions(reduce->dimensions());
1821 
1822   // Initialize an accumulator with init_value.
1823   PrimitiveType accumulator_type = reduce->shape().element_type();
1824   llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
1825       llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
1826       &b_, MinimumAlignmentForPrimitiveType(accumulator_type));
1827   TF_ASSIGN_OR_RETURN(
1828       llvm::Value* const init_value,
1829       initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
1830   Store(init_value, accumulator_addr);
1831 
1832   // The enclosing loops go over all the target elements. Now we have to compute
1833   // the actual target element. For this, we build a new loop nest to iterate
1834   // over all the reduction dimensions in the argument.
1835   // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
1836   // are placed for each dimension in dimensions, and all the rest are nullptrs.
1837   llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
1838   std::vector<llvm::Value*> input_multi_index =
1839       loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1840                                          "reduction_dim");
1841 
1842   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1843 
1844   // Build a full index for the input argument, using reduced_dims_index as the
1845   // base. In reduced_dims_index only the reduction dimensions are filled in. We
1846   // fill in the rest of the dimensions with induction Value*s taken from
1847   // 'index' which iterates over the target array.  See the high-level
1848   // description in the XLA documentation for details.
1849   llvm_ir::IrArray::Index::const_iterator it = index.begin();
1850 
1851   for (auto& i : input_multi_index) {
1852     if (i == nullptr) {
1853       i = *it++;
1854     }
1855   }
1856   CHECK(index.end() == it);
1857   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1858                                       b_.getInt64Ty());
1859 
1860   // Apply the reduction function to the loaded value.
1861   TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
1862                       input_generator(input_index));
1863   llvm::Value* result = EmitThreadLocalCall(
1864       *reduce->to_apply(), {Load(accumulator_addr), input_element},
1865       "reduce_function");
1866   Store(result, accumulator_addr);
1867 
1868   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1869   return Load(accumulator_addr);
1870 }
1871 
HandleReduce(HloInstruction * reduce)1872 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1873   // TODO(b/118333695): Support variadic reduce.
1874   if (!reduce->shape().IsArray()) {
1875     return Unimplemented("Variadic reduce is not supported on CPU");
1876   }
1877   auto arg = reduce->mutable_operand(0);
1878   auto init_value = reduce->mutable_operand(1);
1879   absl::Span<const int64> dimensions(reduce->dimensions());
1880   HloComputation* function = reduce->to_apply();
1881   if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1882     string vectorization_failure_reason;
1883     TF_ASSIGN_OR_RETURN(
1884         bool vectorization_successful,
1885         EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1886                              &vectorization_failure_reason));
1887     if (vectorization_successful) {
1888       VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1889               << "\n";
1890       return Status::OK();
1891     } else {
1892       VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1893               << vectorization_failure_reason;
1894     }
1895   }
1896 
1897   return DefaultAction(reduce);
1898 }
1899 
HandleAllToAll(HloInstruction *)1900 Status IrEmitter::HandleAllToAll(HloInstruction*) {
1901   return Unimplemented("AllToAll is not implemented on CPU.");
1902 }
1903 
HandleSend(HloInstruction * send)1904 Status IrEmitter::HandleSend(HloInstruction* send) {
1905   // TODO(b/33942983): Support Send/Recv on CPU.
1906   return Unimplemented("Send is not implemented on CPU.");
1907 }
1908 
HandleSendDone(HloInstruction * send_done)1909 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1910   // TODO(b/33942983): Support Send/Recv on CPU.
1911   return Unimplemented("Send-done is not implemented on CPU.");
1912 }
1913 
HandleScatter(HloInstruction *)1914 Status IrEmitter::HandleScatter(HloInstruction*) {
1915   return Unimplemented("Scatter is not implemented on CPUs.");
1916 }
1917 
HandleSlice(HloInstruction * slice)1918 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1919   VLOG(2) << "HandleSlice: " << slice->ToString();
1920   auto operand = slice->operand(0);
1921   // The code below emits a sequential loop nest. For the parallel backend, use
1922   // ParallelLoopEmitter which respects dynamic loop bounds.
1923   if (ShouldEmitParallelLoopFor(*slice)) {
1924     return DefaultAction(slice);
1925   }
1926 
1927   // The code below assumes the layouts are equal.
1928   if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1929     return DefaultAction(slice);
1930   }
1931 
1932   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1933 
1934   if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1935     return Status::OK();
1936   }
1937 
1938   const Layout& layout = operand->shape().layout();
1939   const int64 num_dims = operand->shape().dimensions_size();
1940 
1941   // The slice lowering finds maximal contiguous blocks of memory that can be
1942   // copied from the source to the target. This is done by looking at the
1943   // source/target layout in minor to major order and do the following:
1944   //
1945   // * Find an initial segment of dimensions along which the slice uses the
1946   //   whole dimension. These are the "inner" dimensions and can be folded into
1947   //   the memcpy.
1948   //
1949   // * Of the remaining dimensions decide which ones require loops.
1950   //
1951   // * Implement the memcpy within the innermost loop.
1952 
1953   absl::flat_hash_set<int64> inner_dims;
1954   for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
1955     if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1956       break;
1957     }
1958     inner_dims.insert(dim);
1959   }
1960 
1961   const bool is_trivial_copy = (inner_dims.size() == num_dims);
1962   if (is_trivial_copy) {
1963     if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1964       return DefaultAction(slice);
1965     } else {
1966       return EmitMemcpy(*slice, *operand);
1967     }
1968   }
1969 
1970   // The memcpy will copy elements that are logically this shape (allowed to be
1971   // scalar).
1972   const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1973       [&inner_dims](int64 dim) { return inner_dims.contains(dim); },
1974       operand->shape());
1975 
1976   const int64 primitive_elements_per_logical_element =
1977       ShapeUtil::ElementsIn(logical_element_shape);
1978 
1979   // memcpy_dim is the innermost (in terms of layout) dimension for which the
1980   // slice does *not* just copy all the elements along the dimension.
1981   const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1982 
1983   const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1984   // The number of logical elements that can be copied in a single call
1985   // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1986   // stride.
1987   const int64 memcpy_logical_elements =
1988       memcpy_is_contiguous
1989           ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1990           : 1;
1991 
1992   // Determine the dimensions that get lowered as loops.
1993   std::vector<int64> outer_dims;
1994   for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1995     outer_dims.push_back(LayoutUtil::Major(layout, i));
1996   }
1997 
1998   // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1999   // needs to be wrapped around a loop as well.
2000   if (!memcpy_is_contiguous) {
2001     outer_dims.push_back(memcpy_dim);
2002   }
2003 
2004   llvm_ir::IrArray target_array = GetIrArrayFor(slice);
2005 
2006   const int64 num_outer_loops = outer_dims.size();
2007   llvm_ir::ForLoopNest loops(IrName(slice), &b_);
2008   std::vector<llvm::Value*> target_multi_index =
2009       loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
2010 
2011   // Only the indices for the outer dimensions have been initialized in
2012   // target_index. The rest of the indices should get initialized to 0, since
2013   // for the rest of the dimensions the copy writes to the full dimension.
2014   std::replace(target_multi_index.begin(), target_multi_index.end(),
2015                static_cast<llvm::Value*>(nullptr),
2016                static_cast<llvm::Value*>(b_.getInt64(0)));
2017   llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
2018                                        b_.getInt64Ty());
2019 
2020   if (num_outer_loops > 0) {
2021     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2022   }
2023 
2024   llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2025   const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
2026       /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
2027       /*strides=*/slice->slice_strides(), /*builder=*/&b_);
2028 
2029   llvm::Value* memcpy_dest =
2030       target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
2031   llvm::Value* memcpy_source =
2032       source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
2033 
2034   const int64 memcpy_elements =
2035       primitive_elements_per_logical_element * memcpy_logical_elements;
2036 
2037   EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
2038                        slice->shape().element_type(), target_array,
2039                        source_array);
2040 
2041   if (VLOG_IS_ON(2)) {
2042     const int64 memcpy_bytes =
2043         ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
2044     VLOG(2) << "  emitted copy of " << memcpy_bytes << " bytes inside "
2045             << num_outer_loops << " loops";
2046   }
2047 
2048   if (num_outer_loops > 0) {
2049     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2050   }
2051 
2052   return Status::OK();
2053 }
2054 
HandleDynamicSlice(HloInstruction * dynamic_slice)2055 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
2056   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2057     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
2058     return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
2059   }
2060   return DefaultAction(dynamic_slice);
2061 }
2062 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2063 Status IrEmitter::HandleDynamicUpdateSlice(
2064     HloInstruction* dynamic_update_slice) {
2065   auto update = dynamic_update_slice->operand(1);
2066   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
2067     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2068     return EmitMemcpy(*update, *dynamic_update_slice);
2069   } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
2070                                                    assignment_)) {
2071     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
2072     auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
2073     return llvm_ir::EmitDynamicUpdateSliceInPlace(
2074         operands, GetIrArrayFor(dynamic_update_slice),
2075         IrName(dynamic_update_slice, "in_place"), &b_);
2076   }
2077   return DefaultAction(dynamic_update_slice);
2078 }
2079 
HandleRecv(HloInstruction * recv)2080 Status IrEmitter::HandleRecv(HloInstruction* recv) {
2081   // TODO(b/33942983): Support Send/Recv on CPU.
2082   return Unimplemented("Recv is not implemented on CPU.");
2083 }
2084 
HandleRecvDone(HloInstruction * recv_done)2085 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2086   // TODO(b/33942983): Support Send/Recv on CPU.
2087   return Unimplemented("Recv-done is not implemented on CPU.");
2088 }
2089 
HandlePad(HloInstruction * pad)2090 Status IrEmitter::HandlePad(HloInstruction* pad) {
2091   // CPU backend does not properly handle negative padding but this is ok
2092   // because negative padding should be removed by the algebraic simplifier.
2093   for (auto& padding_dimension : pad->padding_config().dimensions()) {
2094     if (padding_dimension.edge_padding_low() < 0 ||
2095         padding_dimension.edge_padding_high() < 0) {
2096       return InternalErrorStrCat(
2097           "Encountered negative padding in IrEmitter on CPU. "
2098           "This should have been eliminated at the HLO level. ",
2099           pad->ToString());
2100     }
2101   }
2102 
2103   // First, fill in the padding value to all output elements.
2104   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2105       pad, "initialize",
2106       [this, pad](const llvm_ir::IrArray::Index& target_index) {
2107         const HloInstruction* padding_value = pad->operand(1);
2108         llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2109         return Load(padding_value_addr);
2110       }));
2111 
2112   // Create a loop to iterate over the operand elements and update the output
2113   // locations where the operand elements should be stored.
2114   llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2115   const HloInstruction* operand = pad->operand(0);
2116   const llvm_ir::IrArray::Index operand_index =
2117       loops.AddLoopsForShape(operand->shape(), "operand");
2118 
2119   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2120 
2121   // Load an element from the operand.
2122   llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2123   llvm::Value* operand_data =
2124       operand_array.EmitReadArrayElement(operand_index, &b_);
2125 
2126   // Compute the output index the operand element should be assigned to.
2127   // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2128   const PaddingConfig& padding_config = pad->padding_config();
2129   std::vector<llvm::Value*> output_multi_index;
2130   for (size_t i = 0; i < operand_index.size(); ++i) {
2131     llvm::Value* offset =
2132         Mul(operand_index[i],
2133             b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2134     llvm::Value* index = Add(
2135         offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2136     output_multi_index.push_back(index);
2137   }
2138 
2139   // Store the operand element to the computed output location.
2140   llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2141   llvm_ir::IrArray::Index output_index(
2142       output_multi_index, output_array.GetShape(), operand_index.GetType());
2143   output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2144 
2145   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2146   return Status::OK();
2147 }
2148 
HandleFusion(HloInstruction * fusion)2149 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2150   auto* root = fusion->fused_expression_root();
2151   if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2152     VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2153     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2154     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2155     // Delegate to common implementation of fused in-place dynamic-update-slice.
2156     return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2157         fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion),
2158         &elemental_emitter, &b_);
2159   } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
2160     VLOG(3) << "HandleFusion kLoop";
2161     CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2162     auto operands = GetIrArraysForOperandsOf(fusion);
2163     FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
2164                                  &elemental_emitter);
2165     TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
2166 
2167     return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
2168   } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) {
2169     VLOG(3) << "HandleFusion kOutput";
2170     int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2171     const HloInstruction* dot = root->operand(dot_op_index);
2172     CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2173         << dot->ToString() << "  "
2174         << fusion->fused_instructions_computation()->ToString();
2175 
2176     int64 dot_lhs_param_number = dot->operand(0)->parameter_number();
2177     int64 dot_rhs_param_number = dot->operand(1)->parameter_number();
2178     int64 addend_param_number =
2179         root->operand(1 - dot_op_index)->parameter_number();
2180 
2181     Shape target_shape = fusion->shape();
2182     TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2183     llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2184 
2185     llvm_ir::IrArray lhs_array(
2186         GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2187     llvm_ir::IrArray rhs_array(
2188         GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2189     llvm_ir::IrArray addend_array(
2190         GetIrArrayFor(fusion->operand(addend_param_number)));
2191 
2192     TF_RETURN_IF_ERROR(
2193         EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
2194                          &addend_array, GetExecutableRunOptionsArgument(), &b_,
2195                          hlo_module_config_, target_machine_features_));
2196     return Status::OK();
2197   } else {
2198     return Unimplemented("Fusion kind not implemented on CPU");
2199   }
2200 }
2201 
HandleCall(HloInstruction * call)2202 Status IrEmitter::HandleCall(HloInstruction* call) {
2203   HloComputation* computation = call->to_apply();
2204   llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2205 
2206   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2207 
2208   if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2209     // ParallelTaskAssignment assigned partitions, emit call to
2210     // ParallelForkJoin.
2211     std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2212         {}, &b_, computation->name(),
2213         /*return_value_buffer=*/emitted_value_[call],
2214         /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2215         /*buffer_table_arg=*/GetBufferTableArgument(),
2216         /*profile_counters_arg=*/GetProfileCountersArgument());
2217 
2218     HloInstruction* root = computation->root_instruction();
2219     TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2220         call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2221         call_ir_function, computation->name()));
2222   } else {
2223     EmitGlobalCall(*computation, computation->name());
2224   }
2225 
2226   return Status::OK();
2227 }
2228 
HandleCustomCall(HloInstruction * custom_call)2229 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2230   absl::Span<HloInstruction* const> operands(custom_call->operands());
2231   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2232   llvm::AllocaInst* operands_alloca =
2233       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2234           i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2235   for (size_t i = 0; i < operands.size(); ++i) {
2236     const HloInstruction* operand = operands[i];
2237     llvm::Value* operand_as_i8ptr =
2238         PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2239     llvm::Value* slot_in_operands_alloca =
2240         InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2241     Store(operand_as_i8ptr, slot_in_operands_alloca);
2242   }
2243   if (emit_code_for_msan_) {
2244     // Mark the alloca as initialized for msan. The buffer gets read by the
2245     // custom callee, which might be msan-instrumented.
2246     // TODO(b/66051036): Run the msan instrumentation pass instead.
2247     const llvm::DataLayout& dl = module_->getDataLayout();
2248     llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2249     auto* msan_unpoison_ir_function = llvm::cast<llvm::Function>(
2250         module_
2251             ->getOrInsertFunction(
2252                 "__msan_unpoison",
2253                 llvm::FunctionType::get(
2254                     /*Result=*/b_.getVoidTy(),
2255                     /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false))
2256             .getCallee());
2257     Call(msan_unpoison_ir_function,
2258          {PointerCast(operands_alloca, i8_ptr_type),
2259           llvm::ConstantInt::get(
2260               intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)});
2261   }
2262   auto* custom_call_ir_function = llvm::dyn_cast<llvm::Function>(
2263       module_
2264           ->getOrInsertFunction(
2265               custom_call->custom_call_target(),
2266               llvm::FunctionType::get(
2267                   /*Result=*/b_.getVoidTy(),
2268                   /*Params=*/{i8_ptr_type, operands_alloca->getType()},
2269                   /*isVarArg=*/false))
2270           .getCallee());
2271 
2272   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2273   // Write the tuple table if the output is a tuple.
2274   if (custom_call->shape().IsTuple()) {
2275     std::vector<llvm::Value*> base_ptrs;
2276     for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2277          ++i) {
2278       const Shape& elem_shape =
2279           ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2280       TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2281       TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2282                           assignment_.GetUniqueSlice(custom_call, {i}));
2283       llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2284       base_ptrs.push_back(addr);
2285     }
2286     llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2287   }
2288   auto* output_address_arg =
2289       PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2290 
2291   Call(custom_call_ir_function, {output_address_arg, operands_alloca});
2292 
2293   return Status::OK();
2294 }
2295 
HandleWhile(HloInstruction * xla_while)2296 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2297   // Precondition: Condition computation must return a scalar bool.
2298   HloComputation* condition = xla_while->while_condition();
2299   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2300                condition->root_instruction()->shape().element_type() == PRED)
2301       << "While condition computation must return bool; got: "
2302       << ShapeUtil::HumanString(condition->root_instruction()->shape());
2303   // Check that all while-related buffers share an allocation slice.
2304   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2305       xla_while->shape(),
2306       [this, &xla_while](const Shape& /*subshape*/,
2307                          const ShapeIndex& index) -> Status {
2308         auto check = [this](const HloInstruction* a, const HloInstruction* b,
2309                             const ShapeIndex& index) {
2310           const BufferAllocation::Slice slice_a =
2311               assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2312           const BufferAllocation::Slice slice_b =
2313               assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2314           if (slice_a != slice_b) {
2315             return InternalError(
2316                 "instruction %s %s does not share slice with "
2317                 "instruction %s %s",
2318                 a->ToString(), slice_a.ToString(), b->ToString(),
2319                 slice_b.ToString());
2320           }
2321           return Status::OK();
2322         };
2323         TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2324         TF_RETURN_IF_ERROR(check(
2325             xla_while, xla_while->while_condition()->parameter_instruction(0),
2326             index));
2327         TF_RETURN_IF_ERROR(
2328             check(xla_while, xla_while->while_body()->parameter_instruction(0),
2329                   index));
2330         TF_RETURN_IF_ERROR(check(
2331             xla_while, xla_while->while_body()->root_instruction(), index));
2332         return Status::OK();
2333       }));
2334 
2335   // Set emitted value to that of 'init' with which it shares an allocation.
2336   const HloInstruction* init = xla_while->operand(0);
2337   emitted_value_[xla_while] = GetEmittedValueFor(init);
2338 
2339   // Generating:
2340   //   while (Condition(while_result)) {
2341   //     // CopyInsertion pass inserts copies which enable 'while_result' to
2342   //     // be passed back in as 'Body' parameter.
2343   //     while_result = Body(while_result);  // Insert
2344   //   }
2345 
2346   // Terminates the current block with a branch to a while header.
2347   llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2348       module_->getContext(), IrName(xla_while, "header"),
2349       compute_function_->function());
2350   Br(header_bb);
2351   b_.SetInsertPoint(header_bb);
2352 
2353   // Calls the condition function to determine whether to proceed with the
2354   // body.  It must return a bool, so use the scalar call form.
2355   EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2356   llvm::Value* while_predicate = ICmpNE(
2357       Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2358       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2359 
2360   // Branches to the body or to the while exit depending on the condition.
2361   llvm::BasicBlock* body_bb =
2362       llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2363                                compute_function_->function());
2364   llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2365       module_->getContext(), IrName(xla_while, "exit"));
2366   CondBr(while_predicate, body_bb, exit_bb);
2367 
2368   // Calls the body function from the body block.
2369   b_.SetInsertPoint(body_bb);
2370 
2371   // Calls the body function.
2372   EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2373 
2374   // Finishes with a branch back to the header.
2375   Br(header_bb);
2376 
2377   // Adds the exit block to the function and sets the insert point there.
2378   compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2379   b_.SetInsertPoint(exit_bb);
2380 
2381   return Status::OK();
2382 }
2383 
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2384 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2385     HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2386     string* failure_reason) {
2387   if (ShouldEmitParallelLoopFor(*concatenate)) {
2388     *failure_reason =
2389         "cannot generate memcpy-based concat for the parallel CPU backend";
2390     return false;
2391   }
2392 
2393   const Shape& output_shape = concatenate->shape();
2394   for (auto* op : operands) {
2395     if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2396       *failure_reason = "operand has mismatching layouts";
2397       return false;
2398     }
2399   }
2400 
2401   // We split the dimensions into three categories: the dimension over which we
2402   // are concatenating (concat_dim), the dimensions that are minor to it
2403   // (inner_dims) and the dimensions that are major to it (outer_dims).
2404 
2405   int64 concat_dim = concatenate->dimensions(0);
2406   const Layout& output_layout = output_shape.layout();
2407   auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2408   auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2409 
2410   std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2411   std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2412                                 output_min2maj.end());
2413 
2414   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2415 
2416   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2417   llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2418 
2419   llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2420   std::vector<llvm::Value*> target_multi_index =
2421       loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2422   std::replace(target_multi_index.begin(), target_multi_index.end(),
2423                static_cast<llvm::Value*>(nullptr),
2424                static_cast<llvm::Value*>(b_.getInt64(0)));
2425   llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2426                                        b_.getInt64Ty());
2427 
2428   if (!outer_dims.empty()) {
2429     SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2430   }
2431 
2432   PrimitiveType primitive_type = output_shape.element_type();
2433   unsigned primitive_type_size =
2434       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2435 
2436   // Contiguous subregions from each operand to the concatenate contribute to a
2437   // contiguous subregion in the target buffer starting at target_region_begin.
2438   llvm::Value* target_region_begin = BitCast(
2439       target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2440       i8_ptr_type);
2441   int64 byte_offset_into_target_region = 0;
2442 
2443   int64 inner_dims_product =
2444       std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2445                       [&](int64 product, int64 inner_dim) {
2446                         return product * output_shape.dimensions(inner_dim);
2447                       });
2448 
2449   // For each operand, emit a memcpy from the operand to the target of size
2450   // equal to the product of inner dimensions.
2451   for (HloInstruction* operand : operands) {
2452     const Shape& input_shape = operand->shape();
2453     llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2454     llvm::Value* copy_source_address = BitCast(
2455         source_array.EmitArrayElementAddress(target_index, &b_, "src_addr"),
2456         i8_ptr_type);
2457 
2458     llvm::Value* copy_target_address =
2459         GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2460 
2461     EmitTransferElements(
2462         copy_target_address, copy_source_address,
2463         inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2464         target_array, source_array);
2465 
2466     byte_offset_into_target_region += inner_dims_product *
2467                                       input_shape.dimensions(concat_dim) *
2468                                       primitive_type_size;
2469   }
2470 
2471   if (!outer_dims.empty()) {
2472     SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2473   }
2474 
2475   return true;
2476 }
2477 
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64 element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2478 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2479                                      int64 element_count,
2480                                      PrimitiveType primitive_type,
2481                                      const llvm_ir::IrArray& target_array,
2482                                      const llvm_ir::IrArray& source_array) {
2483   unsigned primitive_type_size =
2484       ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2485   unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
2486       primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
2487   llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2488       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2489 
2490   if (element_count == 1) {
2491     auto* load_instruction =
2492         AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2493     source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2494     auto* store_instruction =
2495         AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2496                      element_alignment);
2497     target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2498   } else {
2499     auto* memcpy_instruction = MemCpy(
2500         target, /*DstAlign=*/element_alignment, source,
2501         /*SrcAlign=*/element_alignment, element_count * primitive_type_size);
2502 
2503     // The memcpy does the load and the store internally.  The aliasing related
2504     // metadata has to reflect that.
2505     std::map<int, llvm::MDNode*> merged_metadata =
2506         llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2507                                target_array.metadata());
2508     for (const auto& kind_md_pair : merged_metadata) {
2509       memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2510     }
2511   }
2512 }
2513 
HandleConcatenate(HloInstruction * concatenate)2514 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2515   absl::Span<HloInstruction* const> operands(concatenate->operands());
2516   string failure_reason;
2517   TF_ASSIGN_OR_RETURN(
2518       bool successful,
2519       EmitFastConcatenate(concatenate, operands, &failure_reason));
2520   if (successful) {
2521     VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2522     return Status::OK();
2523   }
2524 
2525   VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2526           << ": " << failure_reason;
2527 
2528   return DefaultAction(concatenate);
2529 }
2530 
HandleConditional(HloInstruction * conditional)2531 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2532   auto branch_index = conditional->operand(0);
2533   int num_branches = conditional->branch_count();
2534   TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2535                (branch_index->shape().element_type() == PRED ||
2536                 branch_index->shape().element_type() == S32))
2537       << "Branch index on a conditional must be scalar bool or int32; got: "
2538       << ShapeUtil::HumanString(branch_index->shape());
2539 
2540   for (int b = 0; b < num_branches; ++b) {
2541     HloComputation* br_computation = conditional->branch_computation(b);
2542     TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2543                                   br_computation->root_instruction()->shape()))
2544         << "Shape of conditional should be same as the shape of the " << b
2545         << "th branch computation; got: "
2546         << ShapeUtil::HumanString(conditional->shape()) << " and "
2547         << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2548   }
2549 
2550   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2551 
2552   if (branch_index->shape().element_type() == PRED) {
2553     // Emit an if-else to LLVM:
2554     //   if (pred)
2555     //     cond_result = true_computation(true_operand)
2556     //   else
2557     //     cond_result = false_computation(false_operand)
2558     llvm::LoadInst* pred_value = Load(
2559         GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2560     llvm::Value* pred_cond =
2561         ICmpNE(pred_value,
2562                llvm::ConstantInt::get(
2563                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2564                "boolean_predicate");
2565     llvm_ir::LlvmIfData if_data =
2566         llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2567 
2568     SetToFirstInsertPoint(if_data.true_block, &b_);
2569     EmitGlobalCall(*conditional->branch_computation(0),
2570                    IrName(conditional, "_true"));
2571 
2572     SetToFirstInsertPoint(if_data.false_block, &b_);
2573     EmitGlobalCall(*conditional->branch_computation(1),
2574                    IrName(conditional, "_false"));
2575 
2576     SetToFirstInsertPoint(if_data.after_block, &b_);
2577     return Status::OK();
2578   }
2579   // We emit a switch statement to LLVM:
2580   // switch (branch_index) {
2581   //   default:
2582   //     result = branch_computations[num_branches-1](operands[num_branches-1]);
2583   //     break;
2584   //   case 0:
2585   //     result = branch_computations[0](operands[0]); break;
2586   //   case 1:
2587   //     result = branch_computations[1](operands[1]); break;
2588   //   ...
2589   //   case [[num_branches-2]]:
2590   //     result = branch_computations[num_branches-2](operands[num_branches-2]);
2591   //     break;
2592   // }
2593   llvm::LoadInst* branch_index_value = Load(
2594       GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2595 
2596   auto case_block = b_.GetInsertBlock();
2597   llvm::BasicBlock* after_block;
2598   // Add a terminator to the case block, if necessary.
2599   if (case_block->getTerminator() == nullptr) {
2600     after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2601     b_.SetInsertPoint(case_block);
2602     b_.CreateBr(after_block);
2603   } else {
2604     after_block =
2605         case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2606   }
2607   // Our basic block should now end with an unconditional branch.  Remove it;
2608   // we're going to replace it with a switch based branch.
2609   case_block->getTerminator()->eraseFromParent();
2610 
2611   // Lower the default branch computation.
2612   auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2613   b_.SetInsertPoint(default_block);
2614   EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2615                  IrName(conditional, "_default"));
2616   b_.CreateBr(after_block);
2617 
2618   // Prepare the switch (branch_index) { ... } instruction.
2619   b_.SetInsertPoint(case_block);
2620   llvm::SwitchInst* case_inst =
2621       b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2622   // Lower each branch's computation.
2623   for (int b = 0; b < num_branches - 1; ++b) {  // last branch is default
2624     // Lower the case b: { ... ; break; } computation.
2625     auto branch_block =
2626         llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2627     b_.SetInsertPoint(branch_block);
2628     EmitGlobalCall(*conditional->branch_computation(b),
2629                    IrName(conditional, absl::StrCat("_branch", b)));
2630     b_.CreateBr(after_block);
2631     case_inst->addCase(b_.getInt32(b), branch_block);
2632   }
2633 
2634   SetToFirstInsertPoint(after_block, &b_);
2635   return Status::OK();
2636 }
2637 
HandleAfterAll(HloInstruction * after_all)2638 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2639   TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2640   // No code to generate, but we need to emit an address for book-keeping.
2641   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2642   return Status::OK();
2643 }
2644 
HandleAddDependency(HloInstruction * add_dependency)2645 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2646   // AddDedendency just forwards its zero-th operand.
2647   emitted_value_[add_dependency] =
2648       GetEmittedValueFor(add_dependency->operand(0));
2649   return Status::OK();
2650 }
2651 
HandleRng(HloInstruction * rng)2652 Status IrEmitter::HandleRng(HloInstruction* rng) {
2653   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
2654   for (const HloInstruction* operand : rng->operands()) {
2655     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
2656       return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
2657     };
2658   }
2659 
2660   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2661   TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2662       rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator)));
2663 
2664   llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_);
2665 
2666   return Status::OK();
2667 }
2668 
FinishVisit(HloInstruction * root)2669 Status IrEmitter::FinishVisit(HloInstruction* root) {
2670   // When this method is called, we should have already emitted an IR value for
2671   // the root (return) op. The IR value holds the address of the buffer holding
2672   // the value. If the root is a constant or parameter, we perform a memcpy from
2673   // this buffer to the retval buffer of the computation. Otherwise, there's
2674   // nothing to do since the result was already written directly into the output
2675   // buffer.
2676   VLOG(2) << "FinishVisit root: " << root->ToString();
2677   if (root->opcode() == HloOpcode::kOutfeed) {
2678     VLOG(2) << "  outfeed with value: "
2679             << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2680   } else {
2681     VLOG(2) << "  value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2682   }
2683 
2684   auto record_complete_computation = [&](llvm::Value* prof_counter) {
2685     if (prof_counter) {
2686       profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2687     }
2688   };
2689 
2690   // For the entry computation this increment is cumulative of embedded
2691   // computations since it includes cycles spent in computations invoked by
2692   // While, Call etc.
2693   record_complete_computation(GetProfileCounterFor(*root->parent()));
2694   return Status::OK();
2695 }
2696 
2697 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2698 llvm::Value* IrEmitter::GetProfileCounterCommon(
2699     const T& hlo,
2700     const std::unordered_map<const T*, int64>& profile_index_map) {
2701   auto it = profile_index_map.find(&hlo);
2702   if (it == profile_index_map.end()) {
2703     return nullptr;
2704   }
2705 
2706   int64 prof_counter_idx = it->second;
2707   string counter_name = IrName("prof_counter", hlo.name());
2708   return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2709              counter_name);
2710 }
2711 
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2712 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2713                                                      llvm::Value* prof_counter,
2714                                                      llvm::Value* cycle_end,
2715                                                      llvm::Value* cycle_start) {
2716   auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2717   llvm::LoadInst* old_cycle_count =
2718       b->CreateLoad(prof_counter, "old_cycle_count");
2719   auto* new_cycle_count =
2720       b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2721   b->CreateStore(new_cycle_count, prof_counter);
2722 }
2723 
ReadCycleCounter(llvm::IRBuilder<> * b)2724 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2725   llvm::Module* module = b->GetInsertBlock()->getModule();
2726   if (use_rdtscp_) {
2727     llvm::Function* func_llvm_readcyclecounter =
2728         llvm::Intrinsic::getDeclaration(module,
2729                                         llvm::Intrinsic::readcyclecounter);
2730     return b->CreateCall(func_llvm_readcyclecounter);
2731   }
2732   llvm::Function* func_llvm_x86_rdtscp =
2733       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2734   if (!aux_i8ptr_) {
2735     llvm::AllocaInst* rdtscp_aux =
2736         llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b);
2737     aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy());
2738   }
2739   llvm::ConstantInt* alloca_size = b->getInt64(4);
2740   llvm::Function* func_llvm_lifetime_start =
2741       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start);
2742   b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
2743   llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
2744   llvm::Function* func_llvm_lifetime_end =
2745       llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end);
2746   b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
2747   return rdtscp_call;
2748 }
2749 
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2750 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2751                                                  HloInstruction* hlo) {
2752   auto* cycle_start = ReadCycleCounter(b);
2753   cycle_start->setName(IrName(hlo, "cycle_start"));
2754   cycle_starts_[hlo] = cycle_start;
2755   if (first_read_cycle_start_ == nullptr) {
2756     first_read_cycle_start_ = cycle_start;
2757   }
2758 }
2759 
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2760 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2761                                                  HloInstruction* hlo,
2762                                                  llvm::Value* prof_counter) {
2763   auto* cycle_end = ReadCycleCounter(b);
2764   cycle_end->setName(IrName(hlo, "cycle_end"));
2765   auto* cycle_start = cycle_starts_[hlo];
2766   UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2767   last_read_cycle_end_ = cycle_end;
2768 }
2769 
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2770 void IrEmitter::ProfilingState::RecordCompleteComputation(
2771     llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2772   if (last_read_cycle_end_ && first_read_cycle_start_) {
2773     UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2774                          first_read_cycle_start_);
2775   }
2776 }
2777 
Preprocess(HloInstruction * hlo)2778 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2779   VLOG(3) << "Visiting: " << hlo->ToString();
2780   if (instruction_to_profile_idx_.count(hlo)) {
2781     profiling_state_.RecordCycleStart(&b_, hlo);
2782   }
2783   return Status::OK();
2784 }
2785 
Postprocess(HloInstruction * hlo)2786 Status IrEmitter::Postprocess(HloInstruction* hlo) {
2787   if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
2788     profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
2789   }
2790   return Status::OK();
2791 }
2792 
GetIrArrayFor(const HloInstruction * hlo)2793 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
2794   llvm::Value* value_for_op = GetEmittedValueFor(hlo);
2795 
2796   llvm_ir::IrArray array(value_for_op, hlo->shape());
2797   AddAliasingInformationToIrArray(*hlo, &array);
2798   return array;
2799 }
2800 
GetIrArraysForOperandsOf(const HloInstruction * hlo)2801 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
2802     const HloInstruction* hlo) {
2803   std::vector<llvm_ir::IrArray> arrays;
2804   std::transform(
2805       hlo->operands().begin(), hlo->operands().end(),
2806       std::back_inserter(arrays),
2807       [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
2808   return arrays;
2809 }
2810 
GetEmittedValueFor(const HloInstruction * hlo)2811 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
2812   auto it = emitted_value_.find(hlo);
2813   if (it == emitted_value_.end()) {
2814     LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
2815   }
2816   return it->second;
2817 }
2818 
IrShapeType(const Shape & shape)2819 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
2820   return llvm_ir::ShapeToIrType(shape, module_);
2821 }
2822 
GetProfileCountersArgument()2823 llvm::Value* IrEmitter::GetProfileCountersArgument() {
2824   return compute_function_->profile_counters_arg();
2825 }
2826 
GetBufferTableArgument()2827 llvm::Value* IrEmitter::GetBufferTableArgument() {
2828   return compute_function_->buffer_table_arg();
2829 }
2830 
GetExecutableRunOptionsArgument()2831 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
2832   return compute_function_->exec_run_options_arg();
2833 }
2834 
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2835 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
2836     const BufferAllocation::Slice& slice, const Shape& target_shape) {
2837   const BufferAllocation& allocation = *slice.allocation();
2838   llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
2839     if (slice == computation_root_allocation_) {
2840       llvm::Argument* retval = compute_function_->result_arg();
2841       llvm::AttrBuilder attr_builder;
2842       attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
2843       attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
2844       retval->addAttrs(attr_builder);
2845       return retval;
2846     }
2847 
2848     auto param_it =
2849         computation_parameter_allocations_.find(slice.allocation()->index());
2850     if (param_it != computation_parameter_allocations_.end()) {
2851       int64 param_number = param_it->second;
2852       // We have to access the parameter at offset param_number in the params
2853       // array. The code generated here is equivalent to this C code:
2854       //
2855       //   i8* param_address_untyped = params[param_number];
2856       //   Param* param_address_typed = (Param*)param_address_untyped;
2857       //
2858       // Where Param is the actual element type of the underlying buffer (for
2859       // example, float for an XLA F32 element type).
2860       llvm::Value* params = compute_function_->parameters_arg();
2861       llvm::Value* param_address_offset =
2862           llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
2863       llvm::LoadInst* param_address_untyped = Load(param_address_offset);
2864 
2865       if (!target_shape.IsOpaque()) {
2866         AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
2867         AttachDereferenceableMetadataForLoad(param_address_untyped,
2868                                              target_shape);
2869       }
2870       return param_address_untyped;
2871     }
2872 
2873     // Thread-local allocations should only be assigned a single buffer.
2874     const auto& assigned_buffers = allocation.assigned_buffers();
2875     CHECK_EQ(1, assigned_buffers.size());
2876     const Shape& shape = assigned_buffers.begin()->first->shape();
2877 
2878     std::pair<llvm::Function*, BufferAllocation::Slice> key = {
2879         compute_function_->function(), slice};
2880     auto buf_it = thread_local_buffers_.find(key);
2881     if (buf_it == thread_local_buffers_.end()) {
2882       llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2883           IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
2884           &b_, MinimumAlignmentForShape(target_shape));
2885       auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
2886       CHECK(it_inserted_pair.second);
2887       buf_it = it_inserted_pair.first;
2888     }
2889     return buf_it->second;
2890   }();
2891   return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
2892 }
2893 
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2894 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
2895     const BufferAllocation::Slice& slice, const Shape& target_shape) {
2896   const BufferAllocation& allocation = *slice.allocation();
2897   llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
2898       GetBufferTableArgument(), slice.index(), &b_);
2899   llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
2900   if (hlo_module_config_.debug_options()
2901           .xla_llvm_enable_invariant_load_metadata()) {
2902     tempbuf_address_base->setMetadata(
2903         llvm::LLVMContext::MD_invariant_load,
2904         llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
2905   }
2906   AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
2907   AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
2908 
2909   llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
2910   if (slice.offset() > 0) {
2911     // Adjust the address to account for the slice offset.
2912     tempbuf_address_untyped =
2913         InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
2914   }
2915   return BitCast(tempbuf_address_untyped,
2916                  IrShapeType(target_shape)->getPointerTo());
2917 }
2918 
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)2919 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
2920                                           const Shape& target_shape) {
2921   if (slice.allocation()->is_thread_local()) {
2922     return EmitThreadLocalBufferPointer(slice, target_shape);
2923   } else if (slice.allocation()->is_constant()) {
2924     return BitCast(
2925         FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
2926         IrShapeType(target_shape)->getPointerTo());
2927   } else {
2928     return EmitGlobalBufferPointer(slice, target_shape);
2929   }
2930 }
2931 
EmitTargetAddressForOp(const HloInstruction * op)2932 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
2933   const Shape& target_shape = op->shape();
2934   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2935                       assignment_.GetUniqueTopLevelSlice(op));
2936   llvm::Value* addr = EmitBufferPointer(slice, target_shape);
2937   addr->setName(IrName(op));
2938   emitted_value_[op] = addr;
2939   return Status::OK();
2940 }
2941 
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)2942 Status IrEmitter::EmitTargetElementLoop(
2943     HloInstruction* target_op,
2944     const llvm_ir::ElementGenerator& element_generator) {
2945   return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
2946 }
2947 
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)2948 Status IrEmitter::EmitTargetElementLoop(
2949     HloInstruction* target_op, absl::string_view desc,
2950     const llvm_ir::ElementGenerator& element_generator) {
2951   VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
2952 
2953   const Shape& target_shape = target_op->shape();
2954   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
2955   llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
2956 
2957   if (target_op->IsMultiOutputFusion()) {
2958     // For multiple outputs fusion, we need to emit each operand and the root.
2959     TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
2960     std::vector<llvm_ir::IrArray> output_arrays;
2961     for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
2962       TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
2963                           assignment_.GetUniqueSlice(target_op, {i}));
2964       const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
2965       llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
2966       output_arrays.push_back(
2967           llvm_ir::IrArray(op_target_address, element_shape));
2968     }
2969     TF_RETURN_IF_ERROR(
2970         llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
2971             .EmitLoop(IrName(target_op)));
2972 
2973     std::vector<llvm::Value*> tuple_operand_ptrs;
2974     for (int64 i = 0; i < output_arrays.size(); ++i) {
2975       tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
2976     }
2977     llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
2978 
2979   } else {
2980     if (ShouldEmitParallelLoopFor(*target_op)) {
2981       // Emit code to read dynamic loop bounds from compute function argument.
2982       std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
2983           compute_function_->GetDynamicLoopBounds();
2984       // Emit parallel loop with dynamic loop bounds for most-major dimensions.
2985       TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
2986                                              &dynamic_loop_bounds, &b_)
2987                              .EmitLoop(IrName(target_op)));
2988     } else {
2989       TF_RETURN_IF_ERROR(
2990           llvm_ir::LoopEmitter(element_generator, target_array, &b_)
2991               .EmitLoop(IrName(target_op)));
2992     }
2993   }
2994   return Status::OK();
2995 }
2996 
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)2997 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
2998                              const HloInstruction& destination) {
2999   llvm::Value* source_value = GetEmittedValueFor(&source);
3000   llvm::Value* destination_value = GetEmittedValueFor(&destination);
3001   int64 source_size = ByteSizeOf(source.shape());
3002   // TODO(b/63762267): Be more aggressive about specifying alignment.
3003   MemCpy(destination_value, /*DstAlign=*/1, source_value,
3004          /*SrcAlign=*/1, source_size);
3005   return Status::OK();
3006 }
3007 
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3008 Status IrEmitter::ElementTypesSameAndSupported(
3009     const HloInstruction& instruction,
3010     absl::Span<const HloInstruction* const> operands,
3011     absl::Span<const PrimitiveType> supported_types) {
3012   for (auto operand : operands) {
3013     TF_RET_CHECK(
3014         ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3015   }
3016 
3017   TF_RET_CHECK(!operands.empty());
3018   PrimitiveType primitive_type = operands[0]->shape().element_type();
3019   if (!absl::c_linear_search(supported_types, primitive_type)) {
3020     return Unimplemented("unsupported operand type %s in op %s",
3021                          PrimitiveType_Name(primitive_type),
3022                          HloOpcodeString(instruction.opcode()));
3023   }
3024   return Status::OK();
3025 }
3026 
DefaultAction(HloInstruction * hlo)3027 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3028   ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3029   for (const HloInstruction* operand : hlo->operands()) {
3030     operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3031       return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3032     };
3033   }
3034   CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3035   return EmitTargetElementLoop(
3036       hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3037 }
3038 
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3039 llvm::Value* IrEmitter::EmitThreadLocalCall(
3040     const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3041     absl::string_view name) {
3042   CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3043 
3044   const Shape& return_shape = callee.root_instruction()->shape();
3045 
3046   // Lifting this restriction to allow "small" arrays should be easy.  Allowing
3047   // larger arrays is difficult because we allocate the buffer for this return
3048   // value on the stack.
3049   CHECK(ShapeUtil::IsScalar(return_shape));
3050 
3051   PrimitiveType return_type = return_shape.element_type();
3052 
3053   std::vector<llvm::Value*> parameter_addrs;
3054   for (llvm::Value* parameter : parameters) {
3055     CHECK(!parameter->getType()->isPointerTy());
3056     llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3057         parameter->getType(), "arg_addr", &b_);
3058     Store(parameter, parameter_addr);
3059     parameter_addrs.push_back(parameter_addr);
3060   }
3061 
3062   llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3063       llvm_ir::PrimitiveTypeToIrType(return_type, module_),
3064       absl::StrCat(name, "_retval_addr"), &b_,
3065       MinimumAlignmentForPrimitiveType(return_type));
3066 
3067   Call(FindOrDie(emitted_functions_, &callee),
3068        GetArrayFunctionCallArguments(
3069            parameter_addrs, &b_, name,
3070            /*return_value_buffer=*/return_value_buffer,
3071            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3072            /*buffer_table_arg=*/
3073            llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3074            /*profile_counters_arg=*/GetProfileCountersArgument()));
3075 
3076   return Load(return_value_buffer);
3077 }
3078 
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3079 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3080                                absl::string_view name) {
3081   CHECK(absl::c_binary_search(global_computations_, &callee));
3082 
3083   Call(FindOrDie(emitted_functions_, &callee),
3084        GetArrayFunctionCallArguments(
3085            /*parameter_addresses=*/{}, &b_, name,
3086            /*return_value_buffer=*/
3087            llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3088            /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3089            /*buffer_table_arg=*/GetBufferTableArgument(),
3090            /*profile_counters_arg=*/GetProfileCountersArgument()));
3091 }
3092 
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3093 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3094     const HloComputation& callee) {
3095   const HloInstruction* root_inst = callee.root_instruction();
3096   if (root_inst->opcode() == HloOpcode::kOutfeed) {
3097     return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3098   }
3099 
3100   const BufferAllocation::Slice root_buffer =
3101       assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3102   return EmitBufferPointer(root_buffer, root_inst->shape());
3103 }
3104 
3105 }  // namespace cpu
3106 }  // namespace xla
3107