• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/lib/core/stringpiece.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/random/distribution_sampler.h"
21 #include "tensorflow/core/lib/random/philox_random.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/core/util/guarded_philox_random.h"
26 
27 namespace tensorflow {
28 
29 // Number of examples to precalculate.
30 const int kPrecalc = 3000;
31 // Number of words to read into a sentence before processing.
32 const int kSentenceSize = 1000;
33 
34 namespace {
35 
ScanWord(StringPiece * input,string * word)36 bool ScanWord(StringPiece* input, string* word) {
37   str_util::RemoveLeadingWhitespace(input);
38   StringPiece tmp;
39   if (str_util::ConsumeNonWhitespace(input, &tmp)) {
40     word->assign(tmp.data(), tmp.size());
41     return true;
42   } else {
43     return false;
44   }
45 }
46 
47 }  // end namespace
48 
49 class SkipgramOp : public OpKernel {
50  public:
SkipgramOp(OpKernelConstruction * ctx)51   explicit SkipgramOp(OpKernelConstruction* ctx)
52       : OpKernel(ctx), rng_(&philox_) {
53     string filename;
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
55     OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
56     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
57     OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
58     OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
59     OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
60 
61     mutex_lock l(mu_);
62     example_pos_ = corpus_size_;
63     label_pos_ = corpus_size_;
64     label_limit_ = corpus_size_;
65     sentence_index_ = kSentenceSize;
66     for (int i = 0; i < kPrecalc; ++i) {
67       NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label);
68     }
69   }
70 
Compute(OpKernelContext * ctx)71   void Compute(OpKernelContext* ctx) override {
72     Tensor words_per_epoch(DT_INT64, TensorShape({}));
73     Tensor current_epoch(DT_INT32, TensorShape({}));
74     Tensor total_words_processed(DT_INT64, TensorShape({}));
75     Tensor examples(DT_INT32, TensorShape({batch_size_}));
76     auto Texamples = examples.flat<int32>();
77     Tensor labels(DT_INT32, TensorShape({batch_size_}));
78     auto Tlabels = labels.flat<int32>();
79     {
80       mutex_lock l(mu_);
81       for (int i = 0; i < batch_size_; ++i) {
82         Texamples(i) = precalc_examples_[precalc_index_].input;
83         Tlabels(i) = precalc_examples_[precalc_index_].label;
84         precalc_index_++;
85         if (precalc_index_ >= kPrecalc) {
86           precalc_index_ = 0;
87           for (int j = 0; j < kPrecalc; ++j) {
88             NextExample(&precalc_examples_[j].input,
89                         &precalc_examples_[j].label);
90           }
91         }
92       }
93       words_per_epoch.scalar<int64>()() = corpus_size_;
94       current_epoch.scalar<int32>()() = current_epoch_;
95       total_words_processed.scalar<int64>()() = total_words_processed_;
96     }
97     ctx->set_output(0, word_);
98     ctx->set_output(1, freq_);
99     ctx->set_output(2, words_per_epoch);
100     ctx->set_output(3, current_epoch);
101     ctx->set_output(4, total_words_processed);
102     ctx->set_output(5, examples);
103     ctx->set_output(6, labels);
104   }
105 
106  private:
107   struct Example {
108     int32 input;
109     int32 label;
110   };
111 
112   int32 batch_size_ = 0;
113   int32 window_size_ = 5;
114   float subsample_ = 1e-3;
115   int min_count_ = 5;
116   int32 vocab_size_ = 0;
117   Tensor word_;
118   Tensor freq_;
119   int64 corpus_size_ = 0;
120   std::vector<int32> corpus_;
121   std::vector<Example> precalc_examples_;
122   int precalc_index_ = 0;
123   std::vector<int32> sentence_;
124   int sentence_index_ = 0;
125 
126   mutex mu_;
127   random::PhiloxRandom philox_ TF_GUARDED_BY(mu_);
128   random::SimplePhilox rng_ TF_GUARDED_BY(mu_);
129   int32 current_epoch_ TF_GUARDED_BY(mu_) = -1;
130   int64 total_words_processed_ TF_GUARDED_BY(mu_) = 0;
131   int32 example_pos_ TF_GUARDED_BY(mu_);
132   int32 label_pos_ TF_GUARDED_BY(mu_);
133   int32 label_limit_ TF_GUARDED_BY(mu_);
134 
135   // {example_pos_, label_pos_} is the cursor for the next example.
136   // example_pos_ wraps around at the end of corpus_. For each
137   // example, we randomly generate [label_pos_, label_limit) for
138   // labels.
NextExample(int32 * example,int32 * label)139   void NextExample(int32* example, int32* label)
140       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
141     while (true) {
142       if (label_pos_ >= label_limit_) {
143         ++total_words_processed_;
144         ++sentence_index_;
145         if (sentence_index_ >= kSentenceSize) {
146           sentence_index_ = 0;
147           for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) {
148             if (example_pos_ >= corpus_size_) {
149               ++current_epoch_;
150               example_pos_ = 0;
151             }
152             if (subsample_ > 0) {
153               int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
154               // See Eq. 5 in http://arxiv.org/abs/1310.4546
155               float keep_prob =
156                   (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
157                   (subsample_ * corpus_size_) / word_freq;
158               if (rng_.RandFloat() > keep_prob) {
159                 i--;
160                 continue;
161               }
162             }
163             sentence_[i] = corpus_[example_pos_];
164           }
165         }
166         const int32 skip = 1 + rng_.Uniform(window_size_);
167         label_pos_ = std::max<int32>(0, sentence_index_ - skip);
168         label_limit_ =
169             std::min<int32>(kSentenceSize, sentence_index_ + skip + 1);
170       }
171       if (sentence_index_ != label_pos_) {
172         break;
173       }
174       ++label_pos_;
175     }
176     *example = sentence_[sentence_index_];
177     *label = sentence_[label_pos_++];
178   }
179 
Init(Env * env,const string & filename)180   Status Init(Env* env, const string& filename) {
181     string data;
182     TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
183     StringPiece input = data;
184     string w;
185     corpus_size_ = 0;
186     std::unordered_map<string, int32> word_freq;
187     while (ScanWord(&input, &w)) {
188       ++(word_freq[w]);
189       ++corpus_size_;
190     }
191     if (corpus_size_ < window_size_ * 10) {
192       return errors::InvalidArgument(
193           "The text file ", filename,
194           " contains too little data: ", corpus_size_, " words");
195     }
196     typedef std::pair<string, int32> WordFreq;
197     std::vector<WordFreq> ordered;
198     for (const auto& p : word_freq) {
199       if (p.second >= min_count_) ordered.push_back(p);
200     }
201     LOG(INFO) << "Data file: " << filename << " contains " << data.size()
202               << " bytes, " << corpus_size_ << " words, " << word_freq.size()
203               << " unique words, " << ordered.size()
204               << " unique frequent words.";
205     word_freq.clear();
206     std::sort(ordered.begin(), ordered.end(),
207               [](const WordFreq& x, const WordFreq& y) {
208                 return x.second > y.second;
209               });
210     vocab_size_ = static_cast<int32>(1 + ordered.size());
211     Tensor word(DT_STRING, TensorShape({vocab_size_}));
212     Tensor freq(DT_INT32, TensorShape({vocab_size_}));
213     word.flat<tstring>()(0) = "UNK";
214     static const int32 kUnkId = 0;
215     std::unordered_map<string, int32> word_id;
216     int64 total_counted = 0;
217     for (std::size_t i = 0; i < ordered.size(); ++i) {
218       const auto& w = ordered[i].first;
219       auto id = i + 1;
220       word.flat<tstring>()(id) = w;
221       auto word_count = ordered[i].second;
222       freq.flat<int32>()(id) = word_count;
223       total_counted += word_count;
224       word_id[w] = id;
225     }
226     freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
227     word_ = word;
228     freq_ = freq;
229     corpus_.reserve(corpus_size_);
230     input = data;
231     while (ScanWord(&input, &w)) {
232       corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
233     }
234     precalc_examples_.resize(kPrecalc);
235     sentence_.resize(kSentenceSize);
236     return Status::OK();
237   }
238 };
239 
240 REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp);
241 
242 class NegTrainOp : public OpKernel {
243  public:
NegTrainOp(OpKernelConstruction * ctx)244   explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
245     base_.Init(0, 0);
246 
247     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
248 
249     std::vector<int32> vocab_count;
250     OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
251 
252     std::vector<float> vocab_weights;
253     vocab_weights.reserve(vocab_count.size());
254     for (const auto& f : vocab_count) {
255       float r = std::pow(static_cast<float>(f), 0.75f);
256       vocab_weights.push_back(r);
257     }
258     sampler_ = new random::DistributionSampler(vocab_weights);
259   }
260 
~NegTrainOp()261   ~NegTrainOp() override { delete sampler_; }
262 
Compute(OpKernelContext * ctx)263   void Compute(OpKernelContext* ctx) override {
264     Tensor w_in = ctx->mutable_input(0, false);
265     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
266                 errors::InvalidArgument("Must be a matrix"));
267     Tensor w_out = ctx->mutable_input(1, false);
268     OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
269                 errors::InvalidArgument("w_in.shape == w_out.shape"));
270     const Tensor& examples = ctx->input(2);
271     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
272                 errors::InvalidArgument("Must be a vector"));
273     const Tensor& labels = ctx->input(3);
274     OP_REQUIRES(ctx, examples.shape() == labels.shape(),
275                 errors::InvalidArgument("examples.shape == labels.shape"));
276     const Tensor& learning_rate = ctx->input(4);
277     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
278                 errors::InvalidArgument("Must be a scalar"));
279 
280     auto Tw_in = w_in.matrix<float>();
281     auto Tw_out = w_out.matrix<float>();
282     auto Texamples = examples.flat<int32>();
283     auto Tlabels = labels.flat<int32>();
284     auto lr = learning_rate.scalar<float>()();
285     const int64 vocab_size = w_in.dim_size(0);
286     const int64 dims = w_in.dim_size(1);
287     const int64 batch_size = examples.dim_size(0);
288     OP_REQUIRES(ctx, vocab_size == sampler_->num(),
289                 errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
290                                         " vs. ", sampler_->num()));
291 
292     // Gradient accumulator for v_in.
293     Tensor buf(DT_FLOAT, TensorShape({dims}));
294     auto Tbuf = buf.flat<float>();
295 
296     // Scalar buffer to hold sigmoid(+/- dot).
297     Tensor g_buf(DT_FLOAT, TensorShape({}));
298     auto g = g_buf.scalar<float>();
299 
300     // The following loop needs 2 random 32-bit values per negative
301     // sample.  We reserve 8 values per sample just in case the
302     // underlying implementation changes.
303     auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
304     random::SimplePhilox srnd(&rnd);
305 
306     for (int64 i = 0; i < batch_size; ++i) {
307       const int32 example = Texamples(i);
308       DCHECK(0 <= example && example < vocab_size) << example;
309       const int32 label = Tlabels(i);
310       DCHECK(0 <= label && label < vocab_size) << label;
311       auto v_in = Tw_in.chip<0>(example);
312 
313       // Positive: example predicts label.
314       //   forward: x = v_in' * v_out
315       //            l = log(sigmoid(x))
316       //   backward: dl/dx = g = sigmoid(-x)
317       //             dl/d(v_in) = g * v_out'
318       //             dl/d(v_out) = v_in' * g
319       {
320         auto v_out = Tw_out.chip<0>(label);
321         auto dot = (v_in * v_out).sum();
322         g = (dot.exp() + 1.f).inverse();
323         Tbuf = v_out * (g() * lr);
324         v_out += v_in * (g() * lr);
325       }
326 
327       // Negative samples:
328       //   forward: x = v_in' * v_sample
329       //            l = log(sigmoid(-x))
330       //   backward: dl/dx = g = -sigmoid(x)
331       //             dl/d(v_in) = g * v_out'
332       //             dl/d(v_out) = v_in' * g
333       for (int j = 0; j < num_samples_; ++j) {
334         const int sample = sampler_->Sample(&srnd);
335         if (sample == label) continue;  // Skip.
336         auto v_sample = Tw_out.chip<0>(sample);
337         auto dot = (v_in * v_sample).sum();
338         g = -((-dot).exp() + 1.f).inverse();
339         Tbuf += v_sample * (g() * lr);
340         v_sample += v_in * (g() * lr);
341       }
342 
343       // Applies the gradient on v_in.
344       v_in += Tbuf;
345     }
346   }
347 
348  private:
349   int32 num_samples_ = 0;
350   random::DistributionSampler* sampler_ = nullptr;
351   GuardedPhiloxRandom base_;
352 };
353 
354 REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp);
355 
356 }  // end namespace tensorflow
357