• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
18 
19 #include <functional>
20 #include <string>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/service/hlo.pb.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/gtl/int_type.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 // Abstract class describing a value used by one of the dataflow analyses -
36 // TuplePointsToAnalysis or HloDataflowAnalysis.
37 // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused.
38 //
39 // XLA arrays are trivially a single BufferValue. Tuples are made up of more
40 // than one BufferValue: an BufferValue for the pointer vector, and an
41 // BufferValue for each child element.
42 //
43 // Every BufferValue is defined by a particular instruction and most
44 // instructions define only a single BufferValue. Instructions which define a
45 // single BufferValue include array-shaped instructions such as Add but also
46 // includes Tuple-shaped instructions such as Tuple. The Tuple instruction
47 // defines a single BufferValue which is a vector of pointers to the values
48 // containing the Tuple instruction's operands. Though the result of the Tuple
49 // instruction includes multiple values only the top-level BufferValue (the
50 // vector of pointers) is defined by the Tuple instruction. The values
51 // containing the tuple elements are defined by earlier instructions, usually
52 // the operands of the Tuple instruction.
53 //
54 // Instructions which construct both the tuple *and* the tuple elements define
55 // more than one BufferValue. This includes (at least) tuple-shaped Constant,
56 // Parameter, Infeed and While instructions. These tuple-shaped instructions do
57 // not assemble a tuple from existing BufferValues like the Tuple instruction
58 // does, but rather define all the BufferValues in the tuple.
59 //
60 // Some instructions, such as Bitcast, define no buffers. These instructions
61 // simply forward buffers from their operands.
62 //
63 // The BufferValue object describes which HLO instruction defines a buffer and
64 // where within that instruction's output shape the buffer is defined. The
65 // location within the output shape is indicated by BufferValue::index() which
66 // is defined identically to the index used in ShapeUtil::GetSubshape().
67 // Examples:
68 //
69 // %add = Add(%foo, %bar)
70 // %tuple_constant = Constant({1, {42, 43}})
71 //
72 // %add defines a single array-shaped buffer BufferValue(%add, {}) which holds
73 // the array result of the add operation. The nested-tuple-shaped
74 // %tuple_constant defines 5 buffers described by the following BufferValue
75 // objects:
76 //
77 //   BufferValue(%tuple_constant, {})      // "Top-level" buffer: vector of
78 //                                         //  pointers to BufferValues at
79 //                                         //  indices {0} and {1}
80 //   BufferValue(%tuple_constant, {0})     // Holds value "1"
81 //   BufferValue(%tuple_constant, {1})     // Holds nested tuple: vector of
82 //                                         //  pointers to BufferValues at
83 //                                         //  indices {1, 0} and {1, 1}
84 //   BufferValue(%tuple_constant, {1, 0})  // Holds value "42"
85 //   BufferValue(%tuple_constant, {1, 1})  // Holds value "43"
86 
87 class BufferValue {
88  public:
89   TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64);
90 
91   // Id is a unique identifier for the BufferValue to facilitate efficient
92   // collections of BufferValues with stable iteration order.
93   using Id = int64;
94 
95   // Functions which return the size and alignment of a logical buffer in bytes.
96   using SizeFunction = std::function<int64(const BufferValue&)>;
97   using AlignmentFunction = std::function<int64(BufferValue::Color)>;
98 
99   virtual ~BufferValue();
100 
id()101   Id id() const { return id_; }
102 
103   // Return the instruction that defines the buffer.
104   virtual HloInstruction* instruction() const = 0;
105 
106   // Return the index within the output of the instruction where the buffer is
107   // defined. Index used defined as in ShapeUtil::GetSubshape()
108   virtual const ShapeIndex& index() const = 0;
109 
110   // Return the color of the BufferValue. Differently colored buffers can not be
111   // parts of the same allocation.
color()112   Color color() const {
113     CHECK_NE(color_, kInvalidColor)
114         << "Should not query the color of a buffer that was never colored";
115     return color_;
116   }
117 
set_color(Color color)118   void set_color(Color color) {
119     CHECK_NE(color, kInvalidColor)
120         << "Should not set the color of a buffer to the invalid color";
121     color_ = color;
122   }
123 
has_color()124   bool has_color() const { return color_ != kInvalidColor; }
125 
126   // Return the shape of the buffer. This reference points into the shape field
127   // of the instruction defining the buffer.  Therefore, the returned shape will
128   // contain the layout of instruction, if any.
129   virtual const Shape& shape() const = 0;
130 
131   // Returns true if this buffer is the top-level output buffer of the defining
132   // HLO instruction. This is equivalent to index == {}.
IsTopLevel()133   bool IsTopLevel() const { return index().empty(); }
134 
135   // Whether this buffer contains a tuple.
IsTuple()136   bool IsTuple() const { return is_tuple_; }
137 
138   // Whether this buffer contains an array.
IsArray()139   bool IsArray() const { return is_array_; }
140 
141   // operator< is required for std::set.
142   bool operator<(const BufferValue& other) const { return id_ < other.id_; }
143 
144   bool operator==(const BufferValue& other) const { return id_ == other.id_; }
145   bool operator!=(const BufferValue& other) const { return id_ != other.id_; }
146 
147   virtual string ToString() const = 0;
148 
149   // TODO(lauj) rename LogicalBufferProto to BufferValueProto.
150   LogicalBufferProto ToProto(const SizeFunction& size_fn) const;
151 
152   // Returns the LogicalBufferProto::Location that serializes the given
153   // instruction and index.
154   static LogicalBufferProto::Location ToLocationProto(
155       const HloInstruction& instruction, const ShapeIndex& index);
156 
157   const Color kInvalidColor = Color(-1);
158 
159  protected:
160   BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id);
161 
162  private:
163   // The definining instruction and index are not stored here; they can be found
164   // in the LogicalBuffer and HloValue subclasses. This class exists only to
165   // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by
166   // allowing abstract use of LogicalBuffer or HloValue. After those migrations
167   // are complete, this class should be deleted (b/78906445). Because we plan to
168   // delete LogicalBuffer and this class, we don't refactor all the shared
169   // features from LogicalBuffer and HloValue into this class.
170   Id id_ : 62;
171   bool is_array_ : 1;
172   bool is_tuple_ : 1;
173   Color color_ = kInvalidColor;
174 };
175 
176 std::ostream& operator<<(std::ostream& out, const BufferValue& buffer);
177 
178 }  // namespace xla
179 
180 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_
181