• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
17 
18 #include "ruy/profiler/instrumentation.h"  // from @ruy
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/cppmath.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 namespace tflite {
28 namespace optimized_ops {
29 
FullyConnectedSparseWeight(const TfLiteSparsity & sparsity,const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data)30 inline void FullyConnectedSparseWeight(
31     const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
32     const RuntimeShape& input_shape, const float* input_data,
33     const RuntimeShape& weights_shape, const float* weights_data,
34     const RuntimeShape& bias_shape, const float* bias_data,
35     const RuntimeShape& output_shape, float* output_data) {
36   ruy::profiler::ScopeLabel label("FullyConnected");
37   ruy::profiler::ScopeLabel inner_label("Random Sparse");
38   const float output_activation_min = params.float_activation_min;
39   const float output_activation_max = params.float_activation_max;
40 
41   const int output_elements = output_shape.FlatSize();
42   const int output_dims_count = output_shape.DimensionsCount();
43   const int weights_dims_count = weights_shape.DimensionsCount();
44   const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
45   const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
46                                        output_shape, output_dims_count - 1);
47   const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
48   const int w0_size = sparsity.dim_metadata[0].dense_size;
49   const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
50   const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
51 
52   for (int i = 0; i < output_elements; ++i) {
53     output_data[i] = 0.f;
54   }
55 
56   for (int b = 0; b < batches; ++b) {
57     for (int idx_0 = 0; idx_0 < w0_size; ++idx_0) {
58       for (int pw1 = w1_segments[idx_0]; pw1 < w1_segments[idx_0 + 1]; ++pw1) {
59         int idx_1 = w1_indices[pw1];
60         output_data[b * output_depth + idx_0] +=
61             weights_data[pw1] * input_data[b * accum_depth + idx_1];
62       }
63     }
64   }
65 
66   for (int b = 0; b < batches; ++b) {
67     for (int i = 0; i < output_depth; ++i) {
68       float total = output_data[b * output_depth + i];
69       const float bias_value = bias_data ? bias_data[i] : 0;
70       output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
71           total + bias_value, output_activation_min, output_activation_max);
72     }
73   }
74 }
75 
FullyConnectedSparseWeight1x4Impl(const TfLiteSparsity & sparsity,const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,int thread_start,int thread_end,const CpuBackendContext & cpu_backend_context)76 inline void FullyConnectedSparseWeight1x4Impl(
77     const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
78     const RuntimeShape& input_shape, const float* input_data,
79     const RuntimeShape& weights_shape, const float* weights_data,
80     const RuntimeShape& bias_shape, const float* bias_data,
81     const RuntimeShape& output_shape, float* output_data, int thread_start,
82     int thread_end, const CpuBackendContext& cpu_backend_context) {
83   ruy::profiler::ScopeLabel label("FullyConnected");
84   ruy::profiler::ScopeLabel inner_label("1x4 Block Sparse");
85   const float output_activation_min = params.float_activation_min;
86   const float output_activation_max = params.float_activation_max;
87 
88   const int input_dims_count = input_shape.DimensionsCount();
89   const int output_dims_count = output_shape.DimensionsCount();
90   const int weights_dims_count = weights_shape.DimensionsCount();
91   const int batches = thread_end - thread_start;
92   const int input_depth = MatchingDim(weights_shape, weights_dims_count - 1,
93                                       input_shape, input_dims_count - 1);
94   const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
95                                        output_shape, output_dims_count - 1);
96   const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
97   const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
98 
99   tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate1x4(
100       weights_data, w1_segments, w1_indices, weights_shape.Dims(0),
101       weights_shape.Dims(1), input_data + thread_start * input_depth, batches,
102       output_data + thread_start * output_depth);
103 
104   ruy::profiler::ScopeLabel activation_label("activation function");
105   for (int b = thread_start; b < thread_end; ++b) {
106     for (int i = 0; i < output_depth; ++i) {
107       float total = output_data[b * output_depth + i];
108       const float bias_value = bias_data ? bias_data[i] : 0;
109       output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
110           total + bias_value, output_activation_min, output_activation_max);
111     }
112   }
113 }
114 
115 struct FullyConnectedSparseWeight1x4Task : cpu_backend_threadpool::Task {
FullyConnectedSparseWeight1x4TaskFullyConnectedSparseWeight1x4Task116   FullyConnectedSparseWeight1x4Task(
117       const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
118       const RuntimeShape& input_shape, const float* input_data,
119       const RuntimeShape& weights_shape, const float* weights_data,
120       const RuntimeShape& bias_shape, const float* bias_data,
121       const RuntimeShape& output_shape, float* output_data, int thread_start,
122       int thread_end, const CpuBackendContext& cpu_backend_context_x)
123       : sparsity(sparsity),
124         params(params),
125         input_shape(input_shape),
126         input_data(input_data),
127         weights_shape(weights_shape),
128         weights_data(weights_data),
129         bias_shape(bias_shape),
130         bias_data(bias_data),
131         output_shape(output_shape),
132         output_data(output_data),
133         thread_start(thread_start),
134         thread_end(thread_end),
135         cpu_backend_context(cpu_backend_context_x) {}
136 
RunFullyConnectedSparseWeight1x4Task137   void Run() override {
138     FullyConnectedSparseWeight1x4Impl(
139         sparsity, params, input_shape, input_data, weights_shape, weights_data,
140         bias_shape, bias_data, output_shape, output_data, thread_start,
141         thread_end, cpu_backend_context);
142   }
143 
144  private:
145   const TfLiteSparsity& sparsity;
146   const FullyConnectedParams& params;
147   const RuntimeShape& input_shape;
148   const float* input_data;
149   const RuntimeShape& weights_shape;
150   const float* weights_data;
151   const RuntimeShape& bias_shape;
152   const float* bias_data;
153   const RuntimeShape& output_shape;
154   float* output_data;
155   int thread_start;
156   int thread_end;
157   const CpuBackendContext& cpu_backend_context;
158 };
159 
160 // The multi-threaded kernel slices the workload along the batch dimension. If
161 // there's not enough batches of data, the number of threads used is equal to
162 // the batch size. We can improve this later with slicing along the row
163 // dimension of the weight.
FullyConnectedSparseWeight1x4(const TfLiteSparsity & sparsity,const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)164 inline void FullyConnectedSparseWeight1x4(
165     const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
166     const RuntimeShape& input_shape, const float* input_data,
167     const RuntimeShape& weights_shape, const float* weights_data,
168     const RuntimeShape& bias_shape, const float* bias_data,
169     const RuntimeShape& output_shape, float* output_data,
170     CpuBackendContext* cpu_backend_context) {
171   const int output_elements = output_shape.FlatSize();
172   memset(output_data, 0, output_elements * sizeof(float));
173 
174   const int max_threads = cpu_backend_context->max_num_threads();
175   const int batches =
176       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
177   const int thread_count = std::max(1, std::min(batches, max_threads));
178   if (thread_count == 1) {
179     return FullyConnectedSparseWeight1x4Impl(
180         sparsity, params, input_shape, input_data, weights_shape, weights_data,
181         bias_shape, bias_data, output_shape, output_data, 0, batches,
182         *cpu_backend_context);
183   }
184   std::vector<FullyConnectedSparseWeight1x4Task> tasks;
185   tasks.reserve(thread_count);
186   int thread_start = 0;
187   for (int i = 0; i < thread_count; ++i) {
188     // This makes sure the workload is relatively balanced when batches is not a
189     // multiple of thread_count. The first mod(batches, thread_count) tasks need
190     // to process one more batch than the rest.
191     int thread_end = thread_start + batches / thread_count;
192     if (i < batches % thread_count) thread_end++;
193 
194     tasks.emplace_back(sparsity, params, input_shape, input_data, weights_shape,
195                        weights_data, bias_shape, bias_data, output_shape,
196                        output_data, thread_start, thread_end,
197                        *cpu_backend_context);
198     thread_start = thread_end;
199   }
200   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
201                                   cpu_backend_context);
202 }
203 
204 }  // namespace optimized_ops
205 }  // namespace tflite
206 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
207