#include #include #include namespace torch { namespace jit { namespace fuser { namespace onednn { static void LayoutPropagation(Node* n) { if (!LlgaGraphHelper::isLlgaSubgraph(n)) return; // initial attr::output_layouts if undefined if (!n->hasAttribute(attr::output_layouts)) { const auto num_output = n->outputs().size(); GRAPH_DEBUG("Initial output_layouts of size ", num_output); std::vector layouts(num_output, STRIDED_LAYOUT); n->is_(attr::output_layouts, layouts); } for (auto input : n->inputs()) { auto prev = input->node(); auto offset = input->offset(); if (LlgaGraphHelper::isLlgaSubgraph(prev)) { bool useOpaqueLayout = true; for (auto& use : input->uses()) { if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) { useOpaqueLayout = false; break; } } if (useOpaqueLayout) { LlgaNodeWrapper(prev).setOpaqueLayout(offset); } } } } static void LayoutPropagation(at::ArrayRef blocks) { for (Block* block : blocks) for (Node* node : block->nodes()) LayoutPropagation(node); } void PropagateLayout(const std::shared_ptr& graph) { LayoutPropagation(graph->block()); } } // namespace onednn } // namespace fuser } // namespace jit } // namespace torch