1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/stream_executor/lib/statusor.h"
32
33 namespace xla {
34 namespace gpu {
35
36 namespace m = match;
37
38 class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
39 public:
HandleCustomCall(HloInstruction * instr)40 Status HandleCustomCall(HloInstruction *instr) override {
41 HloInstruction *existing_gemm;
42 HloInstruction *bcast;
43 if (Match(instr, m::Op(&existing_gemm)
44 .WithCustomCallTarget(kGemmCallTarget)
45 .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
46 (Match(instr, m::Op(&existing_gemm)
47 .WithCustomCallTarget(kGemmCallTarget)
48 .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
49 TF_ASSIGN_OR_RETURN(auto config,
50 existing_gemm->backend_config<GemmBackendConfig>());
51 DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
52 int bcast_operand_index = instr->operand_index(bcast);
53 int num_bcast_dims = (bcast->shape().dimensions_size() -
54 bcast->operand(0)->shape().dimensions_size());
55 int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
56
57 const tensorflow::protobuf::RepeatedField<int64_t> &batch_dimensions =
58 (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
59 : dim_nums->lhs_batch_dimensions();
60 // This optimization is only valid if the set of broadcasted dimensions
61 // is exactly the set of batch dimensions. First, check that all newly
62 // broadcast dimensions have been inserted on the left i.e. all new
63 // dimensions must be in [0, num_bcast_dims) or equivalently all original
64 // dimensions are >= num_bcast_dims.
65 for (int64_t bcast_dim : bcast->dimensions()) {
66 if (bcast_dim < num_bcast_dims) {
67 return OkStatus();
68 }
69 // bcast_dim should not be in batch_dimensions.
70 if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
71 return OkStatus();
72 }
73 }
74
75 // Then check that all batch dimensions are being broadcast, and that
76 // there is at least one batch dimension.
77 CHECK_GT(num_bcast_dims, 0);
78 if (num_bcast_dims != num_batch_dims) {
79 return OkStatus();
80 }
81
82 if (bcast_operand_index == 1) {
83 CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
84 dim_nums->set_rhs_contracting_dimensions(
85 0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
86 dim_nums->clear_rhs_batch_dimensions();
87 } else {
88 CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
89 dim_nums->set_lhs_contracting_dimensions(
90 0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
91 dim_nums->clear_lhs_batch_dimensions();
92 }
93 TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
94 bcast_operand_index, bcast->mutable_operand(0)));
95 TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
96 MarkAsChanged();
97 }
98 return OkStatus();
99 }
100 };
101
RunOnComputation(HloComputation * computation)102 static StatusOr<bool> RunOnComputation(HloComputation *computation) {
103 GemmBroadcastFoldingVisitor visitor;
104 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
105 return visitor.changed();
106 }
107
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)108 StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
109 HloModule *module,
110 const absl::flat_hash_set<absl::string_view> &execution_threads) {
111 bool changed = false;
112 for (HloComputation *computation :
113 module->MakeNonfusionComputations(execution_threads)) {
114 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
115 changed |= result;
116 }
117 return changed;
118 }
119
120 } // namespace gpu
121 } // namespace xla
122