1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/partial_run_mgr.h"
17
18 #include "tensorflow/core/lib/core/notification.h"
19 #include "tensorflow/core/platform/test.h"
20
21 namespace tensorflow {
22 namespace {
23
TEST(PartialRunMgrFindOrCreate,Create)24 TEST(PartialRunMgrFindOrCreate, Create) {
25 // Basic test of PartialRunMgr CancellationManager creation.
26 PartialRunMgr partial_run_mgr;
27 int step_id = 1;
28 CancellationManager* cancellation_manager;
29 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
30 EXPECT_TRUE(cancellation_manager != nullptr);
31 }
32
TEST(PartialRunMgrFindOrCreate,Find)33 TEST(PartialRunMgrFindOrCreate, Find) {
34 // Basic test of PartialRunMgr CancellationManager find.
35 PartialRunMgr partial_run_mgr;
36 int step_id = 1;
37 CancellationManager* cancellation_manager;
38 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
39 // Looking for the same step should return the same cancellation_manager.
40 CancellationManager* found_cancellation_manager;
41 partial_run_mgr.FindOrCreate(step_id, &found_cancellation_manager);
42 EXPECT_EQ(cancellation_manager, found_cancellation_manager);
43 }
44
TEST(PartialRunMgrFindOrCreate,NewCreate)45 TEST(PartialRunMgrFindOrCreate, NewCreate) {
46 // Test that PartialRunMgr creates a new CancellationManager for new steps.
47 PartialRunMgr partial_run_mgr;
48 int step_id = 1;
49 CancellationManager* cancellation_manager;
50 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
51 // FindOrCreate on a new step should return a new cancellation_manager.
52 int new_step_id = 2;
53 CancellationManager* new_cancellation_manager;
54 partial_run_mgr.FindOrCreate(new_step_id, &new_cancellation_manager);
55 EXPECT_NE(cancellation_manager, new_cancellation_manager);
56 }
57
TEST(PartialRunMgr,PartialRunRemoved)58 TEST(PartialRunMgr, PartialRunRemoved) {
59 // Test that PartialRunMgr ensures that the PartialRun is deleted after
60 // ExecutorDone and PartialRunDone are called.
61 PartialRunMgr partial_run_mgr;
62 int step_id = 1;
63 CancellationManager* cancellation_manager;
64 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
65
66 int called = 0;
67 partial_run_mgr.PartialRunDone(
68 step_id, [&called](Status status) { called++; }, Status::OK());
69 partial_run_mgr.ExecutorDone(step_id, Status::OK());
70
71 // Calling ExecutorDone and PartialRunDone on the step_id should still only
72 // result in the callback being called once.
73 // This proves that the original PartialRun has been removed.
74 partial_run_mgr.PartialRunDone(
75 step_id, [&called](Status status) { called++; }, Status::OK());
76 partial_run_mgr.ExecutorDone(step_id, Status::OK());
77 EXPECT_EQ(1, called);
78 }
79
80 struct StatusTestParam {
81 Status executor_status;
82 Status partial_run_status;
83 Status expected_status;
84 };
85
86 class StatusPropagationTest : public ::testing::TestWithParam<StatusTestParam> {
87 protected:
88 PartialRunMgr partial_run_mgr_;
89
90 // State to help keep track of when the callback is called.
91 Notification invoked_;
92 Status status_;
93
set_status(const Status & status)94 void set_status(const Status& status) {
95 status_ = status;
96 invoked_.Notify();
97 }
98
99 // Blocks until status is set.
status()100 Status status() {
101 invoked_.WaitForNotification();
102 return status_;
103 }
104 };
105
TEST_P(StatusPropagationTest,ExecutorDoneFirst)106 TEST_P(StatusPropagationTest, ExecutorDoneFirst) {
107 // Tests error propagation when ExecutorDone is called first.
108 StatusTestParam param = GetParam();
109 int step_id = 1;
110
111 CancellationManager* cancellation_manager;
112 partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager);
113
114 partial_run_mgr_.ExecutorDone(step_id, param.executor_status);
115 partial_run_mgr_.PartialRunDone(step_id,
116 [this](Status status) { set_status(status); },
117 param.partial_run_status);
118
119 EXPECT_EQ(status(), param.expected_status);
120 }
121
TEST_P(StatusPropagationTest,PartialRunDoneFirst)122 TEST_P(StatusPropagationTest, PartialRunDoneFirst) {
123 // Tests error propagation when PartialRunDone is called first.
124 StatusTestParam param = GetParam();
125 int step_id = 1;
126
127 CancellationManager* cancellation_manager;
128 partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager);
129
130 partial_run_mgr_.PartialRunDone(step_id,
131 [this](Status status) { set_status(status); },
132 param.partial_run_status);
133 partial_run_mgr_.ExecutorDone(step_id, param.executor_status);
134
135 EXPECT_EQ(status(), param.expected_status);
136 }
137
138 // Instantiate tests for all error orderings, for both call orders of
139 // ExecutorDone and PartialRunDone.
ExecutorError()140 Status ExecutorError() { return errors::Internal("executor error"); }
PartialRunError()141 Status PartialRunError() { return errors::Internal("partial run error"); }
142 INSTANTIATE_TEST_SUITE_P(
143 PartialRunMgr, StatusPropagationTest,
144 ::testing::Values(
145 StatusTestParam{Status::OK(), Status::OK(), Status::OK()},
146 StatusTestParam{ExecutorError(), Status::OK(), ExecutorError()},
147 StatusTestParam{Status::OK(), PartialRunError(), PartialRunError()},
148 StatusTestParam{ExecutorError(), PartialRunError(), ExecutorError()}));
149
150 } // namespace
151 } // namespace tensorflow
152