• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 
18 #include "minddata/dataset/kernels/ir/vision/slice_patches_ir.h"
19 #include "minddata/dataset/kernels/image/slice_patches_op.h"
20 #include "minddata/dataset/kernels/ir/validators.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 namespace vision {
25 // SlicePatchesOperation
SlicePatchesOperation(int32_t num_height,int32_t num_width,SliceMode slice_mode,uint8_t fill_value)26 SlicePatchesOperation::SlicePatchesOperation(int32_t num_height, int32_t num_width, SliceMode slice_mode,
27                                              uint8_t fill_value)
28     : TensorOperation(),
29       num_height_(num_height),
30       num_width_(num_width),
31       slice_mode_(slice_mode),
32       fill_value_(fill_value) {}
33 
34 SlicePatchesOperation::~SlicePatchesOperation() = default;
35 
Name() const36 std::string SlicePatchesOperation::Name() const { return kSlicePatchesOperation; }
37 
ValidateParams()38 Status SlicePatchesOperation::ValidateParams() {
39   RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_height", num_height_));
40   RETURN_IF_NOT_OK(ValidateIntScalarPositive("SlicePatches", "num_width", num_width_));
41   return Status::OK();
42 }
43 
Build()44 std::shared_ptr<TensorOp> SlicePatchesOperation::Build() {
45   auto tensor_op = std::make_shared<SlicePatchesOp>(num_height_, num_width_, slice_mode_, fill_value_);
46   return tensor_op;
47 }
48 
to_json(nlohmann::json * out_json)49 Status SlicePatchesOperation::to_json(nlohmann::json *out_json) {
50   nlohmann::json args;
51   args["num_height"] = num_height_;
52   args["num_width"] = num_width_;
53   args["slice_mode"] = slice_mode_;
54   args["fill_value"] = fill_value_;
55   *out_json = args;
56   return Status::OK();
57 }
58 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)59 Status SlicePatchesOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
60   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_height") != op_params.end(), "Failed to find num_height");
61   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_width") != op_params.end(), "Failed to find num_width");
62   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("slice_mode") != op_params.end(), "Failed to find slice_mode");
63   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("fill_value") != op_params.end(), "Failed to find fill_value");
64   int32_t num_height = op_params["num_height"];
65   int32_t num_width = op_params["num_width"];
66   SliceMode slice_mode = static_cast<SliceMode>(op_params["slice_mode"]);
67   uint8_t fill_value = op_params["fill_value"];
68   *operation = std::make_shared<vision::SlicePatchesOperation>(num_height, num_width, slice_mode, fill_value);
69   return Status::OK();
70 }
71 }  // namespace vision
72 }  // namespace dataset
73 }  // namespace mindspore
74