1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/interpreter/compiler.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
23 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
24 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
25 #include "tensorflow/compiler/xla/service/comparison_expander.h"
26 #include "tensorflow/compiler/xla/service/computation_placer.h"
27 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
28 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
29 #include "tensorflow/compiler/xla/service/eigh_expander.h"
30 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
31 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
32 #include "tensorflow/compiler/xla/service/hlo_cse.h"
33 #include "tensorflow/compiler/xla/service/hlo_dce.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
35 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
36 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
37 #include "tensorflow/compiler/xla/service/interpreter/executable.h"
38 #include "tensorflow/compiler/xla/service/layout_assignment.h"
39 #include "tensorflow/compiler/xla/service/map_inliner.h"
40 #include "tensorflow/compiler/xla/service/qr_expander.h"
41 #include "tensorflow/compiler/xla/service/reshape_mover.h"
42 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
43 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/lib/core/errors.h"
46
47 namespace xla {
48 namespace interpreter {
49
50 namespace {
51
52 // Handles custom_call ops during evaluation by routing them through the global
53 // CPU registry used by other CPU-based backends.
HandleEvaluatorCustomCall(HloInstruction * custom_call,absl::Span<const Literal * > operands)54 StatusOr<Literal> HandleEvaluatorCustomCall(
55 HloInstruction* custom_call, absl::Span<const Literal*> operands) {
56 // Find the target C function in the global registry.
57 auto* registry = CustomCallTargetRegistry::Global();
58 void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
59 if (!target_fn) {
60 return NotFound("Custom call target '%s' was not registered",
61 custom_call->custom_call_target());
62 }
63
64 // Populate pointers to operand and output literal data.
65 std::vector<const void*> operand_data;
66 operand_data.reserve(operands.size());
67 for (const auto* literal : operands) {
68 operand_data.push_back(literal->untyped_data());
69 }
70 auto output = Literal::CreateFromShape(custom_call->shape());
71 void* output_data = output.untyped_data();
72
73 // Call the target function matching the C ABI used by the CPU backends.
74 auto* typed_fn = reinterpret_cast<void (*)(void*, const void**)>(target_fn);
75 (*typed_fn)(output_data, operand_data.data());
76
77 return std::move(output);
78 }
79
80 } // namespace
81
RunHloOptimization(HloModule * hlo_module)82 Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
83 HloPassPipeline pipeline("Interpreter");
84
85 pipeline.AddPass<DynamicIndexSplitter>();
86 pipeline.AddPass<CholeskyExpander>();
87 pipeline.AddPass<QrExpander>();
88 pipeline.AddPass<EighExpander>();
89 pipeline.AddPass<ComparisonExpander>();
90 pipeline.AddPass<TriangularSolveExpander>();
91 pipeline.AddPass<BatchNormExpander>(
92 /*rewrite_training_op=*/true,
93 /*rewrite_inference_op=*/true,
94 /*rewrite_grad_op=*/true);
95 pipeline.AddPass<LayoutAssignment>(
96 hlo_module->mutable_entry_computation_layout());
97
98 return pipeline.Run(hlo_module).status();
99 }
100
RunHloPasses(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor *,const CompileOptions &)101 StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
102 std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
103 const CompileOptions& /*options*/) {
104 VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
105 TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
106 return std::move(hlo_module);
107 }
108
RunBackend(std::unique_ptr<HloModule> hlo_module,se::StreamExecutor * stream_exec,const CompileOptions &)109 StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
110 std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
111 const CompileOptions& /*options*/) {
112 TF_RET_CHECK(stream_exec != nullptr);
113
114 VLOG(1) << "Run backend " << hlo_module->name();
115
116 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
117 DynamicDimensionInference::Run(hlo_module.get()));
118
119 auto evaluator = std::make_unique<HloEvaluator>();
120 evaluator->set_use_fast_path(
121 hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path());
122 evaluator->set_custom_call_handler(HandleEvaluatorCustomCall);
123
124 // Create executable from only the Hlo module.
125 std::unique_ptr<Executable> executable =
126 std::make_unique<InterpreterExecutable>(
127 std::move(hlo_module), std::move(evaluator),
128 std::move(dynamic_dimension_inference));
129
130 return std::move(executable);
131 }
132
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,const CompileOptions & options)133 StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
134 std::unique_ptr<HloModuleGroup> module_group,
135 std::vector<std::vector<se::StreamExecutor*>> stream_exec,
136 const CompileOptions& options) {
137 if (module_group->empty()) {
138 return std::vector<std::unique_ptr<Executable>>();
139 }
140 if (module_group->size() > 1) {
141 return tensorflow::errors::Unimplemented(
142 "Compilation of multiple HLO modules is not supported on Interpreter.");
143 }
144 if (stream_exec.size() != 1 || stream_exec[0].size() != 1) {
145 return tensorflow::errors::Unimplemented(
146 "Unexpected number of StreamExecutor's.");
147 }
148 auto hlo_modules = module_group->ConsumeModules();
149 TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
150 stream_exec[0][0], options));
151 TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
152 stream_exec[0][0], options));
153 std::vector<std::unique_ptr<Executable>> ret;
154 ret.push_back(std::move(executable));
155 return std::move(ret);
156 }
157
158 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)159 InterpreterCompiler::CompileAheadOfTime(
160 std::unique_ptr<HloModuleGroup> module_group,
161 const AotCompilationOptions& aot_options) {
162 return tensorflow::errors::InvalidArgument(
163 "AOT compilation not supported on Interpreter");
164 }
165
PlatformId() const166 se::Platform::Id InterpreterCompiler::PlatformId() const {
167 return se::interpreter::kXlaInterpreterPlatformId;
168 }
169
ShapeSizeBytesFunction() const170 HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
171 const {
172 return InterpreterExecutable::ShapeSizeBytes;
173 }
174
InitModule()175 static bool InitModule() {
176 xla::Compiler::RegisterCompilerFactory(
177 se::interpreter::kXlaInterpreterPlatformId, []() {
178 return std::make_unique<xla::interpreter::InterpreterCompiler>();
179 });
180 xla::ComputationPlacer::RegisterComputationPlacer(
181 se::interpreter::kXlaInterpreterPlatformId,
182 []() { return std::make_unique<xla::ComputationPlacer>(); });
183 return true;
184 }
185
186 static bool module_initialized = InitModule();
187
188 } // namespace interpreter
189 } // namespace xla
190