#include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #endif #include #if defined(_MSC_VER) #include typedef SSIZE_T ssize_t; #endif namespace torch::jit { namespace onnx { using namespace ::c10::onnx; } bool isRNN(const Node* node) { auto k = node->kind(); return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU; } bool isNopTranspose(const std::vector& perm) { for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) { if (perm[i] != i) { return false; } } return true; } // returns a vector `ret` such that transposing by `ret` is equivalent // to transposing by `t1` and then by `t2` // // This fires in the case that we have transpose ops T1 -> T2. We are // fusing the transpose op T1 into T2 and discarding T1. We assume the elements // of the permutation in `t1` are raw indices into its input, since a previous // iteration would have folded all the transposes up to that point. Thus, // `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of // the input tensor index contained in t1 at position `t2[i]``". std::vector composeTransposes( const std::vector& t1, const std::vector& t2) { TORCH_INTERNAL_ASSERT(t1.size() == t2.size()); std::vector ret; ret.reserve(t1.size()); for (const auto& i : t2) { TORCH_INTERNAL_ASSERT(i < int64_t(t1.size())); ret.push_back(t1[i]); } return ret; } std::vector getBroadcastPositions(Node* node) { // Most of the element-wise ops in ONNX supports numpy broadcasting. // Only GEMM supports one-directional broadcasting, which broadcasts the bias // to the product. static std::unordered_map> broadcast_positions = { {onnx::Add, {0, 1}}, {onnx::Div, {0, 1}}, {onnx::Mul, {0, 1}}, {onnx::Pow, {0, 1}}, {onnx::Sub, {0, 1}}, {onnx::Gemm, {2}}, {onnx::Equal, {0, 1}}, {onnx::Greater, {0, 1}}, {onnx::Less, {0, 1}}, }; static std::vector no_positions; std::vector positions; auto iter = broadcast_positions.find(node->kind()); if (iter != broadcast_positions.end()) { // skip optional input if not provided for (size_t position : iter->second) { if (position < node->inputs().size()) { positions.emplace_back(position); } } return positions; } return no_positions; } // Determine whether `from` can broadcast to `to`, and if so at which // position. `from` must be a suffix of `to`, except that any // occurrences of 1 in `from` are treated as wildcards. std::optional fusibleExpandTo( at::IntArrayRef from, at::IntArrayRef to) { if (from.size() > to.size()) { return std::nullopt; } for (const auto i : c10::irange(from.size())) { auto fdim = from[from.size() - 1 - i]; auto tdim = to[to.size() - 1 - i]; if (fdim != 1 && fdim != tdim) { return std::nullopt; } } return to.size() - from.size(); } // Fuses expand calls into ONNX operators, because it is // easier for non-strided backends to more efficiently do broadcasts if this // is local information. This optimization is not useful for PyTorch as // 'expand' is free. void fuseBroadcast(Block* b) { for (auto n : b->nodes()) { for (auto* child_block : n->blocks()) { fuseBroadcast(child_block); } auto broadcast_positions = getBroadcastPositions(n); if (!broadcast_positions.empty()) { TORCH_INTERNAL_ASSERT(!n->hasAttribute(attr::axis)); } for (size_t position : broadcast_positions) { auto* expand_node = n->input(position)->node(); // Confirm it is expand node. if (expand_node->kind() != aten::expand || expand_node->input(1)->node()->kind() != onnx::Constant || expand_node->input(2)->node()->kind() != onnx::Constant) { continue; } auto* unexpanded_input = expand_node->input(0); // We need to know what the type pre-expand is. We should basically // always have this information (because expands are only ever traced, // not generated from symbolic), but if for some reason we don't // have it, we need to skip. if (!unexpanded_input->isCompleteTensor() || !n->output()->isCompleteTensor()) { continue; } // Not all broadcasts are supported by ONNX broadcast. std::optional axis = fusibleExpandTo( unexpanded_input->type() ->expectRef() .sizes() .concrete_sizes() .value(), // from n->output() ->type() ->expectRef() .sizes() .concrete_sizes() .value()); // to if (axis == std::nullopt) { continue; } n->replaceInput(position, unexpanded_input); if (!expand_node->hasUses()) { expand_node->destroy(); } } } } void fuseConsecutiveTransposes(Block* b) { for (auto n : b->nodes()) { for (auto* child_block : n->blocks()) { fuseConsecutiveTransposes(child_block); } if (n->kind() == onnx::Transpose && n->input()->node()->kind() == onnx::Transpose && n->owningBlock() == n->input()->node()->owningBlock()) { auto origInput = n->input(); n->is_( attr::perm, composeTransposes( origInput->node()->is(attr::perm), n->is(attr::perm))); n->replaceInput(0, origInput->node()->input()); if (origInput->uses().empty()) { origInput->node()->destroy(); } continue; } } } void eliminateNopTranspose(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { auto n = *it; for (auto* child_block : n->blocks()) { eliminateNopTranspose(child_block); } if (n->kind() == onnx::Transpose) { if (isNopTranspose(n->is(attr::perm))) { n->output()->replaceAllUsesWith(n->input()); it.destroyCurrent(); continue; } } } } void fuseTransposeIntoGemm(Block* b) { static const std::vector simpleTransPerm({1, 0}); for (auto n : b->nodes()) { for (auto* child_block : n->blocks()) { fuseTransposeIntoGemm(child_block); } if (n->kind() == onnx::Gemm) { for (size_t i : {0, 1}) { auto inp = n->inputs()[i]; auto trans = i == 0 ? attr::transA : attr::transB; if (inp->node()->kind() == onnx::Transpose && inp->node()->is(attr::perm) == simpleTransPerm) { n->replaceInput(i, inp->node()->input()); n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1); if (inp->uses().empty()) { inp->node()->destroy(); } } } } } } // Why this is here: // // Pytorch has a "packed" representation of sequences, as well as a // "padded" representation. ONNX has only one representation, // corresponding to pytorch's "padded". Therefore, we need to remove // any use of packed sequences before exporting. // // What this does: // // This code uses the observation that // RNN(PackPadded(x)) == PackPadded(RNN(x)) // and converts the first form to the second whenever possible, // "pushing" the packing operation past the RNN operation. Then, // the removeNopPacking pass removes the packing operations // entirely by pairing them with their inverse PadPacked. If the // input graph does not pair the operations, export will fail. void pushPackingPastRnn(Block* b) { for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { auto* n = *it; for (auto* child_block : n->blocks()) { pushPackingPastRnn(child_block); } if (n->kind() != prim::PackPadded) { continue; } if (n->outputs().at(0)->uses().size() != 1) { // For now, only handle the case where there is one consumer. continue; } Node* rnn = n->outputs()[0]->uses()[0].user; if (!isRNN(rnn)) { continue; } if (rnn->owningBlock() != n->owningBlock()) { continue; } // Packing only has an effect on a network when its outputs are actually // used, so we can remove it here. if (rnn->outputs().at(0)->uses().empty() && n->outputs().at(1)->uses().size() == 1) { n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0)); n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1)); it.destroyCurrent(); continue; } // The rnn is followed by a transpose and a reshape (if // bidirectional), or by a squeeze (if unidirectional). Node* next = rnn->outputs().at(0)->uses().at(0).user; if (next->kind() == onnx::Transpose) { next = next->outputs().at(0)->uses().at(0).user; if (next->kind() != onnx::Reshape) { continue; } } else if (next->kind() != onnx::Squeeze) { continue; } // remove PackPadded from in front of the RNN n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0)); Value* batch_sizes = n->outputs().at(1); while (!batch_sizes->uses().empty()) { Use use_0 = batch_sizes->uses().at(0); Node* user = use_0.user; // Make calculation of max_batch_size not depend on batch_sizes. // This looks for a pattern generated by code such as // https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815. // // Replace onnx::Gather[axis=0](batch_sizes, 0) // with onnx::Gather[axis=0](onnx::Shape(rnn_input), 1) if (use_0.offset == 0 && user->kind() == onnx::Gather && user->i(attr::axis) == 0 && user->inputs().at(1)->node()->kind() == onnx::Constant && user->inputs().at(1)->node()->hasAttribute(attr::value)) { const at::Tensor& const_val_t = user->inputs().at(1)->node()->t(attr::value); if (const_val_t.item().toInt() != 0) { // We'll likely produce an invalid graph if this happens. break; } Value* rnn_input = rnn->inputs().at(0); Node* shape = b->owningGraph()->create(onnx::Shape); shape->insertAfter(rnn_input->node()); shape->addInput(rnn_input); shape->copyMetadata(n); batch_sizes->replaceFirstUseWith(shape->output()); // New Constant node is needed, as it might be shared // with a Constant node 0 from others. Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1); gather_indices->t_(attr::value, at::native::ones_like(const_val_t)); gather_indices->copyMetadata(n); gather_indices->insertBefore(user); user->replaceInput(1, gather_indices->output()); } // Make RNN not depend on batch_sizes. else if (user == rnn) { batch_sizes->replaceFirstUseWith(n->inputs().at(1)); } else { // If there are other uses that are not: // * PadPacked (which will be removed in removeNopPacking), // * Dead code (which will be removed in dead code elimination), // then we likely have produced an invalid graph, since there will be a // use of the output of PackPadded, but the PackPadded (and that output) // will be removed. break; } } // and insert new PackPadded after the RNN Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2); newPackPadded->copyMetadata(n); newPackPadded->insertAfter(next); newPackPadded->copyMetadata(next); // make things consume from the new PackPadded next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0)); n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1)); // set up the new PackPadded's inputs newPackPadded->addInput(next->outputs().at(0)); newPackPadded->addInput(n->inputs().at(1)); // See https://github.com/pytorch/pytorch/issues/9043 for a full // description. Since PackPadded is for now treated in an // unhygenic way, Pytorch ends up propagating an incorrect type. // Until a long-term cleanup comes around, we can fix this by // resetting the size to the correct value. TensorTypePtr oldType = rnn->inputs().at(0)->type()->cast(); if (oldType && oldType->isComplete()) { std::vector new_sizes; new_sizes.push_back(*oldType->sizes()[0]); new_sizes.push_back(*oldType->sizes()[1]); if (next->kind() == onnx::Reshape) { // bidirection new_sizes.push_back(rnn->i(attr::hidden_size) * 2); } else { // unidirection new_sizes.push_back(rnn->i(attr::hidden_size)); } TensorTypePtr newType = TensorType::createContiguous( *oldType->scalarType(), *oldType->device(), new_sizes); next->outputs().at(0)->setType(newType); } it.destroyCurrent(); } } // Despite the name, this actually removes the PadPacked node and leaves // the PackPadded node. The PackPadded should become dead code which will // be eliminated later. void removeNopPacking(Block* graph) { for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; for (auto* child_block : n->blocks()) { removeNopPacking(child_block); } if (n->kind() != prim::PadPacked) { continue; } Node* input = n->inputs()[0]->node(); if (input->kind() != prim::PackPadded) { continue; } if (input->outputs()[0] != n->inputs()[0]) { continue; } if (input->outputs()[1] != n->inputs()[1]) { continue; } n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]); n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]); n->removeAllInputs(); it.destroyCurrent(); } } void hackFixupPadPackedShapes(Block* graph) { // FIXME: the shape of the input to the fictional PadPacked node has // incorrect shape. For now, just copy the shape of PadPacked to the shape // of its input. for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) { auto* n = *it; for (auto* child_block : n->blocks()) { removeNopPacking(child_block); } if (n->kind() != prim::PadPacked) { continue; } Node* input = n->inputs()[0]->node(); input->outputs()[0]->setType(n->outputs()[0]->type()); } } void fixDefaultRNNState( Graph* graph, Node* n, int input_index, int opset_version) { auto initial_state = n->inputs()[input_index]; // The RNN code in pytorch accepts an optional hidden state. // 1- When it is provided as an input, everything works great. // 2- When it is not provided, it is default-initialized by constructing a new // Variable, which gets // traced as a ConstantOfShape with the expected Shape. // 3- When the batch size is fixed, everything works great as well. // 4- When h0 and c0 are specified but are not inputs of the model (they are // Constants) and the batch size is variable, the model should be saved // with a batch size of 1 (or an error will occur), and we save the value // of h0 and c0 with a batch size of 1. When the model is then called with // a different batch size value, h0 and c0 are broadcasted to get the right // shape. // Recognize that last pattern here (4) and fix the shape. // Note that for multi-layer RNNs there will be a Slice operation between the // Constant and the RNN. bool needsFixing = initial_state->node()->kind() == onnx::Constant || (initial_state->node()->kind() == onnx::Slice && initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant); if (!needsFixing) { return; } Node* shape_of_input = graph->create(onnx::Shape, 1); shape_of_input->copyMetadata(n); shape_of_input->insertBefore(n); shape_of_input->addInput(n->inputs()[0]); Node* gather_indices = graph->create(onnx::Constant, 1); gather_indices->copyMetadata(n); gather_indices->insertBefore(n); gather_indices->t_(attr::value, at::scalar_to_tensor(at::Scalar(1))); Node* batch_size = graph->create(onnx::Gather, 1); batch_size->copyMetadata(n); batch_size->insertBefore(n); batch_size->addInput(shape_of_input->outputs()[0]); batch_size->addInput(gather_indices->outputs()[0]); Node* unsqueezed_batch_size = createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version); Node* hidden_size = graph->create(onnx::Constant, 1); hidden_size->copyMetadata(n); hidden_size->insertBefore(n); hidden_size->t_( attr::value, at::full( {1}, n->i(attr::hidden_size), at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor()); Node* num_directions = graph->create(onnx::Constant, 1); num_directions->copyMetadata(n); num_directions->insertBefore(n); num_directions->t_( attr::value, scalar_to_tensor(at::Scalar( n->hasAttribute(attr::direction) && n->s(attr::direction) == "bidirectional" ? 2 : 1))); Node* unsqueezed_num_directions = createONNXUnsqueeze( graph, n, num_directions->outputs()[0], 0, opset_version); Node* concated_dims = graph->create(onnx::Concat, 1); concated_dims->copyMetadata(n); concated_dims->insertBefore(n); concated_dims->i_(attr::axis, 0); concated_dims->addInput(unsqueezed_num_directions->outputs()[0]); concated_dims->addInput(unsqueezed_batch_size->outputs()[0]); concated_dims->addInput(hidden_size->outputs()[0]); Node* fixed_init_state = graph->create(onnx::Expand, 1); fixed_init_state->copyMetadata(n); fixed_init_state->insertBefore(n); fixed_init_state->addInput(initial_state); fixed_init_state->addInput(concated_dims->outputs()[0]); n->replaceInput(input_index, fixed_init_state->outputs()[0]); if (initial_state->uses().empty()) { initial_state->node()->destroy(); } } void fixDefaultRnnHiddenState(Block* b, int opset_version) { for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { auto* n = *it; for (auto* child_block : n->blocks()) { fixDefaultRnnHiddenState(child_block, opset_version); } if (!isRNN(n)) { continue; } // Hidden state is the sixth input for RNN, LSTM, GRU. // See https://pytorch.org/docs/main/nn.html#torch.nn.RNN if (n->inputs().size() < 6) { continue; } fixDefaultRNNState(b->owningGraph(), n, 5, opset_version); } } void fixDefaultLstmCellState(Block* b, int opset_version) { for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { auto* n = *it; for (auto* child_block : n->blocks()) { fixDefaultLstmCellState(child_block, opset_version); } if (n->kind() != onnx::LSTM) { continue; } // Cell state is the seventh input for LSTM. // See https://pytorch.org/docs/main/nn.html#torch.nn.LSTM if (n->inputs().size() < 7) { continue; } fixDefaultRNNState(b->owningGraph(), n, 6, opset_version); } } static bool isSafeToSpeculate(Node* n) { return n->kind() == onnx::Transpose; } // Moves ops outside of control flow blocks so that they are always executed, // no matter the result of the control flow conditions. // Needed only so that the split pass of the ONNX optimizer will put the ops // into the init_net. // TODO: Once the code in caffe2/python/onnx/backend.py no longer calls // optimize_onnx, delete this function. static void speculateOps(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) { Node* n = *it; ++it; // note: increment first so that it is safe to move the node if needed for (auto b : n->blocks()) { speculateOps(b); } if (!isSafeToSpeculate(n)) { continue; } // XXX - only works for nodes with a single input // move node n outside of the control flow it is nested in auto node_input = n->input()->node(); if (node_input->owningBlock() == n->owningBlock()) { continue; } // Skip if output of this node is part of block output. bool is_block_output = false; for (auto node_output : n->outputs()) { for (auto node_output_use : node_output->uses()) { if (node_output_use.user == n->owningBlock()->return_node()) { is_block_output = true; break; } } if (is_block_output) { break; } } if (is_block_output) { continue; } // find the control flow node in the same block as node_input that contains // Node n auto control_flow_node = n->owningBlock()->owningNode(); while (control_flow_node->owningBlock() != node_input->owningBlock()) { control_flow_node = control_flow_node->owningBlock()->owningNode(); } // put the node right before this flow node n->moveBefore(control_flow_node); } } static void replaceInputWithList(Node* node, size_t i, ArrayRef to) { node->removeInput(i); for (auto* to_val : to) { TORCH_INTERNAL_ASSERT(to_val->owningGraph() == node->owningGraph()); node->insertInput(i++, to_val); } } static void eraseListConstruct(Block* block, int opset_version); static void eraseListConstruct(Node* n, int opset_version) { for (auto b : n->blocks()) { eraseListConstruct(b, opset_version); } std::vector>> replacements; auto block = n->owningBlock(); size_t i = 0; for (auto* input : n->inputs()) { if (input->node()->kind() == prim::ListConstruct) { auto* lc_node = input->node(); TypePtr elem = lc_node->output()->type()->castRaw()->getElementType(); if (elem->cast() && isValidToTransformToONNXConcatNode(lc_node)) { auto concat_node = transformToONNXConcatNode( block->owningGraph(), input->node(), false, opset_version); concat_node->copyMetadata(n); // make concat node output as new input, then ListConstruct should // become dead replacements.emplace_back( i, std::vector({concat_node->output()})); } else { if (opset_version >= OPSET_VERSION_11) { c10::Symbol seq_node_kind = !lc_node->inputs().empty() ? onnx::SequenceConstruct : onnx::SequenceEmpty; Node* seq_node = block->owningGraph()->create( seq_node_kind, {lc_node->inputs()}, 1); seq_node->copyMetadata(n); seq_node->insertBefore(lc_node); seq_node->output()->copyMetadata(lc_node->output()); seq_node->copyMetadata(lc_node); lc_node->replaceAllUsesWith(seq_node); } } } i++; } for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) { replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr)); } } static void eraseListConstruct(Block* block, int opset_version) { // TODO: Fix this pass/maybe get rid of this part. // Tensor lists might be used for meshgrid and such ops as well. for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) { Node* n = *it; ++it; eraseListConstruct(n, opset_version); } eraseListConstruct(block->return_node(), opset_version); } static void eraseListUnpack(Block* block, int opset_version); // Replace prim::ListUnpack with onnx::SequenceAt. static void eraseListUnpack(Node* n, int opset_version) { for (auto b : n->blocks()) { eraseListUnpack(b, opset_version); } if (n->kind() == prim::ListUnpack) { if (opset_version < OPSET_VERSION_11) { // onnx::SequenceAt was introduced in onnx opset version 11 throw std::runtime_error( "Unsupported: ONNX export of prim::ListUnpack in opset " + std::to_string(opset_version) + ". Please try opset version 11."); } auto g = n->owningGraph(); for (size_t i = 0; i < n->outputs().size(); ++i) { auto seq_idx_n = g->create(onnx::Constant, 1); seq_idx_n->t_(attr::value, at::scalar_to_tensor(at::Scalar(int64_t(i)))); seq_idx_n->insertBefore(n); auto seq_at_n = g->create(onnx::SequenceAt, 1); seq_at_n->addInput(n->input()); seq_at_n->addInput(seq_idx_n->output()); seq_at_n->output()->setType(n->output(i)->type()); seq_at_n->insertBefore(n); seq_at_n->copyMetadata(n); n->output(i)->replaceAllUsesWith(seq_at_n->output()); } } } static void eraseListUnpack(Block* block, int opset_version) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) { Node* n = *it; ++it; eraseListUnpack(n, opset_version); } } // From: // %list = ListConstruct(%x); // %unpacked = ListUnpack(%list); // do_something(%unpacked); // // To: // %list = ListConstruct(%x); // %unpacked = ListUnpack(%list); // do_something(%x) // // The ListConstruct and ListUnpack may now be dead code. static void fuseListConstructListUnpack(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto* child_block : it->blocks()) { fuseListConstructListUnpack(child_block); } if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == prim::ListConstruct) { for (const auto i : c10::irange(it->outputs().size())) { auto output = it->outputs().at(i); output->replaceAllUsesWith(it->input()->node()->inputs().at(i)); } } } } // https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export static void eraseTupleConstruct(Block* block) { std::vector new_block_outputs; bool found_tuple_construct = false; // TupleConstruct is generated from the symbolics in quantized domain, and // consumed by other quantized operators. The remained TupleConstruct should // be at the output of the blocks. for (auto* output : block->outputs()) { auto output_node = output->node(); if (output_node->kind() == prim::TupleConstruct) { found_tuple_construct = true; for (auto* input : output_node->inputs()) { new_block_outputs.emplace_back(input); } } else { new_block_outputs.emplace_back(output); } } if (found_tuple_construct) { block->removeAllOutputs(); for (auto* output : new_block_outputs) { block->registerOutput(output); } } } void removeMaxPoolUnusedOutput(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { auto n = *it; for (auto* child_block : n->blocks()) { removeMaxPoolUnusedOutput(child_block); } if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) { if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) { it->eraseOutput(1); } } } } // This optimization fuses LogSoftmax and NegativeLogLikelihoodLoss operators // into one operator: SoftmaxCrossEntropyLoss, and depending on the dimensions // of the input and different attributes there will be different subgraphs of // LogSoftmax and NegativeLogLikelihoodLoss. static void fuseLogSoftmaxNllLoss(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto* child_block : it->blocks()) { fuseLogSoftmaxNllLoss(child_block); } if (it->kind() == onnx::NegativeLogLikelihoodLoss) { auto prev = it->input(0)->node(); Node* origNllLossNode = *it; Node* origLogSoftmaxNode = nullptr; // Check for patterns especially in cases with autocasting enabled // in which a cast node is inserted before the NegativeLogLikelihoodLoss // node and this causes the patterns below not to be recognizable by the // fuseLogSoftmaxNllLoss function // For example if the input is 2D // graph(%input : Half(3, 5), // %target : Long(3)): // %4 : Half(3, 5) = onnx::LogSoftmaxaxis=1 // %8 : Float = onnx::Cast[to=1](%4) // %9 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"] // return (%8) Node* castNode = nullptr; if (prev->kind() == onnx::Cast) { castNode = prev; prev = prev->input(0)->node(); } if (prev->kind() == onnx::LogSoftmax) { // if the input is 2D // graph(%input : Float(3, 5), // %target : Long(3)): // %4 : Float(3, 5) = onnx::LogSoftmaxaxis=1 // %8 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"] // return (%8) origLogSoftmaxNode = prev; } else if ( prev->kind() == onnx::Transpose && prev->input(0)->node()->kind() == onnx::LogSoftmax) { // if the input is 4D // graph(%input : Float(3, 5, 2, 7), // %target : Long(3, 2, 7)): // %4 : Tensor = onnx::Transpose[perm=[0, 3, 2, 1]] (%input) // %5 : Tensor = onnx::LogSoftmax[axis=3] (%4) // %6 : Float(3, 5, 2, 7) = onnx::Transpose[perm=[0, 3, 2, 1]] (%5) // %10 : Float(3, 2, 7) = // onnx::NegativeLogLikelihoodLoss[reduction="none"](%6, %target) return // (%10) origLogSoftmaxNode = prev->input(0)->node(); auto transpose = origLogSoftmaxNode->input(0)->node(); if (!transpose->inputs().empty()) { origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0)); } } else if ( prev->kind() == onnx::Reshape && prev->input(0)->node()->kind() == onnx::Transpose && prev->input(0)->node()->input(0)->node()->kind() == onnx::LogSoftmax) { // if the input is 3D or > 4D // graph(%input : Float(3, 5, 2), // %target.1 : Long(3, 2)): // %4 : Tensor = onnx::Transpose[perm=[0, 2, 1]] (%input) // %5 : Tensor = onnx::LogSoftmax[axis=2] (%4) // %6 : Float(3, 5, 2) = onnx::Transpose[perm=[0, 2, 1]] (%5) // %8 : Tensor = onnx::Shape(%6) // %10 : Tensor = onnx::Constantvalue={0} // %11 : Long() = onnx::Gather[axis=0] (%8, %10) // %13 : Tensor = onnx::Shape(%6) // %15 Tensor = onnx::Constantvalue={1} // %16 : Long() = onnx::Gather[axis=0] (%13, %15) // ... // %22 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %21) // ... // %26 : Long(3, 1, 2) = onnx::Reshape(%target.1, %25) // %30 : Float() = onnx::NegativeLogLikelihoodLoss[reduction="sum"](%22, // %26) return (%30) origLogSoftmaxNode = prev->input(0)->node()->input(0)->node(); auto transpose = origLogSoftmaxNode->input(0)->node(); TORCH_INTERNAL_ASSERT(transpose->kind() == onnx::Transpose); origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0)); auto reshape = origNllLossNode->input(1)->node(); TORCH_INTERNAL_ASSERT(reshape->kind() == onnx::Reshape); origNllLossNode->replaceInput(1, reshape->inputs().at(0)); if (origNllLossNode->s(attr::reduction) == "none") { // when reduction=none a different graph is created and the graph // doesn't end with node NegativeLogLikelihoodLoss like in all other // cases. // graph(%input : Float(3, 5, 2), %target.1 : Long(3, 2)): // %4 : Tensor = onnx::Transposeperm=[0, 2, 1] // %5 : Tensor = onnx::LogSoftmaxaxis=2 // %6 : Float(3, 5, 2) = onnx::Transposeperm=[0, 2, 1] // ... // %27 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %26) // %31 : Long(3, 1, 2) = onnx::Reshape(%target.1, %30) // %35 : Float(3, 1, 2) = // onnx::NegativeLogLikelihoodLoss[reduction="none"](%27, %31) %36 : // int[] = prim::ListConstruct(%11, %21) %37 : Float(3, 2) = // onnx::Reshape(%35, %36) return (%37) auto nllloss_output = origNllLossNode->output(0)->uses()[0].user; TORCH_INTERNAL_ASSERT(nllloss_output->kind() == onnx::Reshape); // make output of reshape the output of nllloss nllloss_output->replaceAllUsesWith(origNllLossNode); origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0)); } } else { continue; } // If the pattern indeed consists of a cast node before the // NegativeLogLikelihoodLoss node, place a cast node in the beginning // of the pattern instead if (castNode != nullptr) { auto onnx_type = castNode->i(attr::to); Node* cast_node = b->owningGraph()->create(onnx::Cast, 1); cast_node->addInput(origLogSoftmaxNode->inputs().at(0)); cast_node->i_(attr::to, onnx_type); cast_node->insertBefore(origLogSoftmaxNode); cast_node->copyMetadata(castNode); origLogSoftmaxNode->replaceInputWith( origLogSoftmaxNode->inputs().at(0), cast_node->output()); } Node* softmaxCrossEntropyNode = b->owningGraph()->create( onnx::SoftmaxCrossEntropyLoss, it->outputs().size()); for (size_t i = 0; i < softmaxCrossEntropyNode->outputs().size(); ++i) { softmaxCrossEntropyNode->outputs()[i]->copyMetadata(it->outputs()[i]); } softmaxCrossEntropyNode->copyMetadata(origNllLossNode); softmaxCrossEntropyNode->copyAttributes(*origNllLossNode); softmaxCrossEntropyNode->insertBefore(origNllLossNode); softmaxCrossEntropyNode->addInput(origLogSoftmaxNode->inputs().at(0)); softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(1)); softmaxCrossEntropyNode->copyMetadata(origNllLossNode); // optional weight input is provided if (origNllLossNode->inputs().size() == 3) { softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(2)); } it->replaceAllUsesWith(softmaxCrossEntropyNode); it->removeAllInputs(); it.destroyCurrent(); } } } // This optimization removes consecutive SplitToSequence and ConcatFromSequence // operators. The optimization only happens when // 1. Output of SplitToSequence is not used by any other nodes. // 2. The attribute keepdims and axis of SplitToSequence match // attribute new_axis and axis of ConcatFromSequence. // In that case, the two ops combined are no-op, and can be safely removed. static void removeSequenceSplitConcat(Block* b) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto* child_block : it->blocks()) { removeSequenceSplitConcat(child_block); } if (it->kind() == onnx::ConcatFromSequence && it->input()->node()->kind() == onnx::SplitToSequence) { if (it->input()->uses().size() > 1) { continue; } auto split_node = it->input()->node(); auto concat_node = *it; const auto split_axis = split_node->hasAttribute(attr::axis) ? split_node->i(attr::axis) : 0; const auto split_keepdims = split_node->hasAttribute(attr::keepdims) ? split_node->i(attr::keepdims) : 1; const auto concat_axis = concat_node->i(attr::axis); const auto concat_new_axis = concat_node->hasAttribute(attr::new_axis) ? concat_node->i(attr::new_axis) : 0; const bool has_input_split = split_node->inputs().size() == 2; if (has_input_split) { continue; } if (split_keepdims == concat_new_axis) { continue; } if (split_axis != concat_axis) { continue; } concat_node->output()->replaceAllUsesWith(split_node->input()); } } } // Work around limitation from ONNX that the block input cannot be used directly // as block output. Inserts an Identity node inside the block, and have the // block return the output of the Identity. static void insertIdentityForInputUsedAsOutput(Block* b) { for (auto out : b->outputs()) { auto n = out->node(); if (nullptr != n && n->kind() == prim::Param) { Node* id_node = b->owningGraph()->create(onnx::Identity); id_node->insertBefore(b->return_node()); id_node->addInput(out); id_node->output()->setType(out->type()); b->return_node()->replaceInputWith(out, id_node->output()); } } for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { for (auto* child_block : it->blocks()) { insertIdentityForInputUsedAsOutput(child_block); } } } // This optimization does ONNX-specific peephole optimizations. // // Before you write an optimization here, ask yourself, "Could I do this // optimization on ATen operators"? If so, you should seriously consider // writing your optimization in jit/passes/peephole.cpp rather than // here, as it will be generally applicable to the JIT as well. The // optimizations here are ONLY applied on ONNX export. void PeepholeOptimizeONNX( std::shared_ptr& graph, int opset_version, bool fixed_batch_size) { // TODO: decide on fixpoint strategy // TODO: make it easier not to do O(k) iterations over the graph, where // k is the number of distinct peephole optimizations hackFixupPadPackedShapes(graph->block()); pushPackingPastRnn(graph->block()); removeNopPacking(graph->block()); // we only need to fix the size of hidden state and cell state if the batch // size is variable if (!fixed_batch_size) { fixDefaultRnnHiddenState(graph->block(), opset_version); fixDefaultLstmCellState(graph->block(), opset_version); } fuseBroadcast(graph->block()); fuseConsecutiveTransposes(graph->block()); eliminateNopTranspose(graph->block()); fuseTransposeIntoGemm(graph->block()); speculateOps(graph->block()); fuseListConstructListUnpack(graph->block()); fuseLogSoftmaxNllLoss(graph->block()); eraseListConstruct(graph->block(), opset_version); eraseTupleConstruct(graph->block()); EliminateDeadCode( graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); eraseListUnpack(graph->block(), opset_version); removeMaxPoolUnusedOutput(graph->block()); removeSequenceSplitConcat(graph->block()); insertIdentityForInputUsedAsOutput(graph->block()); GRAPH_DUMP("After PeepholeOptimizeONNX", graph); } } // namespace torch::jit