1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18
19 namespace tensorflow {
20
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24
25 namespace {
26
CandidateSamplerShapeFn(InferenceContext * c)27 Status CandidateSamplerShapeFn(InferenceContext* c) {
28 int64 num_sampled;
29 TF_RETURN_IF_ERROR(c->GetAttr("num_sampled", &num_sampled));
30 int64 num_true;
31 TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
32
33 ShapeHandle true_classes_shape;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes_shape));
35 DimensionHandle batch_size = c->Dim(true_classes_shape, 0);
36
37 ShapeHandle num_sampled_v = c->Vector(num_sampled);
38 c->set_output(0, num_sampled_v);
39 c->set_output(1, c->Matrix(batch_size, num_true));
40 c->set_output(2, num_sampled_v);
41 return Status::OK();
42 }
43
44 } // namespace
45
46 REGISTER_OP("UniformCandidateSampler")
47 .Input("true_classes: int64")
48 .Output("sampled_candidates: int64")
49 .Output("true_expected_count: float")
50 .Output("sampled_expected_count: float")
51 .Attr("num_true: int >= 1")
52 .Attr("num_sampled: int >= 1")
53 .Attr("unique: bool")
54 .Attr("range_max: int >= 1")
55 .Attr("seed: int = 0")
56 .Attr("seed2: int = 0")
57 .SetShapeFn(CandidateSamplerShapeFn)
58 .SetIsStateful();
59
60 REGISTER_OP("LogUniformCandidateSampler")
61 .Input("true_classes: int64")
62 .Output("sampled_candidates: int64")
63 .Output("true_expected_count: float")
64 .Output("sampled_expected_count: float")
65 .Attr("num_true: int >= 1")
66 .Attr("num_sampled: int >= 1")
67 .Attr("unique: bool")
68 .Attr("range_max: int >= 1")
69 .Attr("seed: int = 0")
70 .Attr("seed2: int = 0")
71 .SetShapeFn(CandidateSamplerShapeFn)
72 .SetIsStateful();
73
74 REGISTER_OP("LearnedUnigramCandidateSampler")
75 .Input("true_classes: int64")
76 .Output("sampled_candidates: int64")
77 .Output("true_expected_count: float")
78 .Output("sampled_expected_count: float")
79 .Attr("num_true: int >= 1")
80 .Attr("num_sampled: int >= 1")
81 .Attr("unique: bool")
82 .Attr("range_max: int >= 1")
83 .Attr("seed: int = 0")
84 .Attr("seed2: int = 0")
85 .SetShapeFn(CandidateSamplerShapeFn)
86 .SetIsStateful();
87
88 REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
89 .Input("true_classes: int64")
90 .Output("sampled_candidates: int64")
91 .Output("true_expected_count: float")
92 .Output("sampled_expected_count: float")
93 .Attr("num_true: int >= 1")
94 .Attr("num_sampled: int >= 1")
95 .Attr("unique: bool")
96 .Attr("range_max: int >= 1")
97 .Attr("seed: int = 0")
98 .Attr("seed2: int = 0")
99 .SetShapeFn(CandidateSamplerShapeFn)
100 .SetIsStateful();
101
102 REGISTER_OP("FixedUnigramCandidateSampler")
103 .Input("true_classes: int64")
104 .Output("sampled_candidates: int64")
105 .Output("true_expected_count: float")
106 .Output("sampled_expected_count: float")
107 .Attr("num_true: int >= 1")
108 .Attr("num_sampled: int >= 1")
109 .Attr("unique: bool")
110 .Attr("range_max: int >= 1")
111 .Attr("vocab_file: string = ''")
112 .Attr("distortion: float = 1.0")
113 .Attr("num_reserved_ids: int = 0")
114 .Attr("num_shards: int >= 1 = 1")
115 .Attr("shard: int >= 0 = 0")
116 .Attr("unigrams: list(float) = []")
117 .Attr("seed: int = 0")
118 .Attr("seed2: int = 0")
119 .SetShapeFn(CandidateSamplerShapeFn)
120 .SetIsStateful();
121
122 REGISTER_OP("AllCandidateSampler")
123 .Input("true_classes: int64")
124 .Output("sampled_candidates: int64")
125 .Output("true_expected_count: float")
126 .Output("sampled_expected_count: float")
127 .Attr("num_true: int >= 1")
128 .Attr("num_sampled: int >= 1")
129 .Attr("unique: bool")
130 .Attr("seed: int = 0")
131 .Attr("seed2: int = 0")
132 .SetShapeFn(CandidateSamplerShapeFn)
133 .SetIsStateful();
134
135 REGISTER_OP("ComputeAccidentalHits")
136 .Input("true_classes: int64")
137 .Input("sampled_candidates: int64")
138 .Output("indices: int32")
139 .Output("ids: int64")
140 .Output("weights: float")
141 .Attr("num_true: int")
142 .Attr("seed: int = 0")
143 .Attr("seed2: int = 0")
__anon78463f730202(InferenceContext* c) 144 .SetShapeFn([](InferenceContext* c) {
145 int64 num_true;
146 TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
147
148 // Validate true_classes, must be a matrix.
149 ShapeHandle true_classes;
150 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
151 DimensionHandle unused;
152 TF_RETURN_IF_ERROR(
153 c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
154 // Validate sampled_candidates, must be a vector.
155 ShapeHandle sampled_candidates;
156 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates));
157
158 // All three outputs are the same shape.
159 ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);
160 c->set_output(0, v);
161 c->set_output(1, v);
162 c->set_output(2, v);
163 return Status::OK();
164 });
165
166 } // namespace tensorflow
167