• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
17 
18 #include <stdio.h>
19 #include <unistd.h>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/status.h"
29 
30 namespace xla {
31 namespace {
32 
33 // Visitor that build a new HLO module with an entry computation and a root that
34 // is provided to the visit function. Only HLOs that are reachable from the new
35 // root instruction are included in the new module.
36 //
37 // The constructor allows specifying a set of boundary HLOs to prune the HLO
38 // graph. HLOs at the boundary are replaced with parameters. Can be nullptr
39 // which means no boundary, i.e. no HLOs are replaced with parameters.
40 class ExtractionVisitor : public ConstDfsHloVisitorWithDefault {
41  public:
ExtractionVisitor(const HloModule & old_module,absl::flat_hash_set<const HloInstruction * > * boundary)42   explicit ExtractionVisitor(
43       const HloModule& old_module,
44       absl::flat_hash_set<const HloInstruction*>* boundary)
45       : old_module_(old_module),
46         module_(absl::make_unique<HloModule>("extracted", config_)),
47         clone_context_(module_.get()),
48         builder_("entry_computation"),
49         boundary_(boundary) {}
50 
HandleParameter(const HloInstruction * parameter)51   Status HandleParameter(const HloInstruction* parameter) override {
52     // Entry parameters need renumbering.
53     auto new_parameter = HloInstruction::CreateParameter(
54         parameter_number_++, parameter->shape(), parameter->name());
55     clone_context_.MapInstruction(parameter, new_parameter.get());
56     builder_.AddInstruction(std::move(new_parameter));
57     return Status::OK();
58   }
59 
DefaultAction(const HloInstruction * hlo)60   Status DefaultAction(const HloInstruction* hlo) override {
61     // Replace instructions at the boundary with parameters, but leave constants
62     // untouched.
63     if (boundary_ != nullptr && boundary_->count(hlo) > 0) {
64       auto new_parameter = HloInstruction::CreateParameter(
65           parameter_number_, hlo->shape(), hlo->name());
66       parameter_number_++;
67       clone_context_.MapInstruction(hlo, new_parameter.get());
68       builder_.AddInstruction(std::move(new_parameter));
69       return Status::OK();
70     }
71     std::vector<HloInstruction*> new_operands;
72     for (auto operand : hlo->operands()) {
73       new_operands.push_back(clone_context_.GetInstruction(operand));
74     }
75     auto instruction =
76         hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_);
77     builder_.AddInstruction(std::move(instruction));
78     return Status::OK();
79   }
80 
FinishVisit(const HloInstruction *)81   Status FinishVisit(const HloInstruction* /*root*/) override {
82     module_->AddEntryComputation(builder_.Build());
83     // Rename HLOs so that their name matches the original. By default,
84     // HLOs get new unique names when adding a new entry computation to
85     // a module.
86     for (auto computation : old_module_.MakeComputationPostOrder()) {
87       for (auto old_instruction : computation->MakeInstructionPostOrder()) {
88         if (auto new_instruction =
89                 clone_context_.FindInstruction(old_instruction)) {
90           new_instruction->SetAndSanitizeName(old_instruction->name());
91         }
92       }
93     }
94     return Status::OK();
95   }
96 
module()97   HloModule* module() { return module_.get(); }
98 
ConsumeModule()99   std::unique_ptr<HloModule> ConsumeModule() { return std::move(module_); }
100 
101  private:
102   const HloModule& old_module_;
103   HloModuleConfig config_;
104   std::unique_ptr<HloModule> module_;
105   HloCloneContext clone_context_;
106   HloComputation::Builder builder_;
107   absl::flat_hash_set<const HloInstruction*>* boundary_;
108   int64 parameter_number_ = 0;
109 };
110 
ComputeBoundary(const HloInstruction * root,int64 limit,absl::flat_hash_set<const HloInstruction * > * boundary)111 void ComputeBoundary(const HloInstruction* root, int64 limit,
112                      absl::flat_hash_set<const HloInstruction*>* boundary) {
113   std::deque<const HloInstruction*> worklist;
114   absl::flat_hash_map<const HloInstruction*, int64> visited;
115   worklist.push_back(root);
116   visited.emplace(root, 0);
117   while (!worklist.empty()) {
118     auto hlo = worklist.front();
119     worklist.pop_front();
120     int64 hops = visited[hlo];
121     if (hops > limit) {
122       boundary->insert(hlo);
123       continue;
124     }
125     for (const HloInstruction* operand : hlo->operands()) {
126       if (visited.count(operand)) {
127         continue;
128       }
129       worklist.push_back(operand);
130       visited.emplace(operand, hops + 1);
131     }
132   }
133 }
134 
135 }  // namespace
136 
ExtractModule(HloInstruction * instruction,int64 height)137 std::unique_ptr<HloModule> ExtractModule(HloInstruction* instruction,
138                                          int64 height) {
139   absl::flat_hash_set<const HloInstruction*> boundary;
140   if (height != -1) {
141     ComputeBoundary(instruction, height, &boundary);
142   }
143   ExtractionVisitor visitor(*instruction->GetModule(), &boundary);
144   CHECK(instruction->Accept(&visitor).ok());
145 
146   // The first pass may leave unused parameter instructions. Do another
147   // extraction pass to remove unused parameters. This is done because
148   // HloComputation does not allow removing parameters after the computation has
149   // been built.
150   ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr);
151   TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept(
152       &cleanup_visitor));
153 
154   HloVerifier verifier(/*layout_sensitive=*/false,
155                        /*allow_mixed_precision=*/true);
156   TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status());
157   return cleanup_visitor.ConsumeModule();
158 }
159 
160 }  // namespace xla
161