• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17 
18 #include <stddef.h>
19 
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
26 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/DerivedTypes.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Type.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
39 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/statusor.h"
48 #include "tensorflow/compiler/xla/types.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/compiler/xla/window_util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 
53 namespace xla {
54 namespace gpu {
55 
56 using absl::StrAppend;
57 using llvm_ir::IrArray;
58 using llvm_ir::IrName;
59 using llvm_ir::SetToFirstInsertPoint;
60 
61 namespace {
62 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)63 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
64   if (operand->opcode() == HloOpcode::kConstant &&
65       operand->literal().IsAllFloat(value)) {
66     return true;
67   }
68   return operand->opcode() == HloOpcode::kBroadcast &&
69          IsFPLiteralWithValue(operand->operand(0), value);
70 }
71 }  // namespace
72 
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)73 GpuElementalIrEmitter::GpuElementalIrEmitter(
74     const HloModuleConfig& hlo_module_config, llvm::Module* module,
75     llvm::IRBuilder<>* b, NestedComputer compute_nested)
76     : ElementalIrEmitter(module, b),
77       hlo_module_config_(hlo_module_config),
78       compute_nested_(std::move(compute_nested)) {}
79 
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)80 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
81     TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
82     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
83     absl::string_view name) {
84   // Device functions dont have f16 math functions, so we convert the operands
85   // to f32 before calling the function and then convert the result back to f16.
86   bool cast_result_to_fp16 = false;
87   std::vector<llvm::Value*> converted_operands(operands.begin(),
88                                                operands.end());
89   std::vector<PrimitiveType> converted_input_types(input_types.begin(),
90                                                    input_types.end());
91   switch (output_type) {
92     case F16:
93       cast_result_to_fp16 = true;
94       for (int64_t i = 0; i < operands.size(); ++i) {
95         if (input_types[i] == F16) {
96           converted_operands[i] =
97               FPCast(converted_operands[i], b()->getFloatTy());
98           converted_input_types[i] = F32;
99         }
100       }
101       output_type = F32;
102       TF_FALLTHROUGH_INTENDED;
103     case F32:
104       break;
105     case F64:
106       break;
107     default:
108       return Unimplemented("Bad type for device math call: %s",
109                            PrimitiveType_Name(output_type));
110   }
111   const string& munged_callee =
112       ObtainDeviceFunctionName(funcid, output_type, b());
113   llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
114                                      converted_input_types, output_type, name)
115                             .ValueOrDie();
116   if (cast_result_to_fp16) {
117     result = FPCast(result, b()->getHalfTy());
118   }
119   return result;
120 }
121 
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)122 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
123     const string& callee_name, absl::Span<llvm::Value* const> operands,
124     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
125   // llvm intrinsics differentiate between half/float/double functions via
126   // the suffixes ".f16", ".f32" and ".f64".
127   string munged_callee = callee_name;
128   switch (output_type) {
129     case F16:
130       StrAppend(&munged_callee, ".f16");
131       break;
132     case F32:
133       StrAppend(&munged_callee, ".f32");
134       break;
135     case F64:
136       StrAppend(&munged_callee, ".f64");
137       break;
138     default:
139       return Unimplemented("Bad type for llvm intrinsic math call: %s",
140                            PrimitiveType_Name(output_type));
141   }
142   return EmitMathCall(munged_callee, operands, input_types, output_type);
143 }
144 
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)145 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
146     const string& callee_name, absl::Span<llvm::Value* const> operands,
147     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
148     absl::string_view name) {
149   // Binary math functions transform are of type [T] -> T.
150   for (PrimitiveType input_type : input_types) {
151     if (output_type != input_type) {
152       return Unimplemented("Input type != output type: %s != %s",
153                            PrimitiveType_Name(input_type),
154                            PrimitiveType_Name(output_type));
155     }
156   }
157 
158   return EmitDeviceFunctionCall(
159       callee_name, operands, input_types, output_type,
160       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
161 }
162 
GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)163 llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast(
164     const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) {
165   Shape shape = hlo->shape();
166   Shape operand_shape = hlo->operand(0)->shape();
167 
168   // Decode the layout of the shape from the Protobugs attached to
169   // backend_config_.
170   BitcastBackendConfig bitcast_config;
171   CHECK(bitcast_config.ParseFromString(hlo->raw_backend_config_string()));
172 
173   *shape.mutable_layout() =
174       xla::Layout::CreateFromProto(bitcast_config.result_layout());
175   *operand_shape.mutable_layout() =
176       xla::Layout::CreateFromProto(bitcast_config.source_layout());
177   return index.SourceIndexOfBitcast(shape, operand_shape, b());
178 }
179 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)180 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
181     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
182   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
183   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
184   PrimitiveType output_type = op->shape().element_type();
185   HloOpcode opcode = op->opcode();
186 
187   if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
188       (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
189     return llvm_ir::EmitCallToIntrinsic(
190         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
191                                       : llvm::Intrinsic::minnum,
192         {lhs_value, rhs_value}, {lhs_value->getType()}, b());
193   }
194 
195   switch (op->opcode()) {
196     case HloOpcode::kRemainder: {
197       return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
198                                 {lhs_value, rhs_value},
199                                 {lhs_input_type, rhs_input_type}, output_type);
200     }
201     case HloOpcode::kPower: {
202       return EmitPowerOp(op, lhs_value, rhs_value);
203     }
204     default:
205       return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
206   }
207 }
208 
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)209 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
210     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
211   CHECK_EQ(op->opcode(), HloOpcode::kPower);
212   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
213   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
214   PrimitiveType output_type = op->shape().element_type();
215   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
216                             {lhs_value, rhs_value},
217                             {lhs_input_type, rhs_input_type}, output_type);
218 }
219 
EmitLog(PrimitiveType prim_type,llvm::Value * value)220 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
221                                                       llvm::Value* value) {
222   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
223                             prim_type);
224 }
225 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)226 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
227                                                         llvm::Value* value) {
228   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
229                             {prim_type}, prim_type);
230 }
231 
EmitSin(PrimitiveType prim_type,llvm::Value * value)232 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
233                                                       llvm::Value* value) {
234   return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
235                             prim_type);
236 }
237 
EmitCos(PrimitiveType prim_type,llvm::Value * value)238 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
239                                                       llvm::Value* value) {
240   return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
241                             prim_type);
242 }
243 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)244 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
245     PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
246   return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
247                             prim_type);
248 }
249 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)250 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
251                                                         llvm::Value* value) {
252   return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
253                             {prim_type}, prim_type);
254 }
255 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)256 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
257                                                       llvm::Value* lhs,
258                                                       llvm::Value* rhs,
259                                                       absl::string_view name) {
260   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
261                             {prim_type, prim_type}, prim_type, name);
262 }
263 
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)264 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
265                                                        llvm::Value* value) {
266   return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
267                             prim_type);
268 }
269 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)270 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
271                                                         llvm::Value* value) {
272   return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
273                             {prim_type}, prim_type);
274 }
275 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)276 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
277     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
278     absl::string_view name) {
279   return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
280                             {prim_type, prim_type}, prim_type, name);
281 }
282 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)283 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
284                                                        llvm::Value* value) {
285   // When F64 is being requested, assume performance is less important and use
286   // the more numerically precise tanh function.
287   if (prim_type == F64) {
288     return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
289                               {prim_type}, prim_type);
290   }
291 
292   // Emit a fast approximation of tanh instead of calling __nv_tanh.
293   // __nv_tanh is particularly bad because it contains branches, thus
294   // preventing LLVM's load-store vectorizer from working its magic across a
295   // function which contains tanh calls.
296   //
297   // This routine isn't numerically precise, but it's good enough for ML.
298 
299   // Upcast F16 to F32 if necessary.
300   llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
301   llvm::Value* input = FPCast(value, type);
302 
303   // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
304   constexpr double kMaxValue = 20.0;
305   auto max_value = llvm::ConstantFP::get(type, kMaxValue);
306   llvm::Value* abs_value =
307       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
308 
309   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
310   auto one = llvm::ConstantFP::get(type, 1.0);
311   auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
312                                                     {one, input}, {type}, b());
313   return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
314                 value->getType(), "tanh");
315 }
316 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)317 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
318     PrimitiveType prim_type, llvm::Value* value) {
319   return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
320                             {EmitExtractReal(value), EmitExtractImag(value)},
321                             {prim_type, prim_type}, prim_type);
322 }
323 
EmitThreadId()324 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
325   llvm::Value* block_id = IntCast(
326       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
327       b()->getIntNTy(128), /*isSigned=*/true, "block.id");
328   llvm::Value* thread_id_in_block = IntCast(
329       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
330       b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
331   llvm::Value* threads_per_block = IntCast(
332       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
333       b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
334   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
335 }
336 
337 }  // namespace gpu
338 }  // namespace xla
339