#include #include #include #include #include namespace torch { namespace jit { namespace fuser { namespace onednn { using opkind = dnnl::graph::op::kind; static void fixConvOptionalBias(Node* node) { if (node->namedInput("bias")->mustNotBeNone() == false) { // Replace non-existent optional bias with const None auto g = node->owningGraph(); auto n = g->createNone(); auto v = n->insertBefore(node)->output(); node->replaceInput(2, v); } } static std::optional getDimensions(Value* v) { if (v->type()->isSubtypeOf(TensorType::get())) { return v->type()->cast()->sizes().size(); } else { return std::nullopt; } } // PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as // Wildcards instead. They make the integration code with PyTorch simpler by // passing every op to the oneDNN Graph library in the add_op call - // no need to check beforehand whether the op is supported by oneDNN Graph or // not oneDNN Graph ops separated by wildcards don't end up in the same // partition. static Operator makeWildcardOp(Node* node) { auto o = Operator(node, opkind::Wildcard); // wildcard op contains only topology info for (size_t i = 0; i < node->inputs().size(); i++) { o.setInput(0, i); } for (size_t i = 0; i < node->outputs().size(); i++) { o.setOutput(i); } return o; } // If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph // op, then we create a wildcard op corresponding to that PyTorch op instead. #define REQUIRE(cond) \ if (!(cond)) { \ GRAPH_DEBUG("Unsupported condition " #cond "\n"); \ return makeWildcardOp(node); \ } Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) { return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0); } Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) { REQUIRE( node->input(0)->type()->isSubtypeOf(TensorType::get()) && node->input(1)->type()->isSubtypeOf(TensorType::get())) return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0); } // Map a PyTorch op to its corresponding oneDNN Graph op. // If mapping isn't possible, then create a wildcard op instead. // The mapping is done as per oneDNN Graph op schema defined in // third_party/ideep/mkl-dnn/src/interface/op_def.hpp. Operator LlgaGraphHelper::createOperator(Node* node) { auto nodeKind = node->kind(); // we're using an if-else clause instead of a switch staement // because we would soon be adding custom ops with function schemas. // We would have to use Symbol::fromQualString at that time anyway, // but we are okay with this choice, since this code is not in the hot-path. if (nodeKind == Symbol::fromQualString("aten::conv2d")) { fixConvOptionalBias(node); return Operator(node, opkind::Convolution) .setInput(0, 1, 2) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4) .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4) .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5) .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 6) .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX")) .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX")); } else if ( (nodeKind == Symbol::fromQualString("aten::_convolution")) || (nodeKind == Symbol::fromQualString("aten::convolution"))) { bool transposed = toIValue(node->namedInput("transposed"))->toBool(); REQUIRE(!transposed); return Operator(node, opkind::Convolution) .setInput(0, 1, 2) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4) .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4) .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5) .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 8) .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX")) .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX")); } else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) { auto training = toIValue(node->namedInput("training")); REQUIRE(training.has_value()); // cannot get training status in script mode if (!training->toBool()) { return Operator(node, opkind::BatchNormInference) .setInput(0, 1, 2, 3, 4) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 7) .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX")); } } else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) { auto normalized_shape = toIValue(node->namedInput("normalized_shape")); REQUIRE(normalized_shape->toIntList().size() == 1); return Operator(node, opkind::LayerNorm) .setInput(0, 2, 3) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 4) .setAttr(dnnl::graph::op::attr::keep_stats, false); } else if (nodeKind == Symbol::fromQualString("aten::addmm")) { auto alpha = toIValue(node->namedInput("alpha")); auto beta = toIValue(node->namedInput("beta")); if (alpha.has_value() && beta.has_value()) { if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) { return Operator(node, opkind::MatMul) .setInput(1, 2, 0) .setOutput(dnnl_graph_, 0); } else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) { return Operator(node, opkind::MatMul) .setInput(1, 2) .setOutput(dnnl_graph_, 0); } } } else if (nodeKind == Symbol::fromQualString("aten::add")) return makeBinaryOp(node, opkind::Add); else if (nodeKind == Symbol::fromQualString("aten::mul")) return makeBinaryOp(node, opkind::Multiply); else if (nodeKind == Symbol::fromQualString("aten::div")) return makeBinaryOp(node, opkind::Divide); else if (nodeKind == Symbol::fromQualString("aten::tanh")) return makeEltwiseOp(node, opkind::Tanh); else if (nodeKind == Symbol::fromQualString("aten::relu")) return makeEltwiseOp(node, opkind::ReLU); else if (nodeKind == Symbol::fromQualString("aten::elu")) return makeEltwiseOp(node, opkind::Elu) .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1); else if (nodeKind == Symbol::fromQualString("aten::sigmoid")) return makeEltwiseOp(node, opkind::Sigmoid); else if (nodeKind == Symbol::fromQualString("aten::gelu")) return makeEltwiseOp(node, opkind::GELU); else if (nodeKind == Symbol::fromQualString("aten::round")) return makeEltwiseOp(node, opkind::Round); else if (nodeKind == Symbol::fromQualString("aten::exp")) return makeEltwiseOp(node, opkind::Exp); else if (nodeKind == Symbol::fromQualString("aten::sqrt")) return makeEltwiseOp(node, opkind::Sqrt); else if (nodeKind == Symbol::fromQualString("aten::abs")) return makeEltwiseOp(node, opkind::Abs); else if (nodeKind == Symbol::fromQualString("aten::square")) return makeEltwiseOp(node, opkind::Square); else if (nodeKind == Symbol::fromQualString("aten::clamp")) { // PyTorch API already checks that both min & max are not None. // But we can check it nevertheless. auto clamp_min = toIValue(node->input(1)); auto clamp_max = toIValue(node->input(2)); REQUIRE(!(clamp_max->isNone() && clamp_min->isNone())); auto clamp_min_value = (clamp_min->isNone()) ? -std::numeric_limits::infinity() : Operator::ScalarToFloat(node, 1); auto clamp_max_value = (clamp_max->isNone()) ? std::numeric_limits::infinity() : Operator::ScalarToFloat(node, 2); return makeEltwiseOp(node, opkind::Clamp) .setAttr(dnnl::graph::op::attr::min, clamp_min_value) .setAttr(dnnl::graph::op::attr::max, clamp_max_value); } else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) { return makeEltwiseOp(node, opkind::Clamp) .setAttr(dnnl::graph::op::attr::min, Operator::ScalarToFloat, 1) .setAttr(dnnl::graph::op::attr::max, Operator::ScalarToFloat, 2); } else if (nodeKind == Symbol::fromQualString("aten::hardswish")) return makeEltwiseOp(node, opkind::HardSwish); else if (nodeKind == Symbol::fromQualString("aten::log")) return makeEltwiseOp(node, opkind::Log); else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) { return makeEltwiseOp(node, opkind::LeakyReLU) .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1); } else if (nodeKind == Symbol::fromQualString("aten::relu6")) { return makeEltwiseOp(node, opkind::Clamp) .setAttr(dnnl::graph::op::attr::min, 0.f) .setAttr(dnnl::graph::op::attr::max, 6.f); } else if ( (nodeKind == Symbol::fromQualString("aten::softmax")) || (nodeKind == Symbol::fromQualString("aten::_softmax"))) { auto axis = toIValue(node->namedInput("dim"))->toInt(); return Operator(node, opkind::SoftMax) .setInput(0) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::axis, axis); } else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) { auto axis = toIValue(node->namedInput("dim"))->toInt(); return Operator(node, opkind::LogSoftmax) .setInput(0) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::axis, axis); } else if (nodeKind == Symbol::fromQualString("aten::cat")) { auto o = Operator(node, opkind::Concat); REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct); REQUIRE(node->namedInput("tensors")->uses().size() == 1); REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant); // aten::cat needs a special handling since it takes a Tensor[] as input. // We set the inputs of ListConstruct as the inputs of cat. // // Pytorch IR: LLGA sees: // %a %b %c %dim %a %b %c // \ | / | \ | / // prim::ListConstruct prim::Constant llga::Concat[axis=%dim] // \ / // aten::cat auto listConstruct = node->input(0)->node(); for (auto input : listConstruct->inputs()) o.setInputValue(input); return o.setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::axis, Operator::Int, 1); } else if ( (nodeKind == Symbol::fromQualString("aten::max_pool2d")) || (nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) { // Currently, LLGA lacks support to create indices mask. // Once it's supported, max_pool2d_with_indices should be mapped differently REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant); auto rounding_type = toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor"; return Operator(node, opkind::MaxPool) .setInput(0) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1) .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2) .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 4) .setAttr( dnnl::graph::op::attr::rounding_type, std::string(rounding_type)) .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX")); } else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) { // TODO: do we need add checks for all Constants? REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant); auto rounding_type = toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor"; auto divisor_override = toIValue(node->namedInput("divisor_override")); REQUIRE(divisor_override->isNone()); return Operator(node, opkind::AvgPool) .setInput(0) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1) .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2) .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3) .setAttr(dnnl::graph::op::attr::exclude_pad, !Operator::Bool(node, 5)) .setAttr( dnnl::graph::op::attr::rounding_type, std::string(rounding_type)) .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX")); } else if (nodeKind == Symbol::fromQualString("aten::matmul")) { auto dim0 = getDimensions(node->namedInput("self")).value_or(-1); auto dim1 = getDimensions(node->namedInput("other")).value_or(-1); // TODO: support all shape combinations REQUIRE( (dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) || (dim0 == 3 && dim1 == 2)); return Operator(node, opkind::MatMul) .setInput(0, 1) .setOutput(dnnl_graph_, 0); } // fall through else if (nodeKind == Symbol::fromQualString("aten::mm")) { return Operator(node, opkind::MatMul) .setInput(0, 1) .setOutput(dnnl_graph_, 0); } else if (nodeKind == Symbol::fromQualString("aten::bmm")) { return Operator(node, opkind::MatMul) .setInput(0, 1) .setOutput(dnnl_graph_, 0); } else if (nodeKind == Symbol::fromQualString("aten::linear")) { return Operator(node, opkind::MatMul) .setInput(0, 1, 2) .setOutput(dnnl_graph_, 0) .setAttr(dnnl::graph::op::attr::transpose_b, true); } else if (nodeKind == Symbol::fromQualString("aten::permute")) { REQUIRE(aliasDb_->hasInputWriters(node) == false); return Operator(node, opkind::StaticTranspose) .setInput(0) .setOutput(dnnl_graph_, 0) .setAttr( dnnl::graph::op::attr::order, toIValue(node->namedInput("dims"))->toIntVector()); } else if (nodeKind == Symbol::fromQualString("aten::contiguous")) { // Contiguous should only be mapped to oneDNN Graph if the destination // memory-layout is different than the source memory-format // Strides would be different, but shape would be same auto typeOfInput = node->input(0)->type()->expect(); auto typeOfOutput = node->output(0)->type()->expect(); auto inputStrides = typeOfInput->strides().concrete_sizes(); auto outputStrides = typeOfOutput->strides().concrete_sizes(); REQUIRE(inputStrides != outputStrides); return Operator(node, opkind::Reorder) .setInput(0) .setOutput(dnnl_graph_, 0); } GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard"); return makeWildcardOp(node); } static DeviceType inferDeviceFromValue(Value* v) { auto tt = v->type()->cast(); if (!tt) { return at::kCPU; } auto device = tt->device(); if (!device) { return at::kCPU; } return device->type(); } static DeviceType inferDevice(const std::shared_ptr& graph) { auto dt = inferDeviceFromValue(graph->inputs()[0]); TORCH_CHECK( std::all_of( graph->inputs().begin(), graph->inputs().end(), [dt](Value* v) { return inferDeviceFromValue(v) == dt; }), "All inputs must have the same deive type"); return dt; } static dnnl::engine::kind getLlgaEngineKind(DeviceType type) { switch (type) { case DeviceType::CPU: return dnnl::engine::kind::cpu; default: TORCH_CHECK(false, "Not support device type ", type); } } static void mayAddListConstructIntoConcatPartition( Node* n, OpPartitionMap& opToOwningPartition) { // Since prim::ListConstruct is not visible to the LLGA, // it will not be in any partition returned from partfuseritioning results. // We need rewrite opToOwningPartition to make the prim::ListConstruct to be // 'virtually' in the same partition with the aten::cat, so that // prim::ListConstruct can be fused into the fusion group by graph fuser. // We emphasize on 'virtually' because get_num_ops() for cat's partition // would still return 1. if (n->kind() == aten::cat && opToOwningPartition.has(n)) { auto listConstrcut = n->namedInput("tensors")->node(); auto partitionId = opToOwningPartition.get(n); opToOwningPartition.add(listConstrcut, partitionId); } } // Verify that input tensors are compatible with oneDNN Graph. // Scalars would be converted to 1-D tensors later anyway, // but they shouldn't be complex-double // If this check fails, convert op to wildcard static bool checkInputCompatibility(Node* node) { auto allInputs = node->inputs(); for (auto input : allInputs) { c10::IValue inputIValue = toIValue(input); if (inputIValue.isTensor()) { const at::Tensor& tensor = inputIValue.toTensor(); if (tensor.device() != at::kCPU) { return false; } auto dtype = tensor.scalar_type(); if ((dtype != at::ScalarType::BFloat16) && (dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) { // We've allowed Long dtype here although oneDNN Graph does not support // Long dtype because oneDNN Graph will end up not handling the op that // has an input with Long dtype, so it'd be handled by PyTorch. return false; } } else if (inputIValue.isScalar()) { if (inputIValue.isComplexDouble()) { return false; } } else if (input->type()->isSubtypeOf(TensorType::get())) { auto input_typeptr = input->type()->cast(); if (input_typeptr->scalarType().has_value()) { at::ScalarType dtype = input_typeptr->scalarType().value(); if ((dtype != at::ScalarType::Float) && (dtype != at::ScalarType::BFloat16)) { return false; } } } } return true; } LlgaGraphHelper::LlgaGraphHelper( const std::shared_ptr& graph, dnnl::graph::partition::policy policy) { auto deviceType = inferDevice(graph); auto engineKind = getLlgaEngineKind(deviceType); dnnl_graph_ = std::make_unique(engineKind); aliasDb_ = std::make_unique(graph); GRAPH_DEBUG("Constructing LLGA graph"); // TODO: select nodes in top-level block for now for (auto* node : graph->block()->nodes()) { auto kindOfNode = node->kind(); GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString()); if (checkInputCompatibility(node)) { auto op = createOperator(node); dnnl_graph_->add_op(op.llgaOp()); GRAPH_DEBUG(" Added node ", kindOfNode.toQualString()); } else { GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString()); dnnl_graph_->add_op(makeWildcardOp(node).llgaOp()); } for (Value* input : node->inputs()) { tensorIdToValue_.emplace(input->unique(), input); } } dnnl_graph_->finalize(); GRAPH_DEBUG("Get Partitions"); std::vector partitions = dnnl_graph_->get_partitions(policy); // excluded unsupported Wildcard partitions for (auto& partition : partitions) { if (partition.is_supported()) { partitions_.push_back(partition); } } GRAPH_DEBUG(" Got #partitions: ", partitions_.size()); for (size_t partId = 0; partId < partitions_.size(); partId++) { for (auto opId : partitions_[partId].get_ops()) { opToOwningPartition_.add(opId, partId); } } // Scanning the graph again for post processing for (auto* node : graph->block()->nodes()) { mayAddListConstructIntoConcatPartition(node, opToOwningPartition_); } } bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) { return node->hasAttribute(attr::Subgraph) && node->kind() == prim::oneDNNFusionGroup; } bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) { TORCH_CHECK( isLlgaSubgraph(subgraph), "The consumer node does not contain a subgraph"); if (!shouldConsiderForMerge(toMerge)) { return false; } return opToOwningPartition_.get(toMerge) == opToOwningPartition_.get(subgraph); } // Except for conv & GEMMs, which should always be handled by oneDNN Graph, // only use single-op partitions for ops unsupported by NNC, or ops // that oneDNN executes faster. prim::ListConstruct is an exception, since // we simply want to fuse it with cat. static bool isBetterSuitedForLLGA(NodeKind kindOfOp) { return ( (kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) || (kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) || (kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) || (kindOfOp == aten::mm) || (kindOfOp == aten::linear) || (kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct)); } bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) { if (opToOwningPartition_.has(node)) { auto partitionId = opToOwningPartition_.get(node); if (partitions_[partitionId].get_ops_num() == 1) { auto kindOfNode = node->kind(); return isBetterSuitedForLLGA(kindOfNode); } else { // multi-op partition return true; } } else { // this op isn't present in any partition return false; } } bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) { // if we're already in the process of merging if (isLlgaSubgraph(node)) { return true; } return checkForSingleOpPartition(node); } Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) { auto partitionId = opToOwningPartition_.get(n); GRAPH_DEBUG( "Creating FusionGroup_", partitionId, " for ", n->kind().toQualString()); auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( n, prim::oneDNNFusionGroup, aliasDb); opToOwningPartition_.add(group, partitionId); return group; } void LlgaGraphHelper::mergeNodeIntoSubgraph( Node* toMerge, Node* subgraphNode, AliasDb& aliasDb) { if (isLlgaSubgraph(toMerge)) { GRAPH_DEBUG( "Merging ", toMerge->kind().toQualString(), "_", opToOwningPartition_.get(toMerge), " into ", subgraphNode->kind().toQualString(), "_", opToOwningPartition_.get(subgraphNode)); } else { GRAPH_DEBUG( "Merging ", toMerge->kind().toQualString(), " into ", subgraphNode->kind().toQualString(), "_", opToOwningPartition_.get(subgraphNode)); } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( toMerge, subgraphNode, aliasDb); } void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) { TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node"); auto partitionId = opToOwningPartition_.get(subgraphNode); auto expectOpNum = partitions_[partitionId].get_ops_num(); auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph)); if (expectOpNum != actualOpNum) { GRAPH_DEBUG( "Unmerging FusionGroup_", partitionId, ". Expected ", expectOpNum, " ops, but got ", actualOpNum, " ops."); SubgraphUtils::unmergeSubgraph(subgraphNode); } } size_t LlgaGraphHelper::countSupportedOps( const std::shared_ptr& graph) const { // TODO: count nodes in top-level block for now size_t cnt = 0; for (auto* node : graph->block()->nodes()) { auto nodeKind = node->kind(); if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) { cnt++; } } return cnt; } std::vector LlgaGraphHelper::getPartitions() const { return partitions_; } std::map LlgaGraphHelper::getTensorIdToValue() const { return tensorIdToValue_; } LlgaNodeWrapper::LlgaNodeWrapper(const Node* node) : n(const_cast(node)) { // NOLINT TORCH_CHECK( LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node"); } void LlgaNodeWrapper::setOpaqueLayout(size_t offset) { const auto num_output = n->is(attr::output_layouts).size(); TORCH_CHECK( offset < num_output, "Out of range. (Invalid index ", offset, " for attr::output_layouts with size ", num_output, ")"); auto& layouts = const_cast&>(n->is(attr::output_layouts)); // NOLINT layouts.at(offset) = OPAQUE_LAYOUT; } bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const { const auto num_output = n->is(attr::output_layouts).size(); TORCH_CHECK( offset < num_output, "Out of range. (Invalid index ", offset, " for attr::output_layouts with size ", num_output, ")"); return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT; } } // namespace onednn } // namespace fuser } // namespace jit } // namespace torch