File: dev.go

package info (click to toggle)
golang-ariga-atlas 0.7.2-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 5,676 kB
  • sloc: javascript: 592; sql: 404; makefile: 10
file content (132 lines) | stat: -rw-r--r-- 3,804 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package sqlx

import (
	"context"
	"fmt"
	"hash/fnv"
	"time"

	"ariga.io/atlas/sql/migrate"
	"ariga.io/atlas/sql/schema"
)

// DevDriver is a driver that provides additional functionality
// to interact with the development database.
type DevDriver struct {
	// A Driver connected to the dev database.
	migrate.Driver

	// MaxNameLen configures the max length of object names in
	// the connected database (e.g. 64 in MySQL). Longer names
	// are trimmed and suffixed with their hash.
	MaxNameLen int

	// DropClause holds optional clauses that
	// can be added to the DropSchema change.
	DropClause []schema.Clause

	// PatchColumn allows providing a custom function to patch
	// columns that hold a schema reference.
	PatchColumn func(*schema.Schema, *schema.Column)
}

// NormalizeRealm implements the schema.Normalizer interface.
//
// The implementation converts schema objects in "natural form" (e.g. HCL or DSL)
// to their "normal presentation" in the database, by creating them temporarily in
// a "dev database", and then inspects them from there.
func (d *DevDriver) NormalizeRealm(ctx context.Context, r *schema.Realm) (nr *schema.Realm, err error) {
	var (
		names   = make(map[string]string)
		changes = make([]schema.Change, 0, len(r.Schemas))
		reverse = make([]schema.Change, 0, len(r.Schemas))
		opts    = &schema.InspectRealmOption{
			Schemas: make([]string, 0, len(r.Schemas)),
		}
	)
	for _, s := range r.Schemas {
		if s.Realm != r {
			s.Realm = r
		}
		dev := d.formatName(s.Name)
		names[dev] = s.Name
		s.Name = dev
		opts.Schemas = append(opts.Schemas, s.Name)
		// Skip adding the schema.IfNotExists clause
		// to fail if the schema exists.
		st := schema.New(dev).AddAttrs(s.Attrs...)
		changes = append(changes, &schema.AddSchema{S: st})
		reverse = append(reverse, &schema.DropSchema{S: st, Extra: append(d.DropClause, &schema.IfExists{})})
		for _, t := range s.Tables {
			// If objects are not strongly connected.
			if t.Schema != s {
				t.Schema = s
			}
			for _, c := range t.Columns {
				if e, ok := c.Type.Type.(*schema.EnumType); ok && e.Schema != s {
					e.Schema = s
				}
				if d.PatchColumn != nil {
					d.PatchColumn(s, c)
				}
			}
			changes = append(changes, &schema.AddTable{T: t})
		}
	}
	patch := func(r *schema.Realm) {
		for _, s := range r.Schemas {
			s.Name = names[s.Name]
		}
	}
	// Delete the dev resources, and return
	// the source realm to its initial state.
	defer func() {
		patch(r)
		if rerr := d.ApplyChanges(ctx, reverse); rerr != nil {
			if err != nil {
				rerr = fmt.Errorf("%w: %v", err, rerr)
			}
			err = rerr
		}
	}()
	if err := d.ApplyChanges(ctx, changes); err != nil {
		return nil, err
	}
	if nr, err = d.InspectRealm(ctx, opts); err != nil {
		return nil, err
	}
	patch(nr)
	return nr, nil
}

// NormalizeSchema returns the normal representation of the given database. See NormalizeRealm for more info.
func (d *DevDriver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) {
	r := &schema.Realm{}
	if s.Realm != nil {
		r.Attrs = s.Realm.Attrs
	}
	r.Schemas = append(r.Schemas, s)
	nr, err := d.NormalizeRealm(ctx, r)
	if err != nil {
		return nil, err
	}
	ns, ok := nr.Schema(s.Name)
	if !ok {
		return nil, fmt.Errorf("missing normalized schema %q", s.Name)
	}
	return ns, nil
}

func (d *DevDriver) formatName(name string) string {
	dev := fmt.Sprintf("atlas_dev_%s_%d", name, time.Now().Unix())
	if d.MaxNameLen == 0 || len(dev) <= d.MaxNameLen {
		return dev
	}
	h := fnv.New128()
	h.Write([]byte(dev))
	return fmt.Sprintf("%s_%x", dev[:d.MaxNameLen-1-h.Size()*2], h.Sum(nil))
}