1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"google.golang.org/api/googleapi"
prediction "google.golang.org/api/prediction/v1.6"
)
func init() {
scopes := []string{
prediction.DevstorageFullControlScope,
prediction.DevstorageReadOnlyScope,
prediction.DevstorageReadWriteScope,
prediction.PredictionScope,
}
registerDemo("prediction", strings.Join(scopes, " "), predictionMain)
}
type predictionType struct {
api *prediction.Service
projectNumber string
bucketName string
trainingFileName string
modelName string
}
// This example demonstrates calling the Prediction API.
// Training data is uploaded to a pre-created Google Cloud Storage Bucket and
// then the Prediction API is called to train a model based on that data.
// After a few minutes, the model should be completely trained and ready
// for prediction. At that point, text is sent to the model and the Prediction
// API attempts to classify the data, and the results are printed out.
//
// To get started, follow the instructions found in the "Hello Prediction!"
// Getting Started Guide located here:
// https://developers.google.com/prediction/docs/hello_world
//
// Example usage:
// go-api-demo -clientid="my-clientid" -secret="my-secret" prediction
// my-project-number my-bucket-name my-training-filename my-model-name
//
// Example output:
// Predict result: language=Spanish
// English Score: 0.000000
// French Score: 0.000000
// Spanish Score: 1.000000
// analyze: output feature text=&{157 English}
// analyze: output feature text=&{149 French}
// analyze: output feature text=&{100 Spanish}
// feature text count=406
func predictionMain(client *http.Client, argv []string) {
if len(argv) != 4 {
fmt.Fprintln(os.Stderr,
"Usage: prediction project_number bucket training_data model_name")
return
}
api, err := prediction.New(client)
if err != nil {
log.Fatalf("unable to create prediction API client: %v", err)
}
t := &predictionType{
api: api,
projectNumber: argv[0],
bucketName: argv[1],
trainingFileName: argv[2],
modelName: argv[3],
}
t.trainModel()
t.predictModel()
}
func (t *predictionType) trainModel() {
// First, check to see if our trained model already exists.
res, err := t.api.Trainedmodels.Get(t.projectNumber, t.modelName).Do()
if err != nil {
if ae, ok := err.(*googleapi.Error); ok && ae.Code != http.StatusNotFound {
log.Fatalf("error getting trained model: %v", err)
}
log.Printf("Training model not found, creating new model.")
res, err = t.api.Trainedmodels.Insert(t.projectNumber, &prediction.Insert{
Id: t.modelName,
StorageDataLocation: filepath.Join(t.bucketName, t.trainingFileName),
}).Do()
if err != nil {
log.Fatalf("unable to create trained model: %v", err)
}
}
if res.TrainingStatus != "DONE" {
// Wait for the trained model to finish training.
fmt.Printf("Training model. Please wait and re-run program after a few minutes.")
os.Exit(0)
}
}
func (t *predictionType) predictModel() {
// Model has now been trained. Predict with it.
input := &prediction.Input{
Input: &prediction.InputInput{
CsvInstance: []interface{}{
"Hola, con quien hablo",
},
},
}
res, err := t.api.Trainedmodels.Predict(t.projectNumber, t.modelName, input).Do()
if err != nil {
log.Fatalf("unable to get trained prediction: %v", err)
}
fmt.Printf("Predict result: language=%v\n", res.OutputLabel)
for _, m := range res.OutputMulti {
fmt.Printf("%v Score: %v\n", m.Label, m.Score)
}
// Now analyze the model.
an, err := t.api.Trainedmodels.Analyze(t.projectNumber, t.modelName).Do()
if err != nil {
log.Fatalf("unable to analyze trained model: %v", err)
}
for _, f := range an.DataDescription.OutputFeature.Text {
fmt.Printf("analyze: output feature text=%v\n", f)
}
for _, f := range an.DataDescription.Features {
fmt.Printf("feature text count=%v\n", f.Text.Count)
}
}
|