File: prepared.go

package info (click to toggle)
golang-github-ziutek-mymysql 1.5.4%2Bgit20170206.23.0582bcf-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye, sid, trixie
  • size: 388 kB
  • sloc: makefile: 8; sh: 2
file content (163 lines) | stat: -rw-r--r-- 3,617 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package native

import (
	"github.com/ziutek/mymysql/mysql"
	"log"
)

type Stmt struct {
	my *Conn

	id  uint32
	sql string // For reprepare during reconnect

	params []paramValue // Parameters binding
	rebind bool
	binded bool

	fields []*mysql.Field

	field_count   int
	param_count   int
	warning_count int
	status        mysql.ConnStatus

	null_bitmap []byte
}

func (stmt *Stmt) Fields() []*mysql.Field {
	return stmt.fields
}

func (stmt *Stmt) NumParam() int {
	return stmt.param_count
}

func (stmt *Stmt) WarnCount() int {
	return stmt.warning_count
}

func (stmt *Stmt) sendCmdExec() {
	// Calculate packet length and NULL bitmap
	pkt_len := 1 + 4 + 1 + 4 + 1 + len(stmt.null_bitmap)
	for ii := range stmt.null_bitmap {
		stmt.null_bitmap[ii] = 0
	}
	for ii, param := range stmt.params {
		par_len := param.Len()
		pkt_len += par_len
		if par_len == 0 {
			null_byte := ii >> 3
			null_mask := byte(1) << uint(ii-(null_byte<<3))
			stmt.null_bitmap[null_byte] |= null_mask
		}
	}
	if stmt.rebind {
		pkt_len += stmt.param_count * 2
	}
	// Reset sequence number
	stmt.my.seq = 0
	// Packet sending
	pw := stmt.my.newPktWriter(pkt_len)
	pw.writeByte(_COM_STMT_EXECUTE)
	pw.writeU32(stmt.id)
	pw.writeByte(0) // flags = CURSOR_TYPE_NO_CURSOR
	pw.writeU32(1)  // iteration_count
	pw.write(stmt.null_bitmap)
	if stmt.rebind {
		pw.writeByte(1)
		// Types
		for _, param := range stmt.params {
			pw.writeU16(param.typ)
		}
	} else {
		pw.writeByte(0)
	}
	// Values
	for i := range stmt.params {
		pw.writeValue(&stmt.params[i])
	}

	if stmt.my.Debug {
		log.Printf("[%2d <-] Exec command packet: len=%d, null_bitmap=%v, rebind=%t",
			stmt.my.seq-1, pkt_len, stmt.null_bitmap, stmt.rebind)
	}

	// Mark that we sended information about binded types
	stmt.rebind = false
}

func (my *Conn) getPrepareResult(stmt *Stmt) interface{} {
loop:
	pr := my.newPktReader() // New reader for next packet
	pkt0 := pr.readByte()

	//log.Println("pkt0:", pkt0, "stmt:", stmt)

	if pkt0 == 255 {
		// Error packet
		my.getErrorPacket(pr)
	}

	if stmt == nil {
		if pkt0 == 0 {
			// OK packet
			return my.getPrepareOkPacket(pr)
		}
	} else {
		unreaded_params := (stmt.param_count < len(stmt.params))
		switch {
		case pkt0 == 254:
			// EOF packet
			stmt.warning_count, stmt.status = my.getEofPacket(pr)
			stmt.my.status = stmt.status
			return stmt

		case pkt0 > 0 && pkt0 < 251 && (stmt.field_count < len(stmt.fields) ||
			unreaded_params):
			// Field packet
			if unreaded_params {
				// Read and ignore parameter field. Sentence from MySQL source:
				/* skip parameters data: we don't support it yet */
				pr.skipAll()
				// Increment param_count count
				stmt.param_count++
			} else {
				field := my.getFieldPacket(pr)
				stmt.fields[stmt.field_count] = field
				// Increment field count
				stmt.field_count++
			}
			// Read next packet
			goto loop
		}
	}
	panic(mysql.ErrUnkResultPkt)
}

func (my *Conn) getPrepareOkPacket(pr *pktReader) (stmt *Stmt) {
	if my.Debug {
		log.Printf("[%2d ->] Perpared OK packet:", my.seq-1)
	}

	stmt = new(Stmt)
	stmt.my = my
	// First byte was readed by getPrepRes
	stmt.id = pr.readU32()
	stmt.fields = make([]*mysql.Field, int(pr.readU16())) // FieldCount
	pl := int(pr.readU16())                               // ParamCount
	if pl > 0 {
		stmt.params = make([]paramValue, pl)
		stmt.null_bitmap = make([]byte, (pl+7)>>3)
	}
	pr.skipN(1)
	stmt.warning_count = int(pr.readU16())
	pr.checkEof()

	if my.Debug {
		log.Printf(tab8s+"ID=0x%x ParamCount=%d FieldsCount=%d WarnCount=%d",
			stmt.id, len(stmt.params), len(stmt.fields), stmt.warning_count,
		)
	}
	return
}