• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#  Copyright (C) 2020 The Android Open Source Project
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14#
15# Licensed under the Apache License, Version 2.0 (the "License");
16# you may not use this file except in compliance with the License.
17# You may obtain a copy of the License at
18#
19#     http://www.apache.org/licenses/LICENSE-2.0
20#
21# Unless required by applicable law or agreed to in writing, software
22# distributed under the License is distributed on an "AS IS" BASIS,
23# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24# See the License for the specific language governing permissions and
25# limitations under the License.
26"""
27Utility functions for atest.
28"""
29from __future__ import print_function
30
31import getpass
32import logging
33import os
34import subprocess
35import uuid
36try:
37    import httplib2
38except ModuleNotFoundError as e:
39    logging.debug('Import error due to %s', e)
40
41from pathlib import Path
42from socket import socket
43
44try:
45    # pylint: disable=import-error
46    from oauth2client import client as oauth2_client
47    from oauth2client.contrib import multistore_file
48    from oauth2client import tools as oauth2_tools
49except ModuleNotFoundError as e:
50    logging.debug('Import error due to %s', e)
51
52from atest.logstorage import logstorage_utils
53from atest import atest_utils
54from atest import constants
55
56class RunFlowFlags():
57    """Flags for oauth2client.tools.run_flow."""
58    def __init__(self, browser_auth):
59        self.auth_host_port = [8080, 8090]
60        self.auth_host_name = "localhost"
61        self.logging_level = "ERROR"
62        self.noauth_local_webserver = not browser_auth
63
64
65class GCPHelper():
66    """GCP bucket helper class."""
67    def __init__(self, client_id=None, client_secret=None,
68                 user_agent=None, scope=constants.SCOPE_BUILD_API_SCOPE):
69        """Init stuff for GCPHelper class.
70        Args:
71            client_id: String, client id from the cloud project.
72            client_secret: String, client secret for the client_id.
73            user_agent: The user agent for the credential.
74            scope: String, scopes separated by space.
75        """
76        self.client_id = client_id
77        self.client_secret = client_secret
78        self.user_agent = user_agent
79        self.scope = scope
80
81    def get_refreshed_credential_from_file(self, creds_file_path):
82        """Get refreshed credential from file.
83        Args:
84            creds_file_path: Credential file path.
85        Returns:
86            An oauth2client.OAuth2Credentials instance.
87        """
88        credential = self.get_credential_from_file(creds_file_path)
89        if credential:
90            try:
91                credential.refresh(httplib2.Http())
92            except oauth2_client.AccessTokenRefreshError as e:
93                logging.debug('Token refresh error: %s', e)
94            if not credential.invalid:
95                return credential
96        logging.debug('Cannot get credential.')
97        return None
98
99    def get_credential_from_file(self, creds_file_path):
100        """Get credential from file.
101        Args:
102            creds_file_path: Credential file path.
103        Returns:
104            An oauth2client.OAuth2Credentials instance.
105        """
106        storage = multistore_file.get_credential_storage(
107            filename=os.path.abspath(creds_file_path),
108            client_id=self.client_id,
109            user_agent=self.user_agent,
110            scope=self.scope)
111        return storage.get()
112
113    def get_credential_with_auth_flow(self, creds_file_path):
114        """Get Credential object from file.
115        Get credential object from file. Run oauth flow if haven't authorized
116        before.
117
118        Args:
119            creds_file_path: Credential file path.
120        Returns:
121            An oauth2client.OAuth2Credentials instance.
122        """
123        credentials = None
124        # SSO auth
125        try:
126            token = self._get_sso_access_token()
127            credentials = oauth2_client.AccessTokenCredentials(
128                token , 'atest')
129            if credentials:
130                return credentials
131        # pylint: disable=broad-except
132        except Exception as e:
133            logging.debug('Exception:%s', e)
134        # GCP auth flow
135        credentials = self.get_refreshed_credential_from_file(creds_file_path)
136        if not credentials:
137            storage = multistore_file.get_credential_storage(
138                filename=os.path.abspath(creds_file_path),
139                client_id=self.client_id,
140                user_agent=self.user_agent,
141                scope=self.scope)
142            return self._run_auth_flow(storage)
143        return credentials
144
145    def _run_auth_flow(self, storage):
146        """Get user oauth2 credentials.
147
148        Using the loopback IP address flow for desktop clients.
149
150        Args:
151            storage: GCP storage object.
152        Returns:
153            An oauth2client.OAuth2Credentials instance.
154        """
155        flags = RunFlowFlags(browser_auth=True)
156
157        # Get a free port on demand.
158        port = None
159        while not port or port < 10000:
160            with socket() as local_socket:
161                local_socket.bind(('',0))
162                _, port = local_socket.getsockname()
163        _localhost_port = port
164        _direct_uri = f'http://localhost:{_localhost_port}'
165        flow = oauth2_client.OAuth2WebServerFlow(
166            client_id=self.client_id,
167            client_secret=self.client_secret,
168            scope=self.scope,
169            user_agent=self.user_agent,
170            redirect_uri=f'{_direct_uri}')
171        credentials = oauth2_tools.run_flow(
172            flow=flow, storage=storage, flags=flags)
173        return credentials
174
175    @staticmethod
176    def _get_sso_access_token():
177        """Use stubby command line to exchange corp sso to a scoped oauth
178        token.
179
180        Returns:
181            A token string.
182        """
183        if not constants.TOKEN_EXCHANGE_COMMAND:
184            return None
185
186        request = constants.TOKEN_EXCHANGE_REQUEST.format(
187            user=getpass.getuser(), scope=constants.SCOPE)
188        # The output format is: oauth2_token: "<TOKEN>"
189        return subprocess.run(constants.TOKEN_EXCHANGE_COMMAND,
190                              input=request,
191                              check=True,
192                              text=True,
193                              shell=True,
194                              stdout=subprocess.PIPE).stdout.split('"')[1]
195
196
197def do_upload_flow(extra_args):
198    """Run upload flow.
199
200    Asking user's decision and do the related steps.
201
202    Args:
203        extra_args: Dict of extra args to add to test run.
204    Return:
205        tuple(invocation, workunit)
206    """
207    config_folder = os.path.join(atest_utils.get_misc_dir(), '.atest')
208    creds = fetch_credential(config_folder, extra_args)
209    if creds:
210        inv, workunit, local_build_id, build_target = _prepare_data(creds)
211        extra_args[constants.INVOCATION_ID] = inv['invocationId']
212        extra_args[constants.WORKUNIT_ID] = workunit['id']
213        extra_args[constants.LOCAL_BUILD_ID] = local_build_id
214        extra_args[constants.BUILD_TARGET] = build_target
215        if not os.path.exists(os.path.dirname(constants.TOKEN_FILE_PATH)):
216            os.makedirs(os.path.dirname(constants.TOKEN_FILE_PATH))
217        with open(constants.TOKEN_FILE_PATH, 'w') as token_file:
218            if creds.token_response:
219                token_file.write(creds.token_response['access_token'])
220            else:
221                token_file.write(creds.access_token)
222        return creds, inv
223    return None, None
224
225def fetch_credential(config_folder, extra_args):
226    """Fetch the credential whenever --request-upload-result is specified.
227
228    Args:
229        config_folder: The directory path to put config file. The default path
230                       is ~/.atest.
231        extra_args: Dict of extra args to add to test run.
232    Return:
233        The credential object.
234    """
235    if not os.path.exists(config_folder):
236        os.makedirs(config_folder)
237    not_upload_file = os.path.join(config_folder, constants.DO_NOT_UPLOAD)
238    # Do nothing if there are no related config or DO_NOT_UPLOAD exists.
239    if (not constants.CREDENTIAL_FILE_NAME or
240            not constants.TOKEN_FILE_PATH):
241        return None
242
243    creds_f = os.path.join(config_folder, constants.CREDENTIAL_FILE_NAME)
244    if extra_args.get(constants.REQUEST_UPLOAD_RESULT):
245        if os.path.exists(not_upload_file):
246            os.remove(not_upload_file)
247    else:
248        # TODO(b/275113186): Change back to default upload after AnTS upload
249        #  extremely slow problem be solved.
250        if os.path.exists(creds_f):
251            os.remove(creds_f)
252        Path(not_upload_file).touch()
253
254    # If DO_NOT_UPLOAD not exist, ATest will try to get the credential
255    # from the file.
256    if not os.path.exists(not_upload_file):
257        return GCPHelper(
258            client_id=constants.CLIENT_ID,
259            client_secret=constants.CLIENT_SECRET,
260            user_agent='atest').get_credential_with_auth_flow(creds_f)
261
262    # TODO(b/275113186): Change back the warning message after the bug solved.
263    atest_utils.colorful_print(
264        'WARNING: AnTS upload disabled by default due to upload slowly'
265        '(b/275113186). If you still want to upload test result to AnTS, '
266        'please add the option --request-upload-result manually.',
267        constants.YELLOW)
268    return None
269
270def _prepare_data(creds):
271    """Prepare data for build api using.
272
273    Args:
274        creds: The credential object.
275    Return:
276        invocation and workunit object.
277        build id and build target of local build.
278    """
279    try:
280        logging.disable(logging.INFO)
281        external_id = str(uuid.uuid4())
282        client = logstorage_utils.BuildClient(creds)
283        branch = _get_branch(client)
284        target = _get_target(branch, client)
285        build_record = client.insert_local_build(external_id,
286                                                    target,
287                                                    branch)
288        client.insert_build_attempts(build_record)
289        invocation = client.insert_invocation(build_record)
290        workunit = client.insert_work_unit(invocation)
291        return invocation, workunit, build_record['buildId'], target
292    finally:
293        logging.disable(logging.NOTSET)
294
295def _get_branch(build_client):
296    """Get source code tree branch.
297
298    Args:
299        build_client: The build client object.
300    Return:
301        "git_master" in internal git, "aosp-master" otherwise.
302    """
303    default_branch = ('git_master'
304                        if constants.CREDENTIAL_FILE_NAME else 'aosp-master')
305    local_branch = "git_%s" % atest_utils.get_manifest_branch()
306    branch = build_client.get_branch(local_branch)
307    return local_branch if branch else default_branch
308
309def _get_target(branch, build_client):
310    """Get local build selected target.
311
312    Args:
313        branch: The branch want to check.
314        build_client: The build client object.
315    Return:
316        The matched build target, "aosp_x86-userdebug" otherwise.
317    """
318    default_target = 'aosp_x86-userdebug'
319    local_target = atest_utils.get_build_target()
320    targets = [t['target']
321                for t in build_client.list_target(branch)['targets']]
322    return local_target if local_target in targets else default_target
323