• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/get_compiler_ir.h"
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/compiler/jit/compilability_check_util.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/flags.h"
24 #include "tensorflow/compiler/jit/xla_launch_util.h"
25 #include "tensorflow/compiler/jit/xla_platform_info.h"
26 #include "tensorflow/compiler/tf2xla/const_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
28 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/util/ptr_util.h"
33 
34 namespace tensorflow {
35 
GetLocalExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompileOptions & compile_options,const NameAttrList & function,XlaCompilationCache * cache,absl::Span<XlaCompiler::Argument const> args,const XlaCompiler & compiler)36 static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
37     const XlaCompiler::Options& options,
38     const XlaCompiler::CompileOptions& compile_options,
39     const NameAttrList& function, XlaCompilationCache* cache,
40     absl::Span<XlaCompiler::Argument const> args, const XlaCompiler& compiler) {
41   const XlaCompiler::CompilationResult* compilation_result = nullptr;
42   xla::LocalExecutable* executable = nullptr;
43   TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options,
44                                     XlaCompilationCache::CompileMode::kStrict,
45                                     &compilation_result, &executable));
46   return executable;
47 }
48 
GetCompilerIr(IrExportStage stage,ProcessFunctionLibraryRuntime * pflr,absl::string_view func_name,Device * dev,EagerContext * context,absl::Span<const TensorHandle * const> inputs_handles)49 xla::StatusOr<std::string> GetCompilerIr(
50     IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
51     absl::string_view func_name, Device* dev, EagerContext* context,
52     absl::Span<const TensorHandle* const> inputs_handles) {
53   NameAttrList function;
54   function.set_name(std::string{func_name});
55 
56   FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
57   ResourceMgr* rmgr = dev->resource_manager();
58 
59   const FunctionBody* fbody = nullptr;
60   std::vector<int> constant_arg_indices;
61   std::vector<int> resource_arg_indices;
62   TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
63       flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
64 
65   MemoryTypeVector input_memory_types =
66       GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
67   MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
68 
69   std::deque<Tensor> inputs_storage;
70   std::vector<const Tensor*> inputs;
71   inputs.reserve(inputs_handles.size());
72   for (int i = 0; i < inputs_handles.size(); i++) {
73     const TensorHandle* th = inputs_handles[i];
74     const Tensor* t;
75     // Handle owns the tensor.
76     TF_RETURN_IF_ERROR(th->Tensor(&t));
77     if (absl::c_binary_search(constant_arg_indices, i)) {
78       // Need to make sure it's on the host.
79       inputs_storage.emplace_back(t->dtype(), t->shape());
80       TF_RETURN_IF_ERROR(
81           th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
82       inputs.push_back(&inputs_storage.back());
83     } else {
84       inputs.push_back(t);
85     }
86   }
87 
88   std::vector<VariableInfo> variable_infos;
89   TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
90       rmgr, dev, inputs, resource_arg_indices, &variable_infos));
91   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
92 
93   XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
94 
95   XlaCompilationCache* cache;
96   TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
97       rmgr->default_container(), "xla_cache", &cache,
98       [&](XlaCompilationCache** cache_write_into) {
99         return BuildXlaCompilationCache(dev, platform_info, cache_write_into);
100       }));
101   core::ScopedUnref cache_ref(cache);
102 
103   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
104 
105   XlaCompiler::Options options =
106       GenerateCompilerOptions(*cache, *flr, dev,
107                               /*stream=*/nullptr, platform_info,
108                               /*has_ref_vars=*/false, &tf_allocator_adapter);
109 
110   XlaCompiler::CompileOptions compile_options;
111   compile_options.always_return_tuple = false;
112   compile_options.alias_resource_update = true;
113 
114   XlaCompiler compiler(options);
115 
116   xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
117       XlaComputationLaunchContext::BuildXlaCompilerArguments(
118           constant_arg_indices, inputs, variable_infos, dev);
119   TF_RETURN_IF_ERROR(args.status());
120 
121   switch (stage) {
122     case IrExportStage::HLO:
123     case IrExportStage::HLO_SERIALIZED: {
124       XlaCompiler::CompilationResult result;
125       TF_RETURN_IF_ERROR(
126           compiler.CompileFunction(compile_options, function, *args, &result));
127 
128       TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
129                           result.computation->GetProgramShape());
130       xla::HloModuleConfig config(program_shape);
131       TF_ASSIGN_OR_RETURN(
132           std::unique_ptr<xla::HloModule> new_module,
133           xla::HloModule::CreateFromProto(result.computation->proto(), config));
134 
135       if (stage == IrExportStage::HLO_SERIALIZED) {
136         return new_module->ToProto().SerializeAsString();
137       } else {
138         return new_module->ToString();
139       }
140     }
141     case IrExportStage::OPTIMIZED_HLO:
142     case IrExportStage::OPTIMIZED_HLO_SERIALIZED: {
143       xla::StatusOr<xla::LocalExecutable*> executable = GetLocalExecutable(
144           options, compile_options, function, cache, *args, compiler);
145       TF_RETURN_IF_ERROR(executable.status());
146       xla::Executable* new_executable = (*executable)->executable();
147       if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) {
148         return new_executable->module().ToProto().SerializeAsString();
149       } else {
150         return new_executable->module().ToString();
151       }
152     }
153     case IrExportStage::OPTIMIZED_HLO_DOT: {
154       xla::StatusOr<xla::LocalExecutable*> executable = GetLocalExecutable(
155           options, compile_options, function, cache, *args, compiler);
156       TF_RETURN_IF_ERROR(executable.status());
157       xla::StatusOr<std::string> graph = xla::RenderGraph(
158           *(*executable)->executable()->module().entry_computation(),
159           "Visualization",
160           /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
161           /*hlo_execution_profile=*/nullptr,
162           /*hlo_render_options=*/{});
163       TF_RETURN_IF_ERROR(graph.status());
164       return *graph;
165     }
166   }
167 }
168 
169 }  // namespace tensorflow
170