1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
|
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
// Package base defines shared basic pieces of the ent command.
package base
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"text/template"
"unicode"
"github.com/facebook/ent/cmd/internal/printer"
"github.com/facebook/ent/entc"
"github.com/facebook/ent/entc/gen"
"github.com/facebook/ent/schema/field"
"github.com/spf13/cobra"
)
// IDType is a custom ID implementation for pflag.
type IDType field.Type
// Set implements the Set method of the flag.Value interface.
func (t *IDType) Set(s string) error {
switch s {
case field.TypeInt.String():
*t = IDType(field.TypeInt)
case field.TypeInt64.String():
*t = IDType(field.TypeInt64)
case field.TypeUint.String():
*t = IDType(field.TypeUint)
case field.TypeUint64.String():
*t = IDType(field.TypeUint64)
case field.TypeString.String():
*t = IDType(field.TypeString)
default:
return fmt.Errorf("invalid type %q", s)
}
return nil
}
// Type returns the type representation of the id option for help command.
func (IDType) Type() string {
return fmt.Sprintf("%v", []field.Type{
field.TypeInt,
field.TypeInt64,
field.TypeUint,
field.TypeUint64,
field.TypeString,
})
}
// String returns the default value for the help command.
func (IDType) String() string {
return field.TypeInt.String()
}
// InitCmd returns the init command for ent/c packages.
func InitCmd() *cobra.Command {
var target string
cmd := &cobra.Command{
Use: "init [flags] [schemas]",
Short: "initialize an environment with zero or more schemas",
Example: examples(
"ent init Example",
"ent init --target entv1/schema User Group",
),
Args: func(_ *cobra.Command, names []string) error {
for _, name := range names {
if !unicode.IsUpper(rune(name[0])) {
return errors.New("schema names must begin with uppercase")
}
}
return nil
},
Run: func(cmd *cobra.Command, names []string) {
if err := initEnv(target, names); err != nil {
log.Fatalln(fmt.Errorf("ent/init: %w", err))
}
},
}
cmd.Flags().StringVar(&target, "target", defaultSchema, "target directory for schemas")
return cmd
}
// DescribeCmd returns the describe command for ent/c packages.
func DescribeCmd() *cobra.Command {
return &cobra.Command{
Use: "describe [flags] path",
Short: "printer a description of the graph schema",
Example: examples(
"ent describe ./ent/schema",
"ent describe github.com/a8m/x",
),
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, path []string) {
graph, err := entc.LoadGraph(path[0], &gen.Config{})
if err != nil {
log.Fatalln(err)
}
printer.Fprint(os.Stdout, graph)
},
}
}
// GenerateCmd returns the generate command for ent/c packages.
func GenerateCmd(postRun ...func(*gen.Config)) *cobra.Command {
var (
cfg gen.Config
storage string
features []string
templates []string
idtype = IDType(field.TypeInt)
cmd = &cobra.Command{
Use: "generate [flags] path",
Short: "generate go code for the schema directory",
Example: examples(
"ent generate ./ent/schema",
"ent generate github.com/a8m/x",
),
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, path []string) {
opts := []entc.Option{
entc.Storage(storage),
entc.FeatureNames(features...),
}
for _, tmpl := range templates {
typ := "dir"
if parts := strings.SplitN(tmpl, "=", 2); len(parts) > 1 {
typ, tmpl = parts[0], parts[1]
}
switch typ {
case "dir":
opts = append(opts, entc.TemplateDir(tmpl))
case "file":
opts = append(opts, entc.TemplateFiles(tmpl))
case "glob":
opts = append(opts, entc.TemplateGlob(tmpl))
default:
log.Fatalln("unsupported template type", typ)
}
}
// If the target directory is not inferred from
// the schema path, resolve its package path.
if cfg.Target != "" {
pkgPath, err := PkgPath(DefaultConfig, cfg.Target)
if err != nil {
log.Fatalln(err)
}
cfg.Package = pkgPath
}
cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)}
if err := entc.Generate(path[0], &cfg, opts...); err != nil {
log.Fatalln(err)
}
for _, fn := range postRun {
fn(&cfg)
}
},
}
)
cmd.Flags().Var(&idtype, "idtype", "type of the id field")
cmd.Flags().StringVar(&storage, "storage", "sql", "storage driver to support in codegen")
cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header")
cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen")
cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features")
cmd.Flags().StringSliceVarP(&templates, "template", "", nil, "external templates to execute")
return cmd
}
// initEnv initialize an environment for ent codegen.
func initEnv(target string, names []string) error {
if err := createDir(target); err != nil {
return fmt.Errorf("create dir %s: %w", target, err)
}
for _, name := range names {
if err := gen.ValidSchemaName(name); err != nil {
return fmt.Errorf("init schema %s: %w", name, err)
}
b := bytes.NewBuffer(nil)
if err := tmpl.Execute(b, name); err != nil {
return fmt.Errorf("executing template %s: %w", name, err)
}
newFileTarget := filepath.Join(target, strings.ToLower(name+".go"))
if err := ioutil.WriteFile(newFileTarget, b.Bytes(), 0644); err != nil {
return fmt.Errorf("writing file %s: %w", newFileTarget, err)
}
}
return nil
}
func createDir(target string) error {
_, err := os.Stat(target)
if err == nil || !os.IsNotExist(err) {
return err
}
if err := os.MkdirAll(target, os.ModePerm); err != nil {
return fmt.Errorf("creating schema directory: %w", err)
}
if target != defaultSchema {
return nil
}
if err := ioutil.WriteFile("ent/generate.go", []byte(genFile), 0644); err != nil {
return fmt.Errorf("creating generate.go file: %w", err)
}
return nil
}
// schema template for the "init" command.
var tmpl = template.Must(template.New("schema").
Parse(`package schema
import "github.com/facebook/ent"
// {{ . }} holds the schema definition for the {{ . }} entity.
type {{ . }} struct {
ent.Schema
}
// Fields of the {{ . }}.
func ({{ . }}) Fields() []ent.Field {
return nil
}
// Edges of the {{ . }}.
func ({{ . }}) Edges() []ent.Edge {
return nil
}
`))
const (
// default schema package path.
defaultSchema = "ent/schema"
// ent/generate.go file used for "go generate" command.
genFile = "package ent\n\n//go:generate go run github.com/facebook/ent/cmd/ent generate ./schema\n"
)
// examples formats the given examples to the cli.
func examples(ex ...string) string {
for i := range ex {
ex[i] = " " + ex[i] // indent each row with 2 spaces.
}
return strings.Join(ex, "\n")
}
|