• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <executorch/runtime/platform/assert.h>
11 
12 #include <cstdint>
13 #include <cstring>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using exec_aten::Tensor;
20 
21 namespace {
22 
check_sizes(exec_aten::ArrayRef<int64_t> size_int64_t,exec_aten::ArrayRef<int32_t> size_int32_t)23 bool check_sizes(
24     exec_aten::ArrayRef<int64_t> size_int64_t,
25     exec_aten::ArrayRef<int32_t> size_int32_t) {
26   ET_LOG_AND_RETURN_IF_FALSE(size_int64_t.size() == size_int32_t.size());
27   for (int i = 0; i < size_int64_t.size(); i++) {
28     ET_LOG_AND_RETURN_IF_FALSE(((int64_t)size_int32_t[i] == size_int64_t[i]));
29   }
30 
31   return true;
32 }
33 
34 } // namespace
35 
36 /*
37  * Zero the out tensor
38  *
39  * zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
40  */
zeros_out(KernelRuntimeContext & ctx,IntArrayRef size,Tensor & out)41 Tensor& zeros_out(KernelRuntimeContext& ctx, IntArrayRef size, Tensor& out) {
42   (void)ctx;
43 
44   // Resize for dynamic shape
45   ET_KERNEL_CHECK_MSG(
46       ctx,
47       resize_tensor(out, size) == Error::Ok,
48       InvalidArgument,
49       out,
50       "Failed to resize output tensor.");
51 
52   ET_KERNEL_CHECK(ctx, check_sizes(size, out.sizes()), InvalidArgument, out);
53 
54   void* out_data = out.mutable_data_ptr();
55   if (out_data != nullptr) {
56     /*
57      * Assuming storage is contiguous and zero tensor is indeed a string of
58      * zeros
59      */
60     memset(out_data, 0, out.nbytes());
61   }
62   return out;
63 }
64 
65 } // namespace native
66 } // namespace executor
67 } // namespace torch
68