1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_init_mode.h"
17
18 #include <atomic>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/platform/mutex.h"
22
23 namespace tensorflow {
24
25 namespace {
26
27 mutex init_mode_mutex(LINKER_INITIALIZED);
28 TPUInitMode init_mode TF_GUARDED_BY(init_mode_mutex);
29
30 } // namespace
31
32 namespace test {
33
ForceSetTPUInitMode(const TPUInitMode mode)34 void ForceSetTPUInitMode(const TPUInitMode mode) {
35 mutex_lock l(init_mode_mutex);
36 init_mode = mode;
37 }
38
39 } // namespace test
40
SetTPUInitMode(const TPUInitMode mode)41 Status SetTPUInitMode(const TPUInitMode mode) {
42 if (mode == TPUInitMode::kNone) {
43 return errors::InvalidArgument("State cannot be set to: ",
44 static_cast<int>(mode));
45 }
46 {
47 mutex_lock l(init_mode_mutex);
48 if (init_mode != TPUInitMode::kNone && mode != init_mode) {
49 return errors::FailedPrecondition(
50 "TPUInit already attempted with mode: ", static_cast<int>(init_mode),
51 " and cannot be changed to: ", static_cast<int>(mode),
52 ". You are most probably trying to initialize the TPU system, both "
53 "using the explicit API and using an initialization Op within the "
54 "graph; please choose one. ");
55 }
56 init_mode = mode;
57 }
58 return Status::OK();
59 }
60
GetTPUInitMode()61 TPUInitMode GetTPUInitMode() {
62 mutex_lock l(init_mode_mutex);
63 return init_mode;
64 }
65
66 } // namespace tensorflow
67