1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/compiler/xla/window_util.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace xla {
27 namespace gpu {
28
PadForTensorCores(HloDotInstruction * dot)29 static StatusOr<bool> PadForTensorCores(HloDotInstruction* dot) {
30 auto* lhs = dot->mutable_operand(0);
31 auto* rhs = dot->mutable_operand(1);
32
33 Shape lshape = lhs->shape();
34 Shape rshape = rhs->shape();
35 Shape result_shape = dot->shape();
36
37 if (lshape.element_type() != PrimitiveType::F16 ||
38 rshape.element_type() != PrimitiveType::F16) {
39 return false;
40 }
41
42 auto pad_dim = [](Shape& s, int64 dim) {
43 s.set_dimensions(dim, RoundUpToNearest<int64>(s.dimensions(dim), 8));
44 };
45
46 auto pad_matrix_dims = [&pad_dim](Shape s) {
47 // Since the dot instruction is canonicalized, the last two dimensions for
48 // each operand represent non-batch dimensions, and the others are the same
49 // for both operands and correspond to batch dimensions.
50 pad_dim(s, s.rank() - 2);
51 pad_dim(s, s.rank() - 1);
52 return s;
53 };
54
55 Shape new_lshape = pad_matrix_dims(lshape);
56 Shape new_rshape = pad_matrix_dims(rshape);
57 Shape new_result_shape = pad_matrix_dims(result_shape);
58
59 if (new_lshape == lshape && new_rshape == rshape) {
60 return false;
61 }
62
63 VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
64 VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
65 << new_result_shape;
66
67 auto create_padding_config = [](Shape& shape, Shape& new_shape) {
68 PaddingConfig padding_config;
69 for (int i = 0; i < shape.rank(); ++i) {
70 auto dimension = padding_config.add_dimensions();
71 dimension->set_edge_padding_high(new_shape.dimensions()[i] -
72 shape.dimensions()[i]);
73 dimension->set_edge_padding_low(0);
74 dimension->set_interior_padding(0);
75 }
76 return padding_config;
77 };
78
79 auto l_padding_config = create_padding_config(lshape, new_lshape);
80 auto r_padding_config = create_padding_config(rshape, new_rshape);
81
82 HloComputation* parent = dot->parent();
83
84 HloInstruction* zero_float = parent->AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<half>((half)0.0)));
86 zero_float->set_metadata(dot->metadata());
87
88 HloInstruction* lpad = parent->AddInstruction(
89 HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
90 lpad->set_metadata(dot->metadata());
91
92 HloInstruction* rpad = parent->AddInstruction(
93 HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
94 rpad->set_metadata(dot->metadata());
95
96 HloInstruction* new_dot = parent->AddInstruction(
97 dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
98
99 std::vector<int64> start_indices(result_shape.rank(), 0);
100 std::vector<int64> strides(result_shape.rank(), 1);
101 HloInstruction* slice = parent->AddInstruction(
102 HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
103 result_shape.dimensions(), strides));
104 slice->set_metadata(dot->metadata());
105
106 bool is_root = dot->user_count() == 0;
107
108 TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
109
110 if (is_root) {
111 parent->set_root_instruction(slice);
112 }
113
114 return true;
115 }
116
117 namespace {
118
119 // We need this check because PadForTensorCores works in the assumption that
120 // the dot instruction is canonicalized.
CheckCanonical(HloDotInstruction * dot)121 bool CheckCanonical(HloDotInstruction* dot) {
122 auto dimension_numbers = dot->dot_dimension_numbers();
123
124 if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
125 dot->operand(0)->shape().rank() ||
126 dimension_numbers.rhs_batch_dimensions_size() + 2 !=
127 dot->operand(1)->shape().rank()) {
128 LOG(ERROR) << "Dot is not canonical: Expected all dimensions but 2 to be "
129 "batch_dimensions.";
130 return false;
131 }
132
133 std::vector<int64> canonical_batch_dims(
134 dimension_numbers.lhs_batch_dimensions_size());
135 absl::c_iota(canonical_batch_dims, 0);
136 if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
137 canonical_batch_dims) ||
138 !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
139 canonical_batch_dims)) {
140 LOG(ERROR) << "Dot is not canonical: Expected batch dimensions to be all "
141 "dimensions except for the last 2 ones.";
142 return false;
143 }
144
145 return true;
146 }
147
148 } // namespace
149
GetRelevantDots(HloComputation * comp)150 static std::vector<HloDotInstruction*> GetRelevantDots(HloComputation* comp) {
151 std::vector<HloDotInstruction*> convs;
152
153 for (HloInstruction* instr : comp->instructions()) {
154 if (IsMatrixMultiplication(*instr)) {
155 HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
156 if (CheckCanonical(dot)) {
157 convs.push_back(dot);
158 }
159 }
160 }
161 return convs;
162 }
163
Run(HloModule * module)164 StatusOr<bool> CublasGemmPadForTensorCores::Run(HloModule* module) {
165 bool changed = false;
166 for (HloComputation* comp : module->MakeNonfusionComputations()) {
167 for (HloDotInstruction* dot : GetRelevantDots(comp)) {
168 TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(dot));
169 changed |= result;
170 }
171 }
172 return changed;
173 }
174
175 } // namespace gpu
176 } // namespace xla
177