• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1The C++ Frontend
2================
3
4The PyTorch C++ frontend is a C++17 library for CPU and GPU
5tensor computation, with automatic differentiation and high level building
6blocks for state of the art machine learning applications.
7
8Description
9-----------
10
11The PyTorch C++ frontend can be thought of as a C++ version of the
12PyTorch Python frontend, providing automatic differentiation and various higher
13level abstractions for machine learning and neural networks.  Specifically,
14it consists of the following components:
15
16+----------------------+------------------------------------------------------------------------+
17| Component            | Description                                                            |
18+======================+========================================================================+
19| ``torch::Tensor``    | Automatically differentiable, efficient CPU and GPU enabled tensors    |
20+----------------------+------------------------------------------------------------------------+
21| ``torch::nn``        | A collection of composable modules for neural network modeling         |
22+----------------------+------------------------------------------------------------------------+
23| ``torch::optim``     | Optimization algorithms like SGD, Adam or RMSprop to train your models |
24+----------------------+------------------------------------------------------------------------+
25| ``torch::data``      | Datasets, data pipelines and multi-threaded, asynchronous data loader  |
26+----------------------+------------------------------------------------------------------------+
27| ``torch::serialize`` | A serialization API for storing and loading model checkpoints          |
28+----------------------+------------------------------------------------------------------------+
29| ``torch::python``    | Glue to bind your C++ models into Python                               |
30+----------------------+------------------------------------------------------------------------+
31| ``torch::jit``       | Pure C++ access to the TorchScript JIT compiler                        |
32+----------------------+------------------------------------------------------------------------+
33
34End-to-end example
35------------------
36
37Here is a simple, end-to-end example of defining and training a simple
38neural network on the MNIST dataset:
39
40.. code-block:: cpp
41
42  #include <torch/torch.h>
43
44  // Define a new Module.
45  struct Net : torch::nn::Module {
46    Net() {
47      // Construct and register two Linear submodules.
48      fc1 = register_module("fc1", torch::nn::Linear(784, 64));
49      fc2 = register_module("fc2", torch::nn::Linear(64, 32));
50      fc3 = register_module("fc3", torch::nn::Linear(32, 10));
51    }
52
53    // Implement the Net's algorithm.
54    torch::Tensor forward(torch::Tensor x) {
55      // Use one of many tensor manipulation functions.
56      x = torch::relu(fc1->forward(x.reshape({x.size(0), 784})));
57      x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
58      x = torch::relu(fc2->forward(x));
59      x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
60      return x;
61    }
62
63    // Use one of many "standard library" modules.
64    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
65  };
66
67  int main() {
68    // Create a new Net.
69    auto net = std::make_shared<Net>();
70
71    // Create a multi-threaded data loader for the MNIST dataset.
72    auto data_loader = torch::data::make_data_loader(
73        torch::data::datasets::MNIST("./data").map(
74            torch::data::transforms::Stack<>()),
75        /*batch_size=*/64);
76
77    // Instantiate an SGD optimization algorithm to update our Net's parameters.
78    torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
79
80    for (size_t epoch = 1; epoch <= 10; ++epoch) {
81      size_t batch_index = 0;
82      // Iterate the data loader to yield batches from the dataset.
83      for (auto& batch : *data_loader) {
84        // Reset gradients.
85        optimizer.zero_grad();
86        // Execute the model on the input data.
87        torch::Tensor prediction = net->forward(batch.data);
88        // Compute a loss value to judge the prediction of our model.
89        torch::Tensor loss = torch::nll_loss(prediction, batch.target);
90        // Compute gradients of the loss w.r.t. the parameters of our model.
91        loss.backward();
92        // Update the parameters based on the calculated gradients.
93        optimizer.step();
94        // Output the loss and checkpoint every 100 batches.
95        if (++batch_index % 100 == 0) {
96          std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
97                    << " | Loss: " << loss.item<float>() << std::endl;
98          // Serialize your model periodically as a checkpoint.
99          torch::save(net, "net.pt");
100        }
101      }
102    }
103  }
104
105To see more complete examples of using the PyTorch C++ frontend, see `the example repository
106<https://github.com/pytorch/examples/tree/master/cpp>`_.
107
108Philosophy
109----------
110
111PyTorch's C++ frontend was designed with the idea that the Python frontend is
112great, and should be used when possible; but in some settings, performance and
113portability requirements make the use of the Python interpreter infeasible. For
114example, Python is a poor choice for low latency, high performance or
115multithreaded environments, such as video games or production servers.  The
116goal of the C++ frontend is to address these use cases, while not sacrificing
117the user experience of the Python frontend.
118
119As such, the C++ frontend has been written with a few philosophical goals in mind:
120
121* **Closely model the Python frontend in its design**, naming, conventions and
122  functionality.  While there may be occasional differences between the two
123  frontends (e.g., where we have dropped deprecated features or fixed "warts"
124  in the Python frontend), we guarantee that the effort in porting a Python model
125  to C++ should lie exclusively in **translating language features**,
126  not modifying functionality or behavior.
127
128* **Prioritize flexibility and user-friendliness over micro-optimization.**
129  In C++, you can often get optimal code, but at the cost of an extremely
130  unfriendly user experience.  Flexibility and dynamism is at the heart of
131  PyTorch, and the C++ frontend seeks to preserve this experience, in some
132  cases sacrificing performance (or "hiding" performance knobs) to keep APIs
133  simple and explicable.  We want researchers who don't write C++ for a living
134  to be able to use our APIs.
135
136A word of warning: Python is not necessarily slower than
137C++! The Python frontend calls into C++ for almost anything computationally expensive
138(especially any kind of numeric operation), and these operations will take up
139the bulk of time spent in a program.  If you would prefer to write Python,
140and can afford to write Python, we recommend using the Python interface to
141PyTorch. However, if you would prefer to write C++, or need to write C++
142(because of multithreading, latency or deployment requirements), the
143C++ frontend to PyTorch provides an API that is approximately as convenient,
144flexible, friendly and intuitive as its Python counterpart. The two frontends
145serve different use cases, work hand in hand, and neither is meant to
146unconditionally replace the other.
147
148Installation
149------------
150
151Instructions on how to install the C++ frontend library distribution, including
152an example for how to build a minimal application depending on LibTorch, may be
153found by following `this <https://pytorch.org/cppdocs/installing.html>`_ link.
154