#pragma once #include #include #include #include #include #include #include namespace torch { namespace lazy { using TSOpVector = std::vector; class TORCH_API TsNode : public lazy::Node { public: TsNode( OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, hash_t hash_seed = kHashSeed); TsNode( OpKind op, OpList operands, const std::function& shape_fn, size_t num_outputs, hash_t hash_seed = kHashSeed); TsNode( OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed = kHashSeed); TsNode( OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); ~TsNode() override = default; hash_t hash() const override; hash_t shapeHash() const override; const std::string getPythonStacktrace() const; // Lower is a backend-specific method since it returns a backend specific // type. hence, it is convenient to define it differently per-backend rather // than at Node API virtual TSOpVector Lower( std::shared_ptr function, TSLoweringContext* loctx) const; private: // The hash of the dag WITH size info. Used for shape caching hash_t shape_hash_; // The hash of the dag used to look up the compiled graph by a hash // in this case, we will use the dag hash WITHOUT size info if dynamic shape // is enabled and use the dag hash WITH size info otherwise. hash_t dag_hash_; }; // Note: this OpKind is separate from ltc_ops.h since it would be a circular // import otherwise, I like leaving TensorList in this file, and I think most of // ltc_ops special cases will be deleted anyway const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); // TensorList represents an at::TensorList which is a vector[Tensor] but is also // a first-class IValue and can be fed as a single input to a TS program. It is // much easier to handle TensorLists in Lazy Tensor code if they are represented // as a single Node so there can be more than one TensorList and more than one // Tensor side-by-side as operands to an op. // // Note: shape is undefined for TensorList. We assert in some places that // #shapes matches #outputs and this stems from // the fact that currently all IR nodes represent tensors (there is no // type system for this IR). Becuase of this, TensorList is a bit of a // hack. // // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and // then implement it as NotImplemented for TensorList, also fixing the assertion // that would fail. struct TORCH_API TensorList : public TsNode { static OpKind ClassOpKind() { return tensor_list_opkind; } TensorList() = delete; TensorList(OpList values); bool CanBeReused(OpList values) const { return operands() == std::vector(values.begin(), values.end()); } TSOpVector Lower( std::shared_ptr function, TSLoweringContext* loctx) const override; }; } // namespace lazy } // namespace torch