#pragma once #include #include #include #include #include #include #include #include #include namespace torch { namespace lazy { struct TorchScriptIrBuilder : IrBuilder { NodePtr MakeDeviceData( const std::shared_ptr& data) const override { return DeviceData::Create(data); } // TODO: Scalar node is not currently used by ts_backend. Enable reusing // Scalar node later if needed. NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode(value, type); } NodePtr MakeExpand( const Value& input0, const std::vector& size, const bool& is_scalar_expand) const override { return ReuseOrMakeNode(input0, size, is_scalar_expand); } NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, const std::optional& stype = std::nullopt) const override { return ReuseOrMakeNode(input0, dtype, stype); } NodePtr MakeTensorList(const OpList& inputs) const override { return ReuseOrMakeNode(inputs); } // Generic needs cleanup NodePtr MakeGeneric( const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const override { return MakeNode(op, operands, shape, num_outputs, hash_seed); } // dynamic ir nodes // TODO: verify if IR node reusing works for Dynamic shape ops NodePtr MakeSizeNode(const Value& input, size_t dim) const override { return MakeNode(input, dim); } NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { return MakeNode(a, b); } NodePtr MakeSizeMul(const Value& a, const Value& b) const override { return MakeNode(a, b); } NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { return MakeNode(a, b); } }; } // namespace lazy } // namespace torch