1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_parser.h"
19
20 namespace xla {
21
22 /*static*/ StatusOr<std::unique_ptr<HloModule>>
CreateModuleFromString(const absl::string_view hlo_string,const DebugOptions & debug_options)23 HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
24 const DebugOptions& debug_options) {
25 HloModuleConfig config;
26 config.set_debug_options(debug_options);
27 return ParseAndReturnUnverifiedModule(hlo_string, config);
28 }
29
30 namespace {
31
32 // Creates an HloModule from the given proto.
HloProtoToModule(const HloProto & proto,const DebugOptions & debug_options)33 StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
34 const HloProto& proto, const DebugOptions& debug_options) {
35 TF_ASSIGN_OR_RETURN(HloModuleConfig config,
36 HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
37 debug_options));
38 TF_ASSIGN_OR_RETURN(auto module,
39 HloModule::CreateFromProto(proto.hlo_module(), config));
40 return std::move(module);
41 }
42
43 } // namespace
44
45 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromBinaryProtoFile(const std::string & filename,const DebugOptions & debug_options)46 HloRunnerInterface::ReadModuleFromBinaryProtoFile(
47 const std::string& filename, const DebugOptions& debug_options) {
48 HloProto proto;
49 TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
50 filename, &proto));
51 return HloProtoToModule(proto, debug_options);
52 }
53
54 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromTextProtoFile(const std::string & filename,const DebugOptions & debug_options)55 HloRunnerInterface::ReadModuleFromTextProtoFile(
56 const std::string& filename, const DebugOptions& debug_options) {
57 HloProto proto;
58 TF_RETURN_IF_ERROR(
59 tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
60 return HloProtoToModule(proto, debug_options);
61 }
62
63 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromHloTextFile(const std::string & filename,const DebugOptions & debug_options)64 HloRunnerInterface::ReadModuleFromHloTextFile(
65 const std::string& filename, const DebugOptions& debug_options) {
66 string hlo_string;
67 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
68 filename, &hlo_string));
69 HloModuleConfig config;
70 config.set_debug_options(debug_options);
71 return ParseAndReturnUnverifiedModule(hlo_string, config);
72 }
73
74 /*static*/ StatusOr<std::unique_ptr<HloModule>>
ReadModuleFromModuleBinaryProtofile(const std::string & filename,const DebugOptions & debug_options)75 HloRunnerInterface::ReadModuleFromModuleBinaryProtofile(
76 const std::string& filename, const DebugOptions& debug_options) {
77 HloModuleProto module_proto;
78 TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
79 filename, &module_proto));
80
81 TF_ASSIGN_OR_RETURN(
82 HloModuleConfig module_config,
83 HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
84
85 return HloModule::CreateFromProto(module_proto, module_config);
86 }
87
Execute(std::unique_ptr<HloModule> module,absl::Span<const Literal> arguments,bool run_hlo_passes,ExecutionProfile * profile)88 StatusOr<Literal> HloRunnerInterface::Execute(
89 std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
90 bool run_hlo_passes, ExecutionProfile* profile) {
91 // Construct a vector of plain pointers for the arguments.
92 std::vector<const Literal*> argument_pointers;
93 argument_pointers.reserve(arguments.size());
94 for (const auto& argument : arguments) {
95 argument_pointers.push_back(&argument);
96 }
97 return Execute(
98 /*module=*/std::move(module),
99 /*arguments=*/argument_pointers,
100 /*run_hlo_passes=*/run_hlo_passes,
101 /*profile=*/profile);
102 }
103
ExecuteWithExecutable(std::unique_ptr<Executable> executable,absl::Span<const Literal> arguments,ExecutionProfile * profile)104 StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
105 std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
106 ExecutionProfile* profile) {
107 // Construct a vector of plain pointers for the arguments.
108 std::vector<const Literal*> argument_pointers;
109 argument_pointers.reserve(arguments.size());
110 for (const auto& argument : arguments) {
111 argument_pointers.push_back(&argument);
112 }
113 return ExecuteWithExecutable(std::move(executable), argument_pointers,
114 nullptr);
115 }
116
117 } // namespace xla
118