• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "TypeUtils.h"
18 
19 #include <android-base/logging.h>
20 
21 #include <algorithm>
22 #include <chrono>
23 #include <limits>
24 #include <memory>
25 #include <ostream>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29 
30 #include "OperandTypes.h"
31 #include "OperationTypes.h"
32 #include "Result.h"
33 #include "SharedMemory.h"
34 #include "Types.h"
35 
36 namespace android::nn {
37 namespace {
38 
39 template <typename Type>
underlyingType(Type object)40 constexpr std::underlying_type_t<Type> underlyingType(Type object) {
41     return static_cast<std::underlying_type_t<Type>>(object);
42 }
43 
44 template <typename Type>
operator <<(std::ostream & os,const std::vector<Type> & vec)45 std::ostream& operator<<(std::ostream& os, const std::vector<Type>& vec) {
46     constexpr size_t kMaxVectorPrint = 20;
47     os << "[";
48     size_t count = 0;
49     for (const auto& element : vec) {
50         if (count > 0) {
51             os << ", ";
52         }
53         os << element;
54         count++;
55         if (count >= kMaxVectorPrint) {
56             return os << "...]";
57         }
58     }
59     return os << "]";
60 }
61 
62 }  // namespace
63 
isExtension(OperandType type)64 bool isExtension(OperandType type) {
65     return getExtensionPrefix(underlyingType(type)) != 0;
66 }
67 
isExtension(OperationType type)68 bool isExtension(OperationType type) {
69     return getExtensionPrefix(underlyingType(type)) != 0;
70 }
71 
isNonExtensionScalar(OperandType operandType)72 bool isNonExtensionScalar(OperandType operandType) {
73     CHECK(!isExtension(operandType));
74     switch (operandType) {
75         case OperandType::FLOAT32:
76         case OperandType::INT32:
77         case OperandType::UINT32:
78         case OperandType::BOOL:
79         case OperandType::FLOAT16:
80         case OperandType::SUBGRAPH:
81         case OperandType::OEM:
82             return true;
83         case OperandType::TENSOR_FLOAT32:
84         case OperandType::TENSOR_INT32:
85         case OperandType::TENSOR_QUANT8_ASYMM:
86         case OperandType::TENSOR_QUANT16_SYMM:
87         case OperandType::TENSOR_FLOAT16:
88         case OperandType::TENSOR_BOOL8:
89         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
90         case OperandType::TENSOR_QUANT16_ASYMM:
91         case OperandType::TENSOR_QUANT8_SYMM:
92         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
93         case OperandType::TENSOR_OEM_BYTE:
94             return false;
95     }
96     return false;
97 }
98 
getNonExtensionSize(OperandType operandType)99 size_t getNonExtensionSize(OperandType operandType) {
100     CHECK(!isExtension(operandType));
101     switch (operandType) {
102         case OperandType::SUBGRAPH:
103         case OperandType::OEM:
104             return 0;
105         case OperandType::TENSOR_QUANT8_ASYMM:
106         case OperandType::BOOL:
107         case OperandType::TENSOR_BOOL8:
108         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
109         case OperandType::TENSOR_QUANT8_SYMM:
110         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
111         case OperandType::TENSOR_OEM_BYTE:
112             return 1;
113         case OperandType::TENSOR_QUANT16_SYMM:
114         case OperandType::TENSOR_FLOAT16:
115         case OperandType::FLOAT16:
116         case OperandType::TENSOR_QUANT16_ASYMM:
117             return 2;
118         case OperandType::FLOAT32:
119         case OperandType::INT32:
120         case OperandType::UINT32:
121         case OperandType::TENSOR_FLOAT32:
122         case OperandType::TENSOR_INT32:
123             return 4;
124     }
125     return 0;
126 }
127 
getNonExtensionSize(OperandType operandType,const Dimensions & dimensions)128 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions) {
129     CHECK(!isExtension(operandType)) << "Size of extension operand data is unknown";
130     size_t size = getNonExtensionSize(operandType);
131     if (isNonExtensionScalar(operandType)) {
132         return size;
133     } else if (dimensions.empty()) {
134         return 0;
135     }
136     for (Dimension dimension : dimensions) {
137         if (dimension != 0 && size > std::numeric_limits<size_t>::max() / dimension) {
138             return std::nullopt;
139         }
140         size *= dimension;
141     }
142     return size;
143 }
144 
getNonExtensionSize(const Operand & operand)145 std::optional<size_t> getNonExtensionSize(const Operand& operand) {
146     return getNonExtensionSize(operand.type, operand.dimensions);
147 }
148 
getOffsetFromInts(int lower,int higher)149 size_t getOffsetFromInts(int lower, int higher) {
150     const int32_t lowBits = static_cast<int32_t>(lower);
151     const int32_t highBits = static_cast<int32_t>(higher);
152     const uint32_t lowOffsetBits = *reinterpret_cast<const uint32_t*>(&lowBits);
153     const uint32_t highOffsetBits = *reinterpret_cast<const uint32_t*>(&highBits);
154     const uint64_t offset = lowOffsetBits | (static_cast<uint64_t>(highOffsetBits) << 32);
155     return offset;
156 }
157 
getIntsFromOffset(size_t offset)158 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset) {
159     const uint64_t bits = static_cast<uint64_t>(offset);
160     const uint32_t lowBits = static_cast<uint32_t>(bits & 0xffffffff);
161     const uint32_t highBits = static_cast<uint32_t>(bits >> 32);
162     const int32_t lowOffsetBits = *reinterpret_cast<const int32_t*>(&lowBits);
163     const int32_t highOffsetBits = *reinterpret_cast<const int32_t*>(&highBits);
164     return std::make_pair(lowOffsetBits, highOffsetBits);
165 }
166 
countNumberOfConsumers(size_t numberOfOperands,const std::vector<nn::Operation> & operations)167 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
168                                                      const std::vector<nn::Operation>& operations) {
169     std::vector<uint32_t> numberOfConsumers(numberOfOperands, 0);
170     for (const auto& operation : operations) {
171         for (uint32_t operandIndex : operation.inputs) {
172             if (operandIndex >= numberOfConsumers.size()) {
173                 return NN_ERROR()
174                        << "countNumberOfConsumers: tried to access out-of-bounds operand ("
175                        << operandIndex << " vs " << numberOfConsumers.size() << ")";
176             }
177             numberOfConsumers[operandIndex]++;
178         }
179     }
180     return numberOfConsumers;
181 }
182 
combineDimensions(const Dimensions & lhs,const Dimensions & rhs)183 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs) {
184     if (rhs.empty()) return lhs;
185     if (lhs.empty()) return rhs;
186     if (lhs.size() != rhs.size()) {
187         std::ostringstream os;
188         os << "Incompatible ranks: " << lhs << " and " << rhs;
189         return NN_ERROR() << os.str();
190     }
191     Dimensions combined = lhs;
192     for (size_t i = 0; i < lhs.size(); i++) {
193         if (lhs[i] == 0) {
194             combined[i] = rhs[i];
195         } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
196             std::ostringstream os;
197             os << "Incompatible dimensions: " << lhs << " and " << rhs;
198             return NN_ERROR() << os.str();
199         }
200     }
201     return combined;
202 }
203 
getMemorySizes(const Model & model)204 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model) {
205     const size_t operandValuesSize = model.operandValues.size();
206 
207     std::vector<size_t> poolSizes;
208     poolSizes.reserve(model.pools.size());
209     std::transform(model.pools.begin(), model.pools.end(), std::back_inserter(poolSizes),
210                    [](const SharedMemory& memory) { return getSize(memory); });
211 
212     return std::make_pair(operandValuesSize, std::move(poolSizes));
213 }
214 
roundUp(size_t size,size_t multiple)215 size_t roundUp(size_t size, size_t multiple) {
216     CHECK(multiple != 0);
217     CHECK((multiple & (multiple - 1)) == 0) << multiple << " is not a power of two";
218     return (size + (multiple - 1)) & ~(multiple - 1);
219 }
220 
getAlignmentForLength(size_t length)221 size_t getAlignmentForLength(size_t length) {
222     if (length < 2) {
223         return 1;  // No alignment necessary
224     } else if (length < 4) {
225         return 2;  // Align on 2-byte boundary
226     } else {
227         return 4;  // Align on 4-byte boundary
228     }
229 }
230 
operator <<(std::ostream & os,const DeviceStatus & deviceStatus)231 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus) {
232     switch (deviceStatus) {
233         case DeviceStatus::AVAILABLE:
234             return os << "AVAILABLE";
235         case DeviceStatus::BUSY:
236             return os << "BUSY";
237         case DeviceStatus::OFFLINE:
238             return os << "OFFLINE";
239         case DeviceStatus::UNKNOWN:
240             return os << "UNKNOWN";
241     }
242     return os << "DeviceStatus{" << underlyingType(deviceStatus) << "}";
243 }
244 
operator <<(std::ostream & os,const ExecutionPreference & executionPreference)245 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference) {
246     switch (executionPreference) {
247         case ExecutionPreference::LOW_POWER:
248             return os << "LOW_POWER";
249         case ExecutionPreference::FAST_SINGLE_ANSWER:
250             return os << "FAST_SINGLE_ANSWER";
251         case ExecutionPreference::SUSTAINED_SPEED:
252             return os << "SUSTAINED_SPEED";
253     }
254     return os << "ExecutionPreference{" << underlyingType(executionPreference) << "}";
255 }
256 
operator <<(std::ostream & os,const DeviceType & deviceType)257 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) {
258     switch (deviceType) {
259         case DeviceType::UNKNOWN:
260             return os << "UNKNOWN";
261         case DeviceType::OTHER:
262             return os << "OTHER";
263         case DeviceType::CPU:
264             return os << "CPU";
265         case DeviceType::GPU:
266             return os << "GPU";
267         case DeviceType::ACCELERATOR:
268             return os << "ACCELERATOR";
269     }
270     return os << "DeviceType{" << underlyingType(deviceType) << "}";
271 }
272 
operator <<(std::ostream & os,const MeasureTiming & measureTiming)273 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming) {
274     switch (measureTiming) {
275         case MeasureTiming::NO:
276             return os << "NO";
277         case MeasureTiming::YES:
278             return os << "YES";
279     }
280     return os << "MeasureTiming{" << underlyingType(measureTiming) << "}";
281 }
282 
operator <<(std::ostream & os,const OperandType & operandType)283 std::ostream& operator<<(std::ostream& os, const OperandType& operandType) {
284     switch (operandType) {
285         case OperandType::FLOAT32:
286             return os << "FLOAT32";
287         case OperandType::INT32:
288             return os << "INT32";
289         case OperandType::UINT32:
290             return os << "UINT32";
291         case OperandType::TENSOR_FLOAT32:
292             return os << "TENSOR_FLOAT32";
293         case OperandType::TENSOR_INT32:
294             return os << "TENSOR_INT32";
295         case OperandType::TENSOR_QUANT8_ASYMM:
296             return os << "TENSOR_QUANT8_ASYMM";
297         case OperandType::BOOL:
298             return os << "BOOL";
299         case OperandType::TENSOR_QUANT16_SYMM:
300             return os << "TENSOR_QUANT16_SYMM";
301         case OperandType::TENSOR_FLOAT16:
302             return os << "TENSOR_FLOAT16";
303         case OperandType::TENSOR_BOOL8:
304             return os << "TENSOR_BOOL8";
305         case OperandType::FLOAT16:
306             return os << "FLOAT16";
307         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
308             return os << "TENSOR_QUANT8_SYMM_PER_CHANNEL";
309         case OperandType::TENSOR_QUANT16_ASYMM:
310             return os << "TENSOR_QUANT16_ASYMM";
311         case OperandType::TENSOR_QUANT8_SYMM:
312             return os << "TENSOR_QUANT8_SYMM";
313         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
314             return os << "TENSOR_QUANT8_ASYMM_SIGNED";
315         case OperandType::SUBGRAPH:
316             return os << "SUBGRAPH";
317         case OperandType::OEM:
318             return os << "OEM";
319         case OperandType::TENSOR_OEM_BYTE:
320             return os << "TENSOR_OEM_BYTE";
321     }
322     if (isExtension(operandType)) {
323         return os << "Extension OperandType " << underlyingType(operandType);
324     }
325     return os << "OperandType{" << underlyingType(operandType) << "}";
326 }
327 
operator <<(std::ostream & os,const Operand::LifeTime & lifetime)328 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime) {
329     switch (lifetime) {
330         case Operand::LifeTime::TEMPORARY_VARIABLE:
331             return os << "TEMPORARY_VARIABLE";
332         case Operand::LifeTime::SUBGRAPH_INPUT:
333             return os << "SUBGRAPH_INPUT";
334         case Operand::LifeTime::SUBGRAPH_OUTPUT:
335             return os << "SUBGRAPH_OUTPUT";
336         case Operand::LifeTime::CONSTANT_COPY:
337             return os << "CONSTANT_COPY";
338         case Operand::LifeTime::CONSTANT_REFERENCE:
339             return os << "CONSTANT_REFERENCE";
340         case Operand::LifeTime::NO_VALUE:
341             return os << "NO_VALUE";
342         case Operand::LifeTime::SUBGRAPH:
343             return os << "SUBGRAPH";
344         case Operand::LifeTime::POINTER:
345             return os << "POINTER";
346     }
347     return os << "Operand::LifeTime{" << underlyingType(lifetime) << "}";
348 }
349 
operator <<(std::ostream & os,const OperationType & operationType)350 std::ostream& operator<<(std::ostream& os, const OperationType& operationType) {
351     switch (operationType) {
352         case OperationType::ADD:
353             return os << "ADD";
354         case OperationType::AVERAGE_POOL_2D:
355             return os << "AVERAGE_POOL_2D";
356         case OperationType::CONCATENATION:
357             return os << "CONCATENATION";
358         case OperationType::CONV_2D:
359             return os << "CONV_2D";
360         case OperationType::DEPTHWISE_CONV_2D:
361             return os << "DEPTHWISE_CONV_2D";
362         case OperationType::DEPTH_TO_SPACE:
363             return os << "DEPTH_TO_SPACE";
364         case OperationType::DEQUANTIZE:
365             return os << "DEQUANTIZE";
366         case OperationType::EMBEDDING_LOOKUP:
367             return os << "EMBEDDING_LOOKUP";
368         case OperationType::FLOOR:
369             return os << "FLOOR";
370         case OperationType::FULLY_CONNECTED:
371             return os << "FULLY_CONNECTED";
372         case OperationType::HASHTABLE_LOOKUP:
373             return os << "HASHTABLE_LOOKUP";
374         case OperationType::L2_NORMALIZATION:
375             return os << "L2_NORMALIZATION";
376         case OperationType::L2_POOL_2D:
377             return os << "L2_POOL_2D";
378         case OperationType::LOCAL_RESPONSE_NORMALIZATION:
379             return os << "LOCAL_RESPONSE_NORMALIZATION";
380         case OperationType::LOGISTIC:
381             return os << "LOGISTIC";
382         case OperationType::LSH_PROJECTION:
383             return os << "LSH_PROJECTION";
384         case OperationType::LSTM:
385             return os << "LSTM";
386         case OperationType::MAX_POOL_2D:
387             return os << "MAX_POOL_2D";
388         case OperationType::MUL:
389             return os << "MUL";
390         case OperationType::RELU:
391             return os << "RELU";
392         case OperationType::RELU1:
393             return os << "RELU1";
394         case OperationType::RELU6:
395             return os << "RELU6";
396         case OperationType::RESHAPE:
397             return os << "RESHAPE";
398         case OperationType::RESIZE_BILINEAR:
399             return os << "RESIZE_BILINEAR";
400         case OperationType::RNN:
401             return os << "RNN";
402         case OperationType::SOFTMAX:
403             return os << "SOFTMAX";
404         case OperationType::SPACE_TO_DEPTH:
405             return os << "SPACE_TO_DEPTH";
406         case OperationType::SVDF:
407             return os << "SVDF";
408         case OperationType::TANH:
409             return os << "TANH";
410         case OperationType::BATCH_TO_SPACE_ND:
411             return os << "BATCH_TO_SPACE_ND";
412         case OperationType::DIV:
413             return os << "DIV";
414         case OperationType::MEAN:
415             return os << "MEAN";
416         case OperationType::PAD:
417             return os << "PAD";
418         case OperationType::SPACE_TO_BATCH_ND:
419             return os << "SPACE_TO_BATCH_ND";
420         case OperationType::SQUEEZE:
421             return os << "SQUEEZE";
422         case OperationType::STRIDED_SLICE:
423             return os << "STRIDED_SLICE";
424         case OperationType::SUB:
425             return os << "SUB";
426         case OperationType::TRANSPOSE:
427             return os << "TRANSPOSE";
428         case OperationType::ABS:
429             return os << "ABS";
430         case OperationType::ARGMAX:
431             return os << "ARGMAX";
432         case OperationType::ARGMIN:
433             return os << "ARGMIN";
434         case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM:
435             return os << "AXIS_ALIGNED_BBOX_TRANSFORM";
436         case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM:
437             return os << "BIDIRECTIONAL_SEQUENCE_LSTM";
438         case OperationType::BIDIRECTIONAL_SEQUENCE_RNN:
439             return os << "BIDIRECTIONAL_SEQUENCE_RNN";
440         case OperationType::BOX_WITH_NMS_LIMIT:
441             return os << "BOX_WITH_NMS_LIMIT";
442         case OperationType::CAST:
443             return os << "CAST";
444         case OperationType::CHANNEL_SHUFFLE:
445             return os << "CHANNEL_SHUFFLE";
446         case OperationType::DETECTION_POSTPROCESSING:
447             return os << "DETECTION_POSTPROCESSING";
448         case OperationType::EQUAL:
449             return os << "EQUAL";
450         case OperationType::EXP:
451             return os << "EXP";
452         case OperationType::EXPAND_DIMS:
453             return os << "EXPAND_DIMS";
454         case OperationType::GATHER:
455             return os << "GATHER";
456         case OperationType::GENERATE_PROPOSALS:
457             return os << "GENERATE_PROPOSALS";
458         case OperationType::GREATER:
459             return os << "GREATER";
460         case OperationType::GREATER_EQUAL:
461             return os << "GREATER_EQUAL";
462         case OperationType::GROUPED_CONV_2D:
463             return os << "GROUPED_CONV_2D";
464         case OperationType::HEATMAP_MAX_KEYPOINT:
465             return os << "HEATMAP_MAX_KEYPOINT";
466         case OperationType::INSTANCE_NORMALIZATION:
467             return os << "INSTANCE_NORMALIZATION";
468         case OperationType::LESS:
469             return os << "LESS";
470         case OperationType::LESS_EQUAL:
471             return os << "LESS_EQUAL";
472         case OperationType::LOG:
473             return os << "LOG";
474         case OperationType::LOGICAL_AND:
475             return os << "LOGICAL_AND";
476         case OperationType::LOGICAL_NOT:
477             return os << "LOGICAL_NOT";
478         case OperationType::LOGICAL_OR:
479             return os << "LOGICAL_OR";
480         case OperationType::LOG_SOFTMAX:
481             return os << "LOG_SOFTMAX";
482         case OperationType::MAXIMUM:
483             return os << "MAXIMUM";
484         case OperationType::MINIMUM:
485             return os << "MINIMUM";
486         case OperationType::NEG:
487             return os << "NEG";
488         case OperationType::NOT_EQUAL:
489             return os << "NOT_EQUAL";
490         case OperationType::PAD_V2:
491             return os << "PAD_V2";
492         case OperationType::POW:
493             return os << "POW";
494         case OperationType::PRELU:
495             return os << "PRELU";
496         case OperationType::QUANTIZE:
497             return os << "QUANTIZE";
498         case OperationType::QUANTIZED_16BIT_LSTM:
499             return os << "QUANTIZED_16BIT_LSTM";
500         case OperationType::RANDOM_MULTINOMIAL:
501             return os << "RANDOM_MULTINOMIAL";
502         case OperationType::REDUCE_ALL:
503             return os << "REDUCE_ALL";
504         case OperationType::REDUCE_ANY:
505             return os << "REDUCE_ANY";
506         case OperationType::REDUCE_MAX:
507             return os << "REDUCE_MAX";
508         case OperationType::REDUCE_MIN:
509             return os << "REDUCE_MIN";
510         case OperationType::REDUCE_PROD:
511             return os << "REDUCE_PROD";
512         case OperationType::REDUCE_SUM:
513             return os << "REDUCE_SUM";
514         case OperationType::ROI_ALIGN:
515             return os << "ROI_ALIGN";
516         case OperationType::ROI_POOLING:
517             return os << "ROI_POOLING";
518         case OperationType::RSQRT:
519             return os << "RSQRT";
520         case OperationType::SELECT:
521             return os << "SELECT";
522         case OperationType::SIN:
523             return os << "SIN";
524         case OperationType::SLICE:
525             return os << "SLICE";
526         case OperationType::SPLIT:
527             return os << "SPLIT";
528         case OperationType::SQRT:
529             return os << "SQRT";
530         case OperationType::TILE:
531             return os << "TILE";
532         case OperationType::TOPK_V2:
533             return os << "TOPK_V2";
534         case OperationType::TRANSPOSE_CONV_2D:
535             return os << "TRANSPOSE_CONV_2D";
536         case OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM:
537             return os << "UNIDIRECTIONAL_SEQUENCE_LSTM";
538         case OperationType::UNIDIRECTIONAL_SEQUENCE_RNN:
539             return os << "UNIDIRECTIONAL_SEQUENCE_RNN";
540         case OperationType::RESIZE_NEAREST_NEIGHBOR:
541             return os << "RESIZE_NEAREST_NEIGHBOR";
542         case OperationType::QUANTIZED_LSTM:
543             return os << "QUANTIZED_LSTM";
544         case OperationType::IF:
545             return os << "IF";
546         case OperationType::WHILE:
547             return os << "WHILE";
548         case OperationType::ELU:
549             return os << "ELU";
550         case OperationType::HARD_SWISH:
551             return os << "HARD_SWISH";
552         case OperationType::FILL:
553             return os << "FILL";
554         case OperationType::RANK:
555             return os << "RANK";
556         case OperationType::OEM_OPERATION:
557             return os << "OEM_OPERATION";
558     }
559     if (isExtension(operationType)) {
560         return os << "Extension OperationType " << underlyingType(operationType);
561     }
562     return os << "OperationType{" << underlyingType(operationType) << "}";
563 }
564 
operator <<(std::ostream & os,const Request::Argument::LifeTime & lifetime)565 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime) {
566     switch (lifetime) {
567         case Request::Argument::LifeTime::POOL:
568             return os << "POOL";
569         case Request::Argument::LifeTime::NO_VALUE:
570             return os << "NO_VALUE";
571         case Request::Argument::LifeTime::POINTER:
572             return os << "POINTER";
573     }
574     return os << "Request::Argument::LifeTime{" << underlyingType(lifetime) << "}";
575 }
576 
operator <<(std::ostream & os,const Priority & priority)577 std::ostream& operator<<(std::ostream& os, const Priority& priority) {
578     switch (priority) {
579         case Priority::LOW:
580             return os << "LOW";
581         case Priority::MEDIUM:
582             return os << "MEDIUM";
583         case Priority::HIGH:
584             return os << "HIGH";
585     }
586     return os << "Priority{" << underlyingType(priority) << "}";
587 }
588 
operator <<(std::ostream & os,const ErrorStatus & errorStatus)589 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) {
590     switch (errorStatus) {
591         case ErrorStatus::NONE:
592             return os << "NONE";
593         case ErrorStatus::DEVICE_UNAVAILABLE:
594             return os << "DEVICE_UNAVAILABLE";
595         case ErrorStatus::GENERAL_FAILURE:
596             return os << "GENERAL_FAILURE";
597         case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
598             return os << "OUTPUT_INSUFFICIENT_SIZE";
599         case ErrorStatus::INVALID_ARGUMENT:
600             return os << "INVALID_ARGUMENT";
601         case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
602             return os << "MISSED_DEADLINE_TRANSIENT";
603         case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
604             return os << "MISSED_DEADLINE_PERSISTENT";
605         case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
606             return os << "RESOURCE_EXHAUSTED_TRANSIENT";
607         case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
608             return os << "RESOURCE_EXHAUSTED_PERSISTENT";
609         case ErrorStatus::DEAD_OBJECT:
610             return os << "DEAD_OBJECT";
611     }
612     return os << "ErrorStatus{" << underlyingType(errorStatus) << "}";
613 }
614 
operator <<(std::ostream & os,const FusedActivationFunc & activation)615 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation) {
616     switch (activation) {
617         case FusedActivationFunc::NONE:
618             return os << "NONE";
619         case FusedActivationFunc::RELU:
620             return os << "RELU";
621         case FusedActivationFunc::RELU1:
622             return os << "RELU1";
623         case FusedActivationFunc::RELU6:
624             return os << "RELU6";
625     }
626     return os << "FusedActivationFunc{" << underlyingType(activation) << "}";
627 }
628 
operator <<(std::ostream & os,const OutputShape & outputShape)629 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) {
630     return os << "OutputShape{.dimensions=" << outputShape.dimensions
631               << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}";
632 }
633 
operator <<(std::ostream & os,const Timing & timing)634 std::ostream& operator<<(std::ostream& os, const Timing& timing) {
635     return os << "Timing{.timeOnDevice=" << timing.timeOnDevice
636               << ", .timeInDriver=" << timing.timeInDriver << "}";
637 }
638 
operator <<(std::ostream & os,const Capabilities::PerformanceInfo & performanceInfo)639 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo) {
640     return os << "Capabilities::PerformanceInfo{.execTime=" << performanceInfo.execTime
641               << ", .powerUsage=" << performanceInfo.powerUsage << "}";
642 }
643 
operator <<(std::ostream & os,const Capabilities::OperandPerformance & operandPerformance)644 std::ostream& operator<<(std::ostream& os,
645                          const Capabilities::OperandPerformance& operandPerformance) {
646     return os << "Capabilities::OperandPerformance{.type=" << operandPerformance.type
647               << ", .info=" << operandPerformance.info << "}";
648 }
649 
operator <<(std::ostream & os,const Capabilities::OperandPerformanceTable & operandPerformances)650 std::ostream& operator<<(std::ostream& os,
651                          const Capabilities::OperandPerformanceTable& operandPerformances) {
652     return os << operandPerformances.asVector();
653 }
654 
operator <<(std::ostream & os,const Capabilities & capabilities)655 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities) {
656     return os << "Capabilities{.relaxedFloat32toFloat16PerformanceScalar="
657               << capabilities.relaxedFloat32toFloat16PerformanceScalar
658               << ", .relaxedFloat32toFloat16PerformanceTensor="
659               << capabilities.relaxedFloat32toFloat16PerformanceTensor
660               << ", .operandPerformance=" << capabilities.operandPerformance
661               << ", .ifPerformance=" << capabilities.ifPerformance
662               << ", .whilePerformance=" << capabilities.whilePerformance << "}";
663 }
664 
operator <<(std::ostream & os,const Extension::OperandTypeInformation & operandTypeInformation)665 std::ostream& operator<<(std::ostream& os,
666                          const Extension::OperandTypeInformation& operandTypeInformation) {
667     return os << "Extension::OperandTypeInformation{.type=" << operandTypeInformation.type
668               << ", .isTensor=" << (operandTypeInformation.isTensor ? "true" : "false")
669               << ", .byteSize=" << operandTypeInformation.byteSize << "}";
670 }
671 
operator <<(std::ostream & os,const Extension & extension)672 std::ostream& operator<<(std::ostream& os, const Extension& extension) {
673     return os << "Extension{.name=" << extension.name
674               << ", .operandTypes=" << extension.operandTypes << "}";
675 }
676 
operator <<(std::ostream & os,const DataLocation & location)677 std::ostream& operator<<(std::ostream& os, const DataLocation& location) {
678     const auto printPointer = [&os](const std::variant<const void*, void*>& pointer) {
679         os << (std::holds_alternative<const void*>(pointer) ? "<constant " : "<mutable ");
680         os << std::visit(
681                 [](const auto* ptr) {
682                     return ptr == nullptr ? "null pointer>" : "non-null pointer>";
683                 },
684                 pointer);
685     };
686     os << "DataLocation{.pointer=";
687     printPointer(location.pointer);
688     return os << ", .poolIndex=" << location.poolIndex << ", .offset=" << location.offset
689               << ", .length=" << location.length << ", .padding=" << location.padding << "}";
690 }
691 
operator <<(std::ostream & os,const Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)692 std::ostream& operator<<(std::ostream& os,
693                          const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
694     return os << "Operand::SymmPerChannelQuantParams{.scales=" << symmPerChannelQuantParams.scales
695               << ", .channelDim=" << symmPerChannelQuantParams.channelDim << "}";
696 }
697 
operator <<(std::ostream & os,const Operand::ExtraParams & extraParams)698 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams) {
699     os << "Operand::ExtraParams{";
700     if (std::holds_alternative<Operand::NoParams>(extraParams)) {
701         os << "<no params>";
702     } else if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(extraParams)) {
703         os << std::get<Operand::SymmPerChannelQuantParams>(extraParams);
704     } else if (std::holds_alternative<Operand::ExtensionParams>(extraParams)) {
705         os << std::get<Operand::ExtensionParams>(extraParams);
706     }
707     return os << "}";
708 }
709 
operator <<(std::ostream & os,const Operand & operand)710 std::ostream& operator<<(std::ostream& os, const Operand& operand) {
711     return os << "Operand{.type=" << operand.type << ", .dimensions=" << operand.dimensions
712               << ", .scale=" << operand.scale << ", .zeroPoint=" << operand.zeroPoint
713               << ", lifetime=" << operand.lifetime << ", .location=" << operand.location
714               << ", .extraParams=" << operand.extraParams << "}";
715 }
716 
operator <<(std::ostream & os,const Operation & operation)717 std::ostream& operator<<(std::ostream& os, const Operation& operation) {
718     return os << "Operation{.type=" << operation.type << ", .inputs=" << operation.inputs
719               << ", .outputs=" << operation.outputs << "}";
720 }
721 
operator <<(std::ostream & os,const Handle & handle)722 static std::ostream& operator<<(std::ostream& os, const Handle& handle) {
723     return os << "<handle with " << handle.fds.size() << " fds and " << handle.ints.size()
724               << " ints>";
725 }
726 
operator <<(std::ostream & os,const SharedHandle & handle)727 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle) {
728     if (handle == nullptr) {
729         return os << "<empty handle>";
730     }
731     return os << *handle;
732 }
733 
operator <<(std::ostream & os,const Memory::Ashmem & memory)734 static std::ostream& operator<<(std::ostream& os, const Memory::Ashmem& memory) {
735     return os << "Ashmem{.fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
736               << ", .size=" << memory.size << "}";
737 }
738 
operator <<(std::ostream & os,const Memory::Fd & memory)739 static std::ostream& operator<<(std::ostream& os, const Memory::Fd& memory) {
740     return os << "Fd{.size=" << memory.size << ", .prot=" << memory.prot
741               << ", .fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
742               << ", .offset=" << memory.offset << "}";
743 }
744 
operator <<(std::ostream & os,const Memory::HardwareBuffer & memory)745 static std::ostream& operator<<(std::ostream& os, const Memory::HardwareBuffer& memory) {
746     if (memory.handle.get() == nullptr) {
747         return os << "<empty HardwareBuffer::Handle>";
748     }
749     return os << (isAhwbBlob(memory) ? "<AHardwareBuffer blob>" : "<non-blob AHardwareBuffer>");
750 }
751 
operator <<(std::ostream & os,const Memory::Unknown & memory)752 static std::ostream& operator<<(std::ostream& os, const Memory::Unknown& memory) {
753     return os << "Unknown{.handle=" << memory.handle << ", .size=" << memory.size
754               << ", .name=" << memory.name << "}";
755 }
756 
operator <<(std::ostream & os,const Memory & memory)757 std::ostream& operator<<(std::ostream& os, const Memory& memory) {
758     os << "Memory{.handle=";
759     std::visit([&os](const auto& x) { os << x; }, memory.handle);
760     return os << "}";
761 }
762 
operator <<(std::ostream & os,const SharedMemory & memory)763 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory) {
764     if (memory == nullptr) {
765         return os << "<empty memory>";
766     }
767     return os << *memory;
768 }
769 
operator <<(std::ostream & os,const MemoryPreference & memoryPreference)770 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference) {
771     return os << "MemoryPreference{.alignment=" << memoryPreference.alignment
772               << ", .padding=" << memoryPreference.padding << "}";
773 }
774 
operator <<(std::ostream & os,const Model::Subgraph & subgraph)775 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph) {
776     std::vector<Operand> operands;
777     std::vector<Operation> operations;
778     std::vector<uint32_t> inputIndexes;
779     std::vector<uint32_t> outputIndexes;
780     return os << "Model::Subgraph{.operands=" << subgraph.operands
781               << ", .operations=" << subgraph.operations
782               << ", .inputIndexes=" << subgraph.inputIndexes
783               << ", .outputIndexes=" << subgraph.outputIndexes << "}";
784 }
785 
operator <<(std::ostream & os,const Model::OperandValues & operandValues)786 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues) {
787     return os << "Model::OperandValues{<" << operandValues.size() << "bytes>}";
788 }
789 
operator <<(std::ostream & os,const Model::ExtensionNameAndPrefix & extensionNameAndPrefix)790 std::ostream& operator<<(std::ostream& os,
791                          const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
792     return os << "Model::ExtensionNameAndPrefix{.name=" << extensionNameAndPrefix.name
793               << ", .prefix=" << extensionNameAndPrefix.prefix << "}";
794 }
795 
operator <<(std::ostream & os,const Model & model)796 std::ostream& operator<<(std::ostream& os, const Model& model) {
797     return os << "Model{.main=" << model.main << ", .referenced=" << model.referenced
798               << ", .operandValues=" << model.operandValues << ", .pools=" << model.pools
799               << ", .relaxComputationFloat32toFloat16="
800               << (model.relaxComputationFloat32toFloat16 ? "true" : "false")
801               << ", extensionNameToPrefix=" << model.extensionNameToPrefix << "}";
802 }
803 
operator <<(std::ostream & os,const BufferDesc & bufferDesc)804 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc) {
805     return os << "BufferDesc{.dimensions=" << bufferDesc.dimensions << "}";
806 }
807 
operator <<(std::ostream & os,const BufferRole & bufferRole)808 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole) {
809     return os << "BufferRole{.modelIndex=" << bufferRole.modelIndex
810               << ", .ioIndex=" << bufferRole.ioIndex << ", .probability=" << bufferRole.probability
811               << "}";
812 }
813 
operator <<(std::ostream & os,const Request::Argument & requestArgument)814 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument) {
815     return os << "Request::Argument{.lifetime=" << requestArgument.lifetime
816               << ", .location=" << requestArgument.location
817               << ", .dimensions=" << requestArgument.dimensions << "}";
818 }
819 
operator <<(std::ostream & os,const Request::MemoryPool & memoryPool)820 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool) {
821     os << "Request::MemoryPool{";
822     if (std::holds_alternative<SharedMemory>(memoryPool)) {
823         os << std::get<SharedMemory>(memoryPool);
824     } else if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
825         const auto& token = std::get<Request::MemoryDomainToken>(memoryPool);
826         if (token == Request::MemoryDomainToken{}) {
827             os << "<invalid MemoryDomainToken>";
828         } else {
829             os << "MemoryDomainToken=" << underlyingType(token);
830         }
831     } else if (std::holds_alternative<SharedBuffer>(memoryPool)) {
832         const auto& buffer = std::get<SharedBuffer>(memoryPool);
833         os << (buffer != nullptr ? "<non-null IBuffer>" : "<null IBuffer>");
834     }
835     return os << "}";
836 }
837 
operator <<(std::ostream & os,const Request & request)838 std::ostream& operator<<(std::ostream& os, const Request& request) {
839     return os << "Request{.inputs=" << request.inputs << ", .outputs=" << request.outputs
840               << ", .pools=" << request.pools << "}";
841 }
842 
operator <<(std::ostream & os,const SyncFence::FenceState & fenceState)843 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState) {
844     switch (fenceState) {
845         case SyncFence::FenceState::ACTIVE:
846             return os << "ACTIVE";
847         case SyncFence::FenceState::SIGNALED:
848             return os << "SIGNALED";
849         case SyncFence::FenceState::ERROR:
850             return os << "ERROR";
851         case SyncFence::FenceState::UNKNOWN:
852             return os << "UNKNOWN";
853     }
854     return os << "SyncFence::FenceState{" << underlyingType(fenceState) << "}";
855 }
856 
operator <<(std::ostream & os,const TimePoint & timePoint)857 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint) {
858     return os << timePoint.time_since_epoch() << " since epoch";
859 }
860 
operator <<(std::ostream & os,const OptionalTimePoint & optionalTimePoint)861 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint) {
862     if (!optionalTimePoint.has_value()) {
863         return os << "<no time point>";
864     }
865     return os << optionalTimePoint.value();
866 }
867 
operator <<(std::ostream & os,const Duration & timeoutDuration)868 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration) {
869     return os << timeoutDuration.count() << "ns";
870 }
871 
operator <<(std::ostream & os,const OptionalDuration & optionalTimeoutDuration)872 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration) {
873     if (!optionalTimeoutDuration.has_value()) {
874         return os << "<no duration>";
875     }
876     return os << optionalTimeoutDuration.value();
877 }
878 
operator <<(std::ostream & os,const Version & version)879 std::ostream& operator<<(std::ostream& os, const Version& version) {
880     switch (version) {
881         case Version::ANDROID_OC_MR1:
882             return os << "ANDROID_OC_MR1";
883         case Version::ANDROID_P:
884             return os << "ANDROID_P";
885         case Version::ANDROID_Q:
886             return os << "ANDROID_Q";
887         case Version::ANDROID_R:
888             return os << "ANDROID_R";
889         case Version::ANDROID_S:
890             return os << "ANDROID_S";
891         case Version::CURRENT_RUNTIME:
892             return os << "CURRENT_RUNTIME";
893     }
894     return os << "Version{" << underlyingType(version) << "}";
895 }
896 
operator <<(std::ostream & os,const HalVersion & halVersion)897 std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) {
898     switch (halVersion) {
899         case HalVersion::UNKNOWN:
900             return os << "UNKNOWN HAL version";
901         case HalVersion::V1_0:
902             return os << "HAL version 1.0";
903         case HalVersion::V1_1:
904             return os << "HAL version 1.1";
905         case HalVersion::V1_2:
906             return os << "HAL version 1.2";
907         case HalVersion::V1_3:
908             return os << "HAL version 1.3";
909         case HalVersion::AIDL_UNSTABLE:
910             return os << "HAL uses unstable AIDL";
911     }
912     return os << "HalVersion{" << underlyingType(halVersion) << "}";
913 }
914 
operator ==(const Timing & a,const Timing & b)915 bool operator==(const Timing& a, const Timing& b) {
916     return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver;
917 }
operator !=(const Timing & a,const Timing & b)918 bool operator!=(const Timing& a, const Timing& b) {
919     return !(a == b);
920 }
921 
operator ==(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)922 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
923     return a.execTime == b.execTime && a.powerUsage == b.powerUsage;
924 }
operator !=(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)925 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
926     return !(a == b);
927 }
928 
operator ==(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)929 bool operator==(const Capabilities::OperandPerformance& a,
930                 const Capabilities::OperandPerformance& b) {
931     return a.type == b.type && a.info == b.info;
932 }
operator !=(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)933 bool operator!=(const Capabilities::OperandPerformance& a,
934                 const Capabilities::OperandPerformance& b) {
935     return !(a == b);
936 }
937 
operator ==(const Capabilities & a,const Capabilities & b)938 bool operator==(const Capabilities& a, const Capabilities& b) {
939     return a.relaxedFloat32toFloat16PerformanceScalar ==
940                    b.relaxedFloat32toFloat16PerformanceScalar &&
941            a.relaxedFloat32toFloat16PerformanceTensor ==
942                    b.relaxedFloat32toFloat16PerformanceTensor &&
943            a.operandPerformance.asVector() == b.operandPerformance.asVector() &&
944            a.ifPerformance == b.ifPerformance && a.whilePerformance == b.whilePerformance;
945 }
operator !=(const Capabilities & a,const Capabilities & b)946 bool operator!=(const Capabilities& a, const Capabilities& b) {
947     return !(a == b);
948 }
949 
operator ==(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)950 bool operator==(const Extension::OperandTypeInformation& a,
951                 const Extension::OperandTypeInformation& b) {
952     return a.type == b.type && a.isTensor == b.isTensor && a.byteSize == b.byteSize;
953 }
operator !=(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)954 bool operator!=(const Extension::OperandTypeInformation& a,
955                 const Extension::OperandTypeInformation& b) {
956     return !(a == b);
957 }
958 
operator ==(const Extension & a,const Extension & b)959 bool operator==(const Extension& a, const Extension& b) {
960     return a.name == b.name && a.operandTypes == b.operandTypes;
961 }
operator !=(const Extension & a,const Extension & b)962 bool operator!=(const Extension& a, const Extension& b) {
963     return !(a == b);
964 }
965 
operator ==(const MemoryPreference & a,const MemoryPreference & b)966 bool operator==(const MemoryPreference& a, const MemoryPreference& b) {
967     return a.alignment == b.alignment && a.padding == b.padding;
968 }
operator !=(const MemoryPreference & a,const MemoryPreference & b)969 bool operator!=(const MemoryPreference& a, const MemoryPreference& b) {
970     return !(a == b);
971 }
972 
operator ==(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)973 bool operator==(const Operand::SymmPerChannelQuantParams& a,
974                 const Operand::SymmPerChannelQuantParams& b) {
975     return a.scales == b.scales && a.channelDim == b.channelDim;
976 }
operator !=(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)977 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
978                 const Operand::SymmPerChannelQuantParams& b) {
979     return !(a == b);
980 }
981 
operator ==(const DataLocation & a,const DataLocation & b)982 static bool operator==(const DataLocation& a, const DataLocation& b) {
983     constexpr auto toTuple = [](const DataLocation& location) {
984         return std::tie(location.pointer, location.poolIndex, location.offset, location.length,
985                         location.padding);
986     };
987     return toTuple(a) == toTuple(b);
988 }
989 
operator ==(const Operand & a,const Operand & b)990 bool operator==(const Operand& a, const Operand& b) {
991     constexpr auto toTuple = [](const Operand& operand) {
992         return std::tie(operand.type, operand.dimensions, operand.scale, operand.zeroPoint,
993                         operand.lifetime, operand.location, operand.extraParams);
994     };
995     return toTuple(a) == toTuple(b);
996 }
operator !=(const Operand & a,const Operand & b)997 bool operator!=(const Operand& a, const Operand& b) {
998     return !(a == b);
999 }
1000 
operator ==(const Operation & a,const Operation & b)1001 bool operator==(const Operation& a, const Operation& b) {
1002     constexpr auto toTuple = [](const Operation& operation) {
1003         return std::tie(operation.type, operation.inputs, operation.outputs);
1004     };
1005     return toTuple(a) == toTuple(b);
1006 }
operator !=(const Operation & a,const Operation & b)1007 bool operator!=(const Operation& a, const Operation& b) {
1008     return !(a == b);
1009 }
1010 
1011 }  // namespace android::nn
1012