• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
17 
18 #include <stdio.h>
19 #include <unistd.h>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
27 #include "tensorflow/compiler/xla/status.h"
28 
29 namespace xla {
30 namespace {
31 
32 // Visitor that build a new HLO module with an entry computation and a root that
33 // is provided to the visit function. Only HLOs that are reachable from the new
34 // root instruction are included in the new module.
35 //
36 // The constructor allows specifying a set of boundary HLOs to prune the HLO
37 // graph. HLOs at the boundary are replaced with parameters. Can be nullptr
38 // which means no boundary, i.e. no HLOs are replaced with parameters.
39 class ExtractionVisitor : public ConstDfsHloVisitorWithDefault {
40  public:
ExtractionVisitor(const HloModule & old_module,absl::flat_hash_set<const HloInstruction * > * boundary)41   explicit ExtractionVisitor(
42       const HloModule& old_module,
43       absl::flat_hash_set<const HloInstruction*>* boundary)
44       : old_module_(old_module),
45         module_(absl::make_unique<HloModule>("extracted", config_)),
46         clone_context_(module_.get()),
47         builder_("entry_computation"),
48         boundary_(boundary) {}
49 
HandleParameter(const HloInstruction * parameter)50   Status HandleParameter(const HloInstruction* parameter) override {
51     // Entry parameters need renumbering.
52     auto new_parameter = HloInstruction::CreateParameter(
53         parameter_number_++, parameter->shape(), parameter->name());
54     clone_context_.MapInstruction(parameter, new_parameter.get());
55     builder_.AddInstruction(std::move(new_parameter));
56     return Status::OK();
57   }
58 
DefaultAction(const HloInstruction * hlo)59   Status DefaultAction(const HloInstruction* hlo) override {
60     // Replace instructions at the boundary with parameters, but leave constants
61     // untouched.
62     if (boundary_ != nullptr && boundary_->count(hlo) > 0) {
63       auto new_parameter = HloInstruction::CreateParameter(
64           parameter_number_, hlo->shape(), hlo->name());
65       parameter_number_++;
66       clone_context_.MapInstruction(hlo, new_parameter.get());
67       builder_.AddInstruction(std::move(new_parameter));
68       return Status::OK();
69     }
70     std::vector<HloInstruction*> new_operands;
71     for (auto operand : hlo->operands()) {
72       new_operands.push_back(clone_context_.GetInstruction(operand));
73     }
74     auto instruction =
75         hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_);
76     builder_.AddInstruction(std::move(instruction));
77     return Status::OK();
78   }
79 
FinishVisit(const HloInstruction *)80   Status FinishVisit(const HloInstruction* /*root*/) override {
81     module_->AddEntryComputation(builder_.Build());
82     // Rename HLOs so that their name matches the original. By default,
83     // HLOs get new unique names when adding a new entry computation to
84     // a module.
85     for (auto computation : old_module_.MakeComputationPostOrder()) {
86       for (auto old_instruction : computation->MakeInstructionPostOrder()) {
87         if (auto new_instruction =
88                 clone_context_.FindInstruction(old_instruction)) {
89           new_instruction->SetAndSanitizeName(old_instruction->name());
90         }
91       }
92     }
93     return Status::OK();
94   }
95 
module()96   HloModule* module() { return module_.get(); }
97 
ConsumeModule()98   std::unique_ptr<HloModule> ConsumeModule() { return std::move(module_); }
99 
100  private:
101   const HloModule& old_module_;
102   HloModuleConfig config_;
103   std::unique_ptr<HloModule> module_;
104   HloCloneContext clone_context_;
105   HloComputation::Builder builder_;
106   absl::flat_hash_set<const HloInstruction*>* boundary_;
107   int64 parameter_number_ = 0;
108 };
109 
ComputeBoundary(const HloInstruction * root,int64 limit,absl::flat_hash_set<const HloInstruction * > * boundary)110 void ComputeBoundary(const HloInstruction* root, int64 limit,
111                      absl::flat_hash_set<const HloInstruction*>* boundary) {
112   std::deque<const HloInstruction*> worklist;
113   absl::flat_hash_map<const HloInstruction*, int64> visited;
114   worklist.push_back(root);
115   visited.emplace(root, 0);
116   while (!worklist.empty()) {
117     auto hlo = worklist.front();
118     worklist.pop_front();
119     int64 hops = visited[hlo];
120     if (hops > limit) {
121       boundary->insert(hlo);
122       continue;
123     }
124     for (const HloInstruction* operand : hlo->operands()) {
125       if (visited.count(operand)) {
126         continue;
127       }
128       worklist.push_back(operand);
129       visited.emplace(operand, hops + 1);
130     }
131   }
132 }
133 
134 }  // namespace
135 
ExtractModule(HloInstruction * instruction,int64 height)136 std::unique_ptr<HloModule> ExtractModule(HloInstruction* instruction,
137                                          int64 height) {
138   absl::flat_hash_set<const HloInstruction*> boundary;
139   if (height != -1) {
140     ComputeBoundary(instruction, height, &boundary);
141   }
142   ExtractionVisitor visitor(*instruction->GetModule(), &boundary);
143   CHECK(instruction->Accept(&visitor).ok());
144 
145   // The first pass may leave unused parameter instructions. Do another
146   // extraction pass to remove unused parameters. This is done because
147   // HloComputation does not allow removing parameters after the computation has
148   // been built.
149   ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr);
150   TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept(
151       &cleanup_visitor));
152 
153   HloVerifier verifier(/*layout_sensitive=*/false,
154                        /*allow_mixed_precision=*/true);
155   TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status());
156   return cleanup_visitor.ConsumeModule();
157 }
158 
159 }  // namespace xla
160