• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helpers for parsing tensors as runtime flags.
17 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_
18 #define TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_
19 
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 namespace tensor_flag_utils {
29 
30 // Converts tensor.vec<Tindices> to an std::vector<Tindices> object, appends
31 // the value num_nonzero_entries_in_sparse_mat, and returns the result.
32 template <typename Tindices>
33 std::vector<Tindices> ParseRowStartIndices(
34     const tensorflow::Tensor& tensor,
35     const Tindices num_nonzero_entries_in_sparse_mat);
36 
37 // Returns Status::OK() if and only if config is a float scalar or a matrix with
38 // dimensions M x 3. If config is a scalar then config must be in the range
39 // [0, 1.0). If config is a matrix then config must have shape M x 3, all of
40 // its entries must be positive, and entries in the last column may not
41 // exceed 1.0. If config is a matrix then it may not be empty.
42 Status ValidateSparseMatrixShardingConfig(const Tensor& config);
43 
44 // Returns Status::OK() if and only if config is a float scalar or a non-empty
45 // matrix with dimensions M x 2.
46 Status ValidateScalarQuantityShardingConfig(const Tensor& config);
47 
48 // Returns the last entry of the first row in config_mat for which the first
49 // two entries are no smaller than the respective entries in key. If no such
50 // row exists then returns the last entry in the last row in config_mat.
51 // config_mat may not be empty.
52 template <typename MatrixType, typename K>
53 MatrixType FindConfigValueForKey(
54     const typename TTypes<MatrixType>::ConstMatrix& config_mat,
55     const std::pair<K, K>& key);
56 
57 // Returns the last entry of the first row in config_mat for which the first
58 // two entries are no smaller than the respective entries in key. If no such
59 // row exists then returns the last entry in the last row in config_mat.
60 // config_mat may not be empty.
61 template <typename MatrixType, typename K>
62 MatrixType FindConfigValueForKey(
63     const typename TTypes<MatrixType>::ConstMatrix& config_mat, const K key);
64 
65 // Returns largest multiple of bucket_size less than value.
66 // Expects 1 <= bucket_size <= value.
67 template <typename Tindices>
68 Tindices GetLinearBucket(const Tindices value, const Tindices bucket_size);
69 
70 // Returns the largest power of bucket_size less than value.
71 // Expects 1 <= bucket_size <= value. If bucket_size = 1, returns 1.
72 template <typename Tindices>
73 Tindices GetPowerBucket(const Tindices value, const Tindices bucket_size);
74 
75 }  // namespace tensor_flag_utils
76 }  // namespace tensorflow
77 
78 #endif  // TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_
79