1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
package convert
import (
"errors"
"fmt"
"go/ast"
)
/*
* Given the root node of an AST, returns the node containing the
* import statements for the file.
*/
func importsForRootNode(rootNode *ast.File) (imports *ast.GenDecl, err error) {
for _, declaration := range rootNode.Decls {
decl, ok := declaration.(*ast.GenDecl)
if !ok || len(decl.Specs) == 0 {
continue
}
_, ok = decl.Specs[0].(*ast.ImportSpec)
if ok {
imports = decl
return
}
}
err = errors.New(fmt.Sprintf("Could not find imports for root node:\n\t%#v\n", rootNode))
return
}
/*
* Removes "testing" import, if present
*/
func removeTestingImport(rootNode *ast.File) {
importDecl, err := importsForRootNode(rootNode)
if err != nil {
panic(err.Error())
}
var index int
for i, importSpec := range importDecl.Specs {
importSpec := importSpec.(*ast.ImportSpec)
if importSpec.Path.Value == "\"testing\"" {
index = i
break
}
}
importDecl.Specs = append(importDecl.Specs[:index], importDecl.Specs[index+1:]...)
}
/*
* Adds import statements for onsi/ginkgo, if missing
*/
func addGinkgoImports(rootNode *ast.File) {
importDecl, err := importsForRootNode(rootNode)
if err != nil {
panic(err.Error())
}
if len(importDecl.Specs) == 0 {
// TODO: might need to create a import decl here
panic("unimplemented : expected to find an imports block")
}
needsGinkgo := true
for _, importSpec := range importDecl.Specs {
importSpec, ok := importSpec.(*ast.ImportSpec)
if !ok {
continue
}
if importSpec.Path.Value == "\"github.com/onsi/ginkgo\"" {
needsGinkgo = false
}
}
if needsGinkgo {
importDecl.Specs = append(importDecl.Specs, createImport(".", "\"github.com/onsi/ginkgo\""))
}
}
/*
* convenience function to create an import statement
*/
func createImport(name, path string) *ast.ImportSpec {
return &ast.ImportSpec{
Name: &ast.Ident{Name: name},
Path: &ast.BasicLit{Kind: 9, Value: path},
}
}
|