• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/stream_executor/lib/statusor.h"
32 
33 namespace xla {
34 namespace gpu {
35 
36 namespace m = match;
37 
38 class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor {
39  public:
HandleCustomCall(HloInstruction * instr)40   Status HandleCustomCall(HloInstruction *instr) override {
41     HloInstruction *existing_gemm;
42     HloInstruction *bcast;
43     if (Match(instr, m::Op(&existing_gemm)
44                          .WithCustomCallTarget(kGemmCallTarget)
45                          .WithOperand(0, m::Broadcast(&bcast, m::Op()))) ||
46         (Match(instr, m::Op(&existing_gemm)
47                           .WithCustomCallTarget(kGemmCallTarget)
48                           .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) {
49       TF_ASSIGN_OR_RETURN(auto config,
50                           existing_gemm->backend_config<GemmBackendConfig>());
51       DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers();
52       int bcast_operand_index = instr->operand_index(bcast);
53       int num_bcast_dims = (bcast->shape().dimensions_size() -
54                             bcast->operand(0)->shape().dimensions_size());
55       int num_batch_dims = dim_nums->lhs_batch_dimensions_size();
56 
57       const tensorflow::protobuf::RepeatedField<int64_t> &batch_dimensions =
58           (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions()
59                                      : dim_nums->lhs_batch_dimensions();
60       // This optimization is only valid if the set of broadcasted dimensions
61       // is exactly the set of batch dimensions. First, check that all newly
62       // broadcast dimensions have been inserted on the left i.e. all new
63       // dimensions must be in [0, num_bcast_dims) or equivalently all original
64       // dimensions are >= num_bcast_dims.
65       for (int64_t bcast_dim : bcast->dimensions()) {
66         if (bcast_dim < num_bcast_dims) {
67           return OkStatus();
68         }
69         // bcast_dim should not be in batch_dimensions.
70         if (absl::c_linear_search(batch_dimensions, bcast_dim)) {
71           return OkStatus();
72         }
73       }
74 
75       // Then check that all batch dimensions are being broadcast, and that
76       // there is at least one batch dimension.
77       CHECK_GT(num_bcast_dims, 0);
78       if (num_bcast_dims != num_batch_dims) {
79         return OkStatus();
80       }
81 
82       if (bcast_operand_index == 1) {
83         CHECK_EQ(dim_nums->rhs_contracting_dimensions_size(), 1);
84         dim_nums->set_rhs_contracting_dimensions(
85             0, dim_nums->rhs_contracting_dimensions(0) - num_batch_dims);
86         dim_nums->clear_rhs_batch_dimensions();
87       } else {
88         CHECK_EQ(dim_nums->lhs_contracting_dimensions_size(), 1);
89         dim_nums->set_lhs_contracting_dimensions(
90             0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims);
91         dim_nums->clear_lhs_batch_dimensions();
92       }
93       TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape(
94           bcast_operand_index, bcast->mutable_operand(0)));
95       TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
96       MarkAsChanged();
97     }
98     return OkStatus();
99   }
100 };
101 
RunOnComputation(HloComputation * computation)102 static StatusOr<bool> RunOnComputation(HloComputation *computation) {
103   GemmBroadcastFoldingVisitor visitor;
104   TF_RETURN_IF_ERROR(computation->Accept(&visitor));
105   return visitor.changed();
106 }
107 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)108 StatusOr<bool> GemmBroadcastFoldingRewriter::Run(
109     HloModule *module,
110     const absl::flat_hash_set<absl::string_view> &execution_threads) {
111   bool changed = false;
112   for (HloComputation *computation :
113        module->MakeNonfusionComputations(execution_threads)) {
114     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
115     changed |= result;
116   }
117   return changed;
118 }
119 
120 }  // namespace gpu
121 }  // namespace xla
122