1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "TypeUtils.h"
18
19 #include <android-base/logging.h>
20
21 #include <algorithm>
22 #include <chrono>
23 #include <limits>
24 #include <memory>
25 #include <ostream>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29
30 #include "OperandTypes.h"
31 #include "OperationTypes.h"
32 #include "Result.h"
33 #include "SharedMemory.h"
34 #include "Types.h"
35
36 namespace android::nn {
37 namespace {
38
39 template <typename Type>
underlyingType(Type object)40 constexpr std::underlying_type_t<Type> underlyingType(Type object) {
41 return static_cast<std::underlying_type_t<Type>>(object);
42 }
43
44 template <typename Type>
operator <<(std::ostream & os,const std::vector<Type> & vec)45 std::ostream& operator<<(std::ostream& os, const std::vector<Type>& vec) {
46 constexpr size_t kMaxVectorPrint = 20;
47 os << "[";
48 size_t count = 0;
49 for (const auto& element : vec) {
50 if (count > 0) {
51 os << ", ";
52 }
53 os << element;
54 count++;
55 if (count >= kMaxVectorPrint) {
56 return os << "...]";
57 }
58 }
59 return os << "]";
60 }
61
62 } // namespace
63
isExtension(OperandType type)64 bool isExtension(OperandType type) {
65 return getExtensionPrefix(underlyingType(type)) != 0;
66 }
67
isExtension(OperationType type)68 bool isExtension(OperationType type) {
69 return getExtensionPrefix(underlyingType(type)) != 0;
70 }
71
isNonExtensionScalar(OperandType operandType)72 bool isNonExtensionScalar(OperandType operandType) {
73 CHECK(!isExtension(operandType));
74 switch (operandType) {
75 case OperandType::FLOAT32:
76 case OperandType::INT32:
77 case OperandType::UINT32:
78 case OperandType::BOOL:
79 case OperandType::FLOAT16:
80 case OperandType::SUBGRAPH:
81 case OperandType::OEM:
82 return true;
83 case OperandType::TENSOR_FLOAT32:
84 case OperandType::TENSOR_INT32:
85 case OperandType::TENSOR_QUANT8_ASYMM:
86 case OperandType::TENSOR_QUANT16_SYMM:
87 case OperandType::TENSOR_FLOAT16:
88 case OperandType::TENSOR_BOOL8:
89 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
90 case OperandType::TENSOR_QUANT16_ASYMM:
91 case OperandType::TENSOR_QUANT8_SYMM:
92 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
93 case OperandType::TENSOR_OEM_BYTE:
94 return false;
95 }
96 return false;
97 }
98
getNonExtensionSize(OperandType operandType)99 size_t getNonExtensionSize(OperandType operandType) {
100 CHECK(!isExtension(operandType));
101 switch (operandType) {
102 case OperandType::SUBGRAPH:
103 case OperandType::OEM:
104 return 0;
105 case OperandType::TENSOR_QUANT8_ASYMM:
106 case OperandType::BOOL:
107 case OperandType::TENSOR_BOOL8:
108 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
109 case OperandType::TENSOR_QUANT8_SYMM:
110 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
111 case OperandType::TENSOR_OEM_BYTE:
112 return 1;
113 case OperandType::TENSOR_QUANT16_SYMM:
114 case OperandType::TENSOR_FLOAT16:
115 case OperandType::FLOAT16:
116 case OperandType::TENSOR_QUANT16_ASYMM:
117 return 2;
118 case OperandType::FLOAT32:
119 case OperandType::INT32:
120 case OperandType::UINT32:
121 case OperandType::TENSOR_FLOAT32:
122 case OperandType::TENSOR_INT32:
123 return 4;
124 }
125 return 0;
126 }
127
getNonExtensionSize(OperandType operandType,const Dimensions & dimensions)128 std::optional<size_t> getNonExtensionSize(OperandType operandType, const Dimensions& dimensions) {
129 CHECK(!isExtension(operandType)) << "Size of extension operand data is unknown";
130 size_t size = getNonExtensionSize(operandType);
131 if (isNonExtensionScalar(operandType)) {
132 return size;
133 } else if (dimensions.empty()) {
134 return 0;
135 }
136 for (Dimension dimension : dimensions) {
137 if (dimension != 0 && size > std::numeric_limits<size_t>::max() / dimension) {
138 return std::nullopt;
139 }
140 size *= dimension;
141 }
142 return size;
143 }
144
getNonExtensionSize(const Operand & operand)145 std::optional<size_t> getNonExtensionSize(const Operand& operand) {
146 return getNonExtensionSize(operand.type, operand.dimensions);
147 }
148
getOffsetFromInts(int lower,int higher)149 size_t getOffsetFromInts(int lower, int higher) {
150 const int32_t lowBits = static_cast<int32_t>(lower);
151 const int32_t highBits = static_cast<int32_t>(higher);
152 const uint32_t lowOffsetBits = *reinterpret_cast<const uint32_t*>(&lowBits);
153 const uint32_t highOffsetBits = *reinterpret_cast<const uint32_t*>(&highBits);
154 const uint64_t offset = lowOffsetBits | (static_cast<uint64_t>(highOffsetBits) << 32);
155 return offset;
156 }
157
getIntsFromOffset(size_t offset)158 std::pair<int32_t, int32_t> getIntsFromOffset(size_t offset) {
159 const uint64_t bits = static_cast<uint64_t>(offset);
160 const uint32_t lowBits = static_cast<uint32_t>(bits & 0xffffffff);
161 const uint32_t highBits = static_cast<uint32_t>(bits >> 32);
162 const int32_t lowOffsetBits = *reinterpret_cast<const int32_t*>(&lowBits);
163 const int32_t highOffsetBits = *reinterpret_cast<const int32_t*>(&highBits);
164 return std::make_pair(lowOffsetBits, highOffsetBits);
165 }
166
countNumberOfConsumers(size_t numberOfOperands,const std::vector<nn::Operation> & operations)167 Result<std::vector<uint32_t>> countNumberOfConsumers(size_t numberOfOperands,
168 const std::vector<nn::Operation>& operations) {
169 std::vector<uint32_t> numberOfConsumers(numberOfOperands, 0);
170 for (const auto& operation : operations) {
171 for (uint32_t operandIndex : operation.inputs) {
172 if (operandIndex >= numberOfConsumers.size()) {
173 return NN_ERROR()
174 << "countNumberOfConsumers: tried to access out-of-bounds operand ("
175 << operandIndex << " vs " << numberOfConsumers.size() << ")";
176 }
177 numberOfConsumers[operandIndex]++;
178 }
179 }
180 return numberOfConsumers;
181 }
182
combineDimensions(const Dimensions & lhs,const Dimensions & rhs)183 Result<Dimensions> combineDimensions(const Dimensions& lhs, const Dimensions& rhs) {
184 if (rhs.empty()) return lhs;
185 if (lhs.empty()) return rhs;
186 if (lhs.size() != rhs.size()) {
187 std::ostringstream os;
188 os << "Incompatible ranks: " << lhs << " and " << rhs;
189 return NN_ERROR() << os.str();
190 }
191 Dimensions combined = lhs;
192 for (size_t i = 0; i < lhs.size(); i++) {
193 if (lhs[i] == 0) {
194 combined[i] = rhs[i];
195 } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
196 std::ostringstream os;
197 os << "Incompatible dimensions: " << lhs << " and " << rhs;
198 return NN_ERROR() << os.str();
199 }
200 }
201 return combined;
202 }
203
getMemorySizes(const Model & model)204 std::pair<size_t, std::vector<size_t>> getMemorySizes(const Model& model) {
205 const size_t operandValuesSize = model.operandValues.size();
206
207 std::vector<size_t> poolSizes;
208 poolSizes.reserve(model.pools.size());
209 std::transform(model.pools.begin(), model.pools.end(), std::back_inserter(poolSizes),
210 [](const SharedMemory& memory) { return getSize(memory); });
211
212 return std::make_pair(operandValuesSize, std::move(poolSizes));
213 }
214
roundUp(size_t size,size_t multiple)215 size_t roundUp(size_t size, size_t multiple) {
216 CHECK(multiple != 0);
217 CHECK((multiple & (multiple - 1)) == 0) << multiple << " is not a power of two";
218 return (size + (multiple - 1)) & ~(multiple - 1);
219 }
220
getAlignmentForLength(size_t length)221 size_t getAlignmentForLength(size_t length) {
222 if (length < 2) {
223 return 1; // No alignment necessary
224 } else if (length < 4) {
225 return 2; // Align on 2-byte boundary
226 } else {
227 return 4; // Align on 4-byte boundary
228 }
229 }
230
operator <<(std::ostream & os,const DeviceStatus & deviceStatus)231 std::ostream& operator<<(std::ostream& os, const DeviceStatus& deviceStatus) {
232 switch (deviceStatus) {
233 case DeviceStatus::AVAILABLE:
234 return os << "AVAILABLE";
235 case DeviceStatus::BUSY:
236 return os << "BUSY";
237 case DeviceStatus::OFFLINE:
238 return os << "OFFLINE";
239 case DeviceStatus::UNKNOWN:
240 return os << "UNKNOWN";
241 }
242 return os << "DeviceStatus{" << underlyingType(deviceStatus) << "}";
243 }
244
operator <<(std::ostream & os,const ExecutionPreference & executionPreference)245 std::ostream& operator<<(std::ostream& os, const ExecutionPreference& executionPreference) {
246 switch (executionPreference) {
247 case ExecutionPreference::LOW_POWER:
248 return os << "LOW_POWER";
249 case ExecutionPreference::FAST_SINGLE_ANSWER:
250 return os << "FAST_SINGLE_ANSWER";
251 case ExecutionPreference::SUSTAINED_SPEED:
252 return os << "SUSTAINED_SPEED";
253 }
254 return os << "ExecutionPreference{" << underlyingType(executionPreference) << "}";
255 }
256
operator <<(std::ostream & os,const DeviceType & deviceType)257 std::ostream& operator<<(std::ostream& os, const DeviceType& deviceType) {
258 switch (deviceType) {
259 case DeviceType::UNKNOWN:
260 return os << "UNKNOWN";
261 case DeviceType::OTHER:
262 return os << "OTHER";
263 case DeviceType::CPU:
264 return os << "CPU";
265 case DeviceType::GPU:
266 return os << "GPU";
267 case DeviceType::ACCELERATOR:
268 return os << "ACCELERATOR";
269 }
270 return os << "DeviceType{" << underlyingType(deviceType) << "}";
271 }
272
operator <<(std::ostream & os,const MeasureTiming & measureTiming)273 std::ostream& operator<<(std::ostream& os, const MeasureTiming& measureTiming) {
274 switch (measureTiming) {
275 case MeasureTiming::NO:
276 return os << "NO";
277 case MeasureTiming::YES:
278 return os << "YES";
279 }
280 return os << "MeasureTiming{" << underlyingType(measureTiming) << "}";
281 }
282
operator <<(std::ostream & os,const OperandType & operandType)283 std::ostream& operator<<(std::ostream& os, const OperandType& operandType) {
284 switch (operandType) {
285 case OperandType::FLOAT32:
286 return os << "FLOAT32";
287 case OperandType::INT32:
288 return os << "INT32";
289 case OperandType::UINT32:
290 return os << "UINT32";
291 case OperandType::TENSOR_FLOAT32:
292 return os << "TENSOR_FLOAT32";
293 case OperandType::TENSOR_INT32:
294 return os << "TENSOR_INT32";
295 case OperandType::TENSOR_QUANT8_ASYMM:
296 return os << "TENSOR_QUANT8_ASYMM";
297 case OperandType::BOOL:
298 return os << "BOOL";
299 case OperandType::TENSOR_QUANT16_SYMM:
300 return os << "TENSOR_QUANT16_SYMM";
301 case OperandType::TENSOR_FLOAT16:
302 return os << "TENSOR_FLOAT16";
303 case OperandType::TENSOR_BOOL8:
304 return os << "TENSOR_BOOL8";
305 case OperandType::FLOAT16:
306 return os << "FLOAT16";
307 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
308 return os << "TENSOR_QUANT8_SYMM_PER_CHANNEL";
309 case OperandType::TENSOR_QUANT16_ASYMM:
310 return os << "TENSOR_QUANT16_ASYMM";
311 case OperandType::TENSOR_QUANT8_SYMM:
312 return os << "TENSOR_QUANT8_SYMM";
313 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
314 return os << "TENSOR_QUANT8_ASYMM_SIGNED";
315 case OperandType::SUBGRAPH:
316 return os << "SUBGRAPH";
317 case OperandType::OEM:
318 return os << "OEM";
319 case OperandType::TENSOR_OEM_BYTE:
320 return os << "TENSOR_OEM_BYTE";
321 }
322 if (isExtension(operandType)) {
323 return os << "Extension OperandType " << underlyingType(operandType);
324 }
325 return os << "OperandType{" << underlyingType(operandType) << "}";
326 }
327
operator <<(std::ostream & os,const Operand::LifeTime & lifetime)328 std::ostream& operator<<(std::ostream& os, const Operand::LifeTime& lifetime) {
329 switch (lifetime) {
330 case Operand::LifeTime::TEMPORARY_VARIABLE:
331 return os << "TEMPORARY_VARIABLE";
332 case Operand::LifeTime::SUBGRAPH_INPUT:
333 return os << "SUBGRAPH_INPUT";
334 case Operand::LifeTime::SUBGRAPH_OUTPUT:
335 return os << "SUBGRAPH_OUTPUT";
336 case Operand::LifeTime::CONSTANT_COPY:
337 return os << "CONSTANT_COPY";
338 case Operand::LifeTime::CONSTANT_REFERENCE:
339 return os << "CONSTANT_REFERENCE";
340 case Operand::LifeTime::NO_VALUE:
341 return os << "NO_VALUE";
342 case Operand::LifeTime::SUBGRAPH:
343 return os << "SUBGRAPH";
344 case Operand::LifeTime::POINTER:
345 return os << "POINTER";
346 }
347 return os << "Operand::LifeTime{" << underlyingType(lifetime) << "}";
348 }
349
operator <<(std::ostream & os,const OperationType & operationType)350 std::ostream& operator<<(std::ostream& os, const OperationType& operationType) {
351 switch (operationType) {
352 case OperationType::ADD:
353 return os << "ADD";
354 case OperationType::AVERAGE_POOL_2D:
355 return os << "AVERAGE_POOL_2D";
356 case OperationType::CONCATENATION:
357 return os << "CONCATENATION";
358 case OperationType::CONV_2D:
359 return os << "CONV_2D";
360 case OperationType::DEPTHWISE_CONV_2D:
361 return os << "DEPTHWISE_CONV_2D";
362 case OperationType::DEPTH_TO_SPACE:
363 return os << "DEPTH_TO_SPACE";
364 case OperationType::DEQUANTIZE:
365 return os << "DEQUANTIZE";
366 case OperationType::EMBEDDING_LOOKUP:
367 return os << "EMBEDDING_LOOKUP";
368 case OperationType::FLOOR:
369 return os << "FLOOR";
370 case OperationType::FULLY_CONNECTED:
371 return os << "FULLY_CONNECTED";
372 case OperationType::HASHTABLE_LOOKUP:
373 return os << "HASHTABLE_LOOKUP";
374 case OperationType::L2_NORMALIZATION:
375 return os << "L2_NORMALIZATION";
376 case OperationType::L2_POOL_2D:
377 return os << "L2_POOL_2D";
378 case OperationType::LOCAL_RESPONSE_NORMALIZATION:
379 return os << "LOCAL_RESPONSE_NORMALIZATION";
380 case OperationType::LOGISTIC:
381 return os << "LOGISTIC";
382 case OperationType::LSH_PROJECTION:
383 return os << "LSH_PROJECTION";
384 case OperationType::LSTM:
385 return os << "LSTM";
386 case OperationType::MAX_POOL_2D:
387 return os << "MAX_POOL_2D";
388 case OperationType::MUL:
389 return os << "MUL";
390 case OperationType::RELU:
391 return os << "RELU";
392 case OperationType::RELU1:
393 return os << "RELU1";
394 case OperationType::RELU6:
395 return os << "RELU6";
396 case OperationType::RESHAPE:
397 return os << "RESHAPE";
398 case OperationType::RESIZE_BILINEAR:
399 return os << "RESIZE_BILINEAR";
400 case OperationType::RNN:
401 return os << "RNN";
402 case OperationType::SOFTMAX:
403 return os << "SOFTMAX";
404 case OperationType::SPACE_TO_DEPTH:
405 return os << "SPACE_TO_DEPTH";
406 case OperationType::SVDF:
407 return os << "SVDF";
408 case OperationType::TANH:
409 return os << "TANH";
410 case OperationType::BATCH_TO_SPACE_ND:
411 return os << "BATCH_TO_SPACE_ND";
412 case OperationType::DIV:
413 return os << "DIV";
414 case OperationType::MEAN:
415 return os << "MEAN";
416 case OperationType::PAD:
417 return os << "PAD";
418 case OperationType::SPACE_TO_BATCH_ND:
419 return os << "SPACE_TO_BATCH_ND";
420 case OperationType::SQUEEZE:
421 return os << "SQUEEZE";
422 case OperationType::STRIDED_SLICE:
423 return os << "STRIDED_SLICE";
424 case OperationType::SUB:
425 return os << "SUB";
426 case OperationType::TRANSPOSE:
427 return os << "TRANSPOSE";
428 case OperationType::ABS:
429 return os << "ABS";
430 case OperationType::ARGMAX:
431 return os << "ARGMAX";
432 case OperationType::ARGMIN:
433 return os << "ARGMIN";
434 case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM:
435 return os << "AXIS_ALIGNED_BBOX_TRANSFORM";
436 case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM:
437 return os << "BIDIRECTIONAL_SEQUENCE_LSTM";
438 case OperationType::BIDIRECTIONAL_SEQUENCE_RNN:
439 return os << "BIDIRECTIONAL_SEQUENCE_RNN";
440 case OperationType::BOX_WITH_NMS_LIMIT:
441 return os << "BOX_WITH_NMS_LIMIT";
442 case OperationType::CAST:
443 return os << "CAST";
444 case OperationType::CHANNEL_SHUFFLE:
445 return os << "CHANNEL_SHUFFLE";
446 case OperationType::DETECTION_POSTPROCESSING:
447 return os << "DETECTION_POSTPROCESSING";
448 case OperationType::EQUAL:
449 return os << "EQUAL";
450 case OperationType::EXP:
451 return os << "EXP";
452 case OperationType::EXPAND_DIMS:
453 return os << "EXPAND_DIMS";
454 case OperationType::GATHER:
455 return os << "GATHER";
456 case OperationType::GENERATE_PROPOSALS:
457 return os << "GENERATE_PROPOSALS";
458 case OperationType::GREATER:
459 return os << "GREATER";
460 case OperationType::GREATER_EQUAL:
461 return os << "GREATER_EQUAL";
462 case OperationType::GROUPED_CONV_2D:
463 return os << "GROUPED_CONV_2D";
464 case OperationType::HEATMAP_MAX_KEYPOINT:
465 return os << "HEATMAP_MAX_KEYPOINT";
466 case OperationType::INSTANCE_NORMALIZATION:
467 return os << "INSTANCE_NORMALIZATION";
468 case OperationType::LESS:
469 return os << "LESS";
470 case OperationType::LESS_EQUAL:
471 return os << "LESS_EQUAL";
472 case OperationType::LOG:
473 return os << "LOG";
474 case OperationType::LOGICAL_AND:
475 return os << "LOGICAL_AND";
476 case OperationType::LOGICAL_NOT:
477 return os << "LOGICAL_NOT";
478 case OperationType::LOGICAL_OR:
479 return os << "LOGICAL_OR";
480 case OperationType::LOG_SOFTMAX:
481 return os << "LOG_SOFTMAX";
482 case OperationType::MAXIMUM:
483 return os << "MAXIMUM";
484 case OperationType::MINIMUM:
485 return os << "MINIMUM";
486 case OperationType::NEG:
487 return os << "NEG";
488 case OperationType::NOT_EQUAL:
489 return os << "NOT_EQUAL";
490 case OperationType::PAD_V2:
491 return os << "PAD_V2";
492 case OperationType::POW:
493 return os << "POW";
494 case OperationType::PRELU:
495 return os << "PRELU";
496 case OperationType::QUANTIZE:
497 return os << "QUANTIZE";
498 case OperationType::QUANTIZED_16BIT_LSTM:
499 return os << "QUANTIZED_16BIT_LSTM";
500 case OperationType::RANDOM_MULTINOMIAL:
501 return os << "RANDOM_MULTINOMIAL";
502 case OperationType::REDUCE_ALL:
503 return os << "REDUCE_ALL";
504 case OperationType::REDUCE_ANY:
505 return os << "REDUCE_ANY";
506 case OperationType::REDUCE_MAX:
507 return os << "REDUCE_MAX";
508 case OperationType::REDUCE_MIN:
509 return os << "REDUCE_MIN";
510 case OperationType::REDUCE_PROD:
511 return os << "REDUCE_PROD";
512 case OperationType::REDUCE_SUM:
513 return os << "REDUCE_SUM";
514 case OperationType::ROI_ALIGN:
515 return os << "ROI_ALIGN";
516 case OperationType::ROI_POOLING:
517 return os << "ROI_POOLING";
518 case OperationType::RSQRT:
519 return os << "RSQRT";
520 case OperationType::SELECT:
521 return os << "SELECT";
522 case OperationType::SIN:
523 return os << "SIN";
524 case OperationType::SLICE:
525 return os << "SLICE";
526 case OperationType::SPLIT:
527 return os << "SPLIT";
528 case OperationType::SQRT:
529 return os << "SQRT";
530 case OperationType::TILE:
531 return os << "TILE";
532 case OperationType::TOPK_V2:
533 return os << "TOPK_V2";
534 case OperationType::TRANSPOSE_CONV_2D:
535 return os << "TRANSPOSE_CONV_2D";
536 case OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM:
537 return os << "UNIDIRECTIONAL_SEQUENCE_LSTM";
538 case OperationType::UNIDIRECTIONAL_SEQUENCE_RNN:
539 return os << "UNIDIRECTIONAL_SEQUENCE_RNN";
540 case OperationType::RESIZE_NEAREST_NEIGHBOR:
541 return os << "RESIZE_NEAREST_NEIGHBOR";
542 case OperationType::QUANTIZED_LSTM:
543 return os << "QUANTIZED_LSTM";
544 case OperationType::IF:
545 return os << "IF";
546 case OperationType::WHILE:
547 return os << "WHILE";
548 case OperationType::ELU:
549 return os << "ELU";
550 case OperationType::HARD_SWISH:
551 return os << "HARD_SWISH";
552 case OperationType::FILL:
553 return os << "FILL";
554 case OperationType::RANK:
555 return os << "RANK";
556 case OperationType::OEM_OPERATION:
557 return os << "OEM_OPERATION";
558 }
559 if (isExtension(operationType)) {
560 return os << "Extension OperationType " << underlyingType(operationType);
561 }
562 return os << "OperationType{" << underlyingType(operationType) << "}";
563 }
564
operator <<(std::ostream & os,const Request::Argument::LifeTime & lifetime)565 std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime) {
566 switch (lifetime) {
567 case Request::Argument::LifeTime::POOL:
568 return os << "POOL";
569 case Request::Argument::LifeTime::NO_VALUE:
570 return os << "NO_VALUE";
571 case Request::Argument::LifeTime::POINTER:
572 return os << "POINTER";
573 }
574 return os << "Request::Argument::LifeTime{" << underlyingType(lifetime) << "}";
575 }
576
operator <<(std::ostream & os,const Priority & priority)577 std::ostream& operator<<(std::ostream& os, const Priority& priority) {
578 switch (priority) {
579 case Priority::LOW:
580 return os << "LOW";
581 case Priority::MEDIUM:
582 return os << "MEDIUM";
583 case Priority::HIGH:
584 return os << "HIGH";
585 }
586 return os << "Priority{" << underlyingType(priority) << "}";
587 }
588
operator <<(std::ostream & os,const ErrorStatus & errorStatus)589 std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) {
590 switch (errorStatus) {
591 case ErrorStatus::NONE:
592 return os << "NONE";
593 case ErrorStatus::DEVICE_UNAVAILABLE:
594 return os << "DEVICE_UNAVAILABLE";
595 case ErrorStatus::GENERAL_FAILURE:
596 return os << "GENERAL_FAILURE";
597 case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
598 return os << "OUTPUT_INSUFFICIENT_SIZE";
599 case ErrorStatus::INVALID_ARGUMENT:
600 return os << "INVALID_ARGUMENT";
601 case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
602 return os << "MISSED_DEADLINE_TRANSIENT";
603 case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
604 return os << "MISSED_DEADLINE_PERSISTENT";
605 case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
606 return os << "RESOURCE_EXHAUSTED_TRANSIENT";
607 case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
608 return os << "RESOURCE_EXHAUSTED_PERSISTENT";
609 case ErrorStatus::DEAD_OBJECT:
610 return os << "DEAD_OBJECT";
611 }
612 return os << "ErrorStatus{" << underlyingType(errorStatus) << "}";
613 }
614
operator <<(std::ostream & os,const FusedActivationFunc & activation)615 std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation) {
616 switch (activation) {
617 case FusedActivationFunc::NONE:
618 return os << "NONE";
619 case FusedActivationFunc::RELU:
620 return os << "RELU";
621 case FusedActivationFunc::RELU1:
622 return os << "RELU1";
623 case FusedActivationFunc::RELU6:
624 return os << "RELU6";
625 }
626 return os << "FusedActivationFunc{" << underlyingType(activation) << "}";
627 }
628
operator <<(std::ostream & os,const OutputShape & outputShape)629 std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) {
630 return os << "OutputShape{.dimensions=" << outputShape.dimensions
631 << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}";
632 }
633
operator <<(std::ostream & os,const Timing & timing)634 std::ostream& operator<<(std::ostream& os, const Timing& timing) {
635 return os << "Timing{.timeOnDevice=" << timing.timeOnDevice
636 << ", .timeInDriver=" << timing.timeInDriver << "}";
637 }
638
operator <<(std::ostream & os,const Capabilities::PerformanceInfo & performanceInfo)639 std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo) {
640 return os << "Capabilities::PerformanceInfo{.execTime=" << performanceInfo.execTime
641 << ", .powerUsage=" << performanceInfo.powerUsage << "}";
642 }
643
operator <<(std::ostream & os,const Capabilities::OperandPerformance & operandPerformance)644 std::ostream& operator<<(std::ostream& os,
645 const Capabilities::OperandPerformance& operandPerformance) {
646 return os << "Capabilities::OperandPerformance{.type=" << operandPerformance.type
647 << ", .info=" << operandPerformance.info << "}";
648 }
649
operator <<(std::ostream & os,const Capabilities::OperandPerformanceTable & operandPerformances)650 std::ostream& operator<<(std::ostream& os,
651 const Capabilities::OperandPerformanceTable& operandPerformances) {
652 return os << operandPerformances.asVector();
653 }
654
operator <<(std::ostream & os,const Capabilities & capabilities)655 std::ostream& operator<<(std::ostream& os, const Capabilities& capabilities) {
656 return os << "Capabilities{.relaxedFloat32toFloat16PerformanceScalar="
657 << capabilities.relaxedFloat32toFloat16PerformanceScalar
658 << ", .relaxedFloat32toFloat16PerformanceTensor="
659 << capabilities.relaxedFloat32toFloat16PerformanceTensor
660 << ", .operandPerformance=" << capabilities.operandPerformance
661 << ", .ifPerformance=" << capabilities.ifPerformance
662 << ", .whilePerformance=" << capabilities.whilePerformance << "}";
663 }
664
operator <<(std::ostream & os,const Extension::OperandTypeInformation & operandTypeInformation)665 std::ostream& operator<<(std::ostream& os,
666 const Extension::OperandTypeInformation& operandTypeInformation) {
667 return os << "Extension::OperandTypeInformation{.type=" << operandTypeInformation.type
668 << ", .isTensor=" << (operandTypeInformation.isTensor ? "true" : "false")
669 << ", .byteSize=" << operandTypeInformation.byteSize << "}";
670 }
671
operator <<(std::ostream & os,const Extension & extension)672 std::ostream& operator<<(std::ostream& os, const Extension& extension) {
673 return os << "Extension{.name=" << extension.name
674 << ", .operandTypes=" << extension.operandTypes << "}";
675 }
676
operator <<(std::ostream & os,const DataLocation & location)677 std::ostream& operator<<(std::ostream& os, const DataLocation& location) {
678 const auto printPointer = [&os](const std::variant<const void*, void*>& pointer) {
679 os << (std::holds_alternative<const void*>(pointer) ? "<constant " : "<mutable ");
680 os << std::visit(
681 [](const auto* ptr) {
682 return ptr == nullptr ? "null pointer>" : "non-null pointer>";
683 },
684 pointer);
685 };
686 os << "DataLocation{.pointer=";
687 printPointer(location.pointer);
688 return os << ", .poolIndex=" << location.poolIndex << ", .offset=" << location.offset
689 << ", .length=" << location.length << ", .padding=" << location.padding << "}";
690 }
691
operator <<(std::ostream & os,const Operand::SymmPerChannelQuantParams & symmPerChannelQuantParams)692 std::ostream& operator<<(std::ostream& os,
693 const Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
694 return os << "Operand::SymmPerChannelQuantParams{.scales=" << symmPerChannelQuantParams.scales
695 << ", .channelDim=" << symmPerChannelQuantParams.channelDim << "}";
696 }
697
operator <<(std::ostream & os,const Operand::ExtraParams & extraParams)698 std::ostream& operator<<(std::ostream& os, const Operand::ExtraParams& extraParams) {
699 os << "Operand::ExtraParams{";
700 if (std::holds_alternative<Operand::NoParams>(extraParams)) {
701 os << "<no params>";
702 } else if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(extraParams)) {
703 os << std::get<Operand::SymmPerChannelQuantParams>(extraParams);
704 } else if (std::holds_alternative<Operand::ExtensionParams>(extraParams)) {
705 os << std::get<Operand::ExtensionParams>(extraParams);
706 }
707 return os << "}";
708 }
709
operator <<(std::ostream & os,const Operand & operand)710 std::ostream& operator<<(std::ostream& os, const Operand& operand) {
711 return os << "Operand{.type=" << operand.type << ", .dimensions=" << operand.dimensions
712 << ", .scale=" << operand.scale << ", .zeroPoint=" << operand.zeroPoint
713 << ", lifetime=" << operand.lifetime << ", .location=" << operand.location
714 << ", .extraParams=" << operand.extraParams << "}";
715 }
716
operator <<(std::ostream & os,const Operation & operation)717 std::ostream& operator<<(std::ostream& os, const Operation& operation) {
718 return os << "Operation{.type=" << operation.type << ", .inputs=" << operation.inputs
719 << ", .outputs=" << operation.outputs << "}";
720 }
721
operator <<(std::ostream & os,const Handle & handle)722 static std::ostream& operator<<(std::ostream& os, const Handle& handle) {
723 return os << "<handle with " << handle.fds.size() << " fds and " << handle.ints.size()
724 << " ints>";
725 }
726
operator <<(std::ostream & os,const SharedHandle & handle)727 std::ostream& operator<<(std::ostream& os, const SharedHandle& handle) {
728 if (handle == nullptr) {
729 return os << "<empty handle>";
730 }
731 return os << *handle;
732 }
733
operator <<(std::ostream & os,const Memory::Ashmem & memory)734 static std::ostream& operator<<(std::ostream& os, const Memory::Ashmem& memory) {
735 return os << "Ashmem{.fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
736 << ", .size=" << memory.size << "}";
737 }
738
operator <<(std::ostream & os,const Memory::Fd & memory)739 static std::ostream& operator<<(std::ostream& os, const Memory::Fd& memory) {
740 return os << "Fd{.size=" << memory.size << ", .prot=" << memory.prot
741 << ", .fd=" << (memory.fd.ok() ? "<valid fd>" : "<invalid fd>")
742 << ", .offset=" << memory.offset << "}";
743 }
744
operator <<(std::ostream & os,const Memory::HardwareBuffer & memory)745 static std::ostream& operator<<(std::ostream& os, const Memory::HardwareBuffer& memory) {
746 if (memory.handle.get() == nullptr) {
747 return os << "<empty HardwareBuffer::Handle>";
748 }
749 return os << (isAhwbBlob(memory) ? "<AHardwareBuffer blob>" : "<non-blob AHardwareBuffer>");
750 }
751
operator <<(std::ostream & os,const Memory::Unknown & memory)752 static std::ostream& operator<<(std::ostream& os, const Memory::Unknown& memory) {
753 return os << "Unknown{.handle=" << memory.handle << ", .size=" << memory.size
754 << ", .name=" << memory.name << "}";
755 }
756
operator <<(std::ostream & os,const Memory & memory)757 std::ostream& operator<<(std::ostream& os, const Memory& memory) {
758 os << "Memory{.handle=";
759 std::visit([&os](const auto& x) { os << x; }, memory.handle);
760 return os << "}";
761 }
762
operator <<(std::ostream & os,const SharedMemory & memory)763 std::ostream& operator<<(std::ostream& os, const SharedMemory& memory) {
764 if (memory == nullptr) {
765 return os << "<empty memory>";
766 }
767 return os << *memory;
768 }
769
operator <<(std::ostream & os,const MemoryPreference & memoryPreference)770 std::ostream& operator<<(std::ostream& os, const MemoryPreference& memoryPreference) {
771 return os << "MemoryPreference{.alignment=" << memoryPreference.alignment
772 << ", .padding=" << memoryPreference.padding << "}";
773 }
774
operator <<(std::ostream & os,const Model::Subgraph & subgraph)775 std::ostream& operator<<(std::ostream& os, const Model::Subgraph& subgraph) {
776 std::vector<Operand> operands;
777 std::vector<Operation> operations;
778 std::vector<uint32_t> inputIndexes;
779 std::vector<uint32_t> outputIndexes;
780 return os << "Model::Subgraph{.operands=" << subgraph.operands
781 << ", .operations=" << subgraph.operations
782 << ", .inputIndexes=" << subgraph.inputIndexes
783 << ", .outputIndexes=" << subgraph.outputIndexes << "}";
784 }
785
operator <<(std::ostream & os,const Model::OperandValues & operandValues)786 std::ostream& operator<<(std::ostream& os, const Model::OperandValues& operandValues) {
787 return os << "Model::OperandValues{<" << operandValues.size() << "bytes>}";
788 }
789
operator <<(std::ostream & os,const Model::ExtensionNameAndPrefix & extensionNameAndPrefix)790 std::ostream& operator<<(std::ostream& os,
791 const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
792 return os << "Model::ExtensionNameAndPrefix{.name=" << extensionNameAndPrefix.name
793 << ", .prefix=" << extensionNameAndPrefix.prefix << "}";
794 }
795
operator <<(std::ostream & os,const Model & model)796 std::ostream& operator<<(std::ostream& os, const Model& model) {
797 return os << "Model{.main=" << model.main << ", .referenced=" << model.referenced
798 << ", .operandValues=" << model.operandValues << ", .pools=" << model.pools
799 << ", .relaxComputationFloat32toFloat16="
800 << (model.relaxComputationFloat32toFloat16 ? "true" : "false")
801 << ", extensionNameToPrefix=" << model.extensionNameToPrefix << "}";
802 }
803
operator <<(std::ostream & os,const BufferDesc & bufferDesc)804 std::ostream& operator<<(std::ostream& os, const BufferDesc& bufferDesc) {
805 return os << "BufferDesc{.dimensions=" << bufferDesc.dimensions << "}";
806 }
807
operator <<(std::ostream & os,const BufferRole & bufferRole)808 std::ostream& operator<<(std::ostream& os, const BufferRole& bufferRole) {
809 return os << "BufferRole{.modelIndex=" << bufferRole.modelIndex
810 << ", .ioIndex=" << bufferRole.ioIndex << ", .probability=" << bufferRole.probability
811 << "}";
812 }
813
operator <<(std::ostream & os,const Request::Argument & requestArgument)814 std::ostream& operator<<(std::ostream& os, const Request::Argument& requestArgument) {
815 return os << "Request::Argument{.lifetime=" << requestArgument.lifetime
816 << ", .location=" << requestArgument.location
817 << ", .dimensions=" << requestArgument.dimensions << "}";
818 }
819
operator <<(std::ostream & os,const Request::MemoryPool & memoryPool)820 std::ostream& operator<<(std::ostream& os, const Request::MemoryPool& memoryPool) {
821 os << "Request::MemoryPool{";
822 if (std::holds_alternative<SharedMemory>(memoryPool)) {
823 os << std::get<SharedMemory>(memoryPool);
824 } else if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
825 const auto& token = std::get<Request::MemoryDomainToken>(memoryPool);
826 if (token == Request::MemoryDomainToken{}) {
827 os << "<invalid MemoryDomainToken>";
828 } else {
829 os << "MemoryDomainToken=" << underlyingType(token);
830 }
831 } else if (std::holds_alternative<SharedBuffer>(memoryPool)) {
832 const auto& buffer = std::get<SharedBuffer>(memoryPool);
833 os << (buffer != nullptr ? "<non-null IBuffer>" : "<null IBuffer>");
834 }
835 return os << "}";
836 }
837
operator <<(std::ostream & os,const Request & request)838 std::ostream& operator<<(std::ostream& os, const Request& request) {
839 return os << "Request{.inputs=" << request.inputs << ", .outputs=" << request.outputs
840 << ", .pools=" << request.pools << "}";
841 }
842
operator <<(std::ostream & os,const SyncFence::FenceState & fenceState)843 std::ostream& operator<<(std::ostream& os, const SyncFence::FenceState& fenceState) {
844 switch (fenceState) {
845 case SyncFence::FenceState::ACTIVE:
846 return os << "ACTIVE";
847 case SyncFence::FenceState::SIGNALED:
848 return os << "SIGNALED";
849 case SyncFence::FenceState::ERROR:
850 return os << "ERROR";
851 case SyncFence::FenceState::UNKNOWN:
852 return os << "UNKNOWN";
853 }
854 return os << "SyncFence::FenceState{" << underlyingType(fenceState) << "}";
855 }
856
operator <<(std::ostream & os,const TimePoint & timePoint)857 std::ostream& operator<<(std::ostream& os, const TimePoint& timePoint) {
858 return os << timePoint.time_since_epoch() << " since epoch";
859 }
860
operator <<(std::ostream & os,const OptionalTimePoint & optionalTimePoint)861 std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTimePoint) {
862 if (!optionalTimePoint.has_value()) {
863 return os << "<no time point>";
864 }
865 return os << optionalTimePoint.value();
866 }
867
operator <<(std::ostream & os,const Duration & timeoutDuration)868 std::ostream& operator<<(std::ostream& os, const Duration& timeoutDuration) {
869 return os << timeoutDuration.count() << "ns";
870 }
871
operator <<(std::ostream & os,const OptionalDuration & optionalTimeoutDuration)872 std::ostream& operator<<(std::ostream& os, const OptionalDuration& optionalTimeoutDuration) {
873 if (!optionalTimeoutDuration.has_value()) {
874 return os << "<no duration>";
875 }
876 return os << optionalTimeoutDuration.value();
877 }
878
operator <<(std::ostream & os,const Version & version)879 std::ostream& operator<<(std::ostream& os, const Version& version) {
880 switch (version) {
881 case Version::ANDROID_OC_MR1:
882 return os << "ANDROID_OC_MR1";
883 case Version::ANDROID_P:
884 return os << "ANDROID_P";
885 case Version::ANDROID_Q:
886 return os << "ANDROID_Q";
887 case Version::ANDROID_R:
888 return os << "ANDROID_R";
889 case Version::ANDROID_S:
890 return os << "ANDROID_S";
891 case Version::CURRENT_RUNTIME:
892 return os << "CURRENT_RUNTIME";
893 }
894 return os << "Version{" << underlyingType(version) << "}";
895 }
896
operator <<(std::ostream & os,const HalVersion & halVersion)897 std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) {
898 switch (halVersion) {
899 case HalVersion::UNKNOWN:
900 return os << "UNKNOWN HAL version";
901 case HalVersion::V1_0:
902 return os << "HAL version 1.0";
903 case HalVersion::V1_1:
904 return os << "HAL version 1.1";
905 case HalVersion::V1_2:
906 return os << "HAL version 1.2";
907 case HalVersion::V1_3:
908 return os << "HAL version 1.3";
909 case HalVersion::AIDL_UNSTABLE:
910 return os << "HAL uses unstable AIDL";
911 }
912 return os << "HalVersion{" << underlyingType(halVersion) << "}";
913 }
914
operator ==(const Timing & a,const Timing & b)915 bool operator==(const Timing& a, const Timing& b) {
916 return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver;
917 }
operator !=(const Timing & a,const Timing & b)918 bool operator!=(const Timing& a, const Timing& b) {
919 return !(a == b);
920 }
921
operator ==(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)922 bool operator==(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
923 return a.execTime == b.execTime && a.powerUsage == b.powerUsage;
924 }
operator !=(const Capabilities::PerformanceInfo & a,const Capabilities::PerformanceInfo & b)925 bool operator!=(const Capabilities::PerformanceInfo& a, const Capabilities::PerformanceInfo& b) {
926 return !(a == b);
927 }
928
operator ==(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)929 bool operator==(const Capabilities::OperandPerformance& a,
930 const Capabilities::OperandPerformance& b) {
931 return a.type == b.type && a.info == b.info;
932 }
operator !=(const Capabilities::OperandPerformance & a,const Capabilities::OperandPerformance & b)933 bool operator!=(const Capabilities::OperandPerformance& a,
934 const Capabilities::OperandPerformance& b) {
935 return !(a == b);
936 }
937
operator ==(const Capabilities & a,const Capabilities & b)938 bool operator==(const Capabilities& a, const Capabilities& b) {
939 return a.relaxedFloat32toFloat16PerformanceScalar ==
940 b.relaxedFloat32toFloat16PerformanceScalar &&
941 a.relaxedFloat32toFloat16PerformanceTensor ==
942 b.relaxedFloat32toFloat16PerformanceTensor &&
943 a.operandPerformance.asVector() == b.operandPerformance.asVector() &&
944 a.ifPerformance == b.ifPerformance && a.whilePerformance == b.whilePerformance;
945 }
operator !=(const Capabilities & a,const Capabilities & b)946 bool operator!=(const Capabilities& a, const Capabilities& b) {
947 return !(a == b);
948 }
949
operator ==(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)950 bool operator==(const Extension::OperandTypeInformation& a,
951 const Extension::OperandTypeInformation& b) {
952 return a.type == b.type && a.isTensor == b.isTensor && a.byteSize == b.byteSize;
953 }
operator !=(const Extension::OperandTypeInformation & a,const Extension::OperandTypeInformation & b)954 bool operator!=(const Extension::OperandTypeInformation& a,
955 const Extension::OperandTypeInformation& b) {
956 return !(a == b);
957 }
958
operator ==(const Extension & a,const Extension & b)959 bool operator==(const Extension& a, const Extension& b) {
960 return a.name == b.name && a.operandTypes == b.operandTypes;
961 }
operator !=(const Extension & a,const Extension & b)962 bool operator!=(const Extension& a, const Extension& b) {
963 return !(a == b);
964 }
965
operator ==(const MemoryPreference & a,const MemoryPreference & b)966 bool operator==(const MemoryPreference& a, const MemoryPreference& b) {
967 return a.alignment == b.alignment && a.padding == b.padding;
968 }
operator !=(const MemoryPreference & a,const MemoryPreference & b)969 bool operator!=(const MemoryPreference& a, const MemoryPreference& b) {
970 return !(a == b);
971 }
972
operator ==(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)973 bool operator==(const Operand::SymmPerChannelQuantParams& a,
974 const Operand::SymmPerChannelQuantParams& b) {
975 return a.scales == b.scales && a.channelDim == b.channelDim;
976 }
operator !=(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)977 bool operator!=(const Operand::SymmPerChannelQuantParams& a,
978 const Operand::SymmPerChannelQuantParams& b) {
979 return !(a == b);
980 }
981
operator ==(const DataLocation & a,const DataLocation & b)982 static bool operator==(const DataLocation& a, const DataLocation& b) {
983 constexpr auto toTuple = [](const DataLocation& location) {
984 return std::tie(location.pointer, location.poolIndex, location.offset, location.length,
985 location.padding);
986 };
987 return toTuple(a) == toTuple(b);
988 }
989
operator ==(const Operand & a,const Operand & b)990 bool operator==(const Operand& a, const Operand& b) {
991 constexpr auto toTuple = [](const Operand& operand) {
992 return std::tie(operand.type, operand.dimensions, operand.scale, operand.zeroPoint,
993 operand.lifetime, operand.location, operand.extraParams);
994 };
995 return toTuple(a) == toTuple(b);
996 }
operator !=(const Operand & a,const Operand & b)997 bool operator!=(const Operand& a, const Operand& b) {
998 return !(a == b);
999 }
1000
operator ==(const Operation & a,const Operation & b)1001 bool operator==(const Operation& a, const Operation& b) {
1002 constexpr auto toTuple = [](const Operation& operation) {
1003 return std::tie(operation.type, operation.inputs, operation.outputs);
1004 };
1005 return toTuple(a) == toTuple(b);
1006 }
operator !=(const Operation & a,const Operation & b)1007 bool operator!=(const Operation& a, const Operation& b) {
1008 return !(a == b);
1009 }
1010
1011 } // namespace android::nn
1012