• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/interpreter/compiler.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
23 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
24 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
25 #include "tensorflow/compiler/xla/service/comparison_expander.h"
26 #include "tensorflow/compiler/xla/service/computation_placer.h"
27 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
28 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
29 #include "tensorflow/compiler/xla/service/eigh_expander.h"
30 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
31 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
32 #include "tensorflow/compiler/xla/service/hlo_cse.h"
33 #include "tensorflow/compiler/xla/service/hlo_dce.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
35 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
36 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
37 #include "tensorflow/compiler/xla/service/interpreter/executable.h"
38 #include "tensorflow/compiler/xla/service/layout_assignment.h"
39 #include "tensorflow/compiler/xla/service/map_inliner.h"
40 #include "tensorflow/compiler/xla/service/qr_expander.h"
41 #include "tensorflow/compiler/xla/service/reshape_mover.h"
42 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
43 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 
47 namespace xla {
48 namespace interpreter {
49 
50 namespace {
51 
52 // Handles custom_call ops during evaluation by routing them through the global
53 // CPU registry used by other CPU-based backends.
HandleEvaluatorCustomCall(HloInstruction * custom_call,absl::Span<const Literal * > operands)54 StatusOr<Literal> HandleEvaluatorCustomCall(
55     HloInstruction* custom_call, absl::Span<const Literal*> operands) {
56   // Find the target C function in the global registry.
57   auto* registry = CustomCallTargetRegistry::Global();
58   void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
59   if (!target_fn) {
60     return NotFound("Custom call target '%s' was not registered",
61                     custom_call->custom_call_target());
62   }
63 
64   // Populate pointers to operand and output literal data.
65   std::vector<const void*> operand_data;
66   operand_data.reserve(operands.size());
67   for (const auto* literal : operands) {
68     operand_data.push_back(literal->untyped_data());
69   }
70   auto output = Literal::CreateFromShape(custom_call->shape());
71   void* output_data = output.untyped_data();
72 
73   // Call the target function matching the C ABI used by the CPU backends.
74   auto* typed_fn = reinterpret_cast<void (*)(void*, const void**)>(target_fn);
75   (*typed_fn)(output_data, operand_data.data());
76 
77   return std::move(output);
78 }
79 
80 }  // namespace
81 
RunHloOptimization(HloModule * hlo_module)82 Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
83   HloPassPipeline pipeline("Interpreter");
84 
85   pipeline.AddPass<DynamicIndexSplitter>();
86   pipeline.AddPass<CholeskyExpander>();
87   pipeline.AddPass<QrExpander>();
88   pipeline.AddPass<EighExpander>();
89   pipeline.AddPass<ComparisonExpander>();
90   pipeline.AddPass<TriangularSolveExpander>();
91   pipeline.AddPass<BatchNormExpander>(
92       /*rewrite_training_op=*/true,
93       /*rewrite_inference_op=*/true,
94       /*rewrite_grad_op=*/true);
95   pipeline.AddPass<LayoutAssignment>(
96       hlo_module->mutable_entry_computation_layout());
97 
98   return pipeline.Run(hlo_module).status();
99 }
100 
RunHloPasses(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor *,const CompileOptions &)101 StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
102     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
103     const CompileOptions& /*options*/) {
104   VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
105   TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
106   return std::move(hlo_module);
107 }
108 
RunBackend(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor * stream_exec,const CompileOptions &)109 StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
110     std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
111     const CompileOptions& /*options*/) {
112   TF_RET_CHECK(stream_exec != nullptr);
113 
114   VLOG(1) << "Run backend " << hlo_module->name();
115 
116   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
117                       DynamicDimensionInference::Run(hlo_module.get()));
118 
119   auto evaluator = std::make_unique<HloEvaluator>();
120   evaluator->set_use_fast_path(
121       hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path());
122   evaluator->set_custom_call_handler(HandleEvaluatorCustomCall);
123 
124   // Create executable from only the Hlo module.
125   std::unique_ptr<Executable> executable =
126       std::make_unique<InterpreterExecutable>(
127           std::move(hlo_module), std::move(evaluator),
128           std::move(dynamic_dimension_inference));
129 
130   return std::move(executable);
131 }
132 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,const CompileOptions & options)133 StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
134     std::unique_ptr<HloModuleGroup> module_group,
135     std::vector<std::vector<se::StreamExecutor*>> stream_exec,
136     const CompileOptions& options) {
137   if (module_group->empty()) {
138     return std::vector<std::unique_ptr<Executable>>();
139   }
140   if (module_group->size() > 1) {
141     return tensorflow::errors::Unimplemented(
142         "Compilation of multiple HLO modules is not supported on Interpreter.");
143   }
144   if (stream_exec.size() != 1 || stream_exec[0].size() != 1) {
145     return tensorflow::errors::Unimplemented(
146         "Unexpected number of StreamExecutor's.");
147   }
148   auto hlo_modules = module_group->ConsumeModules();
149   TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
150                                                 stream_exec[0][0], options));
151   TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
152                                                   stream_exec[0][0], options));
153   std::vector<std::unique_ptr<Executable>> ret;
154   ret.push_back(std::move(executable));
155   return std::move(ret);
156 }
157 
158 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)159 InterpreterCompiler::CompileAheadOfTime(
160     std::unique_ptr<HloModuleGroup> module_group,
161     const AotCompilationOptions& aot_options) {
162   return tensorflow::errors::InvalidArgument(
163       "AOT compilation not supported on Interpreter");
164 }
165 
PlatformId() const166 se::Platform::Id InterpreterCompiler::PlatformId() const {
167   return se::interpreter::kXlaInterpreterPlatformId;
168 }
169 
ShapeSizeBytesFunction() const170 HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
171     const {
172   return InterpreterExecutable::ShapeSizeBytes;
173 }
174 
InitModule()175 static bool InitModule() {
176   xla::Compiler::RegisterCompilerFactory(
177       se::interpreter::kXlaInterpreterPlatformId, []() {
178         return std::make_unique<xla::interpreter::InterpreterCompiler>();
179       });
180   xla::ComputationPlacer::RegisterComputationPlacer(
181       se::interpreter::kXlaInterpreterPlatformId,
182       []() { return std::make_unique<xla::ComputationPlacer>(); });
183   return true;
184 }
185 
186 static bool module_initialized = InitModule();
187 
188 }  // namespace interpreter
189 }  // namespace xla
190