1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/training/module/training_module.h>
11
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/platform/runtime.h>
14 #include <gtest/gtest.h>
15
16 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
17
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::Error;
22 using torch::executor::Span;
23 using torch::executor::testing::TensorFactory;
24
25 class TrainingModuleTest : public ::testing::Test {
26 protected:
SetUp()27 void SetUp() override {
28 torch::executor::runtime_init();
29 }
30 };
31
TEST_F(TrainingModuleTest,JointGraphTest)32 TEST_F(TrainingModuleTest, JointGraphTest) {
33 // Create a loader for the serialized ModuleAdd program.
34 const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
35 executorch::runtime::Result<torch::executor::util::FileDataLoader>
36 loader_res = torch::executor::util::FileDataLoader::from(path);
37 ASSERT_EQ(loader_res.error(), Error::Ok);
38 auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
39 std::move(loader_res.get()));
40
41 auto mod = executorch::extension::training::TrainingModule(std::move(loader));
42
43 TensorFactory<ScalarType::Float> tf;
44 Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
45 Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
46
47 std::vector<executorch::runtime::EValue> inputs;
48 inputs.push_back(input);
49 inputs.push_back(label);
50
51 auto res = mod.execute_forward_backward("forward", inputs);
52 ASSERT_EQ(res.error(), Error::Ok);
53 ASSERT_EQ(res.get().size(), 1);
54
55 // Test Gradients
56 auto grad_res = mod.named_gradients("forward");
57 ASSERT_EQ(grad_res.error(), Error::Ok);
58 auto& grad = grad_res.get();
59 ASSERT_EQ(grad.size(), 2);
60 ASSERT_NE(grad.find("linear.weight"), grad.end());
61 ASSERT_NE(grad.find("linear.bias"), grad.end());
62
63 ASSERT_EQ(grad.find("linear.weight")->second.sizes()[0], 3);
64 ASSERT_EQ(grad.find("linear.weight")->second.sizes()[1], 3);
65 ASSERT_EQ(grad.find("linear.weight")->second.dim(), 2);
66 ASSERT_EQ(grad.find("linear.bias")->second.sizes()[0], 3);
67 ASSERT_EQ(grad.find("linear.bias")->second.dim(), 1);
68
69 // Test Parameters
70 auto param_res = mod.named_parameters("forward");
71 ASSERT_EQ(param_res.error(), Error::Ok);
72 auto& param = grad_res.get();
73 ASSERT_EQ(param.size(), 2);
74 ASSERT_NE(param.find("linear.weight"), grad.end());
75 ASSERT_NE(param.find("linear.bias"), grad.end());
76
77 ASSERT_EQ(param.find("linear.weight")->second.sizes()[0], 3);
78 ASSERT_EQ(param.find("linear.weight")->second.sizes()[1], 3);
79 ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
80 ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
81 ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
82 }
83
TEST_F(TrainingModuleTest,NonTrainingModuleTest)84 TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
85 // Create a loader for the serialized ModuleAdd program.
86 const char* path = std::getenv("ET_MODULE_ADD_PATH");
87 executorch::runtime::Result<torch::executor::util::FileDataLoader>
88 loader_res = torch::executor::util::FileDataLoader::from(path);
89 ASSERT_EQ(loader_res.error(), Error::Ok);
90 auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
91 std::move(loader_res.get()));
92
93 auto mod = executorch::extension::training::TrainingModule(std::move(loader));
94
95 TensorFactory<ScalarType::Float> tf;
96 Tensor input = tf.make({2, 2}, {1.0, 1.0, 1.0, 1.0});
97 Tensor input2 = tf.make({2, 2}, {1.0, 0.0, 0.0, 0.0});
98
99 std::vector<executorch::runtime::EValue> inputs;
100 inputs.push_back(input);
101 inputs.push_back(input2);
102
103 // Non-training module should fail to execute forward/backward as it cant find
104 // the gradients or params.
105 auto res = mod.execute_forward_backward("forward", inputs);
106 ASSERT_EQ(res.error(), Error::InvalidArgument);
107 }
108