1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/compile.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/base/call_once.h"
24 #include "llvm-c/Target.h"
25 #include "tensorflow/compiler/aot/codegen.h"
26 #include "tensorflow/compiler/aot/flags.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.h"
28 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
29 #include "tensorflow/compiler/xla/client/client_library.h"
30 #include "tensorflow/compiler/xla/client/compile_only_client.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/strings/proto_serialization.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace tensorflow {
45 namespace tfcompile {
46
47 namespace {
48
49 // Compiles the XLA computation into executable code.
CompileXla(xla::CompileOnlyClient * client,const xla::XlaComputation & computation,const xla::cpu::CpuAotCompilationOptions & aot_opts,CompileResult * compile_result)50 Status CompileXla(xla::CompileOnlyClient* client,
51 const xla::XlaComputation& computation,
52 const xla::cpu::CpuAotCompilationOptions& aot_opts,
53 CompileResult* compile_result) {
54 // Retrieves arg and result layouts from the computation.
55 // TODO(toddw): Should we let the user choose the major/minor ordering?
56 xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
57 client->GetComputationShape(computation);
58 if (!pshape_or.ok()) {
59 return errors::Unknown("Couldn't get XLA program shape: ",
60 pshape_or.status().error_message());
61 }
62 compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
63 xla::ProgramShapeProto* pshape = &compile_result->program_shape;
64
65 // AotXlaComputationInstance::argument_layouts is a vector of Shape
66 // pointers. Accumulate the Shape objects themselves in a separate vector
67 // while building the vector of pointers.
68 std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
69 std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
70 for (int i = 0; i < pshape->parameters_size(); ++i) {
71 arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
72 arg_layout_ptrs[i] = &arg_layouts[i];
73 }
74 xla::CompileOnlyClient::AotXlaComputationInstance instance;
75 instance.computation = &computation;
76 instance.argument_layouts = std::move(arg_layout_ptrs);
77 xla::Shape result_shape(pshape->result());
78 instance.result_layout = &result_shape;
79 xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
80 aot_or = client->CompileAheadOfTime({instance}, aot_opts);
81 if (!aot_or.ok()) {
82 return errors::Unknown("XLA compilation failed: ",
83 aot_or.status().error_message());
84 }
85 compile_result->aot =
86 xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
87 std::move(aot_or.ValueOrDie().back()));
88 compile_result->entry_point = aot_opts.entry_point_name();
89 compile_result->pointer_size =
90 xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
91 return Status::OK();
92 }
93
94 } // namespace
95
CompileGraph(GraphDef graph_def,const tf2xla::Config & config,const MainFlags & flags,CompileResult * compile_result)96 Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
97 const MainFlags& flags, CompileResult* compile_result) {
98 // Converts the graph into an XLA computation, and compiles the
99 // computation.
100 // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
101 se::Platform* cpu_platform =
102 se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
103 xla::CompileOnlyClient* client =
104 xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
105 .ValueOrDie();
106 xla::XlaComputation computation;
107 if (flags.mlir_components == "Bridge") {
108 TF_RETURN_IF_ERROR(
109 ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
110 } else {
111 if (!flags.mlir_components.empty()) {
112 return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
113 }
114 TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
115 client, &computation));
116 }
117 if (!flags.out_session_module.empty()) {
118 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
119 computation.Snapshot());
120 // Serialize the HloSnapshot deterministically so that all the outputs of a
121 // tf_library genrule are deterministic.
122 const size_t size = module->ByteSizeLong();
123 auto serialized = absl::make_unique<char[]>(size);
124 TF_RET_CHECK(
125 SerializeToBufferDeterministic(*module, serialized.get(), size));
126 TF_RETURN_IF_ERROR(
127 WriteStringToFile(Env::Default(), flags.out_session_module,
128 absl::string_view(serialized.get(), size)));
129 }
130 xla::cpu::CpuAotCompilationOptions aot_opts(
131 flags.target_triple, flags.target_cpu, flags.target_features,
132 flags.entry_point,
133 xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
134
135 return CompileXla(client, computation, aot_opts, compile_result);
136 }
137
ReadProtoFile(const string & fname,protobuf::Message * proto)138 static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
139 if (absl::EndsWith(fname, ".pbtxt")) {
140 return ReadTextProto(Env::Default(), fname, proto);
141 } else {
142 return ReadBinaryProto(Env::Default(), fname, proto);
143 }
144 }
145
146 static absl::once_flag targets_init;
147
InitializeTargets()148 static void InitializeTargets() {
149 // Initialize all LLVM targets so we can cross compile.
150 #if TF_LLVM_AARCH64_AVAILABLE
151 LLVMInitializeAArch64Target();
152 LLVMInitializeAArch64TargetInfo();
153 LLVMInitializeAArch64TargetMC();
154 LLVMInitializeAArch64AsmPrinter();
155 #endif
156 LLVMInitializeARMTarget();
157 LLVMInitializeARMTargetInfo();
158 LLVMInitializeARMTargetMC();
159 LLVMInitializeARMAsmPrinter();
160 LLVMInitializePowerPCTarget();
161 LLVMInitializePowerPCTargetInfo();
162 LLVMInitializePowerPCTargetMC();
163 LLVMInitializePowerPCAsmPrinter();
164 LLVMInitializeX86Target();
165 LLVMInitializeX86TargetInfo();
166 LLVMInitializeX86TargetMC();
167 LLVMInitializeX86AsmPrinter();
168 }
169
Main(const MainFlags & flags)170 Status Main(const MainFlags& flags) {
171 absl::call_once(targets_init, &InitializeTargets);
172
173 // Process config.
174 tf2xla::Config config;
175 if (flags.config.empty()) {
176 return errors::InvalidArgument("Must specify --config");
177 }
178 TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
179 TF_RETURN_IF_ERROR(ValidateConfig(config));
180 if (flags.dump_fetch_nodes) {
181 std::set<string> nodes;
182 for (const tf2xla::Fetch& fetch : config.fetch()) {
183 nodes.insert(fetch.id().node_name());
184 }
185 std::cout << absl::StrJoin(nodes, ",");
186 return Status::OK();
187 }
188
189 // Read and initialize the graph.
190 if (flags.graph.empty()) {
191 return errors::InvalidArgument("Must specify --graph");
192 }
193 GraphDef graph_def;
194 TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
195 CompileResult compile_result;
196 TF_RETURN_IF_ERROR(
197 CompileGraph(std::move(graph_def), config, flags, &compile_result));
198
199 // Write output files.
200 Env* env = Env::Default();
201 const std::vector<char>& obj = compile_result.aot->object_file_data();
202 TF_RETURN_IF_ERROR(
203 WriteStringToFile(env, flags.out_function_object,
204 absl::string_view(obj.data(), obj.size())));
205 CodegenOpts codegen_opts;
206 codegen_opts.gen_name_to_index = flags.gen_name_to_index;
207 codegen_opts.gen_program_shape = flags.gen_program_shape;
208 codegen_opts.target_triple = flags.target_triple;
209 if (flags.cpp_class.empty()) {
210 return errors::InvalidArgument("Must specify --cpp_class");
211 }
212 codegen_opts.gen_hlo_profile_printer_data =
213 xla::GetDebugOptionsFromFlags().xla_hlo_profile();
214 TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
215 &codegen_opts.namespaces));
216
217 MetadataResult metadata_result;
218 TF_RETURN_IF_ERROR(
219 GenerateMetadata(codegen_opts, compile_result, &metadata_result));
220 TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
221 metadata_result.object_file_data));
222 string header;
223 TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
224 metadata_result, &header));
225 TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
226 return Status::OK();
227 }
228
229 } // namespace tfcompile
230 } // namespace tensorflow
231