1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/tensor_flag_utils.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20
21 namespace tensorflow {
22 namespace tensor_flag_utils {
23
ValidateSparseMatrixShardingConfig(const Tensor & config)24 Status ValidateSparseMatrixShardingConfig(const Tensor& config) {
25 if (TensorShapeUtils::IsScalar(config.shape())) {
26 const float scalar_config = config.template scalar<float>()();
27 if (0 < scalar_config && scalar_config <= 1.0) {
28 return Status::OK();
29 }
30 return Status(
31 error::INVALID_ARGUMENT,
32 absl::StrCat("Expected config to be in range (0, 1] but instead found ",
33 scalar_config));
34 }
35 if (!TensorShapeUtils::IsMatrix(config.shape())) {
36 return Status(error::INVALID_ARGUMENT,
37 absl::StrCat("Expected config to be either scalar or matrix "
38 "but instead found tensor of rank ",
39 config.dims()));
40 }
41 if (config.dim_size(1) != 3) {
42 return Status(
43 error::INVALID_ARGUMENT,
44 absl::StrCat(
45 "Expected config matrix to have dim(1) = 3 but instead found ",
46 config.dim_size(1)));
47 }
48
49 auto config_matrix = config.matrix<float>();
50 for (int i = 0; i < config.dim_size(0); ++i) {
51 if (0 > config_matrix(i, 0)) {
52 return errors::InvalidArgument(
53 "First column of fraction_rows_per_thread_config "
54 "should "
55 "have non-negative values but found ",
56 config_matrix(i, 0), " in row ", i);
57 }
58 if (0 > config_matrix(i, 1)) {
59 return errors::InvalidArgument(
60 "Second column of fraction_rows_per_thread_config "
61 "should "
62 "have non-negative values but found ",
63 config_matrix(i, 1), " in row ", i);
64 }
65 if (!(0 < config_matrix(i, 2) && config_matrix(i, 2) <= 1)) {
66 return errors::InvalidArgument(
67 "Last column of fraction_rows_per_thread_config should "
68 "have values in the range (0, 1] but found ",
69 config_matrix(i, 2), " in row ", i);
70 }
71 }
72 return Status::OK();
73 }
74
75 template <typename MatrixType, typename K>
FindConfigValueForKey(const typename TTypes<MatrixType>::ConstMatrix & config_mat,const std::pair<K,K> & key)76 MatrixType FindConfigValueForKey(
77 const typename TTypes<MatrixType>::ConstMatrix& config_mat,
78 const std::pair<K, K>& key) {
79 const int last_row_index = config_mat.dimension(0) - 1;
80 for (int i = 0; i < last_row_index; ++i) {
81 if (key.first >= config_mat(i, 0) && key.second >= config_mat(i, 1)) {
82 return config_mat(i, 2);
83 }
84 }
85 return config_mat(last_row_index, 2);
86 }
87
ValidateScalarQuantityShardingConfig(const Tensor & config)88 Status ValidateScalarQuantityShardingConfig(const Tensor& config) {
89 if (TensorShapeUtils::IsScalar(config.shape())) {
90 const float scalar_config = config.template scalar<float>()();
91 if (0 < scalar_config && scalar_config <= 1.0) {
92 return Status::OK();
93 }
94 return Status(
95 error::INVALID_ARGUMENT,
96 absl::StrCat("Expected config to be in range (0, 1] but instead found ",
97 scalar_config));
98 }
99 if (!TensorShapeUtils::IsMatrix(config.shape())) {
100 return Status(error::INVALID_ARGUMENT,
101 absl::StrCat("Expected config to be either scalar or matrix "
102 "but instead found tensor of rank ",
103 config.dims()));
104 }
105 if (config.dim_size(1) != 2) {
106 return Status(
107 error::INVALID_ARGUMENT,
108 absl::StrCat(
109 "Expected config matrix to have dim(1) = 2 but instead found ",
110 config.dim_size(1)));
111 }
112
113 auto config_matrix = config.matrix<float>();
114 for (int i = 0; i < config.dim_size(0); ++i) {
115 if (0 > config_matrix(i, 0)) {
116 return errors::InvalidArgument(
117 "First column of fraction_rows_per_thread_config "
118 "should "
119 "have non-negative values but found ",
120 config_matrix(i, 0), " in row ", i);
121 }
122 if (!(0 < config_matrix(i, 1) && config_matrix(i, 1) <= 1)) {
123 return errors::InvalidArgument(
124 "Last column of fraction_rows_per_thread_config should "
125 "have values in the range (0, 1] but found ",
126 config_matrix(i, 1), " in row ", i);
127 }
128 }
129 return Status::OK();
130 }
131
132 template <typename MatrixType, typename K>
FindConfigValueForKey(const typename TTypes<MatrixType>::ConstMatrix & config_mat,const K key)133 MatrixType FindConfigValueForKey(
134 const typename TTypes<MatrixType>::ConstMatrix& config_mat, const K key) {
135 const int last_row_index = config_mat.dimension(0) - 1;
136 for (int i = 0; i < last_row_index; ++i) {
137 if (key >= config_mat(i, 0)) {
138 return config_mat(i, 1);
139 }
140 }
141 return config_mat(last_row_index, 1);
142 }
143
144 template <typename Tindices>
GetLinearBucket(const Tindices value,const Tindices bucket_size)145 Tindices GetLinearBucket(const Tindices value, const Tindices bucket_size) {
146 const Tindices next_multiple_of_bucket_size =
147 (value + bucket_size - 1) / bucket_size * bucket_size;
148 return next_multiple_of_bucket_size - (bucket_size - 1);
149 }
150
151 template <typename Tindices>
GetPowerBucket(const Tindices value,const Tindices bucket_size)152 Tindices GetPowerBucket(const Tindices value, const Tindices bucket_size) {
153 if (bucket_size == 1) {
154 return 1;
155 }
156 return std::pow(bucket_size, std::floor(std::log(bucket_size * (value - 1)) /
157 std::log(bucket_size)) -
158 1) +
159 1;
160 }
161
162 #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \
163 template float FindConfigValueForKey<float, TypeIndex>( \
164 const TTypes<float>::ConstMatrix& config_mat, \
165 const std::pair<TypeIndex, TypeIndex>& key); \
166 template float FindConfigValueForKey<float, TypeIndex>( \
167 const TTypes<float>::ConstMatrix& config_mat, const TypeIndex key); \
168 template int64 FindConfigValueForKey<int64, TypeIndex>( \
169 const TTypes<int64>::ConstMatrix& config_mat, const TypeIndex key);
170
171 REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
172 REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
173 REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
174 REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
175 REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
176 REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
177
178 template int32 GetLinearBucket(const int32 value, const int32 bucket_size);
179
180 template int64 GetLinearBucket(const int64 value, const int64 bucket_size);
181
182 template int32 GetPowerBucket(const int32 value, const int32 bucket_size);
183
184 template int64 GetPowerBucket(const int64 value, const int64 bucket_size);
185
186 } // namespace tensor_flag_utils
187 } // namespace tensorflow
188