File: objects.go

package info (click to toggle)
incus 6.0.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 24,392 kB
  • sloc: sh: 16,313; ansic: 3,121; python: 457; makefile: 337; ruby: 51; sql: 50; lisp: 6
file content (110 lines) | stat: -rw-r--r-- 2,829 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package query

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"strings"
)

// SelectObjects executes a statement which must yield rows with a specific
// columns schema. It invokes the given Dest hook for each yielded row.
func SelectObjects(ctx context.Context, stmt *sql.Stmt, rowFunc Dest, args ...any) error {
	rows, err := stmt.QueryContext(ctx, args...)
	if err != nil {
		return err
	}

	defer func() { _ = rows.Close() }()

	for rows.Next() {
		err = rowFunc(rows.Scan)
		if err != nil {
			return err
		}
	}

	return rows.Err()
}

// Scan runs a query with inArgs and provides the rowFunc with the scan function for each row.
// It handles closing the rows and errors from the result set.
func Scan(ctx context.Context, tx *sql.Tx, sql string, rowFunc Dest, inArgs ...any) error {
	rows, err := tx.QueryContext(ctx, sql, inArgs...)
	if err != nil {
		return err
	}

	defer func() { _ = rows.Close() }()

	for rows.Next() {
		err = rowFunc(rows.Scan)
		if err != nil {
			return err
		}
	}

	return rows.Err()
}

// Dest is a function that is expected to return the objects to pass to the
// 'dest' argument of sql.Rows.Scan(). It is invoked by SelectObjects once per
// yielded row, and it will be passed the index of the row being scanned.
type Dest func(scan func(dest ...any) error) error

// UpsertObject inserts or replaces a new row with the given column values, to
// the given table using columns order. For example:
//
// UpsertObject(tx, "cars", []string{"id", "brand"}, []any{1, "ferrari"})
//
// The number of elements in 'columns' must match the one in 'values'.
func UpsertObject(tx *sql.Tx, table string, columns []string, values []any) (int64, error) {
	n := len(columns)
	if n == 0 {
		return -1, errors.New("columns length is zero")
	}

	if n != len(values) {
		return -1, errors.New("columns length does not match values length")
	}

	stmt := fmt.Sprintf(
		"INSERT OR REPLACE INTO %s (%s) VALUES %s",
		table, strings.Join(columns, ", "), Params(n))
	result, err := tx.Exec(stmt, values...)
	if err != nil {
		return -1, fmt.Errorf("insert or replaced row: %w", err)
	}

	id, err := result.LastInsertId()
	if err != nil {
		return -1, fmt.Errorf("get last inserted ID: %w", err)
	}

	return id, nil
}

// DeleteObject removes the row identified by the given ID. The given table
// must have a primary key column called 'id'.
//
// It returns a flag indicating if a matching row was actually found and
// deleted or not.
func DeleteObject(tx *sql.Tx, table string, id int64) (bool, error) {
	stmt := fmt.Sprintf("DELETE FROM %s WHERE id=?", table)
	result, err := tx.Exec(stmt, id)
	if err != nil {
		return false, err
	}

	n, err := result.RowsAffected()
	if err != nil {
		return false, err
	}

	if n > 1 {
		return true, errors.New("more than one row was deleted")
	}

	return n == 1, nil
}