1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17
18 #include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
19 #include "minddata/dataset/kernels/image/slice_patches_op.h"
20 #include "minddata/dataset/kernels/ir/validators.h"
21
22 namespace mindspore {
23 namespace dataset {
24 namespace vision {
25 // SlicePatchesOperation
SlicePatchesOperation(int32_t num_height,int32_t num_width,SliceMode slice_mode,uint8_t fill_value)26 SlicePatchesOperation::SlicePatchesOperation(int32_t num_height, int32_t num_width, SliceMode slice_mode,
27 uint8_t fill_value)
28 : TensorOperation(),
29 num_height_(num_height),
30 num_width_(num_width),
31 slice_mode_(slice_mode),
32 fill_value_(fill_value) {}
33
34 SlicePatchesOperation::~SlicePatchesOperation() = default;
35
Name() const36 std::string SlicePatchesOperation::Name() const { return kSlicePatchesOperation; }
37
ValidateParams()38 Status SlicePatchesOperation::ValidateParams() {
39 RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_height", num_height_));
40 RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_width", num_width_));
41 return Status::OK();
42 }
43
Build()44 std::shared_ptr<TensorOp> SlicePatchesOperation::Build() {
45 auto tensor_op = std::make_shared<SlicePatchesOp>(num_height_, num_width_, slice_mode_, fill_value_);
46 return tensor_op;
47 }
48
to_json(nlohmann::json * out_json)49 Status SlicePatchesOperation::to_json(nlohmann::json *out_json) {
50 nlohmann::json args;
51 args["num_height"] = num_height_;
52 args["num_width"] = num_width_;
53 args["slice_mode"] = slice_mode_;
54 args["fill_value"] = fill_value_;
55 *out_json = args;
56 return Status::OK();
57 }
58
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)59 Status SlicePatchesOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
60 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_height") != op_params.end(), "Failed to find num_height");
61 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_width") != op_params.end(), "Failed to find num_width");
62 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("slice_mode") != op_params.end(), "Failed to find slice_mode");
63 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("fill_value") != op_params.end(), "Failed to find fill_value");
64 int32_t num_height = op_params["num_height"];
65 int32_t num_width = op_params["num_width"];
66 SliceMode slice_mode = static_cast<SliceMode>(op_params["slice_mode"]);
67 uint8_t fill_value = op_params["fill_value"];
68 *operation = std::make_shared<vision::SlicePatchesOperation>(num_height, num_width, slice_mode, fill_value);
69 return Status::OK();
70 }
71 } // namespace vision
72 } // namespace dataset
73 } // namespace mindspore
74