• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/util/device_name_utils.h"
22 
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24 
25 namespace tensorflow {
26 namespace tensorrt {
27 namespace segment {
28 
29 // ClusterBatchSize is a data structure to record the batch size we have seen
30 // for a cluster during segmentation.
31 //
32 // With the help of shape inference, all the dynamic batch sizes are converted
33 // to a negative integer number.
34 // If the number is -1, then nothing is known about the dynamic batch size.
35 // Ideally, we should not put nodes with -1 batch size into the same cluster,
36 // as they will likely have different batch sizes at runtime. However, we
37 // currently treat -1 as an equivalent class for simple implementation. We may
38 // need to revise this if it causes performance issues.
39 // If the number is strictly less than -1, then it represents a equivalent
40 // class. It is infered that all the nodes with the same equivalent class
41 // (strictly less than -1) shall have the same batch size at runtime.
42 //
43 // When constructing clusters for implicit batch mode, we support both
44 // dynamic batch sizes and static batch sizes. As all the nodes inside the same
45 // cluster shall have the same batch size at runtime, we restrict nodes inside a
46 // cluster to either have the same dynamic batch size equivalent class or the
47 // same static batch size value.
48 //
49 // Besides, all the nodes with an annotated max batch size inside the same
50 // cluster shall have the same annotated max batch size. (It is allowed if
51 // part or all the nodes inside the cluster doesn't have annotated max batch
52 // size). Static batch sizes are treated as max batch size annotations. The
53 // converter max batch size is used for an OP with a dynamic batch size and no
54 // annotated max batch size.
55 //
56 // cluster:  a = a1[1,3] + a1[1,3]
57 // ClusterBatchSize: batch_size_ = 1
58 //                   max_batch_size_ = 1
59 //
60 // cluster:  b = b1[-1,3] + b2[-1, 3]
61 // ClusterBatchSize: batch_size_ = -1
62 //                   max_batch_size_ = null
63 //
64 // cluster:  c = c1[-2,3] + c2[-2, 3](max_batch_size=100)
65 // ClusterBatchSize: batch_size_ = -2
66 //                   max_batch_size_ = 100
67 //
68 // When constructing cluster for explicit batch mode, all ClusterBatchSize is
69 // irrelevant.
70 //
71 
72 class ClusterBatchSize {
73  public:
74   ClusterBatchSize();
75 
76   bool operator==(const ClusterBatchSize& other);
77   bool operator!=(const ClusterBatchSize& other) { return !(*this == other); }
78 
79   // Sets the batch size assuming that the object doesn't have a batch size yet:
80   //   A non-negative input representing a static batch size value.
81   //   A negative input representing a dynamic batch size equivalent class.
82   ClusterBatchSize& SetBatchSize(int batch_size);
83   bool HasBatchSize() const;
84   int GetBatchSize() const;
85 
86   // Sets the max batch size assuming that the object doesn't have a max batch
87   // size yet.
88   ClusterBatchSize& SetMaxBatchSize(int max_batch_size);
89   absl::optional<int> GetOptionalMaxBatchSize() const;
90 
91   // Merge `other` into the current ClusterBatchSize if the two are not
92   // conflicting. Two ClusterBatchSizes are conflicting iff they both have a
93   // value and their values are different.
94   bool MergeIfCompatible(const ClusterBatchSize& other);
95 
96   // Returns a string for the batch size and the annotated max batch size.
97   // For the batch size:
98   //   If the object has a static batch size, return a string representing a
99   //     non-negative integer.
100   //   If the object has a dynamic batch size, return a string representing a
101   //     negative integer as an equivalent class.
102   //   If the object doesn't have a batch size yet, return "?".
103   // For the annotated max batch size:
104   //   If the cluster has annotated max batch size in at least one of the nodes,
105   //     return a string representing the annotated max batch size. Otherwise,
106   //     return "?".
107   std::string ToString() const;
108 
109  private:
110   ClusterBatchSize& SetBatchSize(const absl::optional<int>& batch_size);
111   ClusterBatchSize& SetMaxBatchSize(const absl::optional<int>& batch_size);
112 
113   absl::optional<int> batch_size_;
114   absl::optional<int> max_batch_size_;
115 };
116 
117 inline std::ostream& operator<<(std::ostream& os,
118                                 const ClusterBatchSize& batch_size) {
119   return os << batch_size.ToString();
120 }
121 
122 // Represents the accumulated properties of a cluster during segmentation,
123 // including information about batch size and device assignment. Clusters shall
124 // have compatible properties in order to be merged together.
125 class ClusterProperty {
126  public:
ClusterProperty()127   ClusterProperty() {}
128   ClusterProperty(const ClusterBatchSize& batch_size,
129                   const DeviceNameUtils::ParsedName& device_name);
130 
131   // Returns the batch size of the cluster and compresses the path from this
132   // object to the root object.
BatchSize()133   const ClusterBatchSize& BatchSize() const { return batch_size_; }
134 
135   // Returns the device name of the cluster and compresses the path from this
136   // object to the root object.
DeviceName()137   const DeviceNameUtils::ParsedName& DeviceName() const { return device_name_; }
138 
139   Status Merge(const ClusterProperty& other);
140 
141  private:
142   ClusterBatchSize batch_size_;
143   DeviceNameUtils::ParsedName device_name_;
144 };
145 
146 // Represents a disjoint set of copyable value with type T and accumulated
147 // property of the values with type P. Most of the methods in this class are
148 // side-effecting as they also compress the path from the object to the parent
149 // of its containing set.
150 template <typename T, typename P = ClusterProperty>
151 class UnionFind {
152  public:
UnionFind()153   UnionFind() : size_(1), parent_(nullptr) {}
UnionFind(const T & v,const P & p)154   UnionFind(const T& v, const P& p)
155       : size_(1), parent_(nullptr), value_(v), property_(p) {}
UnionFind(const T & v,P && p)156   UnionFind(const T& v, P&& p)
157       : size_(1), parent_(nullptr), value_(v), property_(p) {}
158 
159   // Returns the number of elements in the set and compresses the path from
160   // this object to the root of the set.
Size()161   int Size() { return FindRoot()->size_; }
162 
163   // Returns the accumulated property of all the elements in the set and
164   // compresses the path from this object to the root of the set.
Property()165   const P& Property() { return FindRoot()->property_; }
166 
167   // Merges this set with 'other'. This updates the size_ and property_ of the
168   // set. The size_ and property_ of 'other' becomes inaccessible as only the
169   // size_ and property_ of the root of the set is accessible.
170   Status Merge(UnionFind* other);
171 
172   // Retrieves the value for the root of the set.
ParentValue()173   const T& ParentValue() { return FindRoot()->value_; }
174 
175   // Returns the value for the object.
Value()176   const T& Value() const { return value_; }
177 
178  private:
179   // Returns the root object for the set and compresses the path from this
180   // object to the root object.
181   UnionFind* FindRoot();
182 
183   int size_;
184   UnionFind* parent_;
185   T value_;
186   P property_;
187 };
188 
189 template <typename T, typename P>
Merge(UnionFind * other)190 Status UnionFind<T, P>::Merge(UnionFind* other) {
191   UnionFind<T>* a = FindRoot();
192   UnionFind<T>* b = other->FindRoot();
193   if (a == b) return Status::OK();
194 
195   P merged_property(a->property_);
196   TF_RETURN_IF_ERROR(merged_property.Merge(b->property_));
197   b->parent_ = a;
198   a->size_ += b->size_;
199   a->property_ = std::move(merged_property);
200   return Status::OK();
201 }
202 
203 template <typename T, typename P>
FindRoot()204 UnionFind<T, P>* UnionFind<T, P>::FindRoot() {
205   if (!parent_) return this;
206   // Path compression: update intermediate nodes to point to the root of the
207   // equivalence class.
208   parent_ = parent_->FindRoot();
209   return parent_;
210 }
211 
212 }  // namespace segment
213 }  // namespace tensorrt
214 }  // namespace tensorflow
215 
216 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
217 
218 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
219