• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/training/module/training_module.h>
11 
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/platform/runtime.h>
14 #include <gtest/gtest.h>
15 
16 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
17 
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::Error;
22 using torch::executor::Span;
23 using torch::executor::testing::TensorFactory;
24 
25 class TrainingModuleTest : public ::testing::Test {
26  protected:
SetUp()27   void SetUp() override {
28     torch::executor::runtime_init();
29   }
30 };
31 
TEST_F(TrainingModuleTest,JointGraphTest)32 TEST_F(TrainingModuleTest, JointGraphTest) {
33   // Create a loader for the serialized ModuleAdd program.
34   const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
35   executorch::runtime::Result<torch::executor::util::FileDataLoader>
36       loader_res = torch::executor::util::FileDataLoader::from(path);
37   ASSERT_EQ(loader_res.error(), Error::Ok);
38   auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
39       std::move(loader_res.get()));
40 
41   auto mod = executorch::extension::training::TrainingModule(std::move(loader));
42 
43   TensorFactory<ScalarType::Float> tf;
44   Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
45   Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
46 
47   std::vector<executorch::runtime::EValue> inputs;
48   inputs.push_back(input);
49   inputs.push_back(label);
50 
51   auto res = mod.execute_forward_backward("forward", inputs);
52   ASSERT_EQ(res.error(), Error::Ok);
53   ASSERT_EQ(res.get().size(), 1);
54 
55   // Test Gradients
56   auto grad_res = mod.named_gradients("forward");
57   ASSERT_EQ(grad_res.error(), Error::Ok);
58   auto& grad = grad_res.get();
59   ASSERT_EQ(grad.size(), 2);
60   ASSERT_NE(grad.find("linear.weight"), grad.end());
61   ASSERT_NE(grad.find("linear.bias"), grad.end());
62 
63   ASSERT_EQ(grad.find("linear.weight")->second.sizes()[0], 3);
64   ASSERT_EQ(grad.find("linear.weight")->second.sizes()[1], 3);
65   ASSERT_EQ(grad.find("linear.weight")->second.dim(), 2);
66   ASSERT_EQ(grad.find("linear.bias")->second.sizes()[0], 3);
67   ASSERT_EQ(grad.find("linear.bias")->second.dim(), 1);
68 
69   // Test Parameters
70   auto param_res = mod.named_parameters("forward");
71   ASSERT_EQ(param_res.error(), Error::Ok);
72   auto& param = grad_res.get();
73   ASSERT_EQ(param.size(), 2);
74   ASSERT_NE(param.find("linear.weight"), grad.end());
75   ASSERT_NE(param.find("linear.bias"), grad.end());
76 
77   ASSERT_EQ(param.find("linear.weight")->second.sizes()[0], 3);
78   ASSERT_EQ(param.find("linear.weight")->second.sizes()[1], 3);
79   ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
80   ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
81   ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
82 }
83 
TEST_F(TrainingModuleTest,NonTrainingModuleTest)84 TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
85   // Create a loader for the serialized ModuleAdd program.
86   const char* path = std::getenv("ET_MODULE_ADD_PATH");
87   executorch::runtime::Result<torch::executor::util::FileDataLoader>
88       loader_res = torch::executor::util::FileDataLoader::from(path);
89   ASSERT_EQ(loader_res.error(), Error::Ok);
90   auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
91       std::move(loader_res.get()));
92 
93   auto mod = executorch::extension::training::TrainingModule(std::move(loader));
94 
95   TensorFactory<ScalarType::Float> tf;
96   Tensor input = tf.make({2, 2}, {1.0, 1.0, 1.0, 1.0});
97   Tensor input2 = tf.make({2, 2}, {1.0, 0.0, 0.0, 0.0});
98 
99   std::vector<executorch::runtime::EValue> inputs;
100   inputs.push_back(input);
101   inputs.push_back(input2);
102 
103   // Non-training module should fail to execute forward/backward as it cant find
104   // the gradients or params.
105   auto res = mod.execute_forward_backward("forward", inputs);
106   ASSERT_EQ(res.error(), Error::InvalidArgument);
107 }
108