/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "lang_id/common/flatbuffers/model-utils.h" #include #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/math/checksum.h" namespace libtextclassifier3 { namespace saft_fbs { namespace { // Returns true if we have clear evidence that |model| fails its checksum. // // E.g., if |model| has the crc32 field, and the value of that field does not // match the checksum, then this function returns true. If there is no crc32 // field, then we don't know what the original (at build time) checksum was, so // we don't know anything clear and this function returns false. bool ClearlyFailsChecksum(const Model &model) { if (!flatbuffers::IsFieldPresent(&model, Model::VT_CRC32)) { SAFTM_LOG(WARNING) << "No CRC32, most likely an old model; skip CRC32 check"; return false; } const mobile::uint32 expected_crc32 = model.crc32(); const mobile::uint32 actual_crc32 = ComputeCrc2Checksum(&model); if (actual_crc32 != expected_crc32) { SAFTM_LOG(ERROR) << "Corrupt model: different CRC32: " << actual_crc32 << " vs " << expected_crc32; return true; } SAFTM_LOG(INFO) << "Successfully checked CRC32 " << actual_crc32; return false; } } // namespace const Model *GetVerifiedModelFromBytes(const char *data, size_t num_bytes) { if ((data == nullptr) || (num_bytes == 0)) { SAFTM_LOG(ERROR) << "GetModel called on an empty sequence of bytes"; return nullptr; } const uint8_t *start = reinterpret_cast(data); flatbuffers::Verifier verifier(start, num_bytes); if (!VerifyModelBuffer(verifier)) { SAFTM_LOG(ERROR) << "Not a valid Model flatbuffer"; return nullptr; } const Model *model = GetModel(start); if (model == nullptr) { return nullptr; } if (ClearlyFailsChecksum(*model)) { return nullptr; } return model; } const ModelInput *GetInputByName(const Model *model, const string &name) { if (model == nullptr) { SAFTM_LOG(ERROR) << "GetInputByName called with model == nullptr"; return nullptr; } const auto *inputs = model->inputs(); if (inputs == nullptr) { // We should always have a list of inputs; maybe an empty one, if no inputs, // but the list should be there. SAFTM_LOG(ERROR) << "null inputs"; return nullptr; } for (const ModelInput *input : *inputs) { if (input != nullptr) { const flatbuffers::String *input_name = input->name(); if (input_name && input_name->str() == name) { return input; } } } return nullptr; } mobile::StringPiece GetInputBytes(const ModelInput *input) { if ((input == nullptr) || (input->data() == nullptr)) { SAFTM_LOG(ERROR) << "ModelInput has no content"; return mobile::StringPiece(nullptr, 0); } const flatbuffers::Vector *input_data = input->data(); if (input_data == nullptr) { SAFTM_LOG(ERROR) << "null input data"; return mobile::StringPiece(nullptr, 0); } return mobile::StringPiece(reinterpret_cast(input_data->data()), input_data->size()); } bool FillParameters(const Model &model, mobile::TaskContext *context) { if (context == nullptr) { SAFTM_LOG(ERROR) << "null context"; return false; } const auto *parameters = model.parameters(); if (parameters == nullptr) { // We should always have a list of parameters; maybe an empty one, if no // parameters, but the list should be there. SAFTM_LOG(ERROR) << "null list of parameters"; return false; } for (const ModelParameter *p : *parameters) { if (p == nullptr) { SAFTM_LOG(ERROR) << "null parameter"; return false; } if (p->name() == nullptr) { SAFTM_LOG(ERROR) << "null parameter name"; return false; } const string name = p->name()->str(); if (name.empty()) { SAFTM_LOG(ERROR) << "empty parameter name"; return false; } if (p->value() == nullptr) { SAFTM_LOG(ERROR) << "null parameter name"; return false; } context->SetParameter(name, p->value()->str()); } return true; } namespace { // Updates |*crc| with the information from |s|. Auxiliary for // ComputeCrc2Checksum. // // The bytes from |info| are also used to update the CRC32 checksum. |info| // should be a brief tag that indicates what |s| represents. The idea is to add // some structure to the information that goes into the CRC32 computation. template void UpdateCrc(mobile::Crc32 *crc, const flatbuffers::Vector *s, mobile::StringPiece info) { crc->Update("|"); crc->Update(info.data(), info.size()); crc->Update(":"); if (s == nullptr) { crc->Update("empty"); } else { crc->Update(reinterpret_cast(s->data()), s->size() * sizeof(T)); } } } // namespace mobile::uint32 ComputeCrc2Checksum(const Model *model) { // Implementation note: originally, I (salcianu@) thought we can just compute // a CRC32 checksum of the model bytes. Unfortunately, the expected checksum // is there too (and because we don't control the flatbuffer format, we can't // "arrange" for it to be placed at the head / tail of those bytes). Instead, // we traverse |model| and feed into the CRC32 computation those parts we are // interested in (which excludes the crc32 field). // // Note: storing the checksum outside the Model would be too disruptive for // the way we currently ship our models. mobile::Crc32 crc; if (model == nullptr) { return crc.Get(); } crc.Update("|Parameters:"); const auto *parameters = model->parameters(); if (parameters != nullptr) { for (const ModelParameter *p : *parameters) { if (p != nullptr) { UpdateCrc(&crc, p->name(), "name"); UpdateCrc(&crc, p->value(), "value"); } } } crc.Update("|Inputs:"); const auto *inputs = model->inputs(); if (inputs != nullptr) { for (const ModelInput *input : *inputs) { if (input != nullptr) { UpdateCrc(&crc, input->name(), "name"); UpdateCrc(&crc, input->type(), "type"); UpdateCrc(&crc, input->sub_type(), "sub-type"); UpdateCrc(&crc, input->data(), "data"); } } } return crc.Get(); } } // namespace saft_fbs } // namespace nlp_saft