/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.ext.services.resolver;

import android.content.Context;
import android.content.Intent;
import android.content.SharedPreferences;
import android.os.Environment;
import android.os.IBinder;
import android.os.UserManager;
import android.os.storage.StorageManager;
import android.service.resolver.ResolverRankerService;
import android.service.resolver.ResolverTarget;
import android.util.ArrayMap;
import android.util.Log;

import androidx.annotation.VisibleForTesting;
import androidx.core.util.Preconditions;

import java.io.File;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
 * A Logistic Regression based {@link android.service.resolver.ResolverRankerService}, to be used
 * in {@link ResolverComparator}.
 */
public final class LRResolverRankerService extends ResolverRankerService {
    private static final String TAG = "LRResolverRankerService";

    private static final boolean DEBUG = false;

    private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
    private static final String BIAS_PREF_KEY = "bias";
    private static final String VERSION_PREF_KEY = "version";

    private static final String LAUNCH_SCORE = "launch";
    private static final String TIME_SPENT_SCORE = "timeSpent";
    private static final String RECENCY_SCORE = "recency";
    private static final String CHOOSER_SCORE = "chooser";

    // parameters for a pre-trained model, to initialize the app ranker. When updating the
    // pre-trained model, please update these params, as well as initModel().
    private static final int CURRENT_VERSION = 1;
    private static final float LEARNING_RATE = 0.0001f;
    private static final float REGULARIZER_PARAM = 0.0001f;

    private SharedPreferences mParamSharedPref;
    private float mBias;
    private boolean mInitModelDone;

    @VisibleForTesting
    ArrayMap<String, Float> mFeatureWeights;

    @Override
    public IBinder onBind(Intent intent) {
        initModel();
        return super.onBind(intent);
    }

    @Override
    public void onPredictSharingProbabilities(List<ResolverTarget> targets) {
        Preconditions.checkState(initModel(), "Service is not ready yet");

        final int size = targets.size();
        for (int i = 0; i < size; ++i) {
            ResolverTarget target = targets.get(i);
            ArrayMap<String, Float> features = getFeatures(target);
            target.setSelectProbability(predict(features));
        }
    }

    @Override
    public void onTrainRankingModel(List<ResolverTarget> targets, int selectedPosition) {
        Preconditions.checkState(initModel(), "Service is not ready yet");

        final int size = targets.size();
        if (selectedPosition < 0 || selectedPosition >= size) {
            if (DEBUG) {
                Log.d(TAG, "Invalid Position of Selected App " + selectedPosition);
            }
            return;
        }
        final ArrayMap<String, Float> positive = getFeatures(targets.get(selectedPosition));
        final float positiveProbability = targets.get(selectedPosition).getSelectProbability();
        final int targetSize = targets.size();
        for (int i = 0; i < targetSize; ++i) {
            if (i == selectedPosition) {
                continue;
            }
            final ArrayMap<String, Float> negative = getFeatures(targets.get(i));
            final float negativeProbability = targets.get(i).getSelectProbability();
            if (negativeProbability > positiveProbability) {
                update(negative, negativeProbability, false);
                update(positive, positiveProbability, true);
            }
        }
        commitUpdate();
    }

    // This is not thread safe, but ResolverRankerService has added the protection to call into it
    // in the same Handler.
    private boolean initModel() {
        if (mInitModelDone) {
            return true;
        }
        final UserManager userManager = (UserManager) getSystemService(Context.USER_SERVICE);
        if (userManager == null || !userManager.isUserUnlocked()) {
            return false;
        }
        mParamSharedPref = getParamSharedPref();
        mFeatureWeights = new ArrayMap<>(4);
        if (mParamSharedPref == null ||
                mParamSharedPref.getInt(VERSION_PREF_KEY, 0) < CURRENT_VERSION) {
            // Initializing the app ranker to a pre-trained model. When updating the pre-trained
            // model, please increment CURRENT_VERSION, and update LEARNING_RATE and
            // REGULARIZER_PARAM.
            mBias = -1.6568f;
            mFeatureWeights.put(LAUNCH_SCORE, 2.5543f);
            mFeatureWeights.put(TIME_SPENT_SCORE, 2.8412f);
            mFeatureWeights.put(RECENCY_SCORE, 0.269f);
            mFeatureWeights.put(CHOOSER_SCORE, 4.2222f);
        } else {
            mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
            mFeatureWeights.put(LAUNCH_SCORE, mParamSharedPref.getFloat(LAUNCH_SCORE, 0.0f));
            mFeatureWeights.put(
                    TIME_SPENT_SCORE, mParamSharedPref.getFloat(TIME_SPENT_SCORE, 0.0f));
            mFeatureWeights.put(RECENCY_SCORE, mParamSharedPref.getFloat(RECENCY_SCORE, 0.0f));
            mFeatureWeights.put(CHOOSER_SCORE, mParamSharedPref.getFloat(CHOOSER_SCORE, 0.0f));
        }
        mInitModelDone = true;
        return true;
    }

    private ArrayMap<String, Float> getFeatures(ResolverTarget target) {
        ArrayMap<String, Float> features = new ArrayMap<>(4);
        features.put(RECENCY_SCORE, target.getRecencyScore());
        features.put(TIME_SPENT_SCORE, target.getTimeSpentScore());
        features.put(LAUNCH_SCORE, target.getLaunchScore());
        features.put(CHOOSER_SCORE, target.getChooserScore());
        return features;
    }

    private float predict(ArrayMap<String, Float> target) {
        if (target == null) {
            return 0.0f;
        }
        final int featureSize = target.size();
        float sum = 0.0f;
        for (int i = 0; i < featureSize; i++) {
            String featureName = target.keyAt(i);
            float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
            sum += weight * target.valueAt(i);
        }
        return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
    }

    private void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
        if (target == null) {
            return;
        }
        final int featureSize = target.size();
        float error = isSelected ? 1.0f - predict : -predict;
        for (int i = 0; i < featureSize; i++) {
            String featureName = target.keyAt(i);
            float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
            mBias += LEARNING_RATE * error;
            currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
                    LEARNING_RATE * error * target.valueAt(i);
            mFeatureWeights.put(featureName, currentWeight);
        }
        if (DEBUG) {
            Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
        }
    }

    private void commitUpdate() {
        try {
            SharedPreferences.Editor editor = mParamSharedPref.edit();
            editor.putFloat(BIAS_PREF_KEY, mBias);
            final int size = mFeatureWeights.size();
            for (int i = 0; i < size; i++) {
                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
            }
            editor.putInt(VERSION_PREF_KEY, CURRENT_VERSION);
            editor.apply();
        } catch (Exception e) {
            Log.e(TAG, "Failed to commit update" + e);
        }
    }

    private SharedPreferences getParamSharedPref() {
        // NOTE: EXtServices sets android:defaultToDeviceProtectedStorage="true" so we need this
        // to make sure we're upgrading these preferences correctly.
        return createCredentialProtectedStorageContext()
                .getSharedPreferences(PARAM_SHARED_PREF_NAME + ".xml", Context.MODE_PRIVATE);
    }
}