1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Emits an HLO module in a text form suitable for diffing.
17
18 #include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
19
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_parser.h"
32 #include "tensorflow/core/lib/io/path.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/regexp.h"
37
38 namespace xla {
39 namespace {
40
OverrideConfig(const hlo_module_loader_details::Config & ovr_config,HloModuleConfig * config)41 Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config,
42 HloModuleConfig* config) {
43 config->set_replica_count(ovr_config.num_replicas);
44 config->set_num_partitions(ovr_config.num_partitions);
45 return OkStatus();
46 }
47
48 } // namespace
49
StripLogHeaders(const std::string & hlo_string)50 std::string StripLogHeaders(const std::string& hlo_string) {
51 // I0521 12:04:45.883483 1509 service.cc:186] ...
52 static RE2* matcher = new RE2(
53 "[IWEF]\\d{4} "
54 "\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)");
55 absl::string_view matches[4];
56 std::vector<std::string> lines = absl::StrSplit(hlo_string, '\n');
57 for (auto& line : lines) {
58 if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) {
59 line = std::string(matches[1]);
60 }
61 }
62 return absl::StrJoin(lines, "\n",
63 [](std::string* out, const std::string& line) {
64 absl::StrAppend(out, line);
65 });
66 }
67
LoadModuleFromData(const std::string & data,const std::string & format,hlo_module_loader_details::Config ovr_config,const std::function<void (HloModuleConfig *)> & config_modifier_hook)68 StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
69 const std::string& data, const std::string& format,
70 hlo_module_loader_details::Config ovr_config,
71 const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
72 DebugOptions debug_options = GetDebugOptionsFromFlags();
73 std::unique_ptr<HloModule> module;
74 if (format == "hlo" || format == "txt") {
75 std::string hlo_string = StripLogHeaders(data);
76 HloModuleConfig config;
77 config.set_debug_options(debug_options);
78 TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
79 if (config_modifier_hook) {
80 config_modifier_hook(&config);
81 }
82 TF_ASSIGN_OR_RETURN(module,
83 ParseAndReturnUnverifiedModule(hlo_string, config));
84 } else {
85 HloSnapshot proto;
86 if (format == "pb") {
87 if (!proto.ParseFromString(data) &&
88 !proto.mutable_hlo()->ParseFromString(data) &&
89 !proto.mutable_hlo()->mutable_hlo_module()->ParseFromString(data)) {
90 return InvalidArgument("Failed to parse input as HLO protobuf binary");
91 }
92 } else if (format == "pbtxt") {
93 if (!tensorflow::protobuf::TextFormat::ParseFromString(data, &proto) &&
94 !tensorflow::protobuf::TextFormat::ParseFromString(
95 data, proto.mutable_hlo()) &&
96 !tensorflow::protobuf::TextFormat::ParseFromString(
97 data, proto.mutable_hlo()->mutable_hlo_module())) {
98 return InvalidArgument("Failed to parse input as HLO protobuf text");
99 }
100 } else {
101 return InvalidArgument(
102 "Invalid format from file extension: '%s'. Expected: hlo, txt, pb, "
103 "or pbtxt",
104 format);
105 }
106 TF_ASSIGN_OR_RETURN(HloModuleConfig config,
107 HloModule::CreateModuleConfigFromProto(
108 proto.hlo().hlo_module(), debug_options));
109 TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
110 if (config_modifier_hook) {
111 config_modifier_hook(&config);
112 }
113 TF_ASSIGN_OR_RETURN(
114 module, HloModule::CreateFromProto(proto.hlo().hlo_module(), config));
115 }
116 return std::move(module);
117 }
118
LoadModuleFromFile(const std::string & path,hlo_module_loader_details::Config ovr_config,std::string format,const std::function<void (HloModuleConfig *)> & config_modifier_hook)119 StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
120 const std::string& path, hlo_module_loader_details::Config ovr_config,
121 std::string format,
122 const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
123 std::string data;
124 if (format.empty()) {
125 format = std::string(tensorflow::io::Extension(path));
126 }
127 TF_RETURN_IF_ERROR(
128 tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &data));
129 return LoadModuleFromData(data, format, ovr_config, config_modifier_hook);
130 }
131
132 } // namespace xla
133