• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.rkpdapp.provisioner;
18 
19 import android.content.Context;
20 import android.os.RemoteException;
21 import android.util.Log;
22 
23 import com.android.rkpdapp.GeekResponse;
24 import com.android.rkpdapp.RkpdException;
25 import com.android.rkpdapp.database.InstantConverter;
26 import com.android.rkpdapp.database.ProvisionedKey;
27 import com.android.rkpdapp.database.ProvisionedKeyDao;
28 import com.android.rkpdapp.database.RkpKey;
29 import com.android.rkpdapp.interfaces.ServerInterface;
30 import com.android.rkpdapp.interfaces.SystemInterface;
31 import com.android.rkpdapp.metrics.ProvisioningAttempt;
32 import com.android.rkpdapp.utils.Settings;
33 import com.android.rkpdapp.utils.StatsProcessor;
34 import com.android.rkpdapp.utils.X509Utils;
35 
36 import java.security.cert.X509Certificate;
37 import java.time.Instant;
38 import java.util.ArrayList;
39 import java.util.Arrays;
40 import java.util.List;
41 
42 import co.nstant.in.cbor.CborException;
43 
44 /**
45  * Provides an easy package to run the provisioning process from start to finish, interfacing
46  * with the system interface and the server backend in order to provision attestation certificates
47  * to the device.
48  */
49 public class Provisioner {
50     private static final String TAG = "RkpdProvisioner";
51     private static final int FAILURE_MAXIMUM = 5;
52     private static final Object provisionKeysLock = new Object();
53 
54     private final Context mContext;
55     private final ProvisionedKeyDao mKeyDao;
56 
Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao)57     public Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao) {
58         mContext = applicationContext;
59         mKeyDao = keyDao;
60     }
61 
62     /**
63      * Check to see if we need to perform provisioning or not for the given
64      * IRemotelyProvisionedComponent.
65      * @param serviceName the name of the remotely provisioned component to be provisioned
66      * @return true if the remotely provisioned component requires more keys, false if the pool
67      *         of available keys is healthy.
68      */
isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName)69     public boolean isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName) {
70         return calculateKeysRequired(metrics, serviceName) > 0;
71     }
72 
73     /**
74      * Generate, sign and store remotely provisioned keys.
75      */
provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse geekResponse)76     public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface,
77             GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException {
78         synchronized (provisionKeysLock) {
79             try {
80                 int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
81                 Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
82                 if (keysRequired == 0) {
83                     metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
84                     return;
85                 }
86 
87                 List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
88                 checkForInterrupts();
89                 List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
90                         geekResponse);
91                 checkForInterrupts();
92                 List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);
93 
94                 mKeyDao.insertKeys(keys);
95                 Log.i(TAG, "Total provisioned keys: " + keys.size());
96                 metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
97             } catch (InterruptedException e) {
98                 metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
99                 throw e;
100             } catch (RkpdException e) {
101                 if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
102                     Log.e(TAG, "Too many failures, resetting defaults.");
103                     Settings.resetDefaultConfig(mContext);
104                 }
105                 // Rethrow to provide failure signal to caller
106                 throw e;
107             }
108         }
109     }
110 
generateKeys(ProvisioningAttempt metrics, int numKeysRequired, SystemInterface systemInterface)111     private List<RkpKey> generateKeys(ProvisioningAttempt metrics, int numKeysRequired,
112             SystemInterface systemInterface)
113             throws CborException, RkpdException, InterruptedException {
114         List<RkpKey> keyArray = new ArrayList<>(numKeysRequired);
115         checkForInterrupts();
116         for (long i = 0; i < numKeysRequired; i++) {
117             keyArray.add(systemInterface.generateKey(metrics));
118         }
119         return keyArray;
120     }
121 
fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated, SystemInterface systemInterface, GeekResponse geekResponse)122     private List<byte[]> fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated,
123             SystemInterface systemInterface, GeekResponse geekResponse)
124             throws RkpdException, CborException, InterruptedException {
125         int provisionedSoFar = 0;
126         List<byte[]> certChains = new ArrayList<>(keysGenerated.size());
127         int maxBatchSize = 0;
128         try {
129             maxBatchSize = systemInterface.getBatchSize();
130         } catch (RemoteException e) {
131             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
132                     "Error getting batch size from the system", e);
133         }
134         while (provisionedSoFar != keysGenerated.size()) {
135             int batchSize = Math.min(keysGenerated.size() - provisionedSoFar, maxBatchSize);
136             certChains.addAll(batchProvision(metrics, systemInterface, geekResponse,
137                     keysGenerated.subList(provisionedSoFar, batchSize + provisionedSoFar)));
138             provisionedSoFar += batchSize;
139         }
140         return certChains;
141     }
142 
batchProvision(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse response, List<RkpKey> keysGenerated)143     private List<byte[]> batchProvision(ProvisioningAttempt metrics,
144             SystemInterface systemInterface,
145             GeekResponse response, List<RkpKey> keysGenerated)
146             throws RkpdException, CborException, InterruptedException {
147         int batch_size = keysGenerated.size();
148         if (batch_size < 1) {
149             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
150                     "Request at least 1 key to be signed. Num requested: " + batch_size);
151         }
152         byte[] certRequest = systemInterface.generateCsr(metrics, response, keysGenerated);
153         if (certRequest == null) {
154             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
155                     "Failed to serialize payload");
156         }
157         return new ServerInterface(mContext).requestSignedCertificates(certRequest,
158                 response.getChallenge(), metrics);
159     }
160 
associateCertsWithKeys(List<byte[]> certChains, List<RkpKey> keysGenerated)161     private List<ProvisionedKey> associateCertsWithKeys(List<byte[]> certChains,
162             List<RkpKey> keysGenerated) throws RkpdException {
163         List<ProvisionedKey> provisionedKeys = new ArrayList<>();
164         for (byte[] chain : certChains) {
165             X509Certificate cert = X509Utils.formatX509Certs(chain)[0];
166             long expirationDate = cert.getNotAfter().getTime();
167             byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(cert);
168             if (rawPublicKey == null) {
169                 Log.e(TAG, "Skipping malformed public key.");
170                 continue;
171             }
172             for (RkpKey key : keysGenerated) {
173                 if (Arrays.equals(key.getPublicKey(), rawPublicKey)) {
174                     provisionedKeys.add(key.generateProvisionedKey(chain,
175                             InstantConverter.fromTimestamp(expirationDate)));
176                     keysGenerated.remove(key);
177                     break;
178                 }
179             }
180         }
181         return provisionedKeys;
182     }
183 
184     /**
185      * Calculate the number of keys to be provisioned.
186      */
calculateKeysRequired(ProvisioningAttempt metrics, String serviceName)187     private int calculateKeysRequired(ProvisioningAttempt metrics, String serviceName) {
188         int numExtraAttestationKeys = Settings.getExtraSignedKeysAvailable(mContext);
189         Instant expirationTime = Settings.getExpirationTime(mContext);
190         StatsProcessor.PoolStats poolStats = StatsProcessor.processPool(mKeyDao, serviceName,
191                 numExtraAttestationKeys, expirationTime);
192         metrics.setIsKeyPoolEmpty(poolStats.keysUnassigned == 0);
193         return poolStats.keysToGenerate;
194     }
195 
checkForInterrupts()196     private void checkForInterrupts() throws InterruptedException {
197         if (Thread.interrupted()) {
198             throw new InterruptedException();
199         }
200     }
201 }
202