1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
package internal
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os/exec"
"runtime"
"strconv"
"strings"
"unicode"
)
// newline is the default newline used by the system.
var newline []byte
func init() {
if runtime.GOOS == "windows" {
newline = []byte("\r\n")
} else {
newline = []byte("\n")
}
}
// ResultSet is the shared interface for a result set.
type ResultSet interface {
Next() bool
Scan(...any) error
Columns() ([]string, error)
Close() error
Err() error
NextResultSet() bool
}
// PsqlEncodeAll does a values query for each of the values in the result set,
// writing captured output to the writer.
func PsqlEncodeAll(w io.Writer, resultSet ResultSet, params map[string]string, dsn string) error {
if err := PsqlEncode(w, resultSet, params, dsn); err != nil {
return err
}
for resultSet.NextResultSet() {
if _, err := w.Write(newline); err != nil {
return err
}
if err := PsqlEncode(w, resultSet, params, dsn); err != nil {
return err
}
}
if _, err := w.Write(newline); err != nil {
return err
}
return nil
}
// PsqlEncode does a single value query using psql, writing the captured output
// to the writer.
func PsqlEncode(w io.Writer, resultSet ResultSet, params map[string]string, dsn string) error {
// read values
var vals string
var i int
for resultSet.Next() {
var id, name, z any
if err := resultSet.Scan(&id, &name, &z); err != nil {
return err
}
var extra string
if i != 0 {
extra = ","
}
n := name.(string)
vals += fmt.Sprintf("%s\n (%v,E'%s', %s)", extra, id, psqlEsc(n), psqlEnc(n, z))
i++
}
if err := resultSet.Err(); err != nil {
return err
}
// build pset
var pset string
for k, v := range params {
pset += fmt.Sprintf("\n\\pset %s '%s'", k, v)
}
// exec
stdout := new(bytes.Buffer)
q := fmt.Sprintf(psqlValuesQuery, pset, vals)
cmd := exec.Command("psql", dsn, "-qX")
cmd.Stdin, cmd.Stdout = bytes.NewReader([]byte(q)), stdout
if err := cmd.Run(); err != nil {
return err
}
if _, err := w.Write(bytes.TrimRightFunc(stdout.Bytes(), unicode.IsSpace)); err != nil {
return err
}
_, err := w.Write(newline)
return err
}
const (
psqlValuesQuery = `%s
SELECT * FROM (
VALUES%s
) AS t (author_id, name, z);`
)
// psqlEsc escapes a string as a psql string.
func psqlEsc(s string) string {
s = strings.Replace(s, "\n", `\n`, -1)
s = strings.Replace(s, "\r", `\r`, -1)
s = strings.Replace(s, "\t", `\t`, -1)
s = strings.Replace(s, "\b", `\b`, -1)
s = strings.Replace(s, "袈", `\u8888`, -1)
return s
}
// psqlEnc encodes v based on n.
func psqlEnc(n string, v any) string {
if n != "javascript" && n != "slice" {
return "NULL"
}
buf, err := json.MarshalIndent(v, "", " ")
if err != nil {
panic(err)
}
s := strconv.QuoteToASCII(string(buf))
return "E'" + s[1:len(s)-1] + "'"
}
|