Home
last modified time | relevance | path

Searched refs:rnn_desc (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/stream_executor/cuda/
Dcuda_dnn.cc506 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) { in CreatePersistentRnnPlan() argument
509 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result)); in CreatePersistentRnnPlan()
1039 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1058 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc, in CudnnRnnDescriptor() argument
1068 : rnn_desc_(std::move(rnn_desc)), in CudnnRnnDescriptor()
1099 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor(); in Create() local
1139 /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode, in Create()
1151 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), in Create()
1156 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); in Create()
1160 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), in Create()
[all …]
Dcuda_dnn.h74 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
92 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
110 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
128 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
153 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
178 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
641 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
660 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
/external/tensorflow/tensorflow/core/kernels/
Dcudnn_rnn_ops.cc559 std::unique_ptr<RnnDescriptor> rnn_desc; member
786 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, in DoForward() argument
831 ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc, in DoForward()
847 OpKernelContext* context, const RnnDescriptor& rnn_desc, in DoBackward() argument
911 rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data, in DoBackward()
1010 std::unique_ptr<RnnDescriptor>* rnn_desc) { in ExtractCudnnRNNParamsInfo() argument
1049 *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); in ExtractCudnnRNNParamsInfo()
1059 std::unique_ptr<RnnDescriptor>* rnn_desc, in CreateRnnDescriptor() argument
1071 *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); in CreateRnnDescriptor()
1085 RnnStateCache* cache, RnnDescriptor** rnn_desc, in GetCachedRnnDescriptor() argument
[all …]
/external/tensorflow/tensorflow/stream_executor/rocm/
Drocm_dnn.cc1730 const MIOpenRnnDescriptor& rnn_desc);
1973 const MIOpenRnnDescriptor& rnn_desc, in ExtractAndCheckRnnForward() argument
1987 model_dims->num_layers = rnn_desc.num_layers(); in ExtractAndCheckRnnForward()
1990 model_dims->hidden_size = rnn_desc.hidden_size(); in ExtractAndCheckRnnForward()
1993 (rnn_desc.direction_mode() == miopenRNNbidirection) ? 2 : 1; in ExtractAndCheckRnnForward()
2033 miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc, in CheckRNNParameterSize() argument
2037 miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, in CheckRNNParameterSize()
2039 rnn_desc.data_type() /*dataType*/); in CheckRNNParameterSize()
2045 rnn_desc.ParamsSizeInBytes(); in CheckRNNParameterSize()
2049 const MIOpenRnnDescriptor& rnn_desc, in CreateRnnWorkspace() argument
[all …]
Drocm_dnn.h101 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
119 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
137 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
155 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
180 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
205 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
738 bool DoRnnForwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
755 bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
/external/tensorflow/tensorflow/stream_executor/
Ddnn.h2185 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward() argument
2206 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward() argument
2227 virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnForward() argument
2289 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward() argument
2317 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward() argument
2345 Stream* stream, const dnn::RnnDescriptor& rnn_desc, in DoRnnBackward() argument
Dstream.h1779 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1798 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1816 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1837 const dnn::RnnDescriptor &rnn_desc,
1862 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1887 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
Dstream.cc4539 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward() argument
4559 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, in ThenRnnForward()
4573 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward() argument
4592 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, in ThenRnnForward()
4606 const dnn::RnnDescriptor &rnn_desc, in ThenRnnForward() argument
4626 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, in ThenRnnForward()
4640 const dnn::RnnDescriptor &rnn_desc, in ThenRnnBackward() argument
4667 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, in ThenRnnBackward()
4685 const dnn::RnnDescriptor &rnn_desc, in ThenRnnBackward() argument
4711 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, in ThenRnnBackward()
[all …]