1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h"
17
18 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
19 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
20
21 namespace xla {
22 namespace gpu {
23
HandleCustomCall(const HloInstruction * custom_call)24 Status GpuHloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
25 if (custom_call->custom_call_target() == gpu::kGemmCallTarget) {
26 // The naming conventions and meanings of gemm parameters are documented
27 // here:
28 // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm
29 TF_ASSIGN_OR_RETURN(auto gemm_config,
30 custom_call->backend_config<gpu::GemmBackendConfig>());
31
32 // Technically, in addition to the dot product (A * B), cuBLAS gemm also
33 // performs additional scaling (by factor 'alpha') and addition with a
34 // scaled third matrix (beta * C), which will introduce additional
35 // multiplications and additions. But total FLOPS will be dominated by the
36 // dot product, so we don't include these extra multiplications and
37 // additions in the FLOPS calculation.
38
39 // Also, this calculation assumes that the strides for the gemm are
40 // properly set such that none of the inputs in a batch overlap with any
41 // other batches. If they do, this will undercount the FLOPS, because it
42 // assumes that the strides are implicit in the sizes of the batch
43 // dimensions.
44
45 // Finally, this is technically incorrect if the element type of this
46 // gemm is an integer type, because in that case no floating point
47 // operations are involved at all! But we still calculate FLOPS because the
48 // number is sometimes required for ad-hoc calculations.
49 current_properties_[kFlopsKey] =
50 GetDotFlops(custom_call->operand(0)->shape(), custom_call->shape(),
51 gemm_config.dot_dimension_numbers());
52 return OkStatus();
53 }
54
55 if (IsCustomCallToDnnConvolution(*custom_call)) {
56 // As with dots, this flops calculation has the following inaccuracies.
57 //
58 // - We may have a fused conv which does additional ops (multiplying by a
59 // scalar `alpha`, adding a bias or side-input, doing a relu, etc). But
60 // we can safely ignore this because the overall computation is dominated
61 // by the convolution itself.
62 //
63 // - cudnn may use complex conv algorithms that do fewer (or more!) flops
64 // than we calculate.
65 //
66 // - for int8_t convs, these aren't *fl*ops, but we fudge it.
67 current_properties_[kFlopsKey] = GetConvolutionFlops(custom_call);
68
69 // conv custom-calls return a tuple (real_output, temp_bytes). Count just
70 // the real_output in output bytes accessed. The main purpose of
71 // hlo_cost_analysis is to figure out if ops are running "as fast as
72 // possible", and if we were to include temp memory in here, we'd
73 // essentially be *rewarding* convs that use additional temp memory!
74 if (custom_call->shape().IsTuple()) {
75 SetOutputBytesAccessed(
76 options_.shape_size(custom_call->shape().tuple_shapes(0)));
77 }
78 return OkStatus();
79 }
80
81 return HloCostAnalysis::HandleCustomCall(custom_call);
82 }
83
GetConvolutionFlops(const HloInstruction * convolution)84 int64_t GpuHloCostAnalysis::GetConvolutionFlops(
85 const HloInstruction* convolution) {
86 auto lhs = convolution->operand(0);
87 auto rhs = convolution->operand(1);
88 const Shape& lhs_shape = lhs->shape();
89 const Shape& rhs_shape = rhs->shape();
90 const Shape& result_shape = [&]() -> const Shape& {
91 // convolution custom-calls return a tuple of (actual_result, temp_buffer).
92 const Shape& shape = convolution->shape();
93 if (IsCustomCallToDnnConvolution(*convolution) &&
94 convolution->shape().IsTuple()) {
95 return shape.tuple_shapes(0);
96 }
97 return shape;
98 }();
99
100 return HloCostAnalysis::GetConvolutionFlops(convolution, lhs_shape, rhs_shape,
101 result_shape);
102 }
103
104 std::unique_ptr<HloCostAnalysis>
CreateNestedCostAnalysis()105 GpuHloCostAnalysis::CreateNestedCostAnalysis() {
106 return std::make_unique<GpuHloCostAnalysis>(options_);
107 }
108
109 } // namespace gpu
110 } // namespace xla
111