1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
18
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/util/device_name_utils.h"
22
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24
25 namespace tensorflow {
26 namespace tensorrt {
27 namespace segment {
28
29 // ClusterBatchSize is a data structure to record the batch size we have seen
30 // for a cluster during segmentation.
31 //
32 // With the help of shape inference, all the dynamic batch sizes are converted
33 // to a negative integer number.
34 // If the number is -1, then nothing is known about the dynamic batch size.
35 // Ideally, we should not put nodes with -1 batch size into the same cluster,
36 // as they will likely have different batch sizes at runtime. However, we
37 // currently treat -1 as an equivalent class for simple implementation. We may
38 // need to revise this if it causes performance issues.
39 // If the number is strictly less than -1, then it represents a equivalent
40 // class. It is infered that all the nodes with the same equivalent class
41 // (strictly less than -1) shall have the same batch size at runtime.
42 //
43 // When constructing clusters for implicit batch mode, we support both
44 // dynamic batch sizes and static batch sizes. As all the nodes inside the same
45 // cluster shall have the same batch size at runtime, we restrict nodes inside a
46 // cluster to either have the same dynamic batch size equivalent class or the
47 // same static batch size value.
48 //
49 // Besides, all the nodes with an annotated max batch size inside the same
50 // cluster shall have the same annotated max batch size. (It is allowed if
51 // part or all the nodes inside the cluster doesn't have annotated max batch
52 // size). Static batch sizes are treated as max batch size annotations. The
53 // converter max batch size is used for an OP with a dynamic batch size and no
54 // annotated max batch size.
55 //
56 // cluster: a = a1[1,3] + a1[1,3]
57 // ClusterBatchSize: batch_size_ = 1
58 // max_batch_size_ = 1
59 //
60 // cluster: b = b1[-1,3] + b2[-1, 3]
61 // ClusterBatchSize: batch_size_ = -1
62 // max_batch_size_ = null
63 //
64 // cluster: c = c1[-2,3] + c2[-2, 3](max_batch_size=100)
65 // ClusterBatchSize: batch_size_ = -2
66 // max_batch_size_ = 100
67 //
68 // When constructing cluster for explicit batch mode, all ClusterBatchSize is
69 // irrelevant.
70 //
71
72 class ClusterBatchSize {
73 public:
74 ClusterBatchSize();
75
76 bool operator==(const ClusterBatchSize& other);
77 bool operator!=(const ClusterBatchSize& other) { return !(*this == other); }
78
79 // Sets the batch size assuming that the object doesn't have a batch size yet:
80 // A non-negative input representing a static batch size value.
81 // A negative input representing a dynamic batch size equivalent class.
82 ClusterBatchSize& SetBatchSize(int batch_size);
83 bool HasBatchSize() const;
84 int GetBatchSize() const;
85
86 // Sets the max batch size assuming that the object doesn't have a max batch
87 // size yet.
88 ClusterBatchSize& SetMaxBatchSize(int max_batch_size);
89 absl::optional<int> GetOptionalMaxBatchSize() const;
90
91 // Merge `other` into the current ClusterBatchSize if the two are not
92 // conflicting. Two ClusterBatchSizes are conflicting iff they both have a
93 // value and their values are different.
94 bool MergeIfCompatible(const ClusterBatchSize& other);
95
96 // Returns a string for the batch size and the annotated max batch size.
97 // For the batch size:
98 // If the object has a static batch size, return a string representing a
99 // non-negative integer.
100 // If the object has a dynamic batch size, return a string representing a
101 // negative integer as an equivalent class.
102 // If the object doesn't have a batch size yet, return "?".
103 // For the annotated max batch size:
104 // If the cluster has annotated max batch size in at least one of the nodes,
105 // return a string representing the annotated max batch size. Otherwise,
106 // return "?".
107 std::string ToString() const;
108
109 private:
110 ClusterBatchSize& SetBatchSize(const absl::optional<int>& batch_size);
111 ClusterBatchSize& SetMaxBatchSize(const absl::optional<int>& batch_size);
112
113 absl::optional<int> batch_size_;
114 absl::optional<int> max_batch_size_;
115 };
116
117 inline std::ostream& operator<<(std::ostream& os,
118 const ClusterBatchSize& batch_size) {
119 return os << batch_size.ToString();
120 }
121
122 // Represents the accumulated properties of a cluster during segmentation,
123 // including information about batch size and device assignment. Clusters shall
124 // have compatible properties in order to be merged together.
125 class ClusterProperty {
126 public:
ClusterProperty()127 ClusterProperty() {}
128 ClusterProperty(const ClusterBatchSize& batch_size,
129 const DeviceNameUtils::ParsedName& device_name);
130
131 // Returns the batch size of the cluster and compresses the path from this
132 // object to the root object.
BatchSize()133 const ClusterBatchSize& BatchSize() const { return batch_size_; }
134
135 // Returns the device name of the cluster and compresses the path from this
136 // object to the root object.
DeviceName()137 const DeviceNameUtils::ParsedName& DeviceName() const { return device_name_; }
138
139 Status Merge(const ClusterProperty& other);
140
141 private:
142 ClusterBatchSize batch_size_;
143 DeviceNameUtils::ParsedName device_name_;
144 };
145
146 // Represents a disjoint set of copyable value with type T and accumulated
147 // property of the values with type P. Most of the methods in this class are
148 // side-effecting as they also compress the path from the object to the parent
149 // of its containing set.
150 template <typename T, typename P = ClusterProperty>
151 class UnionFind {
152 public:
UnionFind()153 UnionFind() : size_(1), parent_(nullptr) {}
UnionFind(const T & v,const P & p)154 UnionFind(const T& v, const P& p)
155 : size_(1), parent_(nullptr), value_(v), property_(p) {}
UnionFind(const T & v,P && p)156 UnionFind(const T& v, P&& p)
157 : size_(1), parent_(nullptr), value_(v), property_(p) {}
158
159 // Returns the number of elements in the set and compresses the path from
160 // this object to the root of the set.
Size()161 int Size() { return FindRoot()->size_; }
162
163 // Returns the accumulated property of all the elements in the set and
164 // compresses the path from this object to the root of the set.
Property()165 const P& Property() { return FindRoot()->property_; }
166
167 // Merges this set with 'other'. This updates the size_ and property_ of the
168 // set. The size_ and property_ of 'other' becomes inaccessible as only the
169 // size_ and property_ of the root of the set is accessible.
170 Status Merge(UnionFind* other);
171
172 // Retrieves the value for the root of the set.
ParentValue()173 const T& ParentValue() { return FindRoot()->value_; }
174
175 // Returns the value for the object.
Value()176 const T& Value() const { return value_; }
177
178 private:
179 // Returns the root object for the set and compresses the path from this
180 // object to the root object.
181 UnionFind* FindRoot();
182
183 int size_;
184 UnionFind* parent_;
185 T value_;
186 P property_;
187 };
188
189 template <typename T, typename P>
Merge(UnionFind * other)190 Status UnionFind<T, P>::Merge(UnionFind* other) {
191 UnionFind<T>* a = FindRoot();
192 UnionFind<T>* b = other->FindRoot();
193 if (a == b) return Status::OK();
194
195 P merged_property(a->property_);
196 TF_RETURN_IF_ERROR(merged_property.Merge(b->property_));
197 b->parent_ = a;
198 a->size_ += b->size_;
199 a->property_ = std::move(merged_property);
200 return Status::OK();
201 }
202
203 template <typename T, typename P>
FindRoot()204 UnionFind<T, P>* UnionFind<T, P>::FindRoot() {
205 if (!parent_) return this;
206 // Path compression: update intermediate nodes to point to the root of the
207 // equivalence class.
208 parent_ = parent_->FindRoot();
209 return parent_;
210 }
211
212 } // namespace segment
213 } // namespace tensorrt
214 } // namespace tensorflow
215
216 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
217
218 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
219