1#!/usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Copyright 2014 Google Inc. All Rights Reserved. 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17 18"""Simple command-line sample for the Google Prediction API 19 20Command-line application that trains on your input data. This sample does 21the same thing as the Hello Prediction! example. You might want to run 22the setup.sh script to load the sample data to Google Storage. 23 24Usage: 25 $ python prediction.py "bucket/object" "model_id" "project_id" 26 27You can also get help on all the command-line flags the program understands 28by running: 29 30 $ python prediction.py --help 31 32To get detailed log output run: 33 34 $ python prediction.py --logging_level=DEBUG 35""" 36from __future__ import print_function 37 38__author__ = ('jcgregorio@google.com (Joe Gregorio), ' 39 'marccohen@google.com (Marc Cohen)') 40 41import argparse 42import pprint 43import sys 44import time 45 46from apiclient import sample_tools 47from oauth2client import client 48 49 50# Time to wait (in seconds) between successive checks of training status. 51SLEEP_TIME = 10 52 53 54# Declare command-line flags. 55argparser = argparse.ArgumentParser(add_help=False) 56argparser.add_argument('object_name', 57 help='Full Google Storage path of csv data (ex bucket/object)') 58argparser.add_argument('model_id', 59 help='Model Id of your choosing to name trained model') 60argparser.add_argument('project_id', 61 help='Project Id of your Google Cloud Project') 62 63 64def print_header(line): 65 '''Format and print header block sized to length of line''' 66 header_str = '=' 67 header_line = header_str * len(line) 68 print('\n' + header_line) 69 print(line) 70 print(header_line) 71 72 73def main(argv): 74 # If you previously ran this app with an earlier version of the API 75 # or if you change the list of scopes below, revoke your app's permission 76 # here: https://accounts.google.com/IssuedAuthSubTokens 77 # Then re-run the app to re-authorize it. 78 service, flags = sample_tools.init( 79 argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser], 80 scope=( 81 'https://www.googleapis.com/auth/prediction', 82 'https://www.googleapis.com/auth/devstorage.read_only')) 83 84 try: 85 # Get access to the Prediction API. 86 papi = service.trainedmodels() 87 88 # List models. 89 print_header('Fetching list of first ten models') 90 result = papi.list(maxResults=10, project=flags.project_id).execute() 91 print('List results:') 92 pprint.pprint(result) 93 94 # Start training request on a data set. 95 print_header('Submitting model training request') 96 body = {'id': flags.model_id, 'storageDataLocation': flags.object_name} 97 start = papi.insert(body=body, project=flags.project_id).execute() 98 print('Training results:') 99 pprint.pprint(start) 100 101 # Wait for the training to complete. 102 print_header('Waiting for training to complete') 103 while True: 104 status = papi.get(id=flags.model_id, project=flags.project_id).execute() 105 state = status['trainingStatus'] 106 print('Training state: ' + state) 107 if state == 'DONE': 108 break 109 elif state == 'RUNNING': 110 time.sleep(SLEEP_TIME) 111 continue 112 else: 113 raise Exception('Training Error: ' + state) 114 115 # Job has completed. 116 print('Training completed:') 117 pprint.pprint(status) 118 break 119 120 # Describe model. 121 print_header('Fetching model description') 122 result = papi.analyze(id=flags.model_id, project=flags.project_id).execute() 123 print('Analyze results:') 124 pprint.pprint(result) 125 126 # Make some predictions using the newly trained model. 127 print_header('Making some predictions') 128 for sample_text in ['mucho bueno', 'bonjour, mon cher ami']: 129 body = {'input': {'csvInstance': [sample_text]}} 130 result = papi.predict( 131 body=body, id=flags.model_id, project=flags.project_id).execute() 132 print('Prediction results for "%s"...' % sample_text) 133 pprint.pprint(result) 134 135 # Delete model. 136 print_header('Deleting model') 137 result = papi.delete(id=flags.model_id, project=flags.project_id).execute() 138 print('Model deleted.') 139 140 except client.AccessTokenRefreshError: 141 print ('The credentials have been revoked or expired, please re-run ' 142 'the application to re-authorize.') 143 144 145if __name__ == '__main__': 146 main(sys.argv) 147