• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Emits an HLO module in a text form suitable for diffing.
17 
18 #include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
19 
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_parser.h"
32 #include "tensorflow/core/lib/io/path.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/regexp.h"
37 
38 namespace xla {
39 namespace {
40 
OverrideConfig(const hlo_module_loader_details::Config & ovr_config,HloModuleConfig * config)41 Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config,
42                       HloModuleConfig* config) {
43   config->set_replica_count(ovr_config.num_replicas);
44   config->set_num_partitions(ovr_config.num_partitions);
45   return OkStatus();
46 }
47 
48 }  // namespace
49 
StripLogHeaders(const std::string & hlo_string)50 std::string StripLogHeaders(const std::string& hlo_string) {
51   // I0521 12:04:45.883483    1509 service.cc:186] ...
52   static RE2* matcher = new RE2(
53       "[IWEF]\\d{4} "
54       "\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)");
55   absl::string_view matches[4];
56   std::vector<std::string> lines = absl::StrSplit(hlo_string, '\n');
57   for (auto& line : lines) {
58     if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) {
59       line = std::string(matches[1]);
60     }
61   }
62   return absl::StrJoin(lines, "\n",
63                        [](std::string* out, const std::string& line) {
64                          absl::StrAppend(out, line);
65                        });
66 }
67 
LoadModuleFromData(const std::string & data,const std::string & format,hlo_module_loader_details::Config ovr_config,const std::function<void (HloModuleConfig *)> & config_modifier_hook)68 StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
69     const std::string& data, const std::string& format,
70     hlo_module_loader_details::Config ovr_config,
71     const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
72   DebugOptions debug_options = GetDebugOptionsFromFlags();
73   std::unique_ptr<HloModule> module;
74   if (format == "hlo" || format == "txt") {
75     std::string hlo_string = StripLogHeaders(data);
76     HloModuleConfig config;
77     config.set_debug_options(debug_options);
78     TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
79     if (config_modifier_hook) {
80       config_modifier_hook(&config);
81     }
82     TF_ASSIGN_OR_RETURN(module,
83                         ParseAndReturnUnverifiedModule(hlo_string, config));
84   } else {
85     HloSnapshot proto;
86     if (format == "pb") {
87       if (!proto.ParseFromString(data) &&
88           !proto.mutable_hlo()->ParseFromString(data) &&
89           !proto.mutable_hlo()->mutable_hlo_module()->ParseFromString(data)) {
90         return InvalidArgument("Failed to parse input as HLO protobuf binary");
91       }
92     } else if (format == "pbtxt") {
93       if (!tensorflow::protobuf::TextFormat::ParseFromString(data, &proto) &&
94           !tensorflow::protobuf::TextFormat::ParseFromString(
95               data, proto.mutable_hlo()) &&
96           !tensorflow::protobuf::TextFormat::ParseFromString(
97               data, proto.mutable_hlo()->mutable_hlo_module())) {
98         return InvalidArgument("Failed to parse input as HLO protobuf text");
99       }
100     } else {
101       return InvalidArgument(
102           "Invalid format from file extension: '%s'. Expected: hlo, txt, pb, "
103           "or pbtxt",
104           format);
105     }
106     TF_ASSIGN_OR_RETURN(HloModuleConfig config,
107                         HloModule::CreateModuleConfigFromProto(
108                             proto.hlo().hlo_module(), debug_options));
109     TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
110     if (config_modifier_hook) {
111       config_modifier_hook(&config);
112     }
113     TF_ASSIGN_OR_RETURN(
114         module, HloModule::CreateFromProto(proto.hlo().hlo_module(), config));
115   }
116   return std::move(module);
117 }
118 
LoadModuleFromFile(const std::string & path,hlo_module_loader_details::Config ovr_config,std::string format,const std::function<void (HloModuleConfig *)> & config_modifier_hook)119 StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
120     const std::string& path, hlo_module_loader_details::Config ovr_config,
121     std::string format,
122     const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
123   std::string data;
124   if (format.empty()) {
125     format = std::string(tensorflow::io::Extension(path));
126   }
127   TF_RETURN_IF_ERROR(
128       tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &data));
129   return LoadModuleFromData(data, format, ovr_config, config_modifier_hook);
130 }
131 
132 }  // namespace xla
133