1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 18 19 #include <functional> 20 #include <string> 21 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo.pb.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace xla { 33 34 // Abstract class describing a value used by one of the dataflow analyses - 35 // TuplePointsToAnalysis or HloDataflowAnalysis. 36 // TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. 37 // 38 // XLA arrays are trivially a single BufferValue. Tuples are made up of more 39 // than one BufferValue: an BufferValue for the pointer vector, and an 40 // BufferValue for each child element. 41 // 42 // Every BufferValue is defined by a particular instruction and most 43 // instructions define only a single BufferValue. Instructions which define a 44 // single BufferValue include array-shaped instructions such as Add but also 45 // includes Tuple-shaped instructions such as Tuple. The Tuple instruction 46 // defines a single BufferValue which is a vector of pointers to the values 47 // containing the Tuple instruction's operands. Though the result of the Tuple 48 // instruction includes multiple values only the top-level BufferValue (the 49 // vector of pointers) is defined by the Tuple instruction. The values 50 // containing the tuple elements are defined by earlier instructions, usually 51 // the operands of the Tuple instruction. 52 // 53 // Instructions which construct both the tuple *and* the tuple elements define 54 // more than one BufferValue. This includes (at least) tuple-shaped Constant, 55 // Parameter, Infeed and While instructions. These tuple-shaped instructions do 56 // not assemble a tuple from existing BufferValues like the Tuple instruction 57 // does, but rather define all the BufferValues in the tuple. 58 // 59 // Some instructions, such as Bitcast, define no buffers. These instructions 60 // simply forward buffers from their operands. 61 // 62 // The BufferValue object describes which HLO instruction defines a buffer and 63 // where within that instruction's output shape the buffer is defined. The 64 // location within the output shape is indicated by BufferValue::index() which 65 // is defined identically to the index used in ShapeUtil::GetSubshape(). 66 // Examples: 67 // 68 // %add = Add(%foo, %bar) 69 // %tuple_constant = Constant({1, {42, 43}}) 70 // 71 // %add defines a single array-shaped buffer BufferValue(%add, {}) which holds 72 // the array result of the add operation. The nested-tuple-shaped 73 // %tuple_constant defines 5 buffers described by the following BufferValue 74 // objects: 75 // 76 // BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of 77 // // pointers to BufferValues at 78 // // indices {0} and {1} 79 // BufferValue(%tuple_constant, {0}) // Holds value "1" 80 // BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of 81 // // pointers to BufferValues at 82 // // indices {1, 0} and {1, 1} 83 // BufferValue(%tuple_constant, {1, 0}) // Holds value "42" 84 // BufferValue(%tuple_constant, {1, 1}) // Holds value "43" 85 86 class BufferValue { 87 public: 88 using Color = int64; 89 90 // Id is a unique identifier for the BufferValue to facilitate efficient 91 // collections of BufferValues with stable iteration order. 92 using Id = int64; 93 94 // Functions which return the size and alignment of a logical buffer in bytes. 95 using SizeFunction = std::function<int64(const BufferValue&)>; 96 using AlignmentFunction = std::function<int64(BufferValue::Color)>; 97 98 virtual ~BufferValue(); 99 id()100 Id id() const { return id_; } 101 102 // Return the instruction that defines the buffer. 103 virtual HloInstruction* instruction() const = 0; 104 105 // Return the index within the output of the instruction where the buffer is 106 // defined. Index used defined as in ShapeUtil::GetSubshape() 107 virtual const ShapeIndex& index() const = 0; 108 109 // Return the color of the BufferValue. Differently colored buffers can not be 110 // parts of the same allocation. 111 ABSL_DEPRECATED("Use Layout::memory_space instead.") color()112 Color color() const { 113 CHECK_NE(color_, kInvalidColor) 114 << "Should not query the color of a buffer that was never colored"; 115 return color_; 116 } 117 118 ABSL_DEPRECATED("Use Layout::memory_space instead.") set_color(Color color)119 void set_color(Color color) { 120 CHECK_NE(color, kInvalidColor) 121 << "Should not set the color of a buffer to the invalid color"; 122 color_ = color; 123 } 124 125 ABSL_DEPRECATED("Use Layout::memory_space instead.") has_color()126 bool has_color() const { return color_ != kInvalidColor; } 127 128 // Return the shape of the buffer. This reference points into the shape field 129 // of the instruction defining the buffer. Therefore, the returned shape will 130 // contain the layout of instruction, if any. 131 virtual const Shape& shape() const = 0; 132 133 // Returns true if this buffer is the top-level output buffer of the defining 134 // HLO instruction. This is equivalent to index == {}. IsTopLevel()135 bool IsTopLevel() const { return index().empty(); } 136 137 // Whether this buffer contains a tuple. IsTuple()138 bool IsTuple() const { return is_tuple_; } 139 140 // Whether this buffer contains an array. IsArray()141 bool IsArray() const { return is_array_; } 142 143 // operator< is required for std::set. 144 bool operator<(const BufferValue& other) const { return id_ < other.id_; } 145 146 bool operator==(const BufferValue& other) const { return id_ == other.id_; } 147 bool operator!=(const BufferValue& other) const { return id_ != other.id_; } 148 149 virtual string ToString() const = 0; 150 151 // TODO(lauj) rename LogicalBufferProto to BufferValueProto. 152 LogicalBufferProto ToProto(const SizeFunction& size_fn) const; 153 154 // Returns the LogicalBufferProto::Location that serializes the given 155 // instruction and index. 156 static LogicalBufferProto::Location ToLocationProto( 157 const HloInstruction& instruction, const ShapeIndex& index); 158 159 const Color kInvalidColor = -1; 160 161 protected: 162 BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); 163 164 private: 165 // The defining instruction and index are not stored here; they can be found 166 // in the LogicalBuffer and HloValue subclasses. This class exists only to 167 // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by 168 // allowing abstract use of LogicalBuffer or HloValue. After those migrations 169 // are complete, this class should be deleted (b/78906445). Because we plan to 170 // delete LogicalBuffer and this class, we don't refactor all the shared 171 // features from LogicalBuffer and HloValue into this class. 172 Id id_ : 62; 173 bool is_array_ : 1; 174 bool is_tuple_ : 1; 175 Color color_ = kInvalidColor; 176 }; 177 178 std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); 179 180 } // namespace xla 181 182 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ 183