• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/tensor_flag_utils.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 
21 namespace tensorflow {
22 namespace tensor_flag_utils {
23 
ValidateSparseMatrixShardingConfig(const Tensor & config)24 Status ValidateSparseMatrixShardingConfig(const Tensor& config) {
25   if (TensorShapeUtils::IsScalar(config.shape())) {
26     const float scalar_config = config.template scalar<float>()();
27     if (0 < scalar_config && scalar_config <= 1.0) {
28       return Status::OK();
29     }
30     return Status(
31         error::INVALID_ARGUMENT,
32         absl::StrCat("Expected config to be in range (0, 1] but instead found ",
33                      scalar_config));
34   }
35   if (!TensorShapeUtils::IsMatrix(config.shape())) {
36     return Status(error::INVALID_ARGUMENT,
37                   absl::StrCat("Expected config to be either scalar or matrix "
38                                "but instead found tensor of rank ",
39                                config.dims()));
40   }
41   if (config.dim_size(1) != 3) {
42     return Status(
43         error::INVALID_ARGUMENT,
44         absl::StrCat(
45             "Expected config matrix to have dim(1) = 3 but instead found ",
46             config.dim_size(1)));
47   }
48 
49   auto config_matrix = config.matrix<float>();
50   for (int i = 0; i < config.dim_size(0); ++i) {
51     if (0 > config_matrix(i, 0)) {
52       return errors::InvalidArgument(
53           "First column of fraction_rows_per_thread_config "
54           "should "
55           "have non-negative values but found ",
56           config_matrix(i, 0), " in row ", i);
57     }
58     if (0 > config_matrix(i, 1)) {
59       return errors::InvalidArgument(
60           "Second column of fraction_rows_per_thread_config "
61           "should "
62           "have non-negative values but found ",
63           config_matrix(i, 1), " in row ", i);
64     }
65     if (!(0 < config_matrix(i, 2) && config_matrix(i, 2) <= 1)) {
66       return errors::InvalidArgument(
67           "Last column of fraction_rows_per_thread_config should "
68           "have values in the range (0, 1] but found ",
69           config_matrix(i, 2), " in row ", i);
70     }
71   }
72   return Status::OK();
73 }
74 
75 template <typename MatrixType, typename K>
FindConfigValueForKey(const typename TTypes<MatrixType>::ConstMatrix & config_mat,const std::pair<K,K> & key)76 MatrixType FindConfigValueForKey(
77     const typename TTypes<MatrixType>::ConstMatrix& config_mat,
78     const std::pair<K, K>& key) {
79   const int last_row_index = config_mat.dimension(0) - 1;
80   for (int i = 0; i < last_row_index; ++i) {
81     if (key.first >= config_mat(i, 0) && key.second >= config_mat(i, 1)) {
82       return config_mat(i, 2);
83     }
84   }
85   return config_mat(last_row_index, 2);
86 }
87 
ValidateScalarQuantityShardingConfig(const Tensor & config)88 Status ValidateScalarQuantityShardingConfig(const Tensor& config) {
89   if (TensorShapeUtils::IsScalar(config.shape())) {
90     const float scalar_config = config.template scalar<float>()();
91     if (0 < scalar_config && scalar_config <= 1.0) {
92       return Status::OK();
93     }
94     return Status(
95         error::INVALID_ARGUMENT,
96         absl::StrCat("Expected config to be in range (0, 1] but instead found ",
97                      scalar_config));
98   }
99   if (!TensorShapeUtils::IsMatrix(config.shape())) {
100     return Status(error::INVALID_ARGUMENT,
101                   absl::StrCat("Expected config to be either scalar or matrix "
102                                "but instead found tensor of rank ",
103                                config.dims()));
104   }
105   if (config.dim_size(1) != 2) {
106     return Status(
107         error::INVALID_ARGUMENT,
108         absl::StrCat(
109             "Expected config matrix to have dim(1) = 2 but instead found ",
110             config.dim_size(1)));
111   }
112 
113   auto config_matrix = config.matrix<float>();
114   for (int i = 0; i < config.dim_size(0); ++i) {
115     if (0 > config_matrix(i, 0)) {
116       return errors::InvalidArgument(
117           "First column of fraction_rows_per_thread_config "
118           "should "
119           "have non-negative values but found ",
120           config_matrix(i, 0), " in row ", i);
121     }
122     if (!(0 < config_matrix(i, 1) && config_matrix(i, 1) <= 1)) {
123       return errors::InvalidArgument(
124           "Last column of fraction_rows_per_thread_config should "
125           "have values in the range (0, 1] but found ",
126           config_matrix(i, 1), " in row ", i);
127     }
128   }
129   return Status::OK();
130 }
131 
132 template <typename MatrixType, typename K>
FindConfigValueForKey(const typename TTypes<MatrixType>::ConstMatrix & config_mat,const K key)133 MatrixType FindConfigValueForKey(
134     const typename TTypes<MatrixType>::ConstMatrix& config_mat, const K key) {
135   const int last_row_index = config_mat.dimension(0) - 1;
136   for (int i = 0; i < last_row_index; ++i) {
137     if (key >= config_mat(i, 0)) {
138       return config_mat(i, 1);
139     }
140   }
141   return config_mat(last_row_index, 1);
142 }
143 
144 template <typename Tindices>
GetLinearBucket(const Tindices value,const Tindices bucket_size)145 Tindices GetLinearBucket(const Tindices value, const Tindices bucket_size) {
146   const Tindices next_multiple_of_bucket_size =
147       (value + bucket_size - 1) / bucket_size * bucket_size;
148   return next_multiple_of_bucket_size - (bucket_size - 1);
149 }
150 
151 template <typename Tindices>
GetPowerBucket(const Tindices value,const Tindices bucket_size)152 Tindices GetPowerBucket(const Tindices value, const Tindices bucket_size) {
153   if (bucket_size == 1) {
154     return 1;
155   }
156   return std::pow(bucket_size, std::floor(std::log(bucket_size * (value - 1)) /
157                                           std::log(bucket_size)) -
158                                    1) +
159          1;
160 }
161 
162 #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex)                           \
163   template float FindConfigValueForKey<float, TypeIndex>(                   \
164       const TTypes<float>::ConstMatrix& config_mat,                         \
165       const std::pair<TypeIndex, TypeIndex>& key);                          \
166   template float FindConfigValueForKey<float, TypeIndex>(                   \
167       const TTypes<float>::ConstMatrix& config_mat, const TypeIndex key);   \
168   template int64 FindConfigValueForKey<int64, TypeIndex>(                   \
169       const TTypes<int64>::ConstMatrix& config_mat, const TypeIndex key);
170 
171 REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
172 REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
173 REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
174 REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
175 REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
176 REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
177 
178 template int32 GetLinearBucket(const int32 value, const int32 bucket_size);
179 
180 template int64 GetLinearBucket(const int64 value, const int64 bucket_size);
181 
182 template int32 GetPowerBucket(const int32 value, const int32 bucket_size);
183 
184 template int64 GetPowerBucket(const int64 value, const int64 bucket_size);
185 
186 }  // namespace tensor_flag_utils
187 }  // namespace tensorflow
188