• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
17 
18 #include <cstdlib>
19 #include <functional>
20 #include <numeric>
21 #include <optional>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/blas.h"
37 
38 namespace xla {
39 namespace gpu {
40 
41 namespace {
42 
SetFortranLayout(Shape * shape)43 void SetFortranLayout(Shape* shape) {
44   LayoutUtil::SetToDefaultLayout(shape);
45   int n = shape->mutable_layout()->minor_to_major_size();
46   CHECK_GE(n, 2);
47   std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
48             shape->mutable_layout()->mutable_minor_to_major()->at(1));
49 }
50 
CreateCholesky(GpuSolverContext * context,HloInstruction * operand,const CholeskyOptions & options,const OpMetadata & metadata)51 StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
52                                          HloInstruction* operand,
53                                          const CholeskyOptions& options,
54                                          const OpMetadata& metadata) {
55   HloComputation* computation = operand->parent();
56 
57   Shape a_shape = operand->shape();
58   int ndim = a_shape.dimensions_size();
59   CHECK_GE(ndim, 2);
60   int64_t n = a_shape.dimensions(ndim - 1);
61 
62   std::vector<int64_t> batch_dims(a_shape.dimensions().begin(),
63                                   a_shape.dimensions().end() - 2);
64   std::vector<int64_t> batch_dim_ids(batch_dims.size());
65   absl::c_iota(batch_dim_ids, 0);
66   int64_t batch_size = absl::c_accumulate(batch_dims, 1, std::multiplies<>{});
67 
68   // Find the workspace size.
69   se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
70                                               : se::blas::UpperLower::kUpper;
71   int64_t workspace_size;  // Number of elements of size a_shape.element_type()
72   TF_ASSIGN_OR_RETURN(
73       workspace_size,
74       context->PotrfBufferSize(a_shape.element_type(), uplo, n, n, batch_size));
75 
76   // TODO(phawkins): Ideally we would relax this constraint. What we actually
77   // want is that:
78   // a) the batch dimensions are major, in no particular order.
79   // b) the two minor dimensions are in fortran (column-major) order,
80 
81   SetFortranLayout(&a_shape);
82 
83   // This call returns a tuple of (cholesky_result, workspace, info) where:
84   // * cholesky_result is the result of the Cholesky decomposition,
85   // * workspace is temporary scratch memory used by cuSolver.
86   // * info contains the Potrf success/failure status.
87   // Currently we have no meaningful way to report an error, so we simply
88   // discard the success/failure information. Obviously this is suboptimal.
89   Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims);
90   Shape call_shape = ShapeUtil::MakeTupleShape(
91       {a_shape,
92        ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
93        info_shape});
94 
95   HloInstruction* custom_call =
96       computation->AddInstruction(HloInstruction::CreateCustomCall(
97           call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
98   custom_call->set_metadata(metadata);
99   TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
100   HloInstruction* out = computation->AddInstruction(
101       HloInstruction::CreateGetTupleElement(a_shape, custom_call, 0));
102   HloInstruction* info = computation->AddInstruction(
103       HloInstruction::CreateGetTupleElement(info_shape, custom_call, 2));
104 
105   // If info was non-zero, indicating that the Cholesky decomposition failed,
106   // returns an array full of NaNs for the corresponding batch element.
107   HloInstruction* zero = computation->AddInstruction(
108       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
109   HloInstruction* zeros =
110       computation->AddInstruction(HloInstruction::CreateBroadcast(
111           info_shape, zero, /*broadcast_dimensions=*/{}));
112   HloInstruction* ok = computation->AddInstruction(
113       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims),
114                                     info, zeros, ComparisonDirection::kEq));
115   ok = computation->AddInstruction(HloInstruction::CreateBroadcast(
116       ShapeUtil::MakeShape(PRED, a_shape.dimensions()), ok,
117       /*broadcast_dimensions=*/batch_dim_ids));
118 
119   TF_ASSIGN_OR_RETURN(Literal nan_literal,
120                       LiteralUtil::NanValue(a_shape.element_type()));
121   HloInstruction* nan = computation->AddInstruction(
122       HloInstruction::CreateConstant(std::move(nan_literal)));
123   HloInstruction* nans =
124       computation->AddInstruction(HloInstruction::CreateBroadcast(
125           a_shape, nan, /*broadcast_dimensions=*/{}));
126 
127   HloInstruction* select =
128       computation->AddInstruction(HloInstruction::CreateTernary(
129           a_shape, HloOpcode::kSelect, ok, out, nans));
130   return select;
131 }
132 
133 // Tries to rewrite a single convolution into a call to cudnn.
RunOnInstruction(GpuSolverContext * context,HloInstruction * instruction)134 StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
135                                 HloInstruction* instruction) {
136   if (instruction->opcode() != HloOpcode::kCholesky) {
137     return false;
138   }
139 
140   TF_ASSIGN_OR_RETURN(
141       HloInstruction * custom_call,
142       CreateCholesky(context, instruction->mutable_operand(0),
143                      instruction->cholesky_options(), instruction->metadata()));
144 
145   VLOG(1) << "Replacing " << instruction->ToString() << " with "
146           << custom_call->ToString();
147 
148   TF_RETURN_IF_ERROR(
149       instruction->parent()->ReplaceInstruction(instruction, custom_call));
150   return true;
151 }
152 
153 }  // namespace
154 
155 // Rewrites the convolutions in the given computation into calls to cudnn.
156 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)157 StatusOr<bool> GpusolverRewriter::RunOnComputation(
158     HloComputation* computation) {
159   std::vector<HloInstruction*> cusolver_calls;
160   for (auto* hlo : computation->instructions()) {
161     if (hlo->opcode() == HloOpcode::kCholesky) {
162       cusolver_calls.push_back(hlo);
163     }
164   }
165 
166   if (cusolver_calls.empty()) {
167     return false;
168   }
169 
170   TF_ASSIGN_OR_RETURN(GpuSolverContext context,
171                       GpuSolverContext::Create(/*stream=*/nullptr));
172 
173   bool changed = false;
174   for (HloInstruction* instruction : cusolver_calls) {
175     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(&context, instruction));
176     changed |= result;
177   }
178   return changed;
179 }
180 
181 GpusolverRewriter::GpusolverRewriter() = default;
182 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)183 StatusOr<bool> GpusolverRewriter::Run(
184     HloModule* module,
185     const absl::flat_hash_set<absl::string_view>& execution_threads) {
186   bool changed = false;
187   for (HloComputation* computation :
188        module->MakeNonfusionComputations(execution_threads)) {
189     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
190     changed |= result;
191   }
192   return changed;
193 }
194 
195 }  // namespace gpu
196 }  // namespace xla
197