• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
17 
18 #include "tensorflow/lite/kernels/internal/types.h"
19 
20 namespace tflite {
21 
22 namespace reference_ops {
23 
24 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
25 //
26 // For example, if sequence of dimensions of one input is
27 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
28 // we can consolidate these as
29 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
30 //
31 // The category is updated in the less-frequent case of shapes that are
32 // not suited to a fivefold-loop broadcast.
33 //
34 // Falls back to generic pattern when it does not know how to process properly.
35 //
36 // Returns true iff there is some sort of broadcast, which includes five-fold
37 // patterns and falling back to generic broadcast.
ProcessBroadcastShapes(const RuntimeShape & shape0,const RuntimeShape & shape1,tflite::ArithmeticParams * params)38 inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
39                                    const RuntimeShape& shape1,
40                                    tflite::ArithmeticParams* params) {
41   const int dims_count =
42       std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
43 
44   params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
45   RuntimeShape scalar_shape(dims_count, 1);
46 
47   auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
48   auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
49 
50   // Check for "exact" match, implicitly accepting any scalar shapes.
51   if (extended_shape0 == extended_shape1) {
52     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
53     return false;
54   }
55 
56   for (int i = dims_count - 1; i >= 0; --i) {
57     if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
58       continue;
59     } else if (extended_shape0.Dims(i) == 1) {
60       params->broadcast_category =
61           BroadcastableOpCategory::kFirstInputBroadcastsFast;
62       break;
63     } else if (extended_shape1.Dims(i) == 1) {
64       params->broadcast_category =
65           BroadcastableOpCategory::kSecondInputBroadcastsFast;
66       break;
67     } else {
68       // This case is erroneous: there is a dimension that does not match and
69       // is not a broadcast from one shape to the other.
70       params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
71       return true;
72     }
73   }
74 
75   if (params->broadcast_category !=
76           BroadcastableOpCategory::kFirstInputBroadcastsFast &&
77       params->broadcast_category !=
78           BroadcastableOpCategory::kSecondInputBroadcastsFast) {
79     return false;
80   }
81 
82   // From this point it is assumed contractually that corresponding dimensions
83   // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
84   const bool swap_inputs = params->broadcast_category ==
85                            BroadcastableOpCategory::kSecondInputBroadcastsFast;
86   const RuntimeShape* shape_a =
87       swap_inputs ? &extended_shape1 : &extended_shape0;
88   const RuntimeShape* shape_b =
89       swap_inputs ? &extended_shape0 : &extended_shape1;
90 
91   int i = dims_count - 1;
92   params->broadcast_shape[0] = 1;
93   params->broadcast_shape[1] = 1;
94   params->broadcast_shape[2] = 1;
95   params->broadcast_shape[3] = 1;
96   params->broadcast_shape[4] = 1;
97   // y_0 is greedy: include dims if both or neither equal 1: in other words,
98   // test for equality rather than (shape_a->Dims(i) != 1).
99   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
100     params->broadcast_shape[4] *= shape_b->Dims(i);
101     --i;
102   }
103   // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
104   // that has the unit dimension, the next two loops are not entered.
105   while (i >= 0 && shape_a->Dims(i) == 1) {
106     params->broadcast_shape[3] *= shape_b->Dims(i);
107     --i;
108   }
109   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
110     params->broadcast_shape[2] *= shape_a->Dims(i);
111     --i;
112   }
113   // Here either input_a or input_b has dim of 1 (if i >= 0).
114   while (i >= 0 && shape_b->Dims(i) == 1) {
115     params->broadcast_shape[1] *= shape_a->Dims(i);
116     --i;
117   }
118   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
119     params->broadcast_shape[0] *= shape_b->Dims(i);
120     --i;
121   }
122 
123   // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
124   // loop.
125   if (i >= 0) {
126     params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
127   }
128   return true;
129 }
130 
131 }  // namespace reference_ops
132 }  // namespace tflite
133 
134 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
135