• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/compiler/xla/window_util.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace xla {
27 namespace gpu {
28 
PadForTensorCores(HloDotInstruction * dot)29 static StatusOr<bool> PadForTensorCores(HloDotInstruction* dot) {
30   auto* lhs = dot->mutable_operand(0);
31   auto* rhs = dot->mutable_operand(1);
32 
33   Shape lshape = lhs->shape();
34   Shape rshape = rhs->shape();
35   Shape result_shape = dot->shape();
36 
37   if (lshape.element_type() != PrimitiveType::F16 ||
38       rshape.element_type() != PrimitiveType::F16) {
39     return false;
40   }
41 
42   auto pad_dim = [](Shape& s, int64 dim) {
43     s.set_dimensions(dim, RoundUpToNearest<int64>(s.dimensions(dim), 8));
44   };
45 
46   auto pad_matrix_dims = [&pad_dim](Shape s) {
47     // Since the dot instruction is canonicalized, the last two dimensions for
48     // each operand represent non-batch dimensions, and the others are the same
49     // for both operands and correspond to batch dimensions.
50     pad_dim(s, s.rank() - 2);
51     pad_dim(s, s.rank() - 1);
52     return s;
53   };
54 
55   Shape new_lshape = pad_matrix_dims(lshape);
56   Shape new_rshape = pad_matrix_dims(rshape);
57   Shape new_result_shape = pad_matrix_dims(result_shape);
58 
59   if (new_lshape == lshape && new_rshape == rshape) {
60     return false;
61   }
62 
63   VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
64   VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
65           << new_result_shape;
66 
67   auto create_padding_config = [](Shape& shape, Shape& new_shape) {
68     PaddingConfig padding_config;
69     for (int i = 0; i < shape.rank(); ++i) {
70       auto dimension = padding_config.add_dimensions();
71       dimension->set_edge_padding_high(new_shape.dimensions()[i] -
72                                        shape.dimensions()[i]);
73       dimension->set_edge_padding_low(0);
74       dimension->set_interior_padding(0);
75     }
76     return padding_config;
77   };
78 
79   auto l_padding_config = create_padding_config(lshape, new_lshape);
80   auto r_padding_config = create_padding_config(rshape, new_rshape);
81 
82   HloComputation* parent = dot->parent();
83 
84   HloInstruction* zero_float = parent->AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::CreateR0<half>((half)0.0)));
86   zero_float->set_metadata(dot->metadata());
87 
88   HloInstruction* lpad = parent->AddInstruction(
89       HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
90   lpad->set_metadata(dot->metadata());
91 
92   HloInstruction* rpad = parent->AddInstruction(
93       HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
94   rpad->set_metadata(dot->metadata());
95 
96   HloInstruction* new_dot = parent->AddInstruction(
97       dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
98 
99   std::vector<int64> start_indices(result_shape.rank(), 0);
100   std::vector<int64> strides(result_shape.rank(), 1);
101   HloInstruction* slice = parent->AddInstruction(
102       HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
103                                   result_shape.dimensions(), strides));
104   slice->set_metadata(dot->metadata());
105 
106   bool is_root = dot->user_count() == 0;
107 
108   TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
109 
110   if (is_root) {
111     parent->set_root_instruction(slice);
112   }
113 
114   return true;
115 }
116 
117 namespace {
118 
119 // We need this check because PadForTensorCores works in the assumption that
120 // the dot instruction is canonicalized.
CheckCanonical(HloDotInstruction * dot)121 bool CheckCanonical(HloDotInstruction* dot) {
122   auto dimension_numbers = dot->dot_dimension_numbers();
123 
124   if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
125           dot->operand(0)->shape().rank() ||
126       dimension_numbers.rhs_batch_dimensions_size() + 2 !=
127           dot->operand(1)->shape().rank()) {
128     LOG(ERROR) << "Dot is not canonical: Expected all dimensions but 2 to be "
129                   "batch_dimensions.";
130     return false;
131   }
132 
133   std::vector<int64> canonical_batch_dims(
134       dimension_numbers.lhs_batch_dimensions_size());
135   absl::c_iota(canonical_batch_dims, 0);
136   if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
137                      canonical_batch_dims) ||
138       !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
139                      canonical_batch_dims)) {
140     LOG(ERROR) << "Dot is not canonical: Expected batch dimensions to be all "
141                   "dimensions except for the last 2 ones.";
142     return false;
143   }
144 
145   return true;
146 }
147 
148 }  // namespace
149 
GetRelevantDots(HloComputation * comp)150 static std::vector<HloDotInstruction*> GetRelevantDots(HloComputation* comp) {
151   std::vector<HloDotInstruction*> convs;
152 
153   for (HloInstruction* instr : comp->instructions()) {
154     if (IsMatrixMultiplication(*instr)) {
155       HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
156       if (CheckCanonical(dot)) {
157         convs.push_back(dot);
158       }
159     }
160   }
161   return convs;
162 }
163 
Run(HloModule * module)164 StatusOr<bool> CublasGemmPadForTensorCores::Run(HloModule* module) {
165   bool changed = false;
166   for (HloComputation* comp : module->MakeNonfusionComputations()) {
167     for (HloDotInstruction* dot : GetRelevantDots(comp)) {
168       TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(dot));
169       changed |= result;
170     }
171   }
172   return changed;
173 }
174 
175 }  // namespace gpu
176 }  // namespace xla
177