• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Converter.h"
18 
19 #include <algorithm>
20 #include <random>
21 #include <utility>
22 #include <vector>
23 
24 namespace android::nn::fuzz {
25 namespace {
26 
27 using namespace test_helper;
28 using namespace android_nn_fuzz;
29 
30 constexpr uint32_t kMaxSize = 65536;
31 
convert(OperandType type)32 TestOperandType convert(OperandType type) {
33     return static_cast<TestOperandType>(type);
34 }
35 
convert(OperationType type)36 TestOperationType convert(OperationType type) {
37     return static_cast<TestOperationType>(type);
38 }
39 
convert(OperandLifeTime lifetime)40 TestOperandLifeTime convert(OperandLifeTime lifetime) {
41     return static_cast<TestOperandLifeTime>(lifetime);
42 }
43 
convert(const Scales & scales)44 std::vector<float> convert(const Scales& scales) {
45     const auto& repeatedScale = scales.scale();
46     return std::vector<float>(repeatedScale.begin(), repeatedScale.end());
47 }
48 
convert(const SymmPerChannelQuantParams & params)49 TestSymmPerChannelQuantParams convert(const SymmPerChannelQuantParams& params) {
50     std::vector<float> scales = convert(params.scales());
51     const uint32_t channelDim = params.channel_dim();
52     return {.scales = std::move(scales), .channelDim = channelDim};
53 }
54 
convert(const Dimensions & dimensions)55 std::vector<uint32_t> convert(const Dimensions& dimensions) {
56     const auto& repeatedDimension = dimensions.dimension();
57     return std::vector<uint32_t>(repeatedDimension.begin(), repeatedDimension.end());
58 }
59 
convert(bool makeEmpty,const Buffer & buffer)60 TestBuffer convert(bool makeEmpty, const Buffer& buffer) {
61     if (makeEmpty) {
62         return TestBuffer();
63     }
64     const uint32_t randomSeed = buffer.random_seed();
65     std::default_random_engine generator{randomSeed};
66     std::uniform_int_distribution<uint32_t> dist{0, kMaxSize};
67     const uint32_t size = dist(generator);
68     return TestBuffer::createFromRng<uint32_t>(size, &generator);
69 }
70 
convert(const Operand & operand)71 TestOperand convert(const Operand& operand) {
72     const TestOperandType type = convert(operand.type());
73     std::vector<uint32_t> dimensions = convert(operand.dimensions());
74     const float scale = operand.scale();
75     const int32_t zeroPoint = operand.zero_point();
76     const TestOperandLifeTime lifetime = convert(operand.lifetime());
77     auto channelQuant = convert(operand.channel_quant());
78     const bool isIgnored = false;
79     const bool makeEmpty = (lifetime == TestOperandLifeTime::NO_VALUE ||
80                             lifetime == TestOperandLifeTime::TEMPORARY_VARIABLE);
81     TestBuffer data = convert(makeEmpty, operand.data());
82     return {.type = type,
83             .dimensions = std::move(dimensions),
84             .numberOfConsumers = 0,
85             .scale = scale,
86             .zeroPoint = zeroPoint,
87             .lifetime = lifetime,
88             .channelQuant = std::move(channelQuant),
89             .isIgnored = isIgnored,
90             .data = std::move(data)};
91 }
92 
convert(const Operands & operands)93 std::vector<TestOperand> convert(const Operands& operands) {
94     std::vector<TestOperand> testOperands;
95     testOperands.reserve(operands.operand_size());
96     const auto& repeatedOperand = operands.operand();
97     std::transform(repeatedOperand.begin(), repeatedOperand.end(), std::back_inserter(testOperands),
98                    [](const auto& operand) { return convert(operand); });
99     return testOperands;
100 }
101 
convert(const Indexes & indexes)102 std::vector<uint32_t> convert(const Indexes& indexes) {
103     const auto& repeatedIndex = indexes.index();
104     return std::vector<uint32_t>(repeatedIndex.begin(), repeatedIndex.end());
105 }
106 
convert(const Operation & operation)107 TestOperation convert(const Operation& operation) {
108     const TestOperationType type = convert(operation.type());
109     std::vector<uint32_t> inputs = convert(operation.inputs());
110     std::vector<uint32_t> outputs = convert(operation.outputs());
111     return {.type = type, .inputs = std::move(inputs), .outputs = std::move(outputs)};
112 }
113 
convert(const Operations & operations)114 std::vector<TestOperation> convert(const Operations& operations) {
115     std::vector<TestOperation> testOperations;
116     testOperations.reserve(operations.operation_size());
117     const auto& repeatedOperation = operations.operation();
118     std::transform(repeatedOperation.begin(), repeatedOperation.end(),
119                    std::back_inserter(testOperations),
120                    [](const auto& operation) { return convert(operation); });
121     return testOperations;
122 }
123 
convert(const Model & model)124 TestModel convert(const Model& model) {
125     std::vector<TestOperand> operands = convert(model.operands());
126     std::vector<TestOperation> operations = convert(model.operations());
127     std::vector<uint32_t> inputIndexes = convert(model.input_indexes());
128     std::vector<uint32_t> outputIndexes = convert(model.output_indexes());
129     const bool isRelaxed = model.is_relaxed();
130     return {.main = {.operands = std::move(operands),
131                      .operations = std::move(operations),
132                      .inputIndexes = std::move(inputIndexes),
133                      .outputIndexes = std::move(outputIndexes)},
134             .isRelaxed = isRelaxed};
135 }
136 
137 }  // anonymous namespace
138 
convertToTestModel(const Test & model)139 TestModel convertToTestModel(const Test& model) {
140     return convert(model.model());
141 }
142 
143 }  // namespace android::nn::fuzz
144