#include #include #include namespace torch::jit { void ArgumentSpecCreator::scan( const TypePtr& typ, size_t depth, const WrittenSlots& written_slots) { auto finishAggregate = [&](size_t pos) { // it is possible after all the work we did to scan this aggregate, // we found no tensors or optionals to specialize. In this case, just // generate a skip for the whole aggregate. bool any_spec = std::any_of( instructions_.begin() + pos, instructions_.end(), [](Inst i) { return i == SPECIALIZE_TENSOR || i == SPECIALIZE_OPTIONAL || i == SPECIALIZE_OPTIONAL_TENSOR; }); if (!any_spec) { instructions_[pos] = SKIP; instructions_.resize(pos + 1); } else { instructions_.emplace_back(LEAVE); } }; // the simple vm that scans instructions_ has a limited stack depth, // this prevents going deeper than that. if (depth >= ARG_SPEC_DEPTH_LIMIT) { instructions_.emplace_back(SKIP); } if (typ->isSubtypeOf(*TensorType::get())) { num_tensors_++; instructions_.emplace_back(SPECIALIZE_TENSOR); } else if (typ->isSubtypeOf(*OptionalType::ofTensor())) { num_tensors_++; num_optionals_++; instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR); } else if (typ->kind() == TypeKind::OptionalType) { // note that Optional[Tuple] or Optional[Class] will just register // as optional (previously they didn't at all, so it's not a regression). num_optionals_++; instructions_.emplace_back(SPECIALIZE_OPTIONAL); } else if (auto tup = typ->cast()) { size_t pos = instructions_.size(); instructions_.emplace_back(ENTER_TUPLE); for (const auto& elem : tup->containedTypes()) { scan(elem, depth + 1, written_slots); } finishAggregate(pos); } else if (auto cls = typ->cast()) { size_t pos = instructions_.size(); instructions_.emplace_back(ENTER_OBJECT); for (size_t i = 0; i < cls->numAttributes(); ++i) { auto key = cls->name()->qualifiedName() + cls->getAttributes().at(i).getName(); // it is only safe to specialize because someone might have written to it if (!written_slots.count(key)) { scan(cls->containedTypes().at(i), depth + 1, written_slots); } else { instructions_.emplace_back(SKIP); } } finishAggregate(pos); } else { instructions_.emplace_back(SKIP); } }; // this is a coarse-grained guarantee that the slots of a class will not be // modified by the function. It works fine for things that used be read-only // modules, but will be overly conservative when some classes are written to. // Doing alias analysis and looking for writes to the class would be more // accurate. static void scanWrittenSlots( Block* block, ArgumentSpecCreator::WrittenSlots& written_slots) { for (Node* n : block->nodes()) { if (n->kind() == prim::SetAttr) { if (auto cls = n->inputs().at(0)->type()->cast()) { written_slots.insert(cls->name()->qualifiedName() + n->s(attr::name)); } } for (Block* subblock : n->blocks()) { scanWrittenSlots(subblock, written_slots); } if (n->hasAttribute(attr::Subgraph)) { scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots); } } } ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph) : num_inputs_(graph.inputs().size()) { WrittenSlots written_slots; scanWrittenSlots(graph.block(), written_slots); for (Value* input : graph.inputs()) { scan(input->type(), 0, written_slots); } } void ArgumentSpecCreator::dump() const { for (Inst inst : instructions_) { switch (inst) { case LEAVE: std::cout << "] "; break; case ENTER_TUPLE: std::cout << "Tuple["; break; case ENTER_OBJECT: std::cout << "Object["; break; case SKIP: std::cout << "Skip "; break; case SPECIALIZE_TENSOR: std::cout << "SpecializeTensor "; break; case SPECIALIZE_OPTIONAL_TENSOR: std::cout << "SpecializeOptionalTensor "; break; case SPECIALIZE_OPTIONAL: std::cout << "SpecializeOptional "; break; } } std::cout << "\n"; } ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input) const { ArgumentSpec spec(num_tensors_, num_optionals_); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) const IValue* stack[ARG_SPEC_DEPTH_LIMIT]; // The stack of IValue lists // The stack gets initialized with the input list stack[0] = last(input, num_inputs_).begin(); size_t stack_top = 0; // offset to the top of the stack for (Inst inst : instructions_) { switch (inst) { case SPECIALIZE_OPTIONAL_TENSOR: { // consume a tensor optional and add to the argspec // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) auto& arg = *stack[stack_top]++; spec.addOptional(arg); if (!arg.isNone()) { spec.addTensor(arg, with_grad); } } break; case SPECIALIZE_TENSOR: // consume a tensor and add to the argspec // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) spec.addTensor(*stack[stack_top]++, with_grad); break; case SPECIALIZE_OPTIONAL: // consume a non-tensor optional and add to the argspec // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) spec.addOptional(*stack[stack_top]++); break; case ENTER_TUPLE: { // consume tuple // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) const IValue* iv = stack[stack_top]++; AT_ASSERT(iv->isTuple(), "Expected Tuple but got ", iv->tagKind()); auto p = *reinterpret_cast(iv); auto tup_ptr = &p->elements()[0]; // push list of tuple elements to the stack stack[++stack_top] = tup_ptr; } break; case ENTER_OBJECT: { // consume object // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) const IValue* iv = stack[stack_top]++; AT_ASSERT(iv->isObject(), "Expected Object but got ", iv->tagKind()); auto obj_ptr = &iv->toObjectRef().slots()[0]; // push list of object elements to the stack stack[++stack_top] = obj_ptr; } break; case SKIP: // consume and skip an element // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) stack[stack_top]++; break; case LEAVE: --stack_top; break; } } return spec; } // For every input of a given graph, returns a most detailed type that can be // inferred for it based on this ArgumentSpec. void ArgumentSpecCreator::specializeTypes( Graph& graph, const ArgumentSpec& spec) const { auto input_types = fmap(graph.inputs(), [](Value* input) { return input->type(); }); std::vector> result_stack; result_stack.emplace_back(); std::vector input_stack = {input_types.data()}; std::vector> aggregate_creators; size_t tensor_arg_spec_offset = 0; // number of specialized tensors seen so far size_t optional_arg_spec_offset = 0; // number of specialized optionals seen so far for (Inst inst : instructions_) { switch (inst) { case SPECIALIZE_OPTIONAL_TENSOR: { auto& input_type = *input_stack.back()++; auto is_present = spec.isPresent(optional_arg_spec_offset++); if (!is_present) { result_stack.back().emplace_back(input_type); break; } auto& arg = spec.tensorAt(tensor_arg_spec_offset++); AT_ASSERT(arg.defined()); result_stack.back().emplace_back(arg.toType()); } break; case SPECIALIZE_TENSOR: { input_stack.back()++; auto& arg = spec.tensorAt(tensor_arg_spec_offset++); if (!arg.defined()) { result_stack.back().emplace_back(TensorType::get()->withUndefined()); } else { result_stack.back().emplace_back(arg.toType()); } } break; case SPECIALIZE_OPTIONAL: { auto is_present = spec.isPresent(optional_arg_spec_offset++); auto ot = (*input_stack.back()++)->expect(); if (!is_present) { result_stack.back().emplace_back(ot); } else { result_stack.back().emplace_back(ot->getElementType()); } } break; case ENTER_TUPLE: { auto tup = (*input_stack.back()++)->expect(); input_stack.emplace_back(tup->elements().data()); result_stack.emplace_back(); aggregate_creators.emplace_back( [&] { return TupleType::create(result_stack.back()); }); } break; case ENTER_OBJECT: { auto cls = (*input_stack.back()++)->expect(); input_stack.emplace_back(cls->containedTypes().data()); result_stack.emplace_back(); aggregate_creators.emplace_back( [&result_stack, cls] { return cls->refine(result_stack.back()); }); } break; case SKIP: result_stack.back().emplace_back(*input_stack.back()++); break; case LEAVE: TypePtr result = aggregate_creators.back()(); result_stack.pop_back(); aggregate_creators.pop_back(); input_stack.pop_back(); result_stack.back().emplace_back(std::move(result)); break; } } AT_ASSERT(result_stack.size() == 1); // FIXME: by doing this only on the inputs, we only capture graph inputs and // not // optionals in tuples or objects. For that to work, we would have // to investigate the uses of the inputs in detail to change the // accesses/ unwrapping auto inputs = graph.inputs(); for (const auto i : c10::irange(inputs.size())) { auto t = result_stack.back()[i]; if (auto ot = t->cast()) { // if an optional input hasn't been specialized above, it is None // so we disconnect the input here and replace its uses with // a constant WithInsertPoint guard(*graph.nodes().begin()); auto c = graph.insertConstant({}); inputs[i]->replaceAllUsesWith(c); } else { inputs[i]->setType(t); } } } } // namespace torch::jit