1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_FUNC_IMPL__FUSED_INFER_ATTENTION_SCORE_H_ 18 #define MINDSPORE_CORE_OPS_FUNC_IMPL__FUSED_INFER_ATTENTION_SCORE_H_ 19 #include <vector> 20 #include "ops/ops_func_impl/op_func_impl.h" 21 22 namespace mindspore { 23 namespace ops { 24 enum FusedInferAttentionScoreInputIndex : size_t { 25 kFusedInferAttentionScoreInputQueryIndex = 0, 26 kFusedInferAttentionScoreInputKeyIndex, 27 kFusedInferAttentionScoreInputValueIndex, 28 kFusedInferAttentionScoreInputPseShiftIndex, 29 kFusedInferAttentionScoreInputAttnMaskIndex, 30 kFusedInferAttentionScoreInputActualSeqLengthsIndex, 31 kFusedInferAttentionScoreInputActualSeqLengthsKvIndex, 32 kFusedInferAttentionScoreInputDequantScale1Index, 33 kFusedInferAttentionScoreInputQuantScale1Index, 34 kFusedInferAttentionScoreInputDequantScale2Index, 35 kFusedInferAttentionScoreInputQuantScale2Index, 36 kFusedInferAttentionScoreInputQuantOffset2Index, 37 kFusedInferAttentionScoreInputAntiquantScaleIndex, 38 kFusedInferAttentionScoreInputAntiquantOffsetIndex, 39 kFusedInferAttentionScoreInputBlockTableIndex, 40 kFusedInferAttentionScoreInputQueryPaddingSizeIndex, 41 kFusedInferAttentionScoreInputKvPaddingSizeIndex, 42 // attrs 43 kFusedInferAttentionScoreInputNumHeadsIndex, 44 kFusedInferAttentionScoreInputScaleIndex, 45 kFusedInferAttentionScoreInputPreTokensIndex, 46 kFusedInferAttentionScoreInputNextTokensIndex, 47 kFusedInferAttentionScoreInputLayoutIndex, 48 kFusedInferAttentionScoreInputNumKeyValueHeadsIndex, 49 kFusedInferAttentionScoreInputSparseModeIndex, 50 kFusedInferAttentionScoreInputInnerPreciseIndex, 51 kFusedInferAttentionScoreInputBlockSizeIndex, 52 kFusedInferAttentionScoreInputAntiquantModeIndex, 53 kFusedInferAttentionScoreInputSoftmaxLseFlagIndex, 54 kFusedInferAttentionScoreInputsNum, 55 }; 56 enum FusedInferAttentionScoreOutputIndex : size_t { 57 kFusedInferAttentionScoreOutputAttentionOutIndex = 0, 58 kFusedInferAttentionScoreOutputSoftmaxLseIndex, 59 kFusedInferAttentionScoreOutputsNum, 60 }; 61 62 class MIND_API FusedInferAttentionScoreFuncImpl : public OpFuncImpl { 63 public: 64 BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override; 65 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override; 66 }; 67 } // namespace ops 68 } // namespace mindspore 69 #endif // MINDSPORE_CORE_OPS_FUNC_IMPL_FUSED_INFER_ATTENTION_SCORE_H_ 70