1 #include <torch/nn/modules/loss.h>
2
3 namespace F = torch::nn::functional;
4
5 namespace torch {
6 namespace nn {
7
L1LossImpl(L1LossOptions options_)8 L1LossImpl::L1LossImpl(L1LossOptions options_) : options(std::move(options_)) {}
9
reset()10 void L1LossImpl::reset() {}
11
pretty_print(std::ostream & stream) const12 void L1LossImpl::pretty_print(std::ostream& stream) const {
13 stream << "torch::nn::L1Loss()";
14 }
15
forward(const Tensor & input,const Tensor & target)16 Tensor L1LossImpl::forward(const Tensor& input, const Tensor& target) {
17 return F::detail::l1_loss(input, target, options.reduction());
18 }
19
20 // ============================================================================
21
KLDivLossImpl(KLDivLossOptions options_)22 KLDivLossImpl::KLDivLossImpl(KLDivLossOptions options_)
23 : options(std::move(options_)) {}
24
reset()25 void KLDivLossImpl::reset() {}
26
pretty_print(std::ostream & stream) const27 void KLDivLossImpl::pretty_print(std::ostream& stream) const {
28 stream << "torch::nn::KLDivLoss()";
29 }
30
forward(const Tensor & input,const Tensor & target)31 Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) {
32 return F::detail::kl_div(
33 input, target, options.reduction(), options.log_target());
34 }
35
36 // ============================================================================
37
MSELossImpl(MSELossOptions options_)38 MSELossImpl::MSELossImpl(MSELossOptions options_)
39 : options(std::move(options_)) {}
40
reset()41 void MSELossImpl::reset() {}
42
pretty_print(std::ostream & stream) const43 void MSELossImpl::pretty_print(std::ostream& stream) const {
44 stream << "torch::nn::MSELoss()";
45 }
46
forward(const Tensor & input,const Tensor & target)47 Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) {
48 return F::detail::mse_loss(input, target, options.reduction());
49 }
50
51 // ============================================================================
52
BCELossImpl(BCELossOptions options_)53 BCELossImpl::BCELossImpl(BCELossOptions options_)
54 : options(std::move(options_)) {
55 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
56 reset();
57 }
58
reset()59 void BCELossImpl::reset() {
60 register_buffer("weight", options.weight());
61 }
62
pretty_print(std::ostream & stream) const63 void BCELossImpl::pretty_print(std::ostream& stream) const {
64 stream << "torch::nn::BCELoss()";
65 }
66
forward(const Tensor & input,const Tensor & target)67 Tensor BCELossImpl::forward(const Tensor& input, const Tensor& target) {
68 return F::detail::binary_cross_entropy(
69 input, target, options.weight(), options.reduction());
70 }
71
72 // ============================================================================
73
HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_)74 HingeEmbeddingLossImpl::HingeEmbeddingLossImpl(
75 HingeEmbeddingLossOptions options_)
76 : options(std::move(options_)) {}
77
reset()78 void HingeEmbeddingLossImpl::reset() {}
79
pretty_print(std::ostream & stream) const80 void HingeEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
81 stream << "torch::nn::HingeEmbeddingLoss(margin=" << options.margin() << ")";
82 }
83
forward(const Tensor & input,const Tensor & target)84 Tensor HingeEmbeddingLossImpl::forward(
85 const Tensor& input,
86 const Tensor& target) {
87 return F::detail::hinge_embedding_loss(
88 input, target, options.margin(), options.reduction());
89 }
90
91 // ============================================================================
92
MultiMarginLossImpl(MultiMarginLossOptions options_)93 MultiMarginLossImpl::MultiMarginLossImpl(MultiMarginLossOptions options_)
94 : options(std::move(options_)) {
95 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
96 reset();
97 }
98
reset()99 void MultiMarginLossImpl::reset() {
100 TORCH_CHECK(
101 (options.p() == 1) || (options.p() == 2),
102 "only p == 1 and p == 2 supported");
103 TORCH_CHECK(!options.weight().defined() || options.weight().dim() == 1);
104
105 register_buffer("weight", options.weight());
106 }
107
pretty_print(std::ostream & stream) const108 void MultiMarginLossImpl::pretty_print(std::ostream& stream) const {
109 stream << "torch::nn::MultiMarginLoss(p=" << options.p()
110 << ", margin=" << options.margin() << ", weight=" << options.weight()
111 << ", reduction=" << enumtype::get_enum_name(options.reduction())
112 << ")";
113 }
114
forward(const Tensor & input,const Tensor & target)115 Tensor MultiMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
116 return F::detail::multi_margin_loss(
117 input,
118 target,
119 options.p(),
120 options.margin(),
121 options.weight(),
122 options.reduction());
123 }
124
125 // ============================================================================
126
CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_)127 CosineEmbeddingLossImpl::CosineEmbeddingLossImpl(
128 CosineEmbeddingLossOptions options_)
129 : options(std::move(options_)) {}
130
reset()131 void CosineEmbeddingLossImpl::reset() {}
132
pretty_print(std::ostream & stream) const133 void CosineEmbeddingLossImpl::pretty_print(std::ostream& stream) const {
134 stream << "torch::nn::CosineEmbeddingLoss(margin=" << options.margin() << ")";
135 }
136
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)137 Tensor CosineEmbeddingLossImpl::forward(
138 const Tensor& input1,
139 const Tensor& input2,
140 const Tensor& target) {
141 return F::detail::cosine_embedding_loss(
142 input1, input2, target, options.margin(), options.reduction());
143 }
144 // ============================================================================
145
MultiLabelSoftMarginLossImpl(torch::nn::MultiLabelSoftMarginLossOptions options_)146 MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl(
147 torch::nn::MultiLabelSoftMarginLossOptions options_)
148 : options(std::move(options_)) {
149 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
150 reset();
151 }
152
pretty_print(std::ostream & stream) const153 void MultiLabelSoftMarginLossImpl::pretty_print(std::ostream& stream) const {
154 stream << "torch::nn::MultiLabelSoftMarginLoss()";
155 }
156
reset()157 void MultiLabelSoftMarginLossImpl::reset() {
158 register_buffer("weight", options.weight());
159 }
160
forward(const Tensor & input,const Tensor & target)161 Tensor MultiLabelSoftMarginLossImpl::forward(
162 const Tensor& input,
163 const Tensor& target) {
164 return F::detail::multilabel_soft_margin_loss(
165 input, target, options.weight(), options.reduction());
166 }
167
168 // ============================================================================
169
TripletMarginLossImpl(TripletMarginLossOptions options_)170 TripletMarginLossImpl::TripletMarginLossImpl(TripletMarginLossOptions options_)
171 : options(std::move(options_)) {}
172
reset()173 void TripletMarginLossImpl::reset() {}
174
pretty_print(std::ostream & stream) const175 void TripletMarginLossImpl::pretty_print(std::ostream& stream) const {
176 stream << "torch::nn::TripletMarginLoss(margin=" << options.margin()
177 << ", p=" << options.p() << ", eps=" << options.eps() << std::boolalpha
178 << ", swap=" << options.swap() << ")";
179 }
180
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)181 Tensor TripletMarginLossImpl::forward(
182 const Tensor& anchor,
183 const Tensor& positive,
184 const Tensor& negative) {
185 return F::detail::triplet_margin_loss(
186 anchor,
187 positive,
188 negative,
189 options.margin(),
190 options.p(),
191 options.eps(),
192 options.swap(),
193 options.reduction());
194 }
195
196 // ============================================================================
197
TripletMarginWithDistanceLossImpl(TripletMarginWithDistanceLossOptions options_)198 TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl(
199 TripletMarginWithDistanceLossOptions options_)
200 : options(std::move(options_)) {}
201
reset()202 void TripletMarginWithDistanceLossImpl::reset() {}
203
pretty_print(std::ostream & stream) const204 void TripletMarginWithDistanceLossImpl::pretty_print(
205 std::ostream& stream) const {
206 stream << "torch::nn::TripletMarginWithDistanceLoss(margin="
207 << options.margin() << std::boolalpha << ", swap=" << options.swap()
208 << ")";
209 }
210
forward(const Tensor & anchor,const Tensor & positive,const Tensor & negative)211 Tensor TripletMarginWithDistanceLossImpl::forward(
212 const Tensor& anchor,
213 const Tensor& positive,
214 const Tensor& negative) {
215 return F::detail::triplet_margin_with_distance_loss(
216 anchor,
217 positive,
218 negative,
219 options.distance_function(),
220 options.margin(),
221 options.swap(),
222 options.reduction());
223 }
224
225 // ============================================================================
226
MultiLabelMarginLossImpl(torch::nn::MultiLabelMarginLossOptions options_)227 MultiLabelMarginLossImpl::MultiLabelMarginLossImpl(
228 torch::nn::MultiLabelMarginLossOptions options_)
229 : options(std::move(options_)) {}
230
reset()231 void MultiLabelMarginLossImpl::reset() {}
232
pretty_print(std::ostream & stream) const233 void MultiLabelMarginLossImpl::pretty_print(std::ostream& stream) const {
234 stream << "torch::nn::MultiLabelMarginLoss()";
235 }
236
forward(const Tensor & input,const Tensor & target)237 Tensor MultiLabelMarginLossImpl::forward(
238 const Tensor& input,
239 const Tensor& target) {
240 return F::detail::multilabel_margin_loss(input, target, options.reduction());
241 }
242
243 // ============================================================================
244
SoftMarginLossImpl(torch::nn::SoftMarginLossOptions options_)245 SoftMarginLossImpl::SoftMarginLossImpl(
246 torch::nn::SoftMarginLossOptions options_)
247 : options(std::move(options_)) {}
248
reset()249 void SoftMarginLossImpl::reset() {}
250
pretty_print(std::ostream & stream) const251 void SoftMarginLossImpl::pretty_print(std::ostream& stream) const {
252 stream << "torch::nn::SoftMarginLoss()";
253 }
254
forward(const Tensor & input,const Tensor & target)255 Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
256 return F::detail::soft_margin_loss(input, target, options.reduction());
257 }
258
259 // ============================================================================
260
SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)261 SmoothL1LossImpl::SmoothL1LossImpl(torch::nn::SmoothL1LossOptions options_)
262 : options(std::move(options_)) {}
263
reset()264 void SmoothL1LossImpl::reset() {}
265
pretty_print(std::ostream & stream) const266 void SmoothL1LossImpl::pretty_print(std::ostream& stream) const {
267 stream << "torch::nn::SmoothL1Loss";
268 }
269
forward(const Tensor & input,const Tensor & target)270 Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) {
271 return F::detail::smooth_l1_loss(
272 input, target, options.reduction(), options.beta());
273 }
274
275 // ============================================================================
276
HuberLossImpl(torch::nn::HuberLossOptions options_)277 HuberLossImpl::HuberLossImpl(torch::nn::HuberLossOptions options_)
278 : options(std::move(options_)) {}
279
reset()280 void HuberLossImpl::reset() {}
281
pretty_print(std::ostream & stream) const282 void HuberLossImpl::pretty_print(std::ostream& stream) const {
283 stream << "torch::nn::HuberLoss";
284 }
285
forward(const Tensor & input,const Tensor & target)286 Tensor HuberLossImpl::forward(const Tensor& input, const Tensor& target) {
287 return F::detail::huber_loss(
288 input, target, options.reduction(), options.delta());
289 }
290
291 // ============================================================================
292
CTCLossImpl(CTCLossOptions options_)293 CTCLossImpl::CTCLossImpl(CTCLossOptions options_)
294 : options(std::move(options_)) {}
295
reset()296 void CTCLossImpl::reset() {}
297
pretty_print(std::ostream & stream) const298 void CTCLossImpl::pretty_print(std::ostream& stream) const {
299 stream << "torch::nn::CTCLoss()";
300 }
301
forward(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths)302 Tensor CTCLossImpl::forward(
303 const Tensor& log_probs,
304 const Tensor& targets,
305 const Tensor& input_lengths,
306 const Tensor& target_lengths) {
307 return F::detail::ctc_loss(
308 log_probs,
309 targets,
310 input_lengths,
311 target_lengths,
312 options.blank(),
313 options.reduction(),
314 options.zero_infinity());
315 }
316
317 // ============================================================================
318
PoissonNLLLossImpl(PoissonNLLLossOptions options_)319 PoissonNLLLossImpl::PoissonNLLLossImpl(PoissonNLLLossOptions options_)
320 : options(std::move(options_)) {}
321
reset()322 void PoissonNLLLossImpl::reset() {}
323
pretty_print(std::ostream & stream) const324 void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const {
325 stream << "torch::nn::PoissonNLLLoss()";
326 }
327
forward(const Tensor & log_input,const Tensor & target)328 Tensor PoissonNLLLossImpl::forward(
329 const Tensor& log_input,
330 const Tensor& target) {
331 return F::detail::poisson_nll_loss(
332 log_input,
333 target,
334 options.log_input(),
335 options.full(),
336 options.eps(),
337 options.reduction());
338 }
339
340 // ============================================================================
341
MarginRankingLossImpl(MarginRankingLossOptions options_)342 MarginRankingLossImpl::MarginRankingLossImpl(MarginRankingLossOptions options_)
343 : options(std::move(options_)) {}
344
reset()345 void MarginRankingLossImpl::reset() {}
346
pretty_print(std::ostream & stream) const347 void MarginRankingLossImpl::pretty_print(std::ostream& stream) const {
348 stream << "torch::nn::MarginRankingLoss()";
349 }
350
forward(const Tensor & input1,const Tensor & input2,const Tensor & target)351 Tensor MarginRankingLossImpl::forward(
352 const Tensor& input1,
353 const Tensor& input2,
354 const Tensor& target) {
355 return F::detail::margin_ranking_loss(
356 input1, input2, target, options.margin(), options.reduction());
357 }
358
359 // ============================================================================
360
NLLLossImpl(NLLLossOptions options_)361 NLLLossImpl::NLLLossImpl(NLLLossOptions options_)
362 : options(std::move(options_)) {
363 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
364 reset();
365 }
366
reset()367 void NLLLossImpl::reset() {
368 weight = register_buffer("weight", options.weight());
369 }
370
pretty_print(std::ostream & stream) const371 void NLLLossImpl::pretty_print(std::ostream& stream) const {
372 stream << "torch::nn::NLLLoss()";
373 }
374
forward(const Tensor & input,const Tensor & target)375 Tensor NLLLossImpl::forward(const Tensor& input, const Tensor& target) {
376 return F::detail::nll_loss(
377 input, target, weight, options.ignore_index(), options.reduction());
378 }
379
380 // ============================================================================
381
CrossEntropyLossImpl(CrossEntropyLossOptions options_)382 CrossEntropyLossImpl::CrossEntropyLossImpl(CrossEntropyLossOptions options_)
383 : options(std::move(options_)) {
384 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
385 reset();
386 }
387
reset()388 void CrossEntropyLossImpl::reset() {
389 weight = register_buffer("weight", options.weight());
390 }
391
pretty_print(std::ostream & stream) const392 void CrossEntropyLossImpl::pretty_print(std::ostream& stream) const {
393 stream << "torch::nn::CrossEntropyLoss()";
394 }
395
forward(const Tensor & input,const Tensor & target)396 Tensor CrossEntropyLossImpl::forward(
397 const Tensor& input,
398 const Tensor& target) {
399 return F::detail::cross_entropy(
400 input,
401 target,
402 weight,
403 options.ignore_index(),
404 options.reduction(),
405 options.label_smoothing());
406 }
407
408 // ============================================================================
409
BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)410 BCEWithLogitsLossImpl::BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_)
411 : options(std::move(options_)) {
412 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
413 reset();
414 }
415
reset()416 void BCEWithLogitsLossImpl::reset() {
417 weight = register_buffer("weight", options.weight());
418 pos_weight = register_buffer("pos_weight", options.pos_weight());
419 }
420
pretty_print(std::ostream & stream) const421 void BCEWithLogitsLossImpl::pretty_print(std::ostream& stream) const {
422 stream << "torch::nn::BCEWithLogitsLoss()";
423 }
424
forward(const Tensor & input,const Tensor & target)425 Tensor BCEWithLogitsLossImpl::forward(
426 const Tensor& input,
427 const Tensor& target) {
428 return F::detail::binary_cross_entropy_with_logits(
429 input,
430 target,
431 options.weight(),
432 options.reduction(),
433 options.pos_weight());
434 }
435
436 } // namespace nn
437 } // namespace torch
438