#include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #endif namespace torch::jit { namespace { using Tensor = at::Tensor; bool supportedLinearNode(Node* n) { if (n->kind() == aten::linear) { return true; } else { return false; } } bool FoldFrozenLinearBatchnorm(Block* b) { bool graph_modified = false; for (Node* n : b->nodes()) { for (Block* block : n->blocks()) { graph_modified |= FoldFrozenLinearBatchnorm(block); } if (n->kind() == aten::batch_norm && supportedLinearNode(n->inputs().at(0)->node())) { auto linear = n->inputs().at(0)->node(); auto bn = n; if (nonConstantParameters(linear) || nonConstantParameters(bn)) { continue; } auto bn_rm_ivalue = bn->namedInput("running_mean"); auto bn_rv_ivalue = bn->namedInput("running_var"); // check running_mean and running_var has value, if they are // None(track_running_stats=False), skipping the folding path. if (bn_rm_ivalue->type() == NoneType::get() && bn_rv_ivalue->type() == NoneType::get()) { continue; } auto bn_rm = constant_as(bn->namedInput("running_mean")).value(); auto bn_rv = constant_as(bn->namedInput("running_var")).value(); auto bn_eps = constant_as(bn->namedInput("eps")).value(); auto linear_w = constant_as(linear->namedInput("weight")).value(); int64_t linear_out_features = linear_w.size(0); int64_t bn_num_features = bn_rm.size(0); // Linear-BN needs to be fused while preserving the shapes of linear // weight/bias. To preserve the shapes of linear weight/bias, the channel // dim of bn needs to be broadcastable with the last dim of linear, // because bn operates over the channel dim, (N, C_in, H, W) while linear // operates over the last dim, (*, H_in). To be broadcastable, the number // of features in bn and the number of output features from linear must // satisfy the following condition: // 1. they are equal, or // 2. the number of features in bn is 1 // Otherwise, skip the folding path if (!(linear_out_features == bn_num_features || bn_num_features == 1)) { continue; } // implementation taken from torch/nn/utils/fusion.py Tensor linear_b; if (linear->namedInput("bias")->type() == NoneType::get()) { at::ScalarType bias_dtype = bn_rm.scalar_type(); at::ScalarType weight_dtype = linear_w.scalar_type(); at::DeviceType weight_device = linear_w.device().type(); if (weight_device == at::kCUDA && (weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) && bias_dtype == at::kFloat) { bias_dtype = weight_dtype; } linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype)); } else { linear_b = constant_as(linear->namedInput("bias")).value(); } Tensor bn_w; if (bn->namedInput("weight")->type() == NoneType::get()) { bn_w = at::ones_like(bn_rm); } else { bn_w = constant_as(bn->namedInput("weight")).value(); } Tensor bn_b; if (n->namedInput("bias")->type() == NoneType::get()) { bn_b = at::zeros_like(bn_rm); } else { bn_b = constant_as(bn->namedInput("bias")).value(); } LinearBNParameters params; params.linear_w = linear_w; params.linear_b = linear_b; params.bn_rm = bn_rm; params.bn_rv = bn_rv; params.bn_eps = bn_eps; params.bn_w = bn_w; params.bn_b = bn_b; std::tuple out = computeUpdatedLinearWeightAndBias(params); WithInsertPoint guard(linear); auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out)); auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out)); auto linear_w_value = linear->namedInput("weight"); auto linear_b_value = linear->namedInput("bias"); fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn"); fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn"); linear->replaceInputWith(linear_w_value, fused_linear_w); linear->replaceInputWith(linear_b_value, fused_linear_b); bn->output()->replaceAllUsesWith(linear->output()); graph_modified = true; } } return graph_modified; } } // namespace bool FoldFrozenLinearBatchnorm(std::shared_ptr& graph) { bool graph_modified = FoldFrozenLinearBatchnorm(graph->block()); EliminateDeadCode(graph); return graph_modified; } } // namespace torch::jit