• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/concat_lib_cpu.h"
19 #include <vector>
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/concat_lib.h"
22 
23 namespace tensorflow {
24 
25 namespace {
26 template <typename T>
27 struct MemCpyCopier {
Copytensorflow::__anon143cf8a00111::MemCpyCopier28   inline void Copy(T* dst, const T* src, int input_index, size_t n) {
29     if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
30       memcpy(dst, src, n * sizeof(T));
31     } else {
32       for (size_t k = 0; k < n; ++k) {
33         *dst++ = *src++;
34       }
35     }
36   }
37 };
38 template <>
39 struct MemCpyCopier<ResourceHandle> {
Copytensorflow::__anon143cf8a00111::MemCpyCopier40   inline void Copy(ResourceHandle* dst, const ResourceHandle* src,
41                    int input_index, size_t n) {
42     for (size_t k = 0; k < n; ++k) {
43       *dst++ = *src++;
44     }
45   }
46 };
47 
48 template <typename T>
EstimateBytesPerElement(const std::vector<std::unique_ptr<typename TTypes<T,2>::ConstMatrix>> & inputs)49 int64 EstimateBytesPerElement(
50     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
51         inputs) {
52   return sizeof(T);
53 }
54 
55 // EstimateBytesPerElement for strings estimates the total bytes involved in
56 // concatenating the strings in the "inputs" matrices (higher-level code
57 // reshapes all the inputs to matrices), by sampling the lengths of the actual
58 // strings in the various tensors.
59 template <>
EstimateBytesPerElement(const std::vector<std::unique_ptr<typename TTypes<tstring,2>::ConstMatrix>> & inputs)60 int64 EstimateBytesPerElement<tstring>(
61     const std::vector<
62         std::unique_ptr<typename TTypes<tstring, 2>::ConstMatrix>>& inputs) {
63   // randomly sample a few input strings to get a sense of the average size
64   // of each element
65   int num_samples = 0;
66   int64 num_bytes_in_samples = 0;
67   for (const auto& input : inputs) {
68     const auto dim0 = input->dimension(0);
69     const auto dim1 = input->dimension(1);
70     const auto zero = dim0 - dim0;  // Make type match
71     if (dim0 > 0 && dim1 > 0) {
72       // Draw 9 samples of string sizes from the input, in this sort of pattern
73       // ("*" is sample), to get an estimate of the lengths of each string
74       // element in the tensors:
75       //
76       //    *...*...*
77       //    .........
78       //    *...*...*
79       //    .........
80       //    *...*...*
81       for (auto i : {zero, dim0 / 2, dim0 - 1}) {
82         for (auto j : {zero, dim1 / 2, dim1 - 1}) {
83           num_bytes_in_samples += (*input)(i, j).size();
84           num_samples++;
85         }
86       }
87     }
88   }
89   // We don't use sizeof(std::string) as the overhead, since that would
90   // overestimate the memory touched for copying a string.
91   int64 string_overhead = sizeof(char*) + sizeof(size_t);
92   return string_overhead +
93          ((num_samples > 0) ? (num_bytes_in_samples / num_samples) : 0);
94 }
95 
96 }  // namespace
97 
98 template <typename T>
ConcatCPU(DeviceBase * d,const std::vector<std::unique_ptr<typename TTypes<T,2>::ConstMatrix>> & inputs,typename TTypes<T,2>::Matrix * output)99 void ConcatCPU(
100     DeviceBase* d,
101     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
102         inputs,
103     typename TTypes<T, 2>::Matrix* output) {
104   int64 cost_per_unit = EstimateBytesPerElement<T>(inputs);
105   ConcatCPUImpl<T>(d, inputs, cost_per_unit, MemCpyCopier<T>(), output);
106 }
107 
108 #define REGISTER(T)                                                            \
109   template void ConcatCPU<T>(                                                  \
110       DeviceBase*,                                                             \
111       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
112       typename TTypes<T, 2>::Matrix* output);
113 TF_CALL_ALL_TYPES(REGISTER)
114 REGISTER(quint8)
115 REGISTER(qint8)
116 REGISTER(quint16)
117 REGISTER(qint16)
118 REGISTER(qint32)
119 
120 #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
121     !defined(__ANDROID_TYPES_FULL__)
122 // Primarily used for SavedModel support on mobile. Registering it here only
123 // if __ANDROID_TYPES_FULL__ is not defined (which already registers string)
124 // to avoid duplicate registration.
125 REGISTER(tstring);
126 #endif  // defined(IS_MOBILE_PLATFORM) &&
127         // !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
128         // !defined(__ANDROID_TYPES_FULL__)
129 
130 }  // namespace tensorflow
131