1# 2# Copyright (C) 2020 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17# Model: z = if (value) then (x + y) else (x - y) 18# where value is a constant 19 20x_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] 21y_data = [8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3] 22output_data = { 23 True: [x + y for (x, y) in zip(x_data, y_data)], 24 False: [x - y for (x, y) in zip(x_data, y_data)], 25} 26 27ValueType = ["TENSOR_FLOAT32", [3, 4]] 28BoolType = ["TENSOR_BOOL8", [1]] 29 30def MakeBranchModel(operation_name): 31 x = Input("x", ValueType) 32 y = Input("y", ValueType) 33 z = Output("z", ValueType) 34 return Model().Operation(operation_name, x, y, 0).To(z) 35 36def Test(value, name): 37 x = Input("x", ValueType) 38 y = Input("y", ValueType) 39 z = Output("z", ValueType) 40 cond = Parameter("cond", BoolType, [value]) 41 then_model = MakeBranchModel("ADD") 42 else_model = MakeBranchModel("SUB") 43 model = Model().Operation("IF", cond, then_model, else_model, x, y).To(z) 44 45 quant8 = DataTypeConverter("quant8", scale=1.0, zeroPoint=100) 46 quant8_signed = DataTypeConverter("quant8_signed", scale=1.0, zeroPoint=100) 47 48 example = Example({ 49 x: x_data, 50 y: y_data, 51 z: output_data[value], 52 }, model=model, name=name) 53 example.AddVariations("relaxed", "float16", "int32", quant8, quant8_signed) 54 example.DisableLifeTimeVariation() 55 56# CONSTANT_COPY 57Configuration.use_shm_for_weights = False 58Test(value=True, name="copy_true") 59Test(value=False, name="copy_false") 60 61# CONSTANT_REFERENCE 62Configuration.use_shm_for_weights = True 63Test(value=True, name="reference_true") 64Test(value=False, name="reference_false") 65