1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 18 19 #include <functional> 20 #include <string> 21 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo.pb.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/lib/gtl/int_type.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace xla { 34 35 // Abstract class describing a value used by one of the dataflow analyses - 36 // TuplePointsToAnalysis or HloDataflowAnalysis. 37 // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. 38 // 39 // XLA arrays are trivially a single BufferValue. Tuples are made up of more 40 // than one BufferValue: an BufferValue for the pointer vector, and an 41 // BufferValue for each child element. 42 // 43 // Every BufferValue is defined by a particular instruction and most 44 // instructions define only a single BufferValue. Instructions which define a 45 // single BufferValue include array-shaped instructions such as Add but also 46 // includes Tuple-shaped instructions such as Tuple. The Tuple instruction 47 // defines a single BufferValue which is a vector of pointers to the values 48 // containing the Tuple instruction's operands. Though the result of the Tuple 49 // instruction includes multiple values only the top-level BufferValue (the 50 // vector of pointers) is defined by the Tuple instruction. The values 51 // containing the tuple elements are defined by earlier instructions, usually 52 // the operands of the Tuple instruction. 53 // 54 // Instructions which construct both the tuple *and* the tuple elements define 55 // more than one BufferValue. This includes (at least) tuple-shaped Constant, 56 // Parameter, Infeed and While instructions. These tuple-shaped instructions do 57 // not assemble a tuple from existing BufferValues like the Tuple instruction 58 // does, but rather define all the BufferValues in the tuple. 59 // 60 // Some instructions, such as Bitcast, define no buffers. These instructions 61 // simply forward buffers from their operands. 62 // 63 // The BufferValue object describes which HLO instruction defines a buffer and 64 // where within that instruction's output shape the buffer is defined. The 65 // location within the output shape is indicated by BufferValue::index() which 66 // is defined identically to the index used in ShapeUtil::GetSubshape(). 67 // Examples: 68 // 69 // %add = Add(%foo, %bar) 70 // %tuple_constant = Constant({1, {42, 43}}) 71 // 72 // %add defines a single array-shaped buffer BufferValue(%add, {}) which holds 73 // the array result of the add operation. The nested-tuple-shaped 74 // %tuple_constant defines 5 buffers described by the following BufferValue 75 // objects: 76 // 77 // BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of 78 // // pointers to BufferValues at 79 // // indices {0} and {1} 80 // BufferValue(%tuple_constant, {0}) // Holds value "1" 81 // BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of 82 // // pointers to BufferValues at 83 // // indices {1, 0} and {1, 1} 84 // BufferValue(%tuple_constant, {1, 0}) // Holds value "42" 85 // BufferValue(%tuple_constant, {1, 1}) // Holds value "43" 86 87 class BufferValue { 88 public: 89 TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); 90 91 // Id is a unique identifier for the BufferValue to facilitate efficient 92 // collections of BufferValues with stable iteration order. 93 using Id = int64; 94 95 // Functions which return the size and alignment of a logical buffer in bytes. 96 using SizeFunction = std::function<int64(const BufferValue&)>; 97 using AlignmentFunction = std::function<int64(BufferValue::Color)>; 98 99 virtual ~BufferValue(); 100 id()101 Id id() const { return id_; } 102 103 // Return the instruction that defines the buffer. 104 virtual HloInstruction* instruction() const = 0; 105 106 // Return the index within the output of the instruction where the buffer is 107 // defined. Index used defined as in ShapeUtil::GetSubshape() 108 virtual const ShapeIndex& index() const = 0; 109 110 // Return the color of the BufferValue. Differently colored buffers can not be 111 // parts of the same allocation. color()112 Color color() const { 113 CHECK_NE(color_, kInvalidColor) 114 << "Should not query the color of a buffer that was never colored"; 115 return color_; 116 } 117 set_color(Color color)118 void set_color(Color color) { 119 CHECK_NE(color, kInvalidColor) 120 << "Should not set the color of a buffer to the invalid color"; 121 color_ = color; 122 } 123 has_color()124 bool has_color() const { return color_ != kInvalidColor; } 125 126 // Return the shape of the buffer. This reference points into the shape field 127 // of the instruction defining the buffer. Therefore, the returned shape will 128 // contain the layout of instruction, if any. 129 virtual const Shape& shape() const = 0; 130 131 // Returns true if this buffer is the top-level output buffer of the defining 132 // HLO instruction. This is equivalent to index == {}. IsTopLevel()133 bool IsTopLevel() const { return index().empty(); } 134 135 // Whether this buffer contains a tuple. IsTuple()136 bool IsTuple() const { return is_tuple_; } 137 138 // Whether this buffer contains an array. IsArray()139 bool IsArray() const { return is_array_; } 140 141 // operator< is required for std::set. 142 bool operator<(const BufferValue& other) const { return id_ < other.id_; } 143 144 bool operator==(const BufferValue& other) const { return id_ == other.id_; } 145 bool operator!=(const BufferValue& other) const { return id_ != other.id_; } 146 147 virtual string ToString() const = 0; 148 149 // TODO(lauj) rename LogicalBufferProto to BufferValueProto. 150 LogicalBufferProto ToProto(const SizeFunction& size_fn) const; 151 152 // Returns the LogicalBufferProto::Location that serializes the given 153 // instruction and index. 154 static LogicalBufferProto::Location ToLocationProto( 155 const HloInstruction& instruction, const ShapeIndex& index); 156 157 const Color kInvalidColor = Color(-1); 158 159 protected: 160 BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); 161 162 private: 163 // The definining instruction and index are not stored here; they can be found 164 // in the LogicalBuffer and HloValue subclasses. This class exists only to 165 // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by 166 // allowing abstract use of LogicalBuffer or HloValue. After those migrations 167 // are complete, this class should be deleted (b/78906445). Because we plan to 168 // delete LogicalBuffer and this class, we don't refactor all the shared 169 // features from LogicalBuffer and HloValue into this class. 170 Id id_ : 62; 171 bool is_array_ : 1; 172 bool is_tuple_ : 1; 173 Color color_ = kInvalidColor; 174 }; 175 176 std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); 177 178 } // namespace xla 179 180 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 181