• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Support for training models.
17
18See the [Training](https://tensorflow.org/api_guides/python/train) guide.
19"""
20
21# Optimizers.
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26# pylint: disable=g-bad-import-order,unused-import
27from tensorflow.python.ops.sdca_ops import sdca_optimizer
28from tensorflow.python.ops.sdca_ops import sdca_fprint
29from tensorflow.python.ops.sdca_ops import sdca_shrink_l1
30from tensorflow.python.training.adadelta import AdadeltaOptimizer
31from tensorflow.python.training.adagrad import AdagradOptimizer
32from tensorflow.python.training.adagrad_da import AdagradDAOptimizer
33from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer
34from tensorflow.python.training.adam import AdamOptimizer
35from tensorflow.python.training.ftrl import FtrlOptimizer
36from tensorflow.python.training.momentum import MomentumOptimizer
37from tensorflow.python.training.moving_averages import ExponentialMovingAverage
38from tensorflow.python.training.optimizer import Optimizer
39from tensorflow.python.training.rmsprop import RMSPropOptimizer
40from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
41from tensorflow.python.training.proximal_gradient_descent import ProximalGradientDescentOptimizer
42from tensorflow.python.training.sync_replicas_optimizer import SyncReplicasOptimizer
43
44# Utility classes for training.
45from tensorflow.python.training.coordinator import Coordinator
46from tensorflow.python.training.coordinator import LooperThread
47# go/tf-wildcard-import
48# pylint: disable=wildcard-import
49from tensorflow.python.training.queue_runner import *
50
51# For the module level doc.
52from tensorflow.python.training import input as _input
53from tensorflow.python.training.input import *  # pylint: disable=redefined-builtin
54# pylint: enable=wildcard-import
55
56from tensorflow.python.training.basic_session_run_hooks import get_or_create_steps_per_run_variable
57from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
58from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
59from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
60from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook
61from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener
62from tensorflow.python.training.basic_session_run_hooks import StepCounterHook
63from tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError
64from tensorflow.python.training.basic_session_run_hooks import NanTensorHook
65from tensorflow.python.training.basic_session_run_hooks import SummarySaverHook
66from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook
67from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
68from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
69from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
70from tensorflow.python.training.basic_loops import basic_train_loop
71from tensorflow.python.training.tracking.python_state import PythonState
72from tensorflow.python.training.tracking.util import Checkpoint
73from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
74from tensorflow.python.training.checkpoint_utils import list_variables
75from tensorflow.python.training.checkpoint_utils import load_checkpoint
76from tensorflow.python.training.checkpoint_utils import load_variable
77
78from tensorflow.python.training.device_setter import replica_device_setter
79from tensorflow.python.training.monitored_session import Scaffold
80from tensorflow.python.training.monitored_session import MonitoredTrainingSession
81from tensorflow.python.training.monitored_session import SessionCreator
82from tensorflow.python.training.monitored_session import ChiefSessionCreator
83from tensorflow.python.training.monitored_session import WorkerSessionCreator
84from tensorflow.python.training.monitored_session import MonitoredSession
85from tensorflow.python.training.monitored_session import SingularMonitoredSession
86from tensorflow.python.training.saver import Saver
87from tensorflow.python.training.checkpoint_management import checkpoint_exists
88from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto
89from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes
90from tensorflow.python.training.checkpoint_management import get_checkpoint_state
91from tensorflow.python.training.checkpoint_management import latest_checkpoint
92from tensorflow.python.training.checkpoint_management import update_checkpoint_state
93from tensorflow.python.training.saver import export_meta_graph
94from tensorflow.python.training.saver import import_meta_graph
95from tensorflow.python.training.session_run_hook import SessionRunHook
96from tensorflow.python.training.session_run_hook import SessionRunArgs
97from tensorflow.python.training.session_run_hook import SessionRunContext
98from tensorflow.python.training.session_run_hook import SessionRunValues
99from tensorflow.python.training.session_manager import SessionManager
100from tensorflow.python.training.summary_io import summary_iterator
101from tensorflow.python.training.supervisor import Supervisor
102from tensorflow.python.training.training_util import write_graph
103from tensorflow.python.training.training_util import global_step
104from tensorflow.python.training.training_util import get_global_step
105from tensorflow.python.training.training_util import assert_global_step
106from tensorflow.python.training.training_util import create_global_step
107from tensorflow.python.training.training_util import get_or_create_global_step
108from tensorflow.python.training.warm_starting_util import VocabInfo
109from tensorflow.python.training.warm_starting_util import warm_start
110from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
111from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
112from tensorflow.python.util.tf_export import tf_export
113
114# pylint: disable=wildcard-import
115# Training data protos.
116from tensorflow.core.example.example_pb2 import *
117from tensorflow.core.example.feature_pb2 import *
118from tensorflow.core.protobuf.saver_pb2 import *
119
120# Utility op.  Open Source. TODO(touts): move to nn?
121from tensorflow.python.training.learning_rate_decay import *
122# pylint: enable=wildcard-import
123
124# Distributed computing support.
125from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
126from tensorflow.core.protobuf.cluster_pb2 import JobDef
127from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
128from tensorflow.python.training.server_lib import ClusterSpec
129from tensorflow.python.training.server_lib import Server
130
131# pylint: disable=undefined-variable
132tf_export("train.BytesList")(BytesList)
133tf_export("train.ClusterDef")(ClusterDef)
134tf_export("train.Example")(Example)
135tf_export("train.Feature")(Feature)
136tf_export("train.Features")(Features)
137tf_export("train.FeatureList")(FeatureList)
138tf_export("train.FeatureLists")(FeatureLists)
139tf_export("train.FloatList")(FloatList)
140tf_export("train.Int64List")(Int64List)
141tf_export("train.JobDef")(JobDef)
142tf_export(v1=["train.SaverDef"])(SaverDef)
143tf_export("train.SequenceExample")(SequenceExample)
144tf_export("train.ServerDef")(ServerDef)
145# pylint: enable=undefined-variable
146
147