• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
17 
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
20 
21 #if GOOGLE_CUDA && GOOGLE_TENSORRT
22 
23 namespace tensorflow {
24 namespace tensorrt {
25 namespace segment {
26 
27 namespace {
28 template <typename T>
CheckIfCompatible(const absl::optional<T> & a,const absl::optional<T> & b)29 inline bool CheckIfCompatible(const absl::optional<T>& a,
30                               const absl::optional<T>& b) {
31   if (a.has_value() && b.has_value()) {
32     return *a == *b;
33   }
34   return true;
35 }
36 
37 template <typename T>
UnifyValues(absl::optional<T> & a,absl::optional<T> & b)38 inline bool UnifyValues(absl::optional<T>& a, absl::optional<T>& b) {
39   if (a.has_value()) {
40     b = a;
41   } else {
42     a = b;
43   }
44   return true;
45 }
46 
47 template <typename T>
MergeCompatible(const absl::optional<T> & a,const absl::optional<T> & b)48 inline absl::optional<T> MergeCompatible(const absl::optional<T>& a,
49                                          const absl::optional<T>& b) {
50   DCHECK(CheckIfCompatible(a, b));
51   return a.has_value() ? a : b;
52 }
53 
54 }  // namespace
55 
ClusterBatchSize()56 ClusterBatchSize::ClusterBatchSize()
57     : batch_size_(absl::nullopt), max_batch_size_(absl::nullopt) {}
58 
operator ==(const ClusterBatchSize & other)59 bool ClusterBatchSize::operator==(const ClusterBatchSize& other) {
60   return batch_size_ == other.batch_size_ &&
61          max_batch_size_ == other.max_batch_size_;
62 }
63 
SetBatchSize(int batch_size)64 ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) {
65   SetBatchSize(static_cast<absl::optional<int>>(batch_size));
66   return *this;
67 }
68 
SetBatchSize(const absl::optional<int> & batch_size)69 ClusterBatchSize& ClusterBatchSize::SetBatchSize(
70     const absl::optional<int>& batch_size) {
71   batch_size_ = MergeCompatible<int>(batch_size_, batch_size);
72   if (batch_size_.has_value() && batch_size_.value() >= 0) {
73     SetMaxBatchSize(batch_size_);
74   }
75   return *this;
76 }
77 
HasBatchSize() const78 bool ClusterBatchSize::HasBatchSize() const { return batch_size_.has_value(); }
79 
GetBatchSize() const80 int ClusterBatchSize::GetBatchSize() const {
81   DCHECK(HasBatchSize());
82   return batch_size_.value();
83 }
84 
SetMaxBatchSize(int max_batch_size)85 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(int max_batch_size) {
86   SetBatchSize(static_cast<absl::optional<int>>(max_batch_size));
87   return *this;
88 }
89 
SetMaxBatchSize(const absl::optional<int> & max_batch_size)90 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(
91     const absl::optional<int>& max_batch_size) {
92   max_batch_size_ = MergeCompatible<int>(max_batch_size_, max_batch_size);
93   return *this;
94 }
95 
GetOptionalMaxBatchSize() const96 absl::optional<int> ClusterBatchSize::GetOptionalMaxBatchSize() const {
97   return max_batch_size_;
98 }
99 
MergeIfCompatible(const ClusterBatchSize & other)100 bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) {
101   if (!CheckIfCompatible(batch_size_, other.batch_size_) ||
102       !CheckIfCompatible(max_batch_size_, other.max_batch_size_)) {
103     return false;
104   }
105 
106   SetBatchSize(other.batch_size_);
107   SetMaxBatchSize(other.max_batch_size_);
108   return true;
109 }
110 
ToString() const111 string ClusterBatchSize::ToString() const {
112   string s;
113   const auto append_optional_num = [&](const absl::optional<int>& num) {
114     if (num.has_value()) {
115       absl::StrAppendFormat(&s, "%d", num.value());
116     } else {
117       absl::StrAppendFormat(&s, "?");
118     }
119   };
120   absl::StrAppendFormat(&s, "batch_size=");
121   append_optional_num(batch_size_);
122   absl::StrAppendFormat(&s, ", max_batch_size=");
123   append_optional_num(max_batch_size_);
124   return s;
125 }
126 
ClusterProperty(const ClusterBatchSize & batch_size,const DeviceNameUtils::ParsedName & device_name)127 ClusterProperty::ClusterProperty(const ClusterBatchSize& batch_size,
128                                  const DeviceNameUtils::ParsedName& device_name)
129     : batch_size_(batch_size), device_name_(device_name) {}
130 
Merge(const ClusterProperty & other)131 Status ClusterProperty::Merge(const ClusterProperty& other) {
132   ClusterBatchSize merged_batch_size(batch_size_);
133   if (!merged_batch_size.MergeIfCompatible(other.batch_size_)) {
134     return errors::Internal(
135         "trying to merge clusters with incompatible batch sizes.");
136   }
137 
138   absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
139       MergeIfCompatible(device_name_, other.device_name_);
140   if (!merged_device_name.has_value()) {
141     return errors::Internal(
142         "trying to merge clusters with incompatible device assignment.");
143   }
144 
145   batch_size_ = std::move(merged_batch_size);
146   device_name_ = std::move(merged_device_name.value());
147   return Status::OK();
148 }
149 
150 }  // namespace segment
151 }  // namespace tensorrt
152 }  // namespace tensorflow
153 
154 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
155