1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
20
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
33 #include "tensorflow/core/lib/core/status.h"
34
35 namespace xla {
36 namespace gpu {
37
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)38 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
39 const HloComputation& nested_computation,
40 IrEmitterContext* ir_emitter_context)
41 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
42 nested_computation_(nested_computation) {}
43
Create(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)44 StatusOr<std::unique_ptr<IrEmitterNested>> IrEmitterNested::Create(
45 const HloModuleConfig& hlo_module_config,
46 const HloComputation& nested_computation,
47 IrEmitterContext* ir_emitter_context) {
48 std::unique_ptr<IrEmitterNested> emitter(new IrEmitterNested(
49 hlo_module_config, nested_computation, ir_emitter_context));
50 TF_RETURN_IF_ERROR(emitter->EmitConstants(nested_computation, false));
51 return emitter;
52 }
53
54 // Nested function serves the same purpose on GPU as a thread-local function on
55 // a CPU.
CodegenNestedComputation()56 Status IrEmitterNested::CodegenNestedComputation() {
57 std::vector<const HloInstruction*> io_hlos;
58 std::vector<llvm::Type*> argument_types;
59 std::vector<int64> argument_dereferenceable_bytes;
60 for (const HloInstruction* param :
61 nested_computation_.parameter_instructions()) {
62 io_hlos.push_back(param);
63 const Shape& param_shape = param->shape();
64 argument_types.push_back(
65 llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
66 int64 param_size =
67 llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
68 argument_dereferenceable_bytes.push_back(param_size);
69 }
70
71 const HloInstruction* root = nested_computation_.root_instruction();
72 {
73 const Shape& root_shape = root->shape();
74 argument_types.push_back(
75 llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
76 int64 root_size = llvm_ir::ByteSizeOf(
77 root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
78 argument_dereferenceable_bytes.push_back(root_size);
79 }
80
81 llvm::FunctionType* function_type =
82 llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
83 llvm::Function* function = llvm::Function::Create(
84 function_type, // The function type.
85 llvm::GlobalValue::InternalLinkage, // The linkage type.
86 ir_emitter_context_->name_uniquer()->GetUniqueName(
87 llvm_ir::SanitizeFunctionName(
88 nested_computation_.name())), // The name of the function.
89 ir_emitter_context_->llvm_module()); // The parent LLVM module.
90 for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
91 ++arg_no) {
92 int64 arg_size = argument_dereferenceable_bytes[arg_no];
93 if (arg_size > 0) {
94 function->addDereferenceableAttr(arg_no + 1, arg_size);
95 }
96 }
97
98 // TODO(b/65380986): Investigate if adding fast math flags for generated
99 // kernels makes sense.
100
101 llvm::BasicBlock* entry_bb =
102 llvm::BasicBlock::Create(function->getContext(), "entry", function);
103 // Emit a "return void" at entry_bb's end, and sets the insert point before
104 // that return instruction.
105 llvm::ReturnInst* ret_instr =
106 llvm::ReturnInst::Create(function->getContext(), entry_bb);
107 b_.SetInsertPoint(ret_instr);
108
109 std::vector<const HloInstruction*> non_io_hlos;
110 non_io_hlos.push_back(root);
111 for (const auto* hlo : nested_computation_.instructions()) {
112 if (hlo->opcode() != HloOpcode::kParameter &&
113 hlo != nested_computation_.root_instruction()) {
114 non_io_hlos.push_back(hlo);
115 }
116 }
117 bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
118
119 TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
120 b_.SetInsertPoint(ret_instr);
121
122 // Function epilogue: copy the output value back.
123 {
124 // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
125 const HloInstruction* root_instruction =
126 nested_computation_.root_instruction();
127 llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
128 const Shape& return_shape = root_instruction->shape();
129
130 // Last argument is the out parameter.
131 llvm::Argument* out_parameter = std::prev(function->arg_end(), 1);
132
133 if (ShapeUtil::IsScalar(return_shape)) {
134 llvm::Value* ret_value = Load(root_value, "load_ret_value");
135 Store(ret_value,
136 BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"));
137 } else {
138 CHECK(return_shape.IsTuple());
139 llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
140 llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
141 llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
142
143 for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
144 const Shape& element_shape = return_shape.tuple_shapes(i);
145 llvm::Value* destination =
146 llvm_ir::EmitGetTupleElement(element_shape,
147 /*index=*/i,
148 /*alignment=*/1, tuple_ptr, &b_);
149 llvm::Value* source =
150 llvm_ir::EmitGetTupleElement(element_shape,
151 /*index=*/i,
152 /*alignment=*/1, root_value, &b_);
153 Store(Load(source), destination);
154 }
155 }
156 }
157 b_.SetInsertPoint(ret_instr);
158 emitted_function_ = function;
159 return Status::OK();
160 }
161
HandleParameter(HloInstruction * parameter)162 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
163 return Status::OK();
164 }
165
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)166 Status IrEmitterNested::EmitTargetElementLoop(
167 const HloInstruction& hlo,
168 const llvm_ir::ElementGenerator& element_generator) {
169 // For MOF we give the loop emitter an array for every output it should
170 // generate.
171 if (hlo.shape().IsTuple()) {
172 std::vector<llvm_ir::IrArray> target_arrays =
173 ConstructIrArrayForOutputs(hlo);
174 TF_RETURN_IF_ERROR(
175 llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
176 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
177 return Status::OK();
178 }
179 return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
180 .EmitLoop();
181 }
182
183 } // namespace gpu
184 } // namespace xla
185