1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 18 19 #include <ostream> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/service/hlo_value.h" 24 #include "tensorflow/compiler/xla/shape_tree.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace xla { 30 31 // A container which can hold one or more HloValues. An HLO buffer abstractly 32 // represents the allocation which HLO instructions write into and read 33 // from. Generally there is a one-to-one correspondence between HloBuffers and 34 // HloValue where each HloValue in the module is held in a unique HloBuffer. An 35 // exception is the while instruction which updates the loop state in-place. In 36 // this case, we have a single HloBuffer for each HloPosition in the loop state, 37 // but multiple HloValues. For example: 38 // 39 // %init = ... 40 // %while = While(%init, body, condition) 41 // 42 // body: 43 // %body_param = Param(0) 44 // ... 45 // %body_root = ... 46 // 47 // condition: 48 // %cond_param = Param(0) 49 // ... 50 // 51 // For simplicity, assume that %while is array-shaped. In this case, we have a 52 // single HloBuffer which holds the following HloValues: HloValue{%init}, 53 // HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and 54 // HloValue{%cond_param}. 55 // 56 // HloBuffers may appear at different HloPositions in the module mirroring the 57 // same property of HloValues. For example: 58 // 59 // %sub = Sub(...) 60 // %add = Add(...) 61 // %tuple = Tuple(%add, %sub) 62 // %gte = GetTupleElement(%tuple, 0) 63 // 64 // In this case, the HloBuffer containing %add appears at the following 65 // positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and 66 // HloPosition{%gte, {}}. 67 // 68 // Different HloPositions which share the same HloBuffer indicate mandatory 69 // aliasing in the HLO module. These positions must share the same memory 70 // allocation for correctness (the backends rely on this property). This differs 71 // from incidental aliasing introduced by memory reuse in BufferAssignment where 72 // different instructions may happen to get the same allocation. 73 class HloBuffer { 74 public: 75 using Id = int64; 76 77 // Predicate comparing HloBuffers by increasing id, useful for std::sort. IdLessThan(const HloBuffer * a,const HloBuffer * b)78 static bool IdLessThan(const HloBuffer* a, const HloBuffer* b) { 79 return a->id() < b->id(); 80 } 81 82 // Predicate comparing HloBuffers by equal id, useful for std::unique. IdEqual(const HloBuffer * a,const HloBuffer * b)83 static bool IdEqual(const HloBuffer* a, const HloBuffer* b) { 84 return a->id() == b->id(); 85 } 86 HloBuffer(Id id,absl::Span<const HloValue * const> values)87 HloBuffer(Id id, absl::Span<const HloValue* const> values) 88 : id_(id), values_(values.begin(), values.end()) {} 89 90 // Return the unique identifier for this HloBuffer. id()91 Id id() const { return id_; } 92 93 // Return all values contained in this buffer. values()94 const std::vector<const HloValue*>& values() const { return values_; } 95 96 // Memory space color. Used to indicate the memory space that the hlo buffer 97 // needs to live in. color()98 BufferValue::Color color() const { 99 // Invariant: All values in the buffer should have the same color. 100 BufferValue::Color result = values()[0]->color(); 101 for (const HloValue* value : values()) { 102 DCHECK_EQ(result, value->color()); 103 } 104 return result; 105 } 106 107 // Return the unique HLO value in the buffer. CHECK fails if the buffer does 108 // not contain exactly one value. GetUniqueValue()109 const HloValue& GetUniqueValue() const { 110 CHECK_EQ(values_.size(), 1); 111 return *values_[0]; 112 } 113 114 std::vector<HloPosition> ComputePositions() const; 115 116 string ToString() const; 117 118 bool operator==(const HloBuffer& other) const; 119 bool operator!=(const HloBuffer& other) const { return !(*this == other); } 120 121 private: 122 // Unique identifier for this HloBuffer. 123 Id id_; 124 125 // The set of values contained in this buffer. Vector contains no duplicates 126 // and is sorted stably by HloValue::Id. 127 std::vector<const HloValue*> values_; 128 }; 129 130 std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); 131 132 } // namespace xla 133 134 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ 135