1# 2# Copyright (C) 2018 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17def test(input0, axis, indices, output0, input_data, output_data): 18 model = Model().Operation("GATHER", input0, axis, indices).To(output0) 19 20 quant8 = DataTypeConverter().Identify({ 21 input0: ["TENSOR_QUANT8_ASYMM", 0.5, 127], 22 output0: ["TENSOR_QUANT8_ASYMM", 0.5, 127], 23 }) 24 25 int32 = DataTypeConverter().Identify({ 26 input0: ["TENSOR_INT32"], 27 output0: ["TENSOR_INT32"], 28 }) 29 30 float16 = DataTypeConverter().Identify({ 31 input0: ["TENSOR_FLOAT16"], 32 output0: ["TENSOR_FLOAT16"], 33 }) 34 35 Example({ 36 input0: input_data, 37 output0: output_data, 38 }, model=model).AddVariations("relaxed", quant8, int32, float16) 39 40test( 41 input0=Input("input0", "TENSOR_FLOAT32", "{2, 2}"), 42 axis=0, 43 indices=[1, 0], 44 output0=Output("output0", "TENSOR_FLOAT32", "{2, 2}"), 45 input_data=[-2.0, 0.2, 46 0.7, 0.8], 47 output_data=[0.7, 0.8, 48 -2.0, 0.2], 49) 50 51test( 52 input0=Input("input0", "TENSOR_FLOAT32", "{2, 2}"), 53 axis=0, 54 indices=[1], # Unlike TensorFlow, 0-D arguments and outputs are not supported. 55 output0=Output("output0", "TENSOR_FLOAT32", "{1, 2}"), 56 input_data=[-2.0, 0.2, 57 0.7, 0.8], 58 output_data=[0.7, 0.8], 59) 60 61test( 62 input0=Input("input0", "TENSOR_FLOAT32", "{3}"), 63 axis=0, 64 indices=[1], 65 output0=Output("output0", "TENSOR_FLOAT32", "{1}"), 66 input_data=[1, 2, 3], 67 output_data=[2], 68) 69 70test( 71 input0=Input("input0", "TENSOR_FLOAT32", "{3}"), 72 axis=0, 73 indices=[1, 0], 74 output0=Output("output0", "TENSOR_FLOAT32", "{2}"), 75 input_data=[1, 2, 3], 76 output_data=[2, 1], 77) 78 79test( 80 input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 2}"), 81 axis=0, 82 indices=[0, 0], 83 output0=Output("output0", "TENSOR_FLOAT32", "{2, 2, 2}"), 84 input_data=[-2.0, 0.2, 85 0.7, 0.8], 86 output_data=[-2.0, 0.2, 87 0.7, 0.8, 88 -2.0, 0.2, 89 0.7, 0.8], 90) 91 92test( 93 input0=Input("input0", "TENSOR_FLOAT32", "{4, 1}"), 94 axis=0, 95 indices=[1, 3], 96 output0=Output("output0", "TENSOR_FLOAT32", "{2, 1}"), 97 input_data=[-2.0, 0.2, 0.7, 0.8], 98 output_data=[0.2, 0.8], 99) 100 101test( 102 input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 3}"), 103 axis=1, 104 indices=[1, 0], 105 output0=Output("output0", "TENSOR_FLOAT32", "{1, 2, 3}"), 106 input_data=[1, 2, 3, 107 4, 5, 6], 108 output_data=[4, 5, 6, 109 1, 2, 3], 110) 111 112test( 113 input0=Input("input0", "TENSOR_FLOAT32", "{1, 2, 3}"), 114 axis=-1, 115 indices=[2, 0], 116 output0=Output("output0", "TENSOR_FLOAT32", "{1, 2, 2}"), 117 input_data=[1, 2, 3, 118 4, 5, 6], 119 output_data=[3, 1, 120 6, 4], 121) 122