1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
17
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
20
21 #if GOOGLE_CUDA && GOOGLE_TENSORRT
22
23 namespace tensorflow {
24 namespace tensorrt {
25 namespace segment {
26
27 namespace {
28 template <typename T>
CheckIfCompatible(const absl::optional<T> & a,const absl::optional<T> & b)29 inline bool CheckIfCompatible(const absl::optional<T>& a,
30 const absl::optional<T>& b) {
31 if (a.has_value() && b.has_value()) {
32 return *a == *b;
33 }
34 return true;
35 }
36
37 template <typename T>
UnifyValues(absl::optional<T> & a,absl::optional<T> & b)38 inline bool UnifyValues(absl::optional<T>& a, absl::optional<T>& b) {
39 if (a.has_value()) {
40 b = a;
41 } else {
42 a = b;
43 }
44 return true;
45 }
46
47 template <typename T>
MergeCompatible(const absl::optional<T> & a,const absl::optional<T> & b)48 inline absl::optional<T> MergeCompatible(const absl::optional<T>& a,
49 const absl::optional<T>& b) {
50 DCHECK(CheckIfCompatible(a, b));
51 return a.has_value() ? a : b;
52 }
53
54 } // namespace
55
ClusterBatchSize()56 ClusterBatchSize::ClusterBatchSize()
57 : batch_size_(absl::nullopt), max_batch_size_(absl::nullopt) {}
58
operator ==(const ClusterBatchSize & other)59 bool ClusterBatchSize::operator==(const ClusterBatchSize& other) {
60 return batch_size_ == other.batch_size_ &&
61 max_batch_size_ == other.max_batch_size_;
62 }
63
SetBatchSize(int batch_size)64 ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) {
65 SetBatchSize(static_cast<absl::optional<int>>(batch_size));
66 return *this;
67 }
68
SetBatchSize(const absl::optional<int> & batch_size)69 ClusterBatchSize& ClusterBatchSize::SetBatchSize(
70 const absl::optional<int>& batch_size) {
71 batch_size_ = MergeCompatible<int>(batch_size_, batch_size);
72 if (batch_size_.has_value() && batch_size_.value() >= 0) {
73 SetMaxBatchSize(batch_size_);
74 }
75 return *this;
76 }
77
HasBatchSize() const78 bool ClusterBatchSize::HasBatchSize() const { return batch_size_.has_value(); }
79
GetBatchSize() const80 int ClusterBatchSize::GetBatchSize() const {
81 DCHECK(HasBatchSize());
82 return batch_size_.value();
83 }
84
SetMaxBatchSize(int max_batch_size)85 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(int max_batch_size) {
86 SetBatchSize(static_cast<absl::optional<int>>(max_batch_size));
87 return *this;
88 }
89
SetMaxBatchSize(const absl::optional<int> & max_batch_size)90 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(
91 const absl::optional<int>& max_batch_size) {
92 max_batch_size_ = MergeCompatible<int>(max_batch_size_, max_batch_size);
93 return *this;
94 }
95
GetOptionalMaxBatchSize() const96 absl::optional<int> ClusterBatchSize::GetOptionalMaxBatchSize() const {
97 return max_batch_size_;
98 }
99
MergeIfCompatible(const ClusterBatchSize & other)100 bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) {
101 if (!CheckIfCompatible(batch_size_, other.batch_size_) ||
102 !CheckIfCompatible(max_batch_size_, other.max_batch_size_)) {
103 return false;
104 }
105
106 SetBatchSize(other.batch_size_);
107 SetMaxBatchSize(other.max_batch_size_);
108 return true;
109 }
110
ToString() const111 string ClusterBatchSize::ToString() const {
112 string s;
113 const auto append_optional_num = [&](const absl::optional<int>& num) {
114 if (num.has_value()) {
115 absl::StrAppendFormat(&s, "%d", num.value());
116 } else {
117 absl::StrAppendFormat(&s, "?");
118 }
119 };
120 absl::StrAppendFormat(&s, "batch_size=");
121 append_optional_num(batch_size_);
122 absl::StrAppendFormat(&s, ", max_batch_size=");
123 append_optional_num(max_batch_size_);
124 return s;
125 }
126
ClusterProperty(const ClusterBatchSize & batch_size,const DeviceNameUtils::ParsedName & device_name)127 ClusterProperty::ClusterProperty(const ClusterBatchSize& batch_size,
128 const DeviceNameUtils::ParsedName& device_name)
129 : batch_size_(batch_size), device_name_(device_name) {}
130
Merge(const ClusterProperty & other)131 Status ClusterProperty::Merge(const ClusterProperty& other) {
132 ClusterBatchSize merged_batch_size(batch_size_);
133 if (!merged_batch_size.MergeIfCompatible(other.batch_size_)) {
134 return errors::Internal(
135 "trying to merge clusters with incompatible batch sizes.");
136 }
137
138 absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
139 MergeIfCompatible(device_name_, other.device_name_);
140 if (!merged_device_name.has_value()) {
141 return errors::Internal(
142 "trying to merge clusters with incompatible device assignment.");
143 }
144
145 batch_size_ = std::move(merged_batch_size);
146 device_name_ = std::move(merged_device_name.value());
147 return Status::OK();
148 }
149
150 } // namespace segment
151 } // namespace tensorrt
152 } // namespace tensorflow
153
154 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
155