• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/nn/modules/loss.h>
2 
3 namespace F = torch::nn::functional;
4 
5 namespace torch {
6 namespace nn {
7 
L1LossImpl(L1LossOptions options_)8 L1LossImpl::L1LossImpl(L1LossOptions options_) : options(std::move(options_)) {}
9 
reset()10 void L1LossImpl::reset() {}
11 
pretty_print(std::ostream & stream) const12 void L1LossImpl::pretty_print(std::ostream& stream) const {
13   stream << "torch::nn::L1Loss()";
14 }
15 
forward(const Tensor & input,const Tensor & target)16 Tensor L1LossImpl::forward(const Tensor& input, const Tensor& target) {
17   return F::detail::l1_loss(input, target, options.reduction());
18 }
19 
20 // ============================================================================
21 
KLDivLossImpl(KLDivLossOptions options_)22 KLDivLossImpl::KLDivLossImpl(KLDivLossOptions options_)
23     : options(std::move(options_)) {}
24 
reset()25 void KLDivLossImpl::reset() {}
26 
pretty_print(std::ostream & stream) const27 void KLDivLossImpl::pretty_print(std::ostream& stream) const {
28   stream << "torch::nn::KLDivLoss()";
29 }
30 
forward(const Tensor & input,const Tensor & target)31 Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) {
32   return F::detail::kl_div(
33       input, target, options.reduction(), options.log_target());
34 }
35 
36 // ============================================================================
37 
MSELossImpl(MSELossOptions options_)38 MSELossImpl::MSELossImpl(MSELossOptions options_)
39     : options(std::move(options_)) {}
40 
reset()41 void MSELossImpl::reset() {}
42 
pretty_print(std::ostream & stream) const43 void MSELossImpl::pretty_print(std::ostream& stream) const {
44   stream << "torch::nn::MSELoss()";
45 }
46 
forward(const Tensor & input,const Tensor & target)47 Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) {
48   return F::detail::mse_loss(input, target, options.reduction());
49 }
50 
51 // ============================================================================
52 
BCELossImpl(BCELossOptions options_)53 BCELossImpl::BCELossImpl(BCELossOptions options_)
54     : options(std::move(options_)) {
55   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
56   reset();
57 }
58 
reset()59 void BCELossImpl::reset() {
60   register_buffer("weight", options.weight());
61 }
62 
pretty_print(std::ostream & stream) const63 void BCELossImpl::pretty_print(std::ostream& stream) const {
64   stream << "torch::nn::BCELoss()";
65 }
66 
forward(const Tensor & input,const Tensor & target)67 Tensor BCELossImpl::forward(const Tensor& input, const Tensor& target) {
68   return F::detail::binary_cross_entropy(
69       input, target, options.weight(), options.reduction());
70 }
71 
72 // ============================================================================
73 
HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_)74 HingeEmbeddingLossImpl::HingeEmbeddingLossImpl(
75     HingeEmbeddingLossOptions options_)
76     : options(std::move(options_)) {}
77 
reset()78 void HingeEmbeddingLossImpl::reset() {}
79 
pretty_print(std::ostream & stream) const80 void HingeEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
81   stream << "torch::nn::HingeEmbeddingLoss(margin=" << options.margin() << ")";
82 }
83 
forward(const Tensor & input,const Tensor & target)84 Tensor HingeEmbeddingLossImpl::forward(
85     const Tensor& input,
86     const Tensor& target) {
87   return F::detail::hinge_embedding_loss(
88       input, target, options.margin(), options.reduction());
89 }
90 
91 // ============================================================================
92 
MultiMarginLossImpl(MultiMarginLossOptions options_)93 MultiMarginLossImpl::MultiMarginLossImpl(MultiMarginLossOptions options_)
94     : options(std::move(options_)) {
95   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
96   reset();
97 }
98 
reset()99 void MultiMarginLossImpl::reset() {
100   TORCH_CHECK(
101       (options.p() == 1) || (options.p() == 2),
102       "only p == 1 and p == 2 supported");
103   TORCH_CHECK(!options.weight().defined() || options.weight().dim() == 1);
104 
105   register_buffer("weight", options.weight());
106 }
107 
pretty_print(std::ostream & stream) const108 void MultiMarginLossImpl::pretty_print(std::ostream& stream) const {
109   stream << "torch::nn::MultiMarginLoss(p=" << options.p()
110          << ", margin=" << options.margin() << ", weight=" << options.weight()
111          << ", reduction=" << enumtype::get_enum_name(options.reduction())
112          << ")";
113 }
114 
forward(const Tensor & input,const Tensor & target)115 Tensor MultiMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
116   return F::detail::multi_margin_loss(
117       input,
118       target,
119       options.p(),
120       options.margin(),
121       options.weight(),
122       options.reduction());
123 }
124 
125 // ============================================================================
126 
CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_)127 CosineEmbeddingLossImpl::CosineEmbeddingLossImpl(
128     CosineEmbeddingLossOptions options_)
129     : options(std::move(options_)) {}
130 
reset()131 void CosineEmbeddingLossImpl::reset() {}
132 
pretty_print(std::ostream & stream) const133 void CosineEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
134   stream << "torch::nn::CosineEmbeddingLoss(margin=" << options.margin() << ")";
135 }
136 
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)137 Tensor CosineEmbeddingLossImpl::forward(
138     const Tensor& input1,
139     const Tensor& input2,
140     const Tensor& target) {
141   return F::detail::cosine_embedding_loss(
142       input1, input2, target, options.margin(), options.reduction());
143 }
144 // ============================================================================
145 
MultiLabelSoftMarginLossImpl(torch::nn::MultiLabelSoftMarginLossOptions options_)146 MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl(
147     torch::nn::MultiLabelSoftMarginLossOptions options_)
148     : options(std::move(options_)) {
149   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
150   reset();
151 }
152 
pretty_print(std::ostream & stream) const153 void MultiLabelSoftMarginLossImpl::pretty_print(std::ostream& stream) const {
154   stream << "torch::nn::MultiLabelSoftMarginLoss()";
155 }
156 
reset()157 void MultiLabelSoftMarginLossImpl::reset() {
158   register_buffer("weight", options.weight());
159 }
160 
forward(const Tensor & input,const Tensor & target)161 Tensor MultiLabelSoftMarginLossImpl::forward(
162     const Tensor& input,
163     const Tensor& target) {
164   return F::detail::multilabel_soft_margin_loss(
165       input, target, options.weight(), options.reduction());
166 }
167 
168 // ============================================================================
169 
TripletMarginLossImpl(TripletMarginLossOptions options_)170 TripletMarginLossImpl::TripletMarginLossImpl(TripletMarginLossOptions options_)
171     : options(std::move(options_)) {}
172 
reset()173 void TripletMarginLossImpl::reset() {}
174 
pretty_print(std::ostream & stream) const175 void TripletMarginLossImpl::pretty_print(std::ostream& stream) const {
176   stream << "torch::nn::TripletMarginLoss(margin=" << options.margin()
177          << ", p=" << options.p() << ", eps=" << options.eps() << std::boolalpha
178          << ", swap=" << options.swap() << ")";
179 }
180 
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)181 Tensor TripletMarginLossImpl::forward(
182     const Tensor& anchor,
183     const Tensor& positive,
184     const Tensor& negative) {
185   return F::detail::triplet_margin_loss(
186       anchor,
187       positive,
188       negative,
189       options.margin(),
190       options.p(),
191       options.eps(),
192       options.swap(),
193       options.reduction());
194 }
195 
196 // ============================================================================
197 
TripletMarginWithDistanceLossImpl(TripletMarginWithDistanceLossOptions options_)198 TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl(
199     TripletMarginWithDistanceLossOptions options_)
200     : options(std::move(options_)) {}
201 
reset()202 void TripletMarginWithDistanceLossImpl::reset() {}
203 
pretty_print(std::ostream & stream) const204 void TripletMarginWithDistanceLossImpl::pretty_print(
205     std::ostream& stream) const {
206   stream << "torch::nn::TripletMarginWithDistanceLoss(margin="
207          << options.margin() << std::boolalpha << ", swap=" << options.swap()
208          << ")";
209 }
210 
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)211 Tensor TripletMarginWithDistanceLossImpl::forward(
212     const Tensor& anchor,
213     const Tensor& positive,
214     const Tensor& negative) {
215   return F::detail::triplet_margin_with_distance_loss(
216       anchor,
217       positive,
218       negative,
219       options.distance_function(),
220       options.margin(),
221       options.swap(),
222       options.reduction());
223 }
224 
225 // ============================================================================
226 
MultiLabelMarginLossImpl(torch::nn::MultiLabelMarginLossOptions options_)227 MultiLabelMarginLossImpl::MultiLabelMarginLossImpl(
228     torch::nn::MultiLabelMarginLossOptions options_)
229     : options(std::move(options_)) {}
230 
reset()231 void MultiLabelMarginLossImpl::reset() {}
232 
pretty_print(std::ostream & stream) const233 void MultiLabelMarginLossImpl::pretty_print(std::ostream& stream) const {
234   stream << "torch::nn::MultiLabelMarginLoss()";
235 }
236 
forward(const Tensor & input,const Tensor & target)237 Tensor MultiLabelMarginLossImpl::forward(
238     const Tensor& input,
239     const Tensor& target) {
240   return F::detail::multilabel_margin_loss(input, target, options.reduction());
241 }
242 
243 // ============================================================================
244 
SoftMarginLossImpl(torch::nn::SoftMarginLossOptions options_)245 SoftMarginLossImpl::SoftMarginLossImpl(
246     torch::nn::SoftMarginLossOptions options_)
247     : options(std::move(options_)) {}
248 
reset()249 void SoftMarginLossImpl::reset() {}
250 
pretty_print(std::ostream & stream) const251 void SoftMarginLossImpl::pretty_print(std::ostream& stream) const {
252   stream << "torch::nn::SoftMarginLoss()";
253 }
254 
forward(const Tensor & input,const Tensor & target)255 Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
256   return F::detail::soft_margin_loss(input, target, options.reduction());
257 }
258 
259 // ============================================================================
260 
SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)261 SmoothL1LossImpl::SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)
262     : options(std::move(options_)) {}
263 
reset()264 void SmoothL1LossImpl::reset() {}
265 
pretty_print(std::ostream & stream) const266 void SmoothL1LossImpl::pretty_print(std::ostream& stream) const {
267   stream << "torch::nn::SmoothL1Loss";
268 }
269 
forward(const Tensor & input,const Tensor & target)270 Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) {
271   return F::detail::smooth_l1_loss(
272       input, target, options.reduction(), options.beta());
273 }
274 
275 // ============================================================================
276 
HuberLossImpl(torch::nn::HuberLossOptions options_)277 HuberLossImpl::HuberLossImpl(torch::nn::HuberLossOptions options_)
278     : options(std::move(options_)) {}
279 
reset()280 void HuberLossImpl::reset() {}
281 
pretty_print(std::ostream & stream) const282 void HuberLossImpl::pretty_print(std::ostream& stream) const {
283   stream << "torch::nn::HuberLoss";
284 }
285 
forward(const Tensor & input,const Tensor & target)286 Tensor HuberLossImpl::forward(const Tensor& input, const Tensor& target) {
287   return F::detail::huber_loss(
288       input, target, options.reduction(), options.delta());
289 }
290 
291 // ============================================================================
292 
CTCLossImpl(CTCLossOptions options_)293 CTCLossImpl::CTCLossImpl(CTCLossOptions options_)
294     : options(std::move(options_)) {}
295 
reset()296 void CTCLossImpl::reset() {}
297 
pretty_print(std::ostream & stream) const298 void CTCLossImpl::pretty_print(std::ostream& stream) const {
299   stream << "torch::nn::CTCLoss()";
300 }
301 
forward(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths)302 Tensor CTCLossImpl::forward(
303     const Tensor& log_probs,
304     const Tensor& targets,
305     const Tensor& input_lengths,
306     const Tensor& target_lengths) {
307   return F::detail::ctc_loss(
308       log_probs,
309       targets,
310       input_lengths,
311       target_lengths,
312       options.blank(),
313       options.reduction(),
314       options.zero_infinity());
315 }
316 
317 // ============================================================================
318 
PoissonNLLLossImpl(PoissonNLLLossOptions options_)319 PoissonNLLLossImpl::PoissonNLLLossImpl(PoissonNLLLossOptions options_)
320     : options(std::move(options_)) {}
321 
reset()322 void PoissonNLLLossImpl::reset() {}
323 
pretty_print(std::ostream & stream) const324 void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const {
325   stream << "torch::nn::PoissonNLLLoss()";
326 }
327 
forward(const Tensor & log_input,const Tensor & target)328 Tensor PoissonNLLLossImpl::forward(
329     const Tensor& log_input,
330     const Tensor& target) {
331   return F::detail::poisson_nll_loss(
332       log_input,
333       target,
334       options.log_input(),
335       options.full(),
336       options.eps(),
337       options.reduction());
338 }
339 
340 // ============================================================================
341 
MarginRankingLossImpl(MarginRankingLossOptions options_)342 MarginRankingLossImpl::MarginRankingLossImpl(MarginRankingLossOptions options_)
343     : options(std::move(options_)) {}
344 
reset()345 void MarginRankingLossImpl::reset() {}
346 
pretty_print(std::ostream & stream) const347 void MarginRankingLossImpl::pretty_print(std::ostream& stream) const {
348   stream << "torch::nn::MarginRankingLoss()";
349 }
350 
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)351 Tensor MarginRankingLossImpl::forward(
352     const Tensor& input1,
353     const Tensor& input2,
354     const Tensor& target) {
355   return F::detail::margin_ranking_loss(
356       input1, input2, target, options.margin(), options.reduction());
357 }
358 
359 // ============================================================================
360 
NLLLossImpl(NLLLossOptions options_)361 NLLLossImpl::NLLLossImpl(NLLLossOptions options_)
362     : options(std::move(options_)) {
363   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
364   reset();
365 }
366 
reset()367 void NLLLossImpl::reset() {
368   weight = register_buffer("weight", options.weight());
369 }
370 
pretty_print(std::ostream & stream) const371 void NLLLossImpl::pretty_print(std::ostream& stream) const {
372   stream << "torch::nn::NLLLoss()";
373 }
374 
forward(const Tensor & input,const Tensor & target)375 Tensor NLLLossImpl::forward(const Tensor& input, const Tensor& target) {
376   return F::detail::nll_loss(
377       input, target, weight, options.ignore_index(), options.reduction());
378 }
379 
380 // ============================================================================
381 
CrossEntropyLossImpl(CrossEntropyLossOptions options_)382 CrossEntropyLossImpl::CrossEntropyLossImpl(CrossEntropyLossOptions options_)
383     : options(std::move(options_)) {
384   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
385   reset();
386 }
387 
reset()388 void CrossEntropyLossImpl::reset() {
389   weight = register_buffer("weight", options.weight());
390 }
391 
pretty_print(std::ostream & stream) const392 void CrossEntropyLossImpl::pretty_print(std::ostream& stream) const {
393   stream << "torch::nn::CrossEntropyLoss()";
394 }
395 
forward(const Tensor & input,const Tensor & target)396 Tensor CrossEntropyLossImpl::forward(
397     const Tensor& input,
398     const Tensor& target) {
399   return F::detail::cross_entropy(
400       input,
401       target,
402       weight,
403       options.ignore_index(),
404       options.reduction(),
405       options.label_smoothing());
406 }
407 
408 // ============================================================================
409 
BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)410 BCEWithLogitsLossImpl::BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)
411     : options(std::move(options_)) {
412   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
413   reset();
414 }
415 
reset()416 void BCEWithLogitsLossImpl::reset() {
417   weight = register_buffer("weight", options.weight());
418   pos_weight = register_buffer("pos_weight", options.pos_weight());
419 }
420 
pretty_print(std::ostream & stream) const421 void BCEWithLogitsLossImpl::pretty_print(std::ostream& stream) const {
422   stream << "torch::nn::BCEWithLogitsLoss()";
423 }
424 
forward(const Tensor & input,const Tensor & target)425 Tensor BCEWithLogitsLossImpl::forward(
426     const Tensor& input,
427     const Tensor& target) {
428   return F::detail::binary_cross_entropy_with_logits(
429       input,
430       target,
431       options.weight(),
432       options.reduction(),
433       options.pos_weight());
434 }
435 
436 } // namespace nn
437 } // namespace torch
438