• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/evalue_util/print_evalue.h>
11 #include <executorch/extension/training/module/training_module.h>
12 #include <executorch/extension/training/optimizer/sgd.h>
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/span.h>
16 #include <executorch/runtime/executor/method.h>
17 #include <executorch/runtime/executor/program.h>
18 #include <executorch/runtime/executor/test/managed_memory_manager.h>
19 #include <executorch/runtime/platform/runtime.h>
20 
21 #include <gtest/gtest.h>
22 #include <iostream>
23 
24 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
25 
26 using namespace ::testing;
27 using namespace executorch::extension::training::optimizer;
28 using namespace torch::executor::testing;
29 using exec_aten::ScalarType;
30 using exec_aten::Tensor;
31 using namespace torch::executor;
32 using torch::executor::util::FileDataLoader;
33 
34 class TrainingLoopTest : public ::testing::Test {
35  protected:
SetUp()36   void SetUp() override {}
37 };
38 
TEST_F(TrainingLoopTest,OptimizerSteps)39 TEST_F(TrainingLoopTest, OptimizerSteps) {
40   const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
41   executorch::runtime::Result<torch::executor::util::FileDataLoader>
42       loader_res = torch::executor::util::FileDataLoader::from(path);
43   ASSERT_EQ(loader_res.error(), Error::Ok);
44   auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
45       std::move(loader_res.get()));
46 
47   auto mod = executorch::extension::training::TrainingModule(std::move(loader));
48 
49   // Create inputs.
50   TensorFactory<ScalarType::Float> tf;
51   Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
52   Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
53 
54   auto res = mod.execute_forward_backward("forward", {input, label});
55   ASSERT_TRUE(res.ok());
56 
57   // Set up optimizer.
58   // Get the params and names
59   auto param_res = mod.named_parameters("forward");
60   ASSERT_EQ(param_res.error(), Error::Ok);
61 
62   float orig_data = param_res.get().at("linear.weight").data_ptr<float>()[0];
63 
64   SGDOptions options{0.1};
65   SGD optimizer(param_res.get(), options);
66 
67   // Get the gradients
68   auto grad_res = mod.named_gradients("forward");
69   ASSERT_EQ(grad_res.error(), Error::Ok);
70   auto& grad = grad_res.get();
71   ASSERT_EQ(grad.size(), 2);
72   ASSERT_NE(grad.find("linear.weight"), grad.end());
73   ASSERT_NE(grad.find("linear.bias"), grad.end());
74 
75   // Step
76   auto opt_err = optimizer.step(grad_res.get());
77   ASSERT_EQ(opt_err, Error::Ok);
78 
79   // Check that the data has changed.
80   ASSERT_NE(
81       param_res.get().at("linear.weight").data_ptr<float>()[0], orig_data);
82 }
83