Searched refs:lstm_node (Results 1 – 2 of 2) sorted by relevance
841 lstm_node = match.output_node()847 lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS852 if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):857 lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"864 lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16870 lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float161079 lstm_node = match.output_node()1087 with graph.inserting_before(lstm_node):1117 lstm_node.replace_all_uses_with(packed_lstm_node)1118 packed_lstm_node.meta.update(lstm_node.meta)[all …]
1393 Node* lstm_node = graph->NewNode(); in ParseBasic() local1394 lstm_node->operation.type = ToString(OperationType::LSTM); in ParseBasic()1397 lstm_node->operation.attributes = lstm_attr; in ParseBasic()1415 RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id)); in ParseBasic()1416 RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state in ParseBasic()1417 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state in ParseBasic()1418 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation in ParseBasic()