• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
26 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 namespace xla {
36 
37 // Analysis which allocates HloBuffers to HloValues.
38 class HloAliasAnalysis {
39  public:
40   // The callgraph of the given HloModule must be flattened
41   // (xla::FlattenCallGraph) prior to running the analysis.
42   static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(
43       const HloModule* module,
44       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr);
45 
46   std::string ToString() const;
47 
48   // Return the buffer containing the given value.
GetBufferContainingValue(const HloValue & value)49   const HloBuffer& GetBufferContainingValue(const HloValue& value) const {
50     return *value_to_buffer_.at(&value);
51   }
GetBufferContainingValue(const HloValue & value)52   HloBuffer& GetBufferContainingValue(const HloValue& value) {
53     return *value_to_buffer_.at(&value);
54   }
55 
56   // Return the HloBuffer with the given ID.
GetBuffer(HloBuffer::Id buffer_id)57   const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const {
58     return buffers_.at(buffer_id);
59   }
GetBuffer(HloBuffer::Id buffer_id)60   HloBuffer& GetBuffer(HloBuffer::Id buffer_id) {
61     return buffers_.at(buffer_id);
62   }
63 
64   // Returns the unique buffer at the given position. CHECK fails if the buffer
65   // set at that position does not contain exactly one buffer.
66   const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
67                                      const ShapeIndex& index = {}) const;
68   HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
69                                const ShapeIndex& index = {});
70 
71   // Compute the set of buffers at the given instruction and index and return as
72   // a vector. This set is exactly the union of the buffers containing the
73   // HloValues at this position.
74   std::vector<const HloBuffer*> ComputeBuffersAt(
75       const HloInstruction* instruction, const ShapeIndex& index = {}) const;
76 
77   // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
78   // vector is lazily computed. Mutating operations on HloAliasAnalysis may
79   // invalidate the underlying vector requiring recomputation.
buffers()80   const std::vector<HloBuffer>& buffers() const { return buffers_; }
81 
82   // Returns the underlying dataflow analysis used by this alias analysis.
dataflow_analysis()83   HloDataflowAnalysis& dataflow_analysis() const { return *dataflow_analysis_; }
84 
85   // Returns true if a buffer lives out of the module.
BufferLivesOut(const HloBuffer & buffer)86   bool BufferLivesOut(const HloBuffer& buffer) const {
87     return live_out_buffers_.contains(&buffer);
88   }
89 
90   // Returns true if a hlo value lives out of the module.
ValueLivesOut(const HloValue & value)91   bool ValueLivesOut(const HloValue& value) const {
92     return live_out_buffers_.contains(&GetBufferContainingValue(value));
93   }
94 
LiveOutBuffers()95   std::vector<const HloBuffer*> LiveOutBuffers() const {
96     std::vector<const HloBuffer*> results(live_out_buffers_.begin(),
97                                           live_out_buffers_.end());
98     absl::c_sort(results, HloBuffer::IdLessThan);
99     return results;
100   }
101 
102  protected:
103   explicit HloAliasAnalysis(const HloModule* module);
104 
105   // Verify various invariants of the alias analysis.
106   Status Verify() const;
107 
108   const HloModule* module_;
109 
110   // A set of buffers that live out the module.
111   absl::flat_hash_set<const HloBuffer*> live_out_buffers_;
112 
113   // The underlying dataflow analysis used by this alias analysis.
114   std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
115 
116   // A map indicating which buffer a value is contained in.
117   absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_;
118 
119   // A lazily constructed vector containing all HloBuffers sorted by
120   // HloBuffer::Id.
121   std::vector<HloBuffer> buffers_;
122 };
123 
124 }  // namespace xla
125 
126 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_
127