1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_parser.h"
19
20 namespace xla {
21
22 /*static*/ StatusOr<std::unique_ptr<HloModule>>
CreateModuleFromString(const absl::string_view hlo_string,const DebugOptions & debug_options)23 HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
24 const DebugOptions& debug_options) {
25 HloModuleConfig config;
26 config.set_debug_options(debug_options);
27 return ParseAndReturnUnverifiedModule(hlo_string, config);
28 }
29
30 namespace {
31
32 // Creates an HloModule from the given proto.
HloProtoToModule(const HloProto & proto,const DebugOptions & debug_options)33 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
34 const HloProto& proto, const DebugOptions& debug_options) {
35 TF_ASSIGN_OR_RETURN(HloModuleConfig config,
36 HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
37 debug_options));
38 TF_ASSIGN_OR_RETURN(auto module,
39 HloModule::CreateFromProto(proto.hlo_module(), config));
40 return std::move(module);
41 }
42
43 } // namespace
44
45 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(const std::string & filename,const DebugOptions & debug_options)46 HloRunnerInterface::ReadModuleFromBinaryProtoFile(
47 const std::string& filename, const DebugOptions& debug_options) {
48 HloProto proto;
49 TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
50 filename, &proto));
51 return HloProtoToModule(proto, debug_options);
52 }
53
54 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromTextProtoFile(const std::string & filename,const DebugOptions & debug_options)55 HloRunnerInterface::ReadModuleFromTextProtoFile(
56 const std::string& filename, const DebugOptions& debug_options) {
57 HloProto proto;
58 TF_RETURN_IF_ERROR(
59 tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
60 return HloProtoToModule(proto, debug_options);
61 }
62
63 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromHloTextFile(const std::string & filename,const DebugOptions & debug_options)64 HloRunnerInterface::ReadModuleFromHloTextFile(
65 const std::string& filename, const DebugOptions& debug_options) {
66 string hlo_string;
67 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
68 filename, &hlo_string));
69 HloModuleConfig config;
70 config.set_debug_options(debug_options);
71 return ParseAndReturnUnverifiedModule(hlo_string, config);
72 }
73
74 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromModuleBinaryProtofile(const std::string & filename,const DebugOptions & debug_options)75 HloRunnerInterface::ReadModuleFromModuleBinaryProtofile(
76 const std::string& filename, const DebugOptions& debug_options) {
77 HloModuleProto module_proto;
78 TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
79 filename, &module_proto));
80
81 TF_ASSIGN_OR_RETURN(
82 HloModuleConfig module_config,
83 HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
84
85 return HloModule::CreateFromProto(module_proto, module_config);
86 }
87
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal> arguments,bool run_hlo_passes,ExecutionProfile * profile)88 StatusOr<Literal> HloRunnerInterface::Execute(
89 std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
90 bool run_hlo_passes, ExecutionProfile* profile) {
91 // Construct a vector of plain pointers for the arguments.
92 std::vector<const Literal*> argument_pointers;
93 argument_pointers.reserve(arguments.size());
94 for (const auto& argument : arguments) {
95 argument_pointers.push_back(&argument);
96 }
97 return Execute(
98 /*module=*/std::move(module),
99 /*arguments=*/argument_pointers,
100 /*run_hlo_passes=*/run_hlo_passes,
101 /*profile=*/profile);
102 }
103
ExecuteWithExecutable(Executable * executable,absl::Span<const Literal> arguments,ExecutionProfile * profile)104 StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
105 Executable* executable, absl::Span<const Literal> arguments,
106 ExecutionProfile* profile) {
107 // Construct a vector of plain pointers for the arguments.
108 std::vector<const Literal*> argument_pointers;
109 argument_pointers.reserve(arguments.size());
110 for (const auto& argument : arguments) {
111 argument_pointers.push_back(&argument);
112 }
113 return ExecuteWithExecutable(executable, argument_pointers, nullptr);
114 }
115
UpdateEntryComputationLayout(HloModule * module,DeviceShapeRepresentationFn shape_representation_fn)116 void HloRunnerInterface::UpdateEntryComputationLayout(
117 HloModule* module, DeviceShapeRepresentationFn shape_representation_fn) {
118 CHECK(shape_representation_fn != nullptr);
119 // Make sure entry computation shapes are in device representation.
120 for (int i = 0; i < module->entry_computation_layout().parameter_count();
121 i++) {
122 Shape shape =
123 module->entry_computation_layout().parameter_layout(i).shape();
124 *module->mutable_entry_computation_layout()->mutable_parameter_layout(i) =
125 ShapeLayout(shape_representation_fn(shape));
126 }
127 *module->mutable_entry_computation_layout()->mutable_result_layout() =
128 ShapeLayout(shape_representation_fn(
129 module->entry_computation_layout().result_layout().shape()));
130 }
131
132 } // namespace xla
133