1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Helpers for parsing tensors as runtime flags. 17 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ 18 #define TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ 19 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 namespace tensor_flag_utils { 29 30 // Converts tensor.vec<Tindices> to an std::vector<Tindices> object, appends 31 // the value num_nonzero_entries_in_sparse_mat, and returns the result. 32 template <typename Tindices> 33 std::vector<Tindices> ParseRowStartIndices( 34 const tensorflow::Tensor& tensor, 35 const Tindices num_nonzero_entries_in_sparse_mat); 36 37 // Returns Status::OK() if and only if config is a float scalar or a matrix with 38 // dimensions M x 3. If config is a scalar then config must be in the range 39 // [0, 1.0). If config is a matrix then config must have shape M x 3, all of 40 // its entries must be positive, and entries in the last column may not 41 // exceed 1.0. If config is a matrix then it may not be empty. 42 Status ValidateSparseMatrixShardingConfig(const Tensor& config); 43 44 // Returns Status::OK() if and only if config is a float scalar or a non-empty 45 // matrix with dimensions M x 2. 46 Status ValidateScalarQuantityShardingConfig(const Tensor& config); 47 48 // Returns the last entry of the first row in config_mat for which the first 49 // two entries are no smaller than the respective entries in key. If no such 50 // row exists then returns the last entry in the last row in config_mat. 51 // config_mat may not be empty. 52 template <typename MatrixType, typename K> 53 MatrixType FindConfigValueForKey( 54 const typename TTypes<MatrixType>::ConstMatrix& config_mat, 55 const std::pair<K, K>& key); 56 57 // Returns the last entry of the first row in config_mat for which the first 58 // two entries are no smaller than the respective entries in key. If no such 59 // row exists then returns the last entry in the last row in config_mat. 60 // config_mat may not be empty. 61 template <typename MatrixType, typename K> 62 MatrixType FindConfigValueForKey( 63 const typename TTypes<MatrixType>::ConstMatrix& config_mat, const K key); 64 65 // Returns largest multiple of bucket_size less than value. 66 // Expects 1 <= bucket_size <= value. 67 template <typename Tindices> 68 Tindices GetLinearBucket(const Tindices value, const Tindices bucket_size); 69 70 // Returns the largest power of bucket_size less than value. 71 // Expects 1 <= bucket_size <= value. If bucket_size = 1, returns 1. 72 template <typename Tindices> 73 Tindices GetPowerBucket(const Tindices value, const Tindices bucket_size); 74 75 } // namespace tensor_flag_utils 76 } // namespace tensorflow 77 78 #endif // TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ 79