Home
last modified time | relevance | path

Searched refs:weight_shape (Results 1 – 22 of 22) sorted by relevance

/external/ComputeLibrary/tests/validation/fixtures/
DQLSTMLayerNormalizationFixture.h48 …void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType dat… in setup() argument
55 _target = compute_target(input_shape, weight_shape, bias_shape); in setup()
56 _reference = compute_reference(input_shape, weight_shape, bias_shape); in setup()
100 …TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const T… in compute_target() argument
103 TensorType weight = create_tensor<TensorType>(weight_shape, _data_type, 1, _qinfo); in compute_target()
116 … compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorSh… in compute_reference() argument
120 SimpleTensor<T> weight{ weight_shape, _data_type, 1, _qinfo }; in compute_reference()
/external/tensorflow/tensorflow/lite/kernels/
Dlsh_projection_test.cc35 std::initializer_list<int> weight_shape) { in LSHProjectionOpModel() argument
38 if (weight_shape.size() > 0) { in LSHProjectionOpModel()
46 if (weight_shape.size() > 0) { in LSHProjectionOpModel()
47 BuildInterpreter({hash_shape, input_shape, weight_shape}); in LSHProjectionOpModel()
Dembedding_lookup_test.cc42 std::initializer_list<int> weight_shape, in BaseEmbeddingLookupOpModel() argument
49 BuildInterpreter({index_shape, weight_shape}); in BaseEmbeddingLookupOpModel()
91 std::initializer_list<int> weight_shape, in HybridEmbeddingLookupOpModel() argument
93 : BaseEmbeddingLookupOpModel(index_shape, weight_shape, type) {} in HybridEmbeddingLookupOpModel()
Dquant_basic_lstm_test.cc47 std::vector<int> weight_shape{4 * outputSize, outputSize + inputSize}; in QuantizedLSTMOpModel() local
68 weights_ = AddConstInput<uint8_t>({TensorType_UINT8, weight_shape, 0.0f, in QuantizedLSTMOpModel()
Dunidirectional_sequence_lstm.cc694 const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); in PrecomputeZeroPointTimesWeightWithBias() local
695 TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); in PrecomputeZeroPointTimesWeightWithBias()
696 const int row = weight_shape.Dims(0); in PrecomputeZeroPointTimesWeightWithBias()
697 const int col = weight_shape.Dims(1); in PrecomputeZeroPointTimesWeightWithBias()
Dlstm.cc1134 const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); in PrecomputeZeroPointTimesWeightWithBias() local
1135 TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); in PrecomputeZeroPointTimesWeightWithBias()
1136 const int row = weight_shape.Dims(0); in PrecomputeZeroPointTimesWeightWithBias()
1137 const int col = weight_shape.Dims(1); in PrecomputeZeroPointTimesWeightWithBias()
Dlstm_eval.cc55 tflite::RuntimeShape weight_shape({m_rows, m_cols}); in MatrixBatchVectorMultiplyAccumulate() local
60 float_fc_params, input_shape, vector, weight_shape, matrix, in MatrixBatchVectorMultiplyAccumulate()
64 float_fc_params, input_shape, vector, weight_shape, matrix, in MatrixBatchVectorMultiplyAccumulate()
/external/ComputeLibrary/src/graph/mutators/
DInPlaceOperationMutator.cpp106 const auto weight_shape = weight_tensor->desc().shape; in try_in_place_depthwiseconv() local
136 …const bool is_1x1 = weight_shape[weights_width_idx] == 1U && weight_shape[weights_heig… in try_in_place_depthwiseconv()
/external/tensorflow/tensorflow/tools/graph_transforms/
Dquantize_weights_test.cc40 const TensorShape& weight_shape, in BuildGraphDef() argument
50 Tensor weights_data(DT_FLOAT, weight_shape); in BuildGraphDef()
Dfold_old_batch_norms_test.cc373 auto weight_shape = in TestFoldFusedBatchNormsWithConcat() local
375 Tensor weights0_data(DT_FLOAT, weight_shape); in TestFoldFusedBatchNormsWithConcat()
384 Tensor weights1_data(DT_FLOAT, weight_shape); in TestFoldFusedBatchNormsWithConcat()
/external/libpalmrejection/ui/events/ozone/evdev/touch_filter/palm_model/
Donedevice_train_palm_detection_filter_inference_beta.cc187 const int32_t* __restrict weight_shape, in MatMul() argument
195 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in MatMul()
196 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in MatMul()
200 const int32_t num_inputs = weight_shape[0]; in MatMul()
201 const int32_t num_outputs = weight_shape[1]; in MatMul()
339 const int32_t* __restrict weight_shape, in FullyConnected() argument
349 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in FullyConnected()
351 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in FullyConnected()
355 const int32_t num_inputs = weight_shape[0]; in FullyConnected()
356 const int32_t num_outputs = weight_shape[1]; in FullyConnected()
Donedevice_train_palm_detection_filter_inference.cc188 const int32_t* __restrict weight_shape, in MatMul() argument
196 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in MatMul()
197 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in MatMul()
201 const int32_t num_inputs = weight_shape[0]; in MatMul()
202 const int32_t num_outputs = weight_shape[1]; in MatMul()
340 const int32_t* __restrict weight_shape, in FullyConnected() argument
350 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in FullyConnected()
352 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in FullyConnected()
356 const int32_t num_inputs = weight_shape[0]; in FullyConnected()
357 const int32_t num_outputs = weight_shape[1]; in FullyConnected()
Donedevice_train_palm_detection_filter_inference_v2.cc187 const int32_t* __restrict weight_shape, in MatMul() argument
195 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in MatMul()
196 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in MatMul()
200 const int32_t num_inputs = weight_shape[0]; in MatMul()
201 const int32_t num_outputs = weight_shape[1]; in MatMul()
339 const int32_t* __restrict weight_shape, in FullyConnected() argument
349 ConstMatrixMap<T>(weight_values, weight_shape[1], weight_shape[0]); in FullyConnected()
351 auto result = MatrixMap<T>(output_values, weight_shape[1], input_shape[0]); in FullyConnected()
355 const int32_t num_inputs = weight_shape[0]; in FullyConnected()
356 const int32_t num_outputs = weight_shape[1]; in FullyConnected()
/external/tensorflow/tensorflow/compiler/mlir/lite/utils/
Dlstm_utils_test.cc49 SmallVector<int64_t, 2> weight_shape{3, 12}; in createLstmCompositeFunc() local
55 auto weight_type = RankedTensorType::get(weight_shape, builder->getF32Type()); in createLstmCompositeFunc()
/external/tensorflow/tensorflow/python/ops/
Dnn_batchnorm_test.py637 weight_shape = [1] * len(shape)
638 weight_shape[idx] = shape[idx]
640 self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
643 weight_shape = shape[-(idx + 1):]
645 shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/
Dconv_2d_helpers.cc211 RuntimeShape weight_shape = {weights_height_size, weights_width_size, in SplitWeightsForDwConv() local
213 optimized_ops::Split(split_params, weight_shape, converted_data.data(), in SplitWeightsForDwConv()
/external/tensorflow/tensorflow/core/grappler/utils/
Dpattern_utils_test.cc35 auto weight_shape = ops::Placeholder::Shape({32, 64}); in GetMatMulBiasAddGeluGraph() local
39 auto weight = Placeholder(s.WithOpName("weight"), DT_FLOAT, weight_shape); in GetMatMulBiasAddGeluGraph()
/external/tensorflow/tensorflow/core/grappler/optimizers/
Dmkl_remapper_test.cc757 auto weight_shape = in VerifyFused() local
762 auto weight_placeholder_shape = ops::Placeholder::Shape(weight_shape); in VerifyFused()
783 Tensor weight_t = Tensor(DataTypeToEnum<T>::v(), weight_shape); in VerifyFused()
/external/tensorflow/tensorflow/python/keras/engine/
Dbase_layer_v1.py1295 weight_shape = weight.shape if hasattr(weight, 'shape') else ()
1297 if not ref_shape.is_compatible_with(weight_shape):
1300 'shape %s' % (ref_shape, weight_shape))
Dbase_layer.py1793 weight_shape = weight.shape if hasattr(weight, 'shape') else ()
1795 if not ref_shape.is_compatible_with(weight_shape):
1798 'shape %s' % (ref_shape, weight_shape))
/external/tensorflow/tensorflow/lite/delegates/nnapi/
Dnnapi_delegate_test.cc2101 std::initializer_list<int> weight_shape) { in LSHProjectionOpModel() argument
2104 if (weight_shape.size() > 0) { in LSHProjectionOpModel()
2112 if (weight_shape.size() > 0) { in LSHProjectionOpModel()
2113 BuildInterpreterWithNNAPI({hash_shape, input_shape, weight_shape}); in LSHProjectionOpModel()
4830 std::initializer_list<int> weight_shape, in BaseEmbeddingLookupOpModel() argument
4836 BuildInterpreterWithNNAPI({index_shape, weight_shape}); in BaseEmbeddingLookupOpModel()
/external/tensorflow/tensorflow/lite/g3doc/examples/convert/
Doperation_fusion.md251 if len(weight_shape) > 2: