• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Tensor Basics
2=============
3
4The ATen tensor library backing PyTorch is a simple tensor library that exposes
5the Tensor operations in Torch directly in C++17. ATen's API is auto-generated
6from the same declarations PyTorch uses so the two APIs will track each other
7over time.
8
9Tensor types are resolved dynamically, such that the API is generic and does not
10include templates. That is, there is one ``Tensor`` type. It can hold a CPU or
11CUDA Tensor, and the tensor may have Doubles, Float, Ints, etc. This design
12makes it easy to write generic code without templating everything.
13
14See https://pytorch.org/cppdocs/api/namespace_at.html#functions for the provided
15API. Excerpt:
16
17.. code-block:: cpp
18
19  Tensor atan2(const Tensor & other) const;
20  Tensor & atan2_(const Tensor & other);
21  Tensor pow(Scalar exponent) const;
22  Tensor pow(const Tensor & exponent) const;
23  Tensor & pow_(Scalar exponent);
24  Tensor & pow_(const Tensor & exponent);
25  Tensor lerp(const Tensor & end, Scalar weight) const;
26  Tensor & lerp_(const Tensor & end, Scalar weight);
27  Tensor histc() const;
28  Tensor histc(int64_t bins) const;
29  Tensor histc(int64_t bins, Scalar min) const;
30  Tensor histc(int64_t bins, Scalar min, Scalar max) const;
31
32In place operations are also provided, and always suffixed by `_` to indicate
33they will modify the Tensor.
34
35Efficient Access to Tensor Elements
36-----------------------------------
37
38When using Tensor-wide operations, the relative cost of dynamic dispatch is very
39small. However, there are cases, especially in your own kernels, where efficient
40element-wise access is needed, and the cost of dynamic dispatch inside the
41element-wise loop is very high. ATen provides *accessors* that are created with
42a single dynamic check that a Tensor is the type and number of dimensions.
43Accessors then expose an API for accessing the Tensor elements efficiently.
44
45Accessors are temporary views of a Tensor. They are only valid for the lifetime
46of the tensor that they view and hence should only be used locally in a
47function, like iterators.
48
49Note that accessors are not compatible with CUDA tensors inside kernel functions.
50Instead, you will have to use a *packed accessor* which behaves the same way but
51copies tensor metadata instead of pointing to it.
52
53It is thus recommended to use *accessors* for CPU tensors and *packed accessors*
54for CUDA tensors.
55
56CPU accessors
57*************
58
59.. code-block:: cpp
60
61  torch::Tensor foo = torch::rand({12, 12});
62
63  // assert foo is 2-dimensional and holds floats.
64  auto foo_a = foo.accessor<float,2>();
65  float trace = 0;
66
67  for(int i = 0; i < foo_a.size(0); i++) {
68    // use the accessor foo_a to get tensor data.
69    trace += foo_a[i][i];
70  }
71
72CUDA accessors
73**************
74
75
76.. code-block:: cpp
77
78  __global__ void packed_accessor_kernel(
79      torch::PackedTensorAccessor64<float, 2> foo,
80      float* trace) {
81    int i = threadIdx.x;
82    gpuAtomicAdd(trace, foo[i][i]);
83  }
84
85  torch::Tensor foo = torch::rand({12, 12});
86
87  // assert foo is 2-dimensional and holds floats.
88  auto foo_a = foo.packed_accessor64<float,2>();
89  float trace = 0;
90
91  packed_accessor_kernel<<<1, 12>>>(foo_a, &trace);
92
93In addition to ``PackedTensorAccessor64`` and ``packed_accessor64`` there are
94also the corresponding ``PackedTensorAccessor32`` and ``packed_accessor32``
95which use 32-bit integers for indexing. This can be quite a bit faster on CUDA
96but may lead to overflows in the indexing calculations.
97
98Note that the template can hold other parameters such as the pointer restriction
99and the integer type for indexing. See documentation for a thorough template
100description of *accessors* and *packed accessors*.
101
102
103Using Externally Created Data
104-----------------------------
105
106If you already have your tensor data allocated in memory (CPU or CUDA),
107you can view that memory as a ``Tensor`` in ATen:
108
109.. code-block:: cpp
110
111  float data[] = { 1, 2, 3,
112                   4, 5, 6 };
113  torch::Tensor f = torch::from_blob(data, {2, 3});
114
115These tensors cannot be resized because ATen does not own the memory, but
116otherwise behave as normal tensors.
117
118Scalars and zero-dimensional tensors
119------------------------------------
120
121In addition to the ``Tensor`` objects, ATen also includes ``Scalar``\s that
122represent a single number. Like a Tensor, Scalars are dynamically typed and can
123hold any one of ATen's number types. Scalars can be implicitly constructed from
124C++ number types. Scalars are needed because some functions like ``addmm`` take
125numbers along with Tensors and expect these numbers to be the same dynamic type
126as the tensor. They are also used in the API to indicate places where a function
127will *always* return a Scalar value, like ``sum``.
128
129.. code-block:: cpp
130
131  namespace torch {
132  Tensor addmm(Scalar beta, const Tensor & self,
133               Scalar alpha, const Tensor & mat1,
134               const Tensor & mat2);
135  Scalar sum(const Tensor & self);
136  } // namespace torch
137
138  // Usage.
139  torch::Tensor a = ...
140  torch::Tensor b = ...
141  torch::Tensor c = ...
142  torch::Tensor r = torch::addmm(1.0, a, .5, b, c);
143
144In addition to ``Scalar``\s, ATen also allows ``Tensor`` objects to be
145zero-dimensional. These Tensors hold a single value and they can be references
146to a single element in a larger ``Tensor``. They can be used anywhere a
147``Tensor`` is expected. They are normally created by operators like `select`
148which reduce the dimensions of a ``Tensor``.
149
150.. code-block:: cpp
151
152  torch::Tensor two = torch::rand({10, 20});
153  two[1][2] = 4;
154  // ^^^^^^ <- zero-dimensional Tensor
155