1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/evalue_util/print_evalue.h>
11 #include <executorch/extension/training/module/training_module.h>
12 #include <executorch/extension/training/optimizer/sgd.h>
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/span.h>
16 #include <executorch/runtime/executor/method.h>
17 #include <executorch/runtime/executor/program.h>
18 #include <executorch/runtime/executor/test/managed_memory_manager.h>
19 #include <executorch/runtime/platform/runtime.h>
20
21 #include <gtest/gtest.h>
22 #include <iostream>
23
24 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
25
26 using namespace ::testing;
27 using namespace executorch::extension::training::optimizer;
28 using namespace torch::executor::testing;
29 using exec_aten::ScalarType;
30 using exec_aten::Tensor;
31 using namespace torch::executor;
32 using torch::executor::util::FileDataLoader;
33
34 class TrainingLoopTest : public ::testing::Test {
35 protected:
SetUp()36 void SetUp() override {}
37 };
38
TEST_F(TrainingLoopTest,OptimizerSteps)39 TEST_F(TrainingLoopTest, OptimizerSteps) {
40 const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
41 executorch::runtime::Result<torch::executor::util::FileDataLoader>
42 loader_res = torch::executor::util::FileDataLoader::from(path);
43 ASSERT_EQ(loader_res.error(), Error::Ok);
44 auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
45 std::move(loader_res.get()));
46
47 auto mod = executorch::extension::training::TrainingModule(std::move(loader));
48
49 // Create inputs.
50 TensorFactory<ScalarType::Float> tf;
51 Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
52 Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
53
54 auto res = mod.execute_forward_backward("forward", {input, label});
55 ASSERT_TRUE(res.ok());
56
57 // Set up optimizer.
58 // Get the params and names
59 auto param_res = mod.named_parameters("forward");
60 ASSERT_EQ(param_res.error(), Error::Ok);
61
62 float orig_data = param_res.get().at("linear.weight").data_ptr<float>()[0];
63
64 SGDOptions options{0.1};
65 SGD optimizer(param_res.get(), options);
66
67 // Get the gradients
68 auto grad_res = mod.named_gradients("forward");
69 ASSERT_EQ(grad_res.error(), Error::Ok);
70 auto& grad = grad_res.get();
71 ASSERT_EQ(grad.size(), 2);
72 ASSERT_NE(grad.find("linear.weight"), grad.end());
73 ASSERT_NE(grad.find("linear.bias"), grad.end());
74
75 // Step
76 auto opt_err = optimizer.step(grad_res.get());
77 ASSERT_EQ(opt_err, Error::Ok);
78
79 // Check that the data has changed.
80 ASSERT_NE(
81 param_res.get().at("linear.weight").data_ptr<float>()[0], orig_data);
82 }
83