1# 2# Copyright (C) 2020 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17# Model: z = if (x) then (y + 10) else (y - 10) 18 19input_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] 20output_add = [y + 10 for y in input_data] 21output_sub = [y - 10 for y in input_data] 22 23ValueType = ["TENSOR_FLOAT32", [3, 4]] 24BoolType = ["TENSOR_BOOL8", [1]] 25 26def MakeBranchModel(operation_name): 27 y = Input("y", ValueType) 28 z = Output("z", ValueType) 29 return Model().Operation(operation_name, y, [10.0], 0).To(z) 30 31def Test(x_data, y_data, z_data, name): 32 x = Input("x", BoolType) 33 y = Input("y", ValueType) 34 z = Output("z", ValueType) 35 then_model = MakeBranchModel("ADD") 36 else_model = MakeBranchModel("SUB") 37 model = Model().Operation("IF", x, then_model, else_model, y).To(z) 38 39 quant8 = DataTypeConverter("quant8", scale=1.0, zeroPoint=100) 40 quant8_signed = DataTypeConverter("quant8_signed", scale=1.0, zeroPoint=100) 41 42 example = Example({x: [x_data], y: y_data, z: z_data}, name=name) 43 example.AddVariations("relaxed", "float16", "int32", quant8, quant8_signed) 44 example.AddVariations(AllOutputsAsInternalCoverter()) 45 46Test(x_data=True, y_data=input_data, z_data=output_add, name="true") 47Test(x_data=False, y_data=input_data, z_data=output_sub, name="false") 48