File: __init__.py

package info (click to toggle)
python-sqlite-migrate 0.1~beta0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 144 kB
  • sloc: python: 371; makefile: 3
file content (115 lines) | stat: -rw-r--r-- 3,525 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from dataclasses import dataclass
import datetime
from typing import cast, Callable, List, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from sqlite_utils.db import Database, Table


class Migrations:
    migrations_table = "_sqlite_migrations"

    @dataclass
    class _Migration:
        name: str
        fn: Callable

    @dataclass
    class _AppliedMigration:
        name: str
        applied_at: datetime.datetime

    def __init__(self, name: str):
        """
        :param name: The name of the migration set. This should be unique.
        """
        self.name = name
        self._migrations: List[Migrations._Migration] = []

    def __call__(self, *, name: Optional[str] = None) -> Callable:
        """
        :param name: The name to use for this migration - if not provided,
          the name of the function will be used
        """

        def inner(func: Callable) -> Callable:
            self._migrations.append(self._Migration(name or func.__name__, func))
            return func

        return inner

    def pending(self, db: "Database") -> List["Migrations._Migration"]:
        """
        Return a list of pending migrations.
        """
        self.ensure_migrations_table(db)
        already_applied = {
            r["name"]
            for r in db[self.migrations_table].rows_where(
                "migration_set = ?", [self.name]
            )
        }
        return [
            migration
            for migration in self._migrations
            if migration.name not in already_applied
        ]

    def applied(self, db: "Database") -> List["Migrations._AppliedMigration"]:
        """
        Return a list of applied migrations.
        """
        self.ensure_migrations_table(db)
        return [
            self._AppliedMigration(name=row["name"], applied_at=row["applied_at"])
            for row in db[self.migrations_table].rows_where(
                "migration_set = ?", [self.name]
            )
        ]

    def apply(self, db: "Database", *, stop_before: Optional[str] = None):
        """
        Apply any pending migrations to the database.
        """
        self.ensure_migrations_table(db)
        for migration in self.pending(db):
            name = migration.name
            if name == stop_before:
                return
            migration.fn(db)
            _table(db, self.migrations_table).insert(
                {
                    "migration_set": self.name,
                    "name": name,
                    "applied_at": str(datetime.datetime.now(datetime.timezone.utc)),
                }
            )

    def ensure_migrations_table(self, db: "Database"):
        """
        Ensure _sqlite_migrations table exists and has the correct schema
        """
        table = _table(db, self.migrations_table)
        if not table.exists():
            table.create(
                {
                    "migration_set": str,
                    "name": str,
                    "applied_at": str,
                },
                pk=("migration_set", "name"),
            )
        elif table.pks != ["migration_set", "name"]:
            # This has the old primary key scheme, upgrade it
            table.transform(pk=("migration_set", "name"))

    def __repr__(self):
        return "<Migrations '{}': [{}]>".format(
            self.name, ", ".join(m.name for m in self._migrations)
        )


def _table(db: "Database", name: str) -> "Table":
    # mypy workaround
    return cast("Table", db[name])