Home
last modified time | relevance | path

Searched refs:state_shape (Results 1 – 11 of 11) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Drng_bit_generator_expander.cc35 XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) { in GetPhiloxStateOp() argument
36 if (state_shape.dimensions(0) >= 3) { in GetPhiloxStateOp()
42 XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) { in GetPhiloxOutputStateOp() argument
43 if (state_shape.dimensions(0) < 3) { in GetPhiloxOutputStateOp()
57 const Shape& data_shape, const Shape& state_shape, in GetGeneratorComputation() argument
59 RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module}; in GetGeneratorComputation()
66 XlaOp state_param = Parameter(&builder, 0, state_shape, "state"); in GetGeneratorComputation()
76 key_op, GetPhiloxStateOp(state_param, state_shape), data_shape); in GetGeneratorComputation()
77 output.state = GetPhiloxOutputStateOp(output.state, state_shape); in GetGeneratorComputation()
111 const Shape& state_shape = rng->operand(0)->shape(); in ExpandInstruction() local
[all …]
Drng_bit_generator_expander.h43 Shape state_shape; member
49 return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm, in AbslHashValue()
54 return data_shape == o.data_shape && state_shape == o.state_shape &&
62 const Shape& state_shape,
/external/tensorflow/tensorflow/lite/experimental/kernels/
Dgru_cell.cc33 const RuntimeShape& state_shape, const float* input_state, in GruCell() argument
47 const int n_output = state_shape.Dims(1); in GruCell()
55 concat_arrays_shapes.push_back(&state_shape); in GruCell()
74 auto h = MapAsArrayWithLastDimAsRows(input_state, state_shape); in GruCell()
Dgru_cell.h28 const RuntimeShape& state_shape, const float* input_state,
Dunidirectional_sequence_gru.cc45 const RuntimeShape state_shape = GetTensorShape(input_state); in GruImpl() local
67 input_shape, input_data, state_shape, input_state_data, in GruImpl()
/external/tensorflow/tensorflow/lite/kernels/
Dquant_basic_lstm_test.cc48 std::vector<int> state_shape{numBatches, outputSize}; in QuantizedLSTMOpModel() local
75 {TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0}); in QuantizedLSTMOpModel()
82 AddOutput({TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0}); in QuantizedLSTMOpModel()
/external/tensorflow/tensorflow/python/keras/layers/
Dconvolutional_recurrent.py357 state_shape = self.compute_output_shape(input_shape)
359 state_shape = state_shape[0]
361 state_shape = state_shape[:1].concatenate(state_shape[2:])
362 if None in state_shape:
378 result = list(state_shape)
Dwrappers.py559 state_shape = tf_utils.convert_shapes(output_shape[1:], to_tuples=False)
573 return output_shape + state_shape + copy.copy(state_shape)
574 return [output_shape] + state_shape + copy.copy(state_shape)
Drecurrent.py517 state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list()
518 return tensor_shape.TensorShape(state_shape)
519 state_shape = nest.map_structure(_get_state_shape, state_size)
520 return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
/external/tensorflow/tensorflow/lite/toco/
Dexport_tensorflow.cc2472 const auto& state_shape = state_array.shape(); in AddPlaceholderForRNNState() local
2473 const int kDims = state_shape.dimensions_count(); in AddPlaceholderForRNNState()
2475 shape->add_dim()->set_size(state_shape.dims(i)); in AddPlaceholderForRNNState()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc2210 TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); in RngBitGenerator()
2228 ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm, in RngBitGenerator()