• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Copyright 2014 Google Inc. All Rights Reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
20Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
23
24Usage:
25  $ python prediction.py "bucket/object" "model_id" "project_id"
26
27You can also get help on all the command-line flags the program understands
28by running:
29
30  $ python prediction.py --help
31
32To get detailed log output run:
33
34  $ python prediction.py --logging_level=DEBUG
35"""
36from __future__ import print_function
37
38__author__ = ('jcgregorio@google.com (Joe Gregorio), '
39              'marccohen@google.com (Marc Cohen)')
40
41import argparse
42import pprint
43import sys
44import time
45
46from apiclient import sample_tools
47from oauth2client import client
48
49
50# Time to wait (in seconds) between successive checks of training status.
51SLEEP_TIME = 10
52
53
54# Declare command-line flags.
55argparser = argparse.ArgumentParser(add_help=False)
56argparser.add_argument('object_name',
57    help='Full Google Storage path of csv data (ex bucket/object)')
58argparser.add_argument('model_id',
59    help='Model Id of your choosing to name trained model')
60argparser.add_argument('project_id',
61    help='Project Id of your Google Cloud Project')
62
63
64def print_header(line):
65  '''Format and print header block sized to length of line'''
66  header_str = '='
67  header_line = header_str * len(line)
68  print('\n' + header_line)
69  print(line)
70  print(header_line)
71
72
73def main(argv):
74  # If you previously ran this app with an earlier version of the API
75  # or if you change the list of scopes below, revoke your app's permission
76  # here: https://accounts.google.com/IssuedAuthSubTokens
77  # Then re-run the app to re-authorize it.
78  service, flags = sample_tools.init(
79      argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser],
80      scope=(
81          'https://www.googleapis.com/auth/prediction',
82          'https://www.googleapis.com/auth/devstorage.read_only'))
83
84  try:
85    # Get access to the Prediction API.
86    papi = service.trainedmodels()
87
88    # List models.
89    print_header('Fetching list of first ten models')
90    result = papi.list(maxResults=10, project=flags.project_id).execute()
91    print('List results:')
92    pprint.pprint(result)
93
94    # Start training request on a data set.
95    print_header('Submitting model training request')
96    body = {'id': flags.model_id, 'storageDataLocation': flags.object_name}
97    start = papi.insert(body=body, project=flags.project_id).execute()
98    print('Training results:')
99    pprint.pprint(start)
100
101    # Wait for the training to complete.
102    print_header('Waiting for training to complete')
103    while True:
104      status = papi.get(id=flags.model_id, project=flags.project_id).execute()
105      state = status['trainingStatus']
106      print('Training state: ' + state)
107      if state == 'DONE':
108        break
109      elif state == 'RUNNING':
110        time.sleep(SLEEP_TIME)
111        continue
112      else:
113        raise Exception('Training Error: ' + state)
114
115      # Job has completed.
116      print('Training completed:')
117      pprint.pprint(status)
118      break
119
120    # Describe model.
121    print_header('Fetching model description')
122    result = papi.analyze(id=flags.model_id, project=flags.project_id).execute()
123    print('Analyze results:')
124    pprint.pprint(result)
125
126    # Make some predictions using the newly trained model.
127    print_header('Making some predictions')
128    for sample_text in ['mucho bueno', 'bonjour, mon cher ami']:
129      body = {'input': {'csvInstance': [sample_text]}}
130      result = papi.predict(
131        body=body, id=flags.model_id, project=flags.project_id).execute()
132      print('Prediction results for "%s"...' % sample_text)
133      pprint.pprint(result)
134
135    # Delete model.
136    print_header('Deleting model')
137    result = papi.delete(id=flags.model_id, project=flags.project_id).execute()
138    print('Model deleted.')
139
140  except client.AccessTokenRefreshError:
141    print ('The credentials have been revoked or expired, please re-run '
142           'the application to re-authorize.')
143
144
145if __name__ == '__main__':
146  main(sys.argv)
147