1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <executorch/runtime/platform/assert.h>
11
12 #include <cstdint>
13 #include <cstring>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using exec_aten::Tensor;
20
21 namespace {
22
check_sizes(exec_aten::ArrayRef<int64_t> size_int64_t,exec_aten::ArrayRef<int32_t> size_int32_t)23 bool check_sizes(
24 exec_aten::ArrayRef<int64_t> size_int64_t,
25 exec_aten::ArrayRef<int32_t> size_int32_t) {
26 ET_LOG_AND_RETURN_IF_FALSE(size_int64_t.size() == size_int32_t.size());
27 for (int i = 0; i < size_int64_t.size(); i++) {
28 ET_LOG_AND_RETURN_IF_FALSE(((int64_t)size_int32_t[i] == size_int64_t[i]));
29 }
30
31 return true;
32 }
33
34 } // namespace
35
36 /*
37 * Zero the out tensor
38 *
39 * zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
40 */
zeros_out(KernelRuntimeContext & ctx,IntArrayRef size,Tensor & out)41 Tensor& zeros_out(KernelRuntimeContext& ctx, IntArrayRef size, Tensor& out) {
42 (void)ctx;
43
44 // Resize for dynamic shape
45 ET_KERNEL_CHECK_MSG(
46 ctx,
47 resize_tensor(out, size) == Error::Ok,
48 InvalidArgument,
49 out,
50 "Failed to resize output tensor.");
51
52 ET_KERNEL_CHECK(ctx, check_sizes(size, out.sizes()), InvalidArgument, out);
53
54 void* out_data = out.mutable_data_ptr();
55 if (out_data != nullptr) {
56 /*
57 * Assuming storage is contiguous and zero tensor is indeed a string of
58 * zeros
59 */
60 memset(out_data, 0, out.nbytes());
61 }
62 return out;
63 }
64
65 } // namespace native
66 } // namespace executor
67 } // namespace torch
68