• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/compile.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/base/call_once.h"
24 #include "llvm-c/Target.h"
25 #include "tensorflow/compiler/aot/codegen.h"
26 #include "tensorflow/compiler/aot/flags.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.h"
28 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
29 #include "tensorflow/compiler/xla/client/client_library.h"
30 #include "tensorflow/compiler/xla/client/compile_only_client.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/strings/proto_serialization.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace tensorflow {
45 namespace tfcompile {
46 
47 namespace {
48 
49 // Compiles the XLA computation into executable code.
CompileXla(xla::CompileOnlyClient * client,const xla::XlaComputation & computation,const xla::cpu::CpuAotCompilationOptions & aot_opts,CompileResult * compile_result)50 Status CompileXla(xla::CompileOnlyClient* client,
51                   const xla::XlaComputation& computation,
52                   const xla::cpu::CpuAotCompilationOptions& aot_opts,
53                   CompileResult* compile_result) {
54   // Retrieves arg and result layouts from the computation.
55   // TODO(toddw): Should we let the user choose the major/minor ordering?
56   xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
57       client->GetComputationShape(computation);
58   if (!pshape_or.ok()) {
59     return errors::Unknown("Couldn't get XLA program shape: ",
60                            pshape_or.status().error_message());
61   }
62   compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
63   xla::ProgramShapeProto* pshape = &compile_result->program_shape;
64 
65   // AotXlaComputationInstance::argument_layouts is a vector of Shape
66   // pointers. Accumulate the Shape objects themselves in a separate vector
67   // while building the vector of pointers.
68   std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
69   std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
70   for (int i = 0; i < pshape->parameters_size(); ++i) {
71     arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
72     arg_layout_ptrs[i] = &arg_layouts[i];
73   }
74   xla::CompileOnlyClient::AotXlaComputationInstance instance;
75   instance.computation = &computation;
76   instance.argument_layouts = std::move(arg_layout_ptrs);
77   xla::Shape result_shape(pshape->result());
78   instance.result_layout = &result_shape;
79   xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
80       aot_or = client->CompileAheadOfTime({instance}, aot_opts);
81   if (!aot_or.ok()) {
82     return errors::Unknown("XLA compilation failed: ",
83                            aot_or.status().error_message());
84   }
85   compile_result->aot =
86       xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
87           std::move(aot_or.ValueOrDie().back()));
88   compile_result->entry_point = aot_opts.entry_point_name();
89   compile_result->pointer_size =
90       xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
91   return Status::OK();
92 }
93 
94 }  // namespace
95 
CompileGraph(GraphDef graph_def,const tf2xla::Config & config,const MainFlags & flags,CompileResult * compile_result)96 Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
97                     const MainFlags& flags, CompileResult* compile_result) {
98   // Converts the graph into an XLA computation, and compiles the
99   // computation.
100   // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
101   se::Platform* cpu_platform =
102       se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
103   xla::CompileOnlyClient* client =
104       xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
105           .ValueOrDie();
106   xla::XlaComputation computation;
107   if (flags.mlir_components == "Bridge") {
108     TF_RETURN_IF_ERROR(
109         ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
110   } else {
111     if (!flags.mlir_components.empty()) {
112       return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
113     }
114     TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
115                                             client, &computation));
116   }
117   if (!flags.out_session_module.empty()) {
118     TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
119                         computation.Snapshot());
120     // Serialize the HloSnapshot deterministically so that all the outputs of a
121     // tf_library genrule are deterministic.
122     const size_t size = module->ByteSizeLong();
123     auto serialized = absl::make_unique<char[]>(size);
124     TF_RET_CHECK(
125         SerializeToBufferDeterministic(*module, serialized.get(), size));
126     TF_RETURN_IF_ERROR(
127         WriteStringToFile(Env::Default(), flags.out_session_module,
128                           absl::string_view(serialized.get(), size)));
129   }
130   xla::cpu::CpuAotCompilationOptions aot_opts(
131       flags.target_triple, flags.target_cpu, flags.target_features,
132       flags.entry_point,
133       xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
134 
135   return CompileXla(client, computation, aot_opts, compile_result);
136 }
137 
ReadProtoFile(const string & fname,protobuf::Message * proto)138 static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
139   if (absl::EndsWith(fname, ".pbtxt")) {
140     return ReadTextProto(Env::Default(), fname, proto);
141   } else {
142     return ReadBinaryProto(Env::Default(), fname, proto);
143   }
144 }
145 
146 static absl::once_flag targets_init;
147 
InitializeTargets()148 static void InitializeTargets() {
149   // Initialize all LLVM targets so we can cross compile.
150 #if TF_LLVM_AARCH64_AVAILABLE
151   LLVMInitializeAArch64Target();
152   LLVMInitializeAArch64TargetInfo();
153   LLVMInitializeAArch64TargetMC();
154   LLVMInitializeAArch64AsmPrinter();
155 #endif
156   LLVMInitializeARMTarget();
157   LLVMInitializeARMTargetInfo();
158   LLVMInitializeARMTargetMC();
159   LLVMInitializeARMAsmPrinter();
160   LLVMInitializePowerPCTarget();
161   LLVMInitializePowerPCTargetInfo();
162   LLVMInitializePowerPCTargetMC();
163   LLVMInitializePowerPCAsmPrinter();
164   LLVMInitializeX86Target();
165   LLVMInitializeX86TargetInfo();
166   LLVMInitializeX86TargetMC();
167   LLVMInitializeX86AsmPrinter();
168 }
169 
Main(const MainFlags & flags)170 Status Main(const MainFlags& flags) {
171   absl::call_once(targets_init, &InitializeTargets);
172 
173   // Process config.
174   tf2xla::Config config;
175   if (flags.config.empty()) {
176     return errors::InvalidArgument("Must specify --config");
177   }
178   TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
179   TF_RETURN_IF_ERROR(ValidateConfig(config));
180   if (flags.dump_fetch_nodes) {
181     std::set<string> nodes;
182     for (const tf2xla::Fetch& fetch : config.fetch()) {
183       nodes.insert(fetch.id().node_name());
184     }
185     std::cout << absl::StrJoin(nodes, ",");
186     return Status::OK();
187   }
188 
189   // Read and initialize the graph.
190   if (flags.graph.empty()) {
191     return errors::InvalidArgument("Must specify --graph");
192   }
193   GraphDef graph_def;
194   TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
195   CompileResult compile_result;
196   TF_RETURN_IF_ERROR(
197       CompileGraph(std::move(graph_def), config, flags, &compile_result));
198 
199   // Write output files.
200   Env* env = Env::Default();
201   const std::vector<char>& obj = compile_result.aot->object_file_data();
202   TF_RETURN_IF_ERROR(
203       WriteStringToFile(env, flags.out_function_object,
204                         absl::string_view(obj.data(), obj.size())));
205   CodegenOpts codegen_opts;
206   codegen_opts.gen_name_to_index = flags.gen_name_to_index;
207   codegen_opts.gen_program_shape = flags.gen_program_shape;
208   codegen_opts.target_triple = flags.target_triple;
209   if (flags.cpp_class.empty()) {
210     return errors::InvalidArgument("Must specify --cpp_class");
211   }
212   codegen_opts.gen_hlo_profile_printer_data =
213       xla::GetDebugOptionsFromFlags().xla_hlo_profile();
214   TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
215                                    &codegen_opts.namespaces));
216 
217   MetadataResult metadata_result;
218   TF_RETURN_IF_ERROR(
219       GenerateMetadata(codegen_opts, compile_result, &metadata_result));
220   TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
221                                        metadata_result.object_file_data));
222   string header;
223   TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
224                                     metadata_result, &header));
225   TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
226   return Status::OK();
227 }
228 
229 }  // namespace tfcompile
230 }  // namespace tensorflow
231