• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdio>
17 #include <functional>
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/default_device.h"
25 #include "tensorflow/core/graph/graph_def_builder.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/init_main.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/public/session.h"
33 
34 using tensorflow::string;
35 using tensorflow::int32;
36 
37 namespace tensorflow {
38 namespace example {
39 
40 struct Options {
41   int num_concurrent_sessions = 1;   // The number of concurrent sessions
42   int num_concurrent_steps = 10;     // The number of concurrent steps
43   int num_iterations = 100;          // Each step repeats this many times
44   bool use_gpu = false;              // Whether to use gpu in the training
45 };
46 
47 // A = [3 2; -1 0]; x = rand(2, 1);
48 // We want to compute the largest eigenvalue for A.
49 // repeat x = y / y.norm(); y = A * x; end
CreateGraphDef()50 GraphDef CreateGraphDef() {
51   // TODO(jeff,opensource): This should really be a more interesting
52   // computation.  Maybe turn this into an mnist model instead?
53   Scope root = Scope::NewRootScope();
54   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
55 
56   // A = [3 2; -1 0].  Using Const<float> means the result will be a
57   // float tensor even though the initializer has integers.
58   auto a = Const<float>(root, {{3, 2}, {-1, 0}});
59 
60   // x = [1.0; 1.0]
61   auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
62 
63   // y = A * x
64   auto y = MatMul(root.WithOpName("y"), a, x);
65 
66   // y2 = y.^2
67   auto y2 = Square(root, y);
68 
69   // y2_sum = sum(y2).  Note that you can pass constants directly as
70   // inputs.  Sum() will automatically create a Const node to hold the
71   // 0 value.
72   auto y2_sum = Sum(root, y2, 0);
73 
74   // y_norm = sqrt(y2_sum)
75   auto y_norm = Sqrt(root, y2_sum);
76 
77   // y_normalized = y ./ y_norm
78   Div(root.WithOpName("y_normalized"), y, y_norm);
79 
80   GraphDef def;
81   TF_CHECK_OK(root.ToGraphDef(&def));
82 
83   return def;
84 }
85 
DebugString(const Tensor & x,const Tensor & y)86 string DebugString(const Tensor& x, const Tensor& y) {
87   CHECK_EQ(x.NumElements(), 2);
88   CHECK_EQ(y.NumElements(), 2);
89   auto x_flat = x.flat<float>();
90   auto y_flat = y.flat<float>();
91   // Compute an estimate of the eigenvalue via
92   //      (x' A x) / (x' x) = (x' y) / (x' x)
93   // and exploit the fact that x' x = 1 by assumption
94   Eigen::Tensor<float, 0, Eigen::RowMajor> lambda = (x_flat * y_flat).sum();
95   return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
96                          lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1));
97 }
98 
ConcurrentSteps(const Options * opts,int session_index)99 void ConcurrentSteps(const Options* opts, int session_index) {
100   // Creates a session.
101   SessionOptions options;
102   std::unique_ptr<Session> session(NewSession(options));
103   GraphDef def = CreateGraphDef();
104   if (options.target.empty()) {
105     graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);
106   }
107 
108   TF_CHECK_OK(session->Create(def));
109 
110   // Spawn M threads for M concurrent steps.
111   const int M = opts->num_concurrent_steps;
112   std::unique_ptr<thread::ThreadPool> step_threads(
113       new thread::ThreadPool(Env::Default(), "trainer", M));
114 
115   for (int step = 0; step < M; ++step) {
116     step_threads->Schedule([&session, opts, session_index, step]() {
117       // Randomly initialize the input.
118       Tensor x(DT_FLOAT, TensorShape({2, 1}));
119       auto x_flat = x.flat<float>();
120       x_flat.setRandom();
121       Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =
122           x_flat.square().sum().sqrt().inverse();
123       x_flat = x_flat * inv_norm();
124 
125       // Iterations.
126       std::vector<Tensor> outputs;
127       for (int iter = 0; iter < opts->num_iterations; ++iter) {
128         outputs.clear();
129         TF_CHECK_OK(
130             session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));
131         CHECK_EQ(size_t{2}, outputs.size());
132 
133         const Tensor& y = outputs[0];
134         const Tensor& y_norm = outputs[1];
135         // Print out lambda, x, and y.
136         std::printf("%06d/%06d %s\n", session_index, step,
137                     DebugString(x, y).c_str());
138         // Copies y_normalized to x.
139         x = y_norm;
140       }
141     });
142   }
143 
144   // Delete the threadpool, thus waiting for all threads to complete.
145   step_threads.reset(nullptr);
146   TF_CHECK_OK(session->Close());
147 }
148 
ConcurrentSessions(const Options & opts)149 void ConcurrentSessions(const Options& opts) {
150   // Spawn N threads for N concurrent sessions.
151   const int N = opts.num_concurrent_sessions;
152 
153   // At the moment our Session implementation only allows
154   // one concurrently computing Session on GPU.
155   CHECK_EQ(1, N) << "Currently can only have one concurrent session.";
156 
157   thread::ThreadPool session_threads(Env::Default(), "trainer", N);
158   for (int i = 0; i < N; ++i) {
159     session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
160   }
161 }
162 
163 }  // end namespace example
164 }  // end namespace tensorflow
165 
166 namespace {
167 
ParseInt32Flag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,int32 * dst)168 bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
169                     int32* dst) {
170   if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
171       tensorflow::str_util::ConsumePrefix(&arg, "=")) {
172     char extra;
173     return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
174   }
175 
176   return false;
177 }
178 
ParseBoolFlag(tensorflow::StringPiece arg,tensorflow::StringPiece flag,bool * dst)179 bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
180                    bool* dst) {
181   if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
182     if (arg.empty()) {
183       *dst = true;
184       return true;
185     }
186 
187     if (arg == "=true") {
188       *dst = true;
189       return true;
190     } else if (arg == "=false") {
191       *dst = false;
192       return true;
193     }
194   }
195 
196   return false;
197 }
198 
199 }  // namespace
200 
main(int argc,char * argv[])201 int main(int argc, char* argv[]) {
202   tensorflow::example::Options opts;
203   std::vector<char*> unknown_flags;
204   for (int i = 1; i < argc; ++i) {
205     if (string(argv[i]) == "--") {
206       while (i < argc) {
207         unknown_flags.push_back(argv[i]);
208         ++i;
209       }
210       break;
211     }
212 
213     if (ParseInt32Flag(argv[i], "--num_concurrent_sessions",
214                        &opts.num_concurrent_sessions) ||
215         ParseInt32Flag(argv[i], "--num_concurrent_steps",
216                        &opts.num_concurrent_steps) ||
217         ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) ||
218         ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) {
219       continue;
220     }
221 
222     fprintf(stderr, "Unknown flag: %s\n", argv[i]);
223     return -1;
224   }
225 
226   // Passthrough any unknown flags.
227   int dst = 1;  // Skip argv[0]
228   for (char* f : unknown_flags) {
229     argv[dst++] = f;
230   }
231   argv[dst++] = nullptr;
232   argc = static_cast<int>(unknown_flags.size() + 1);
233   tensorflow::port::InitMain(argv[0], &argc, &argv);
234   tensorflow::example::ConcurrentSessions(opts);
235 }
236