blob: 42b7551e6d3b1402d52afbc5b5ae41fa7142d8d8 [file] [log] [blame]
Julien Schmidtb8ae1f22012-05-04 02:19:161// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2012 Julien Schmidt. All rights reserved.
4// http://www.julienschmidt.com
5//
6// This Source Code Form is subject to the terms of the Mozilla Public
7// License, v. 2.0. If a copy of the MPL was not distributed with this file,
8// You can obtain one at http://mozilla.org/MPL/2.0/.
9package mysql
10
11import (
12 "database/sql/driver"
13 "fmt"
14 "reflect"
15 "time"
16)
17
18type stmtContent struct {
19 mc *mysqlConn
20 id uint32
21 query string
22 paramCount int
23 params []*mysqlField
24 args *[]driver.Value
25 newParamsBound bool
26}
27
28type mysqlStmt struct {
29 *stmtContent
30}
31
32func (stmt mysqlStmt) Close() error {
33 e := stmt.mc.writeCommandPacket(COM_STMT_CLOSE, stmt.id)
34 stmt.params = nil
35 stmt.mc = nil
36 return e
37}
38
39func (stmt mysqlStmt) NumInput() int {
40 return stmt.paramCount
41}
42
43func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
44 stmt.mc.affectedRows = 0
45 stmt.mc.insertId = 0
46
47 // Send command
48 e := stmt.buildExecutePacket(&args)
49 if e != nil {
50 return nil, e
51 }
52
53 // Read Result
54 var resLen int
55 resLen, e = stmt.mc.readResultSetHeaderPacket()
56 if e != nil {
57 return nil, e
58 }
59
60 if resLen > 0 {
61 _, e = stmt.mc.readUntilEOF()
62 if e != nil {
63 return nil, e
64 }
65
66 stmt.mc.affectedRows, e = stmt.mc.readUntilEOF()
67 if e != nil {
68 return nil, e
69 }
70 }
71 if e != nil {
72 return nil, e
73 }
74
75 if stmt.mc.affectedRows == 0 {
76 return driver.ResultNoRows, nil
77 }
78
79 return &mysqlResult{
80 affectedRows: int64(stmt.mc.affectedRows),
81 insertId: int64(stmt.mc.insertId)},
82 nil
83}
84
85func (stmt mysqlStmt) Query(args []driver.Value) (dr driver.Rows, e error) {
86 // Send command
87 e = stmt.buildExecutePacket(&args)
88 if e != nil {
89 return nil, e
90 }
91
92 // Get Result
93 var resLen int
94 rows := new(mysqlRows)
95 rows.content = new(rowsContent)
96 resLen, e = stmt.mc.readResultSetHeaderPacket()
97 if e != nil {
98 return nil, e
99 }
100
101 if resLen > 0 {
102 // Columns
103 rows.content.columns, e = stmt.mc.readColumns(resLen)
104 if e != nil {
105 return
106 }
107
108 // Rows
109 e = stmt.mc.readBinaryRows(rows.content)
110 if e != nil {
111 return
112 }
113 }
114
115 dr = rows
116 return
117}
118
119/* Command Packet
120Bytes Name
121----- ----
1221 code
1234 statement_id
1241 flags
1254 iteration_count
126 if param_count > 0:
127(param_count+7)/8 null_bit_map
1281 new_parameter_bound_flag
129 if new_params_bound == 1:
130n*2 type of parameters
131n values for the parameters
132*/
133func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
134 if len(*args) < stmt.paramCount {
135 return fmt.Errorf(
136 "Not enough Arguments to call STMT_EXEC (Got: %d Has: %d",
137 len(*args),
138 stmt.paramCount)
139 }
140
141 // Reset packet-sequence
142 stmt.mc.sequence = 0
143
144 data := make([]byte, 0, 10)
145
146 // code [1 byte]
147 data = append(data, byte(COM_STMT_EXECUTE))
148
149 // statement_id [4 bytes]
150 data = append(data, uint32ToBytes(stmt.id)...)
151
152 // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
153 data = append(data, byte(0))
154
155 // iteration_count [4 bytes]
156 data = append(data, uint32ToBytes(1)...)
157
158 if stmt.paramCount > 0 {
159 var i int
160
161 // build nullBitMap
162 nullBitMap := make([]byte, (stmt.paramCount+7)/8)
163 bitMask := uint64(0)
164
165 // Check for NULL fields
166 for i = 0; i < stmt.paramCount; i++ {
167 if (*args)[i] == nil {
168 fmt.Println("nil", i, (*args)[i])
169 bitMask += 1 << uint(i)
170 }
171 }
172 // Convert bitMask to bytes
173 for i = 0; i < len(nullBitMap); i++ {
174 nullBitMap[i] = byte(bitMask >> uint(i*8))
175 }
176
177 // append nullBitMap [(param_count+7)/8 bytes]
178 data = append(data, nullBitMap...)
179
180 // Check for changed Params
181 newParamsBound := true
182 if stmt.args != nil {
183 for i := 0; i < len(*args); i++ {
184 if (*args)[i] != (*stmt.args)[i] {
185 fmt.Println((*args)[i], "!=", (*stmt.args)[i])
186 newParamsBound = false
187 break
188 }
189 }
190 }
191
192 // No (new) Parameters bound or rebound
193 if !newParamsBound {
194 //newParameterBoundFlag 0 [1 byte]
195 data = append(data, byte(0))
196 } else {
197 // newParameterBoundFlag 1 [1 byte]
198 data = append(data, byte(1))
199
200 // append types and cache values
201 paramValues := make([]byte, 0)
202 var pv reflect.Value
203 for i = 0; i < stmt.paramCount; i++ {
204 switch (*args)[i].(type) {
205 case nil:
206 data = append(data, []byte{
207 byte(FIELD_TYPE_NULL),
208 0x0}...)
209 continue
210 case []byte:
211 fmt.Println("[]byte", (*args)[i])
212 case time.Time:
213 fmt.Println("time.Time", (*args)[i])
214 }
215
216 pv = reflect.ValueOf((*args)[i])
217 switch pv.Kind() {
218 case reflect.Int64:
219 data = append(data, []byte{
220 byte(FIELD_TYPE_LONGLONG),
221 0x0}...)
222 paramValues = append(paramValues, int64ToBytes(pv.Int())...)
223 fmt.Println("int64", (*args)[i])
224
225 case reflect.Float64:
226 fmt.Println("float64", (*args)[i])
227
228 case reflect.Bool:
229 data = append(data, []byte{
230 byte(FIELD_TYPE_TINY),
231 0x0}...)
232 val := pv.Bool()
233 if val {
234 paramValues = append(paramValues, byte(1))
235 } else {
236 paramValues = append(paramValues, byte(0))
237 }
238 fmt.Println("bool", (*args)[i])
239
240 case reflect.String:
241 data = append(data, []byte{
242 byte(FIELD_TYPE_STRING),
243 0x0}...)
244 val := pv.String()
245 paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
246 paramValues = append(paramValues, []byte(val)...)
247 fmt.Println("string", string([]byte(val)))
248
249 default:
250 return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
251 }
252 }
253
254 // append cached values
255 data = append(data, paramValues...)
256 fmt.Println("data", string(data))
257 }
258
259 // Save args
260 stmt.args = args
261 }
262 return stmt.mc.writePacket(data)
263}
264
265// ColumnConverter returns a ValueConverter for the provided
266// column index. If the type of a specific column isn't known
267// or shouldn't be handled specially, DefaultValueConverter
268// can be returned.
269func (stmt mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
270 debug(fmt.Sprintf("ColumnConverter(%d)", idx))
271 return driver.DefaultParameterConverter
272}