dependency upgrades for v2.18 release cycle (#2314)
Some checks failed
build / build (push) Has been cancelled
ghcr / Build (push) Has been cancelled

This commit is contained in:
Shivaram Lingamneni 2025-12-23 00:07:39 -05:00 committed by GitHub
parent 462e568f00
commit 748700877e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
134 changed files with 7198 additions and 1595 deletions

View file

@ -13,29 +13,41 @@
Aaron Hopkins <go-sql-driver at die.net>
Achille Roussel <achille.roussel at gmail.com>
Aidan <aidan.liu at pingcap.com>
Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Artur Melanchyk <artur.melanchyk@gmail.com>
Asta Xie <xiemengjun at gmail.com>
B Lamarche <blam413 at gmail.com>
Bes Dollma <bdollma@thousandeyes.com>
Bogdan Constantinescu <bog.con.bc at gmail.com>
Brad Higgins <brad at defined.net>
Brian Hendriks <brian at dolthub.com>
Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Kirkland <chriskirkland at github.com>
Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com>
Daemonxiao <735462752 at qq.com>
Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Elias <evan at skeema.net>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Gusted <postmaster at gusted.xyz>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
@ -45,13 +57,18 @@ ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
Jakub Adamus <kratky at zobak.cz>
James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net>
Jason Ng <oblitorum at gmail.com>
Jean-Yves Pellé <jy at pelle.link>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Jennifer Purevsuren <jennifer at dolthub.com>
Jerome Meyer <jxmeyer at gmail.com>
Jiajia Zhong <zhong2plus at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joe Mann <contact at joemann.co.uk>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
@ -72,17 +89,23 @@ Lunny Xiao <xiaolunwen at gmail.com>
Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com>
Nao Yokotsuka <yokotukanao at gmail.com>
Nathanial Murphy <nathanial.murphy at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
Oliver Bone <owbone at github.com>
Olivier Mengué <dolmen at cpan.org>
oscarzhao <oscarzhaosl at gmail.com>
Paul Bonser <misterpib at gmail.com>
Paulius Lozys <pauliuslozys at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com>
Phil Porada <philporada at gmail.com>
Minh Quang <minhquang4334 at gmail.com>
Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com>
Samantha Frank <hello at entropy.cat>
Santhosh Kumar Tekuri <santhosh.tekuri at gmail.com>
Sho Iizuka <sho.i518 at gmail.com>
Sho Ikeda <suicaicoca at gmail.com>
@ -93,6 +116,7 @@ Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Tan Jinhua <312841925 at qq.com>
Tetsuro Aoki <t.aoki1130 at gmail.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
@ -102,6 +126,7 @@ Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Xuehong Chan <chanxuehong at gmail.com>
Zhang Xiang <angwerzx at 126.com>
Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com>
Ziheng Lyu <zihenglv at gmail.com>
@ -110,15 +135,21 @@ Ziheng Lyu <zihenglv at gmail.com>
Barracuda Networks, Inc.
Counting Ltd.
Defined Networking Inc.
DigitalOcean Inc.
Dolthub Inc.
dyves labs AG
Facebook Inc.
GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Microsoft Corp.
Multiplay Ltd.
Percona LLC
PingCAP Inc.
Pivotal Inc.
Shattered Silicon Ltd.
Stripe Inc.
ThousandEyes
Zendesk Inc.

View file

@ -1,3 +1,110 @@
# Changelog
## v1.9.3 (2025-06-13)
* `tx.Commit()` and `tx.Rollback()` returned `ErrInvalidConn` always.
Now they return cached real error if present. (#1690)
* Optimize reading small resultsets to fix performance regression
introduced by compression protocol support. (#1707)
* Fix `db.Ping()` on compressed connection. (#1723)
## v1.9.2 (2025-04-07)
v1.9.2 is a re-release of v1.9.1 due to a release process issue; no changes were made to the content.
## v1.9.1 (2025-03-21)
### Major Changes
* Add Charset() option. (#1679)
### Bugfixes
* go.mod: fix go version format (#1682)
* Fix FormatDSN missing ConnectionAttributes (#1619)
## v1.9.0 (2025-02-18)
### Major Changes
- Implement zlib compression. (#1487)
- Supported Go version is updated to Go 1.21+. (#1639)
- Add support for VECTOR type introduced in MySQL 9.0. (#1609)
- Config object can have custom dial function. (#1527)
### Bugfixes
- Fix auth errors when username/password are too long. (#1625)
- Check if MySQL supports CLIENT_CONNECT_ATTRS before sending client attributes. (#1640)
- Fix auth switch request handling. (#1666)
### Other changes
- Add "filename:line" prefix to log in go-mysql. Custom loggers now show it. (#1589)
- Improve error handling. It reduces the "busy buffer" errors. (#1595, #1601, #1641)
- Use `strconv.Atoi` to parse max_allowed_packet. (#1661)
- `rejectReadOnly` option now handles ER_READ_ONLY_MODE (1290) error too. (#1660)
## Version 1.8.1 (2024-03-26)
Bugfixes:
- fix race condition when context is canceled in [#1562](https://github.com/go-sql-driver/mysql/pull/1562) and [#1570](https://github.com/go-sql-driver/mysql/pull/1570)
## Version 1.8.0 (2024-03-09)
Major Changes:
- Use `SET NAMES charset COLLATE collation`. by @methane in [#1437](https://github.com/go-sql-driver/mysql/pull/1437)
- Older go-mysql-driver used `collation_id` in the handshake packet. But it caused collation mismatch in some situation.
- If you don't specify charset nor collation, go-mysql-driver sends `SET NAMES utf8mb4` for new connection. This uses server's default collation for utf8mb4.
- If you specify charset, go-mysql-driver sends `SET NAMES <charset>`. This uses the server's default collation for `<charset>`.
- If you specify collation and/or charset, go-mysql-driver sends `SET NAMES charset COLLATE collation`.
- PathEscape dbname in DSN. by @methane in [#1432](https://github.com/go-sql-driver/mysql/pull/1432)
- This is backward incompatible in rare case. Check your DSN.
- Drop Go 1.13-17 support by @methane in [#1420](https://github.com/go-sql-driver/mysql/pull/1420)
- Use Go 1.18+
- Parse numbers on text protocol too by @methane in [#1452](https://github.com/go-sql-driver/mysql/pull/1452)
- When text protocol is used, go-mysql-driver passed bare `[]byte` to database/sql for avoid unnecessary allocation and conversion.
- If user specified `*any` to `Scan()`, database/sql passed the `[]byte` into the target variable.
- This confused users because most user doesn't know when text/binary protocol used.
- go-mysql-driver 1.8 converts integer/float values into int64/double even in text protocol. This doesn't increase allocation compared to `[]byte` and conversion cost is negatable.
- New options start using the Functional Option Pattern to avoid increasing technical debt in the Config object. Future version may introduce Functional Option for existing options, but not for now.
- Make TimeTruncate functional option by @methane in [1552](https://github.com/go-sql-driver/mysql/pull/1552)
- Add BeforeConnect callback to configuration object by @ItalyPaleAle in [#1469](https://github.com/go-sql-driver/mysql/pull/1469)
Other changes:
- Adding DeregisterDialContext to prevent memory leaks with dialers we don't need anymore by @jypelle in https://github.com/go-sql-driver/mysql/pull/1422
- Make logger configurable per connection by @frozenbonito in https://github.com/go-sql-driver/mysql/pull/1408
- Fix ColumnType.DatabaseTypeName for mediumint unsigned by @evanelias in https://github.com/go-sql-driver/mysql/pull/1428
- Add connection attributes by @Daemonxiao in https://github.com/go-sql-driver/mysql/pull/1389
- Stop `ColumnTypeScanType()` from returning `sql.RawBytes` by @methane in https://github.com/go-sql-driver/mysql/pull/1424
- Exec() now provides access to status of multiple statements. by @mherr-google in https://github.com/go-sql-driver/mysql/pull/1309
- Allow to change (or disable) the default driver name for registration by @dolmen in https://github.com/go-sql-driver/mysql/pull/1499
- Add default connection attribute '_server_host' by @oblitorum in https://github.com/go-sql-driver/mysql/pull/1506
- QueryUnescape DSN ConnectionAttribute value by @zhangyangyu in https://github.com/go-sql-driver/mysql/pull/1470
- Add client_ed25519 authentication by @Gusted in https://github.com/go-sql-driver/mysql/pull/1518
## Version 1.7.1 (2023-04-25)
Changes:
- bump actions/checkout@v3 and actions/setup-go@v3 (#1375)
- Add go1.20 and mariadb10.11 to the testing matrix (#1403)
- Increase default maxAllowedPacket size. (#1411)
Bugfixes:
- Use SET syntax as specified in the MySQL documentation (#1402)
## Version 1.7 (2022-11-29)
Changes:
@ -149,7 +256,7 @@ New Features:
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
- Support for returning table alias on Columns() (#289, #359, #382)
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490)
- Placeholder interpolation, can be activated with the DSN parameter `interpolateParams=true` (#309, #318, #490)
- Support for uint64 parameters with high bit set (#332, #345)
- Cleartext authentication plugin support (#327)
- Exported ParseDSN function and the Config struct (#403, #419, #429)
@ -193,7 +300,7 @@ Changes:
- Also exported the MySQLWarning type
- mysqlConn.Close returns the first error encountered instead of ignoring all errors
- writePacket() automatically writes the packet size to the header
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets
- readPacket() uses an iterative approach instead of the recursive approach to merge split packets
New Features:
@ -241,7 +348,7 @@ Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly
- Split packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values

View file

@ -38,17 +38,26 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
* Optional `time.Time` parsing
* Optional placeholder interpolation
* Supports zlib compression.
## Requirements
* Go 1.13 or higher. We aim to support the 3 latest versions of Go.
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
* Go 1.21 or higher. We aim to support the 3 latest versions of Go.
* MySQL (5.7+) and MariaDB (10.5+) are supported.
* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP.
* Do not ask questions about TiDB in our issue tracker or forum.
* [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang)
* [Forum](https://ask.pingcap.com/)
* go-mysql would work with Percona Server, Google CloudSQL or Sphinx (2.2.3+).
* Maintainers won't support them. Do not expect issues are investigated and resolved by maintainers.
* Investigate issues yourself and please send a pull request to fix it.
---------------------------------------
## Installation
Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell:
```bash
$ go get -u github.com/go-sql-driver/mysql
go get -u github.com/go-sql-driver/mysql
```
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
@ -114,6 +123,12 @@ This has the same effect as an empty DSN string:
```
`dbname` is escaped by [PathEscape()](https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes:
```
/dbname%2Fwithslash
```
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
#### Password
@ -121,7 +136,7 @@ Passwords can consist of any character. Escaping is **not** necessary.
#### Protocol
See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available.
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
In general you should use a Unix domain socket if available and TCP otherwise for best performance.
#### Address
For TCP and UDP networks, addresses have the form `host[:port]`.
@ -145,7 +160,7 @@ Default: false
```
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
[*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)
##### `allowCleartextPasswords`
@ -194,10 +209,9 @@ Valid Values: <name>
Default: none
```
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset fails. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
Usage of the `charset` parameter is discouraged because it issues additional queries to the server.
Unless you need the fallback behavior, please use `collation` instead.
See also [Unicode Support](#unicode-support).
##### `checkConnLiveness`
@ -226,6 +240,7 @@ The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You s
Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)).
See also [Unicode Support](#unicode-support).
##### `clientFoundRows`
@ -253,6 +268,16 @@ SELECT u.id FROM users as u
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
##### `compress`
```
Type: bool
Valid Values: true, false
Default: false
```
Toggles zlib compression. false by default.
##### `interpolateParams`
```
@ -279,13 +304,22 @@ Note that this sets the location for time.Time values but does not change MySQL'
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
##### `timeTruncate`
```
Type: duration
Default: 0
```
[Truncate time values](https://pkg.go.dev/time#Duration.Truncate) to the specified duration. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `maxAllowedPacket`
```
Type: decimal number
Default: 4194304
Default: 64*1024*1024
```
Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
##### `multiStatements`
@ -295,9 +329,25 @@ Valid Values: true, false
Default: false
```
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
Allow multiple statements in one query. This can be used to bach multiple queries. Use [Rows.NextResultSet()](https://pkg.go.dev/database/sql#Rows.NextResultSet) to get result of the second and subsequent queries.
When `multiStatements` is used, `?` parameters must only be used in the first statement.
When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
```go
conn, _ := db.Conn(ctx)
conn.Raw(func(conn any) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
UPDATE point SET x = 1 WHERE y = 2;
UPDATE point SET x = 2 WHERE y = 3;
`, nil)
// Both slices have 2 elements.
log.Print(res.(mysql.Result).AllRowsAffected())
log.Print(res.(mysql.Result).AllLastInsertIds())
})
```
##### `parseTime`
@ -393,6 +443,15 @@ Default: 0
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `connectionAttributes`
```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```
[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.
##### System Variables
@ -465,12 +524,15 @@ user:password@/
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
## `ColumnType` Support
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`.
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`.
## `context.Context` Support
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
> [!IMPORTANT]
> The `QueryContext`, `ExecContext`, etc. variants provided by `database/sql` will cause the connection to be closed if the provided context is cancelled or timed out before the result is received by the driver.
### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
@ -478,7 +540,7 @@ For this feature you need direct access to the package. Therefore you must chang
import "github.com/go-sql-driver/mysql"
```
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)).
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
@ -496,9 +558,11 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v
### Unicode support
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
Other collations / charsets can be set using the [`collation`](#collation) DSN parameter.
Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter.
Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default.
- When only the `charset` is specified, the `SET NAMES <charset>` query is sent and the server's default collation is used.
- When both the `charset` and `collation` are specified, the `SET NAMES <charset> COLLATE <collation>` query is sent.
- When only the `collation` is specified, the collation is specified in the protocol handshake and the `SET NAMES` query is not sent. This can save one roundtrip, but note that the server may ignore the specified collation silently and use the server's default charset/collation instead.
See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support.

View file

@ -1,19 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.19
// +build go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
type atomicBool = atomic.Bool

View file

@ -1,47 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !go1.19
// +build !go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
// atomicBool is an implementation of atomic.Bool for older version of Go.
// it is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_ noCopy
value uint32
}
// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}

View file

@ -13,10 +13,13 @@ import (
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"
"filippo.io/edwards25519"
)
// server pub keys registry
@ -33,7 +36,7 @@ var (
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified.
//
// data, err := ioutil.ReadFile("mykey.pem")
// data, err := os.ReadFile("mykey.pem")
// if err != nil {
// log.Fatal(err)
// }
@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}
// authEd25519 does ed25519 authentication used by MariaDB.
func authEd25519(scramble []byte, password string) ([]byte, error) {
// Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c
// Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207
h := sha512.Sum512([]byte(password))
s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
}
A := (&edwards25519.Point{}).ScalarBaseMult(s)
mh := sha512.New()
mh.Write(h[32:])
mh.Write(scramble)
messageDigest := mh.Sum(nil)
r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest)
if err != nil {
return nil, err
}
R := (&edwards25519.Point{}).ScalarBaseMult(r)
kh := sha512.New()
kh.Write(R.Bytes())
kh.Write(A.Bytes())
kh.Write(scramble)
hramDigest := kh.Sum(nil)
k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest)
if err != nil {
return nil, err
}
S := k.MultiplyAdd(k, s, r)
return append(R.Bytes(), S.Bytes()...), nil
}
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
@ -290,8 +331,14 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err
case "client_ed25519":
if len(authData) != 32 {
return nil, ErrMalformPkt
}
return authEd25519(authData, mc.cfg.Passwd)
default:
errLog.Print("unknown auth plugin:", plugin)
mc.log("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
}
}
@ -338,7 +385,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
switch plugin {
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
// https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/
case "caching_sha2_password":
switch len(authData) {
case 0:
@ -346,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
if err = mc.resultUnchanged().readResultOK(); err == nil {
return nil // auth successful
}
@ -376,13 +423,13 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
}
if data[0] != iAuthMoreData {
return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication")
return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")
}
// parse public key
block, rest := pem.Decode(data[1:])
if block == nil {
return fmt.Errorf("No Pem data found, data: %s", rest)
return fmt.Errorf("no pem data found, data: %s", rest)
}
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
@ -397,7 +444,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err
}
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()
default:
return ErrMalformPkt
@ -426,7 +473,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
if err != nil {
return err
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()
}
default:

View file

@ -10,54 +10,47 @@ package mysql
import (
"io"
"net"
"time"
)
const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024
// readerFunc is a function that compatible with io.Reader.
// We use this function type instead of io.Reader because we want to
// just pass mc.readWithTimeout.
type readerFunc func([]byte) (int, error)
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
buf []byte // read buffer.
cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize.
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
fg := make([]byte, defaultBufSize)
func newBuffer() buffer {
return buffer{
buf: fg,
nc: nc,
dbuf: [2][]byte{fg, nil},
cachedBuf: make([]byte, defaultBufSize),
}
}
// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
func (b *buffer) flip() {
b.flipcnt += 1
// busy returns true if the read buffer is not empty.
func (b *buffer) busy() bool {
return len(b.buf) > 0
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// fill data into its double-buffering target: if we've called
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]
// len returns how many bytes in the read buffer.
func (b *buffer) len() int {
return len(b.buf)
}
// fill reads into the read buffer until at least _need_ bytes are in it.
func (b *buffer) fill(need int, r readerFunc) error {
// we'll move the contents of the current buffer to dest before filling it.
dest := b.cachedBuf
// grow buffer if necessary to fit the whole packet.
if need > len(dest) {
@ -67,64 +60,41 @@ func (b *buffer) fill(need int) error {
// if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
b.cachedBuf = dest
}
}
// if we're filling the fg buffer, move the existing data to the start of it.
// if we're filling the bg buffer, copy over the data
if n > 0 {
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
// move the existing data to the start of the buffer.
n := len(b.buf)
copy(dest[:n], b.buf)
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
nn, err := r(dest[n:])
n += nn
switch err {
case nil:
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
if err == nil && n < need {
continue
}
b.buf = dest[:n]
if err == io.EOF {
if n < need {
err = io.ErrUnexpectedEOF
} else {
err = nil
}
}
return err
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
func (b *buffer) readNext(need int) []byte {
data := b.buf[:need:need]
b.buf = b.buf[need:]
return data
}
// takeBuffer returns a buffer with the requested size.
@ -132,18 +102,18 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
if b.busy() {
return nil, ErrBusyBuffer
}
// test (cheap) general case first
if length <= cap(b.buf) {
return b.buf[:length], nil
if length <= len(b.cachedBuf) {
return b.cachedBuf[:length], nil
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf, nil
if length < maxCachedBufSize {
b.cachedBuf = make([]byte, length)
return b.cachedBuf, nil
}
// buffer is larger than we want to store.
@ -154,10 +124,10 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) {
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
if b.busy() {
return nil, ErrBusyBuffer
}
return b.buf[:length], nil
return b.cachedBuf[:length], nil
}
// takeCompleteBuffer returns the complete existing buffer.
@ -165,18 +135,15 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
if b.busy() {
return nil, ErrBusyBuffer
}
return b.buf, nil
return b.cachedBuf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
func (b *buffer) store(buf []byte) {
if cap(buf) <= maxCachedBufSize && cap(buf) > cap(b.cachedBuf) {
b.cachedBuf = buf[:cap(buf)]
}
return nil
}

View file

@ -8,8 +8,8 @@
package mysql
const defaultCollation = "utf8mb4_general_ci"
const binaryCollation = "binary"
const defaultCollationID = 45 // utf8mb4_general_ci
const binaryCollationID = 63
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:

213
vendor/github.com/go-sql-driver/mysql/compress.go generated vendored Normal file
View file

@ -0,0 +1,213 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
)
var (
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
)
func init() {
zrPool = &sync.Pool{
New: func() any { return nil },
}
zwPool = &sync.Pool{
New: func() any {
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
if err != nil {
panic(err) // compress/zlib return non-nil error only if level is invalid
}
return zw
},
}
}
func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
br := bytes.NewReader(src)
var zr io.ReadCloser
var err error
if a := zrPool.Get(); a == nil {
if zr, err = zlib.NewReader(br); err != nil {
return 0, err
}
} else {
zr = a.(io.ReadCloser)
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
return 0, err
}
}
n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
err = zr.Close() // zr.Close() may return chuecksum error.
zrPool.Put(zr)
return int(n), err
}
func zCompress(src []byte, dst io.Writer) error {
zw := zwPool.Get().(*zlib.Writer)
zw.Reset(dst)
if _, err := zw.Write(src); err != nil {
return err
}
err := zw.Close()
zwPool.Put(zw)
return err
}
type compIO struct {
mc *mysqlConn
buff bytes.Buffer
}
func newCompIO(mc *mysqlConn) *compIO {
return &compIO{
mc: mc,
}
}
func (c *compIO) reset() {
c.buff.Reset()
}
func (c *compIO) readNext(need int) ([]byte, error) {
for c.buff.Len() < need {
if err := c.readCompressedPacket(); err != nil {
return nil, err
}
}
data := c.buff.Next(need)
return data[:need:need], nil // prevent caller writes into c.buff
}
func (c *compIO) readCompressedPacket() error {
header, err := c.mc.readNext(7)
if err != nil {
return err
}
_ = header[6] // bounds check hint to compiler; guaranteed by readNext
// compressed header structure
comprLength := getUint24(header[0:3])
compressionSequence := header[3]
uncompressedLength := getUint24(header[4:7])
if debug {
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
}
// Do not return ErrPktSync here.
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
// before receiving all packets from client. In this case, seqnr is younger than expected.
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
if debug && compressionSequence != c.mc.compressSequence {
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
c.mc.compressSequence, compressionSequence)
}
c.mc.compressSequence = compressionSequence + 1
comprData, err := c.mc.readNext(comprLength)
if err != nil {
return err
}
// if payload is uncompressed, its length will be specified as zero, and its
// true length is contained in comprLength
if uncompressedLength == 0 {
c.buff.Write(comprData)
return nil
}
// use existing capacity in bytesBuf if possible
c.buff.Grow(uncompressedLength)
nread, err := zDecompress(comprData, &c.buff)
if err != nil {
return err
}
if nread != uncompressedLength {
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
uncompressedLength, nread)
}
return nil
}
const minCompressLength = 150
const maxPayloadLen = maxPacketSize - 4
// writePackets sends one or some packets with compression.
// Use this instead of mc.netConn.Write() when mc.compress is true.
func (c *compIO) writePackets(packets []byte) (int, error) {
totalBytes := len(packets)
blankHeader := make([]byte, 7)
buf := &c.buff
for len(packets) > 0 {
payloadLen := min(maxPayloadLen, len(packets))
payload := packets[:payloadLen]
uncompressedLen := payloadLen
buf.Reset()
buf.Write(blankHeader) // Buffer.Write() never returns error
// If payload is less than minCompressLength, don't compress.
if uncompressedLen < minCompressLength {
buf.Write(payload)
uncompressedLen = 0
} else {
err := zCompress(payload, buf)
if debug && err != nil {
fmt.Printf("zCompress error: %v", err)
}
// do not compress if compressed data is larger than uncompressed data
// I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes.
if err != nil || buf.Len() >= uncompressedLen {
buf.Reset()
buf.Write(blankHeader)
buf.Write(payload)
uncompressedLen = 0
}
}
if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
// To allow returning ErrBadConn when sending really 0 bytes, we sum
// up compressed bytes that is returned by underlying Write().
return totalBytes - len(packets) + n, err
}
packets = packets[payloadLen:]
}
return totalBytes, nil
}
// writeCompressedPacket writes a compressed packet with header.
// data should start with 7 size space for header followed by payload.
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
mc := c.mc
comprLength := len(data) - 7
if debug {
fmt.Printf(
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v\n",
comprLength, uncompressedLen, mc.compressSequence)
}
// compression header
putUint24(data[0:3], comprLength)
data[3] = mc.compressSequence
putUint24(data[4:7], uncompressedLen)
mc.compressSequence++
return mc.writeWithTimeout(data)
}

View file

@ -13,28 +13,32 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"io"
"net"
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"
)
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
compIO *compIO
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
compressSequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
compress bool
// for context support (Go 1.8+)
watching bool
@ -42,61 +46,92 @@ type mysqlConn struct {
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
closed atomic.Bool // set when conn is closed, before closech is closed
}
// Helper function to call per-connection logger.
func (mc *mysqlConn) log(v ...any) {
_, filename, lineno, ok := runtime.Caller(1)
if ok {
pos := strings.LastIndexByte(filename, '/')
if pos != -1 {
filename = filename[pos+1:]
}
prefix := fmt.Sprintf("%s:%d ", filename, lineno)
v = append([]any{prefix}, v...)
}
mc.cfg.Logger.Print(v...)
}
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
to := mc.cfg.ReadTimeout
if to > 0 {
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Read(b)
}
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
to := mc.cfg.WriteTimeout
if to > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Write(b)
}
func (mc *mysqlConn) resetSequence() {
mc.sequence = 0
mc.compressSequence = 0
}
// syncSequence must be called when finished writing some packet and before start reading.
func (mc *mysqlConn) syncSequence() {
// Syncs compressionSequence to sequence.
// This is not documented but done in `net_flush()` in MySQL and MariaDB.
// https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171
// https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293
if mc.compress {
mc.sequence = mc.compressSequence
mc.compIO.reset()
}
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
// Charset: character_set_connection, character_set_client, character_set_results
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// Other system vars accumulated in a single SET command
default:
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteByte(',')
}
cmdSet.WriteString(param)
cmdSet.WriteByte('=')
cmdSet.WriteString(val)
for param, val := range mc.cfg.Params {
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteString(", ")
}
cmdSet.WriteString(param)
cmdSet.WriteString(" = ")
cmdSet.WriteString(val)
}
if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String())
if err != nil {
return
}
}
return
}
// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn.
// This function is used to return driver.ErrBadConn only when safe to retry.
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
if err == errBadConnNoWrite {
return driver.ErrBadConn
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
return err
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
@ -105,7 +140,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
@ -126,12 +160,16 @@ func (mc *mysqlConn) Close() (err error) {
if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
mc.close()
return
}
// close closes the network connection and clear results without sending COM_QUIT.
func (mc *mysqlConn) close() {
mc.cleanup()
mc.clearResult()
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
@ -143,12 +181,16 @@ func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
conn := mc.rawConn
if conn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
if err := conn.Close(); err != nil {
mc.log("closing connection:", err)
}
// This function can be called from multiple goroutines.
// So we can not mc.clearResult() here.
// Caller should do it if they are in safe goroutine.
}
func (mc *mysqlConn) error() error {
@ -163,14 +205,13 @@ func (mc *mysqlConn) error() error {
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
errLog.Print(err)
mc.log(err)
return nil, driver.ErrBadConn
}
@ -204,8 +245,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(err)
return "", ErrInvalidConn
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
argPos := 0
@ -246,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
@ -296,7 +339,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
@ -310,28 +352,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
@ -348,7 +387,7 @@ func (mc *mysqlConn) exec(query string) error {
}
}
return mc.discardResults()
return handleOk.discardResults()
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@ -356,8 +395,9 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
}
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
@ -373,43 +413,47 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if err != nil {
return nil, mc.markBadConn(err)
}
if resLen == 0 {
rows.rs.done = true
// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
rows := new(textRows)
rows.mc = mc
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return nil, mc.markBadConn(err)
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
@ -430,7 +474,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
return nil, err
}
// finish is called when the query has canceled.
// cancel is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
@ -451,7 +495,6 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
@ -460,11 +503,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
}
defer mc.finish()
handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}
return mc.readResultOK()
return handleOk.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
@ -636,15 +680,42 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() {
if mc.closed.Load() || mc.buf.busy() {
return driver.ErrBadConn
}
mc.reset = true
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.cfg.CheckConnLiveness {
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
if err != nil {
mc.log("closing bad idle connection: ", err)
return driver.ErrBadConn
}
}
return nil
}
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
return !mc.closed.Load() && !mc.buf.busy()
}
var _ driver.SessionResetter = &mysqlConn{}
var _ driver.Validator = &mysqlConn{}

View file

@ -11,11 +11,55 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
"strings"
)
type connector struct {
cfg *Config // immutable private copy.
cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}
func encodeConnectionAttributes(cfg *Config) string {
connAttrsBuf := make([]byte, 0)
// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
serverHost, _, _ := net.SplitHostPort(cfg.Addr)
if serverHost != "" {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
}
// user-defined connection attributes
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
}
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
return string(connAttrsBuf)
}
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}
}
// Connect implements driver.Connector interface.
@ -23,43 +67,56 @@ type connector struct {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
// Invoke beforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.beforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.beforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
if c.cfg.DialFunc != nil {
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
} else {
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{}
mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr)
}
}
if err != nil {
return nil, err
}
mc.rawConn = mc.netConn
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
c.cfg.Logger.Print(err)
}
}
@ -71,11 +128,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
defer mc.finish()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
mc.buf = newBuffer()
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
@ -92,7 +145,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
@ -114,6 +167,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
mc.compress = true
mc.compIO = newCompIO(mc)
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
@ -123,12 +180,36 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
n, err := strconv.Atoi(string(maxap))
if err != nil {
mc.Close()
return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err)
}
mc.maxAllowedPacket = n - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Charset: character_set_connection, character_set_client, character_set_results
if len(mc.cfg.charsets) > 0 {
for _, cs := range mc.cfg.charsets {
// ignore errors here - a charset may not exist
if mc.cfg.Collation != "" {
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
} else {
err = mc.exec("SET NAMES " + cs)
}
if err == nil {
break
}
}
if err != nil {
mc.Close()
return nil, err
}
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {

View file

@ -8,12 +8,27 @@
package mysql
import "runtime"
const (
debug = false // for debugging. Set true only in development.
defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999"
// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "Go-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
connAttrServerHost = "_server_host"
)
// MySQL constants documentation:
@ -112,7 +127,10 @@ const (
fieldTypeBit
)
const (
fieldTypeJSON fieldType = iota + 0xf5
fieldTypeVector fieldType = iota + 0xf2
fieldTypeInvalid
fieldTypeBool
fieldTypeJSON
fieldTypeNewDecimal
fieldTypeEnum
fieldTypeSet

View file

@ -55,6 +55,15 @@ func RegisterDialContext(net string, dial DialContextFunc) {
dials[net] = dial
}
// DeregisterDialContext removes the custom dial function registered with the given net.
func DeregisterDialContext(net string) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials != nil {
delete(dials, net)
}
}
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
@ -74,14 +83,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
}
c := newConnector(cfg)
return c.Connect(context.Background())
}
// This variable can be replaced with -ldflags like below:
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
var driverName = "mysql"
func init() {
sql.Register("mysql", &MySQLDriver{})
if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
}
// NewConnector returns new driver.Connector.
@ -92,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return &connector{cfg: cfg}, nil
return newConnector(cfg), nil
}
// OpenConnector implements driver.DriverContext.
@ -101,7 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
return newConnector(cfg), nil
}

View file

@ -10,6 +10,7 @@ package mysql
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"errors"
@ -34,22 +35,29 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
// non boolean fields
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation. When set, this will be set in SET NAMES <charset> COLLATE <collation> query
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
// DialFunc specifies the dial function for creating connections
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// boolean fields
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
@ -63,17 +71,83 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
// unexported fields. new options should be come here.
// boolean first. alphabetical order.
compress bool // Enable zlib compression
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
charsets []string // Connection charset. When set, this will be set in SET NAMES <charset> query
}
// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error
// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
return &Config{
Collation: defaultCollation,
cfg := &Config{
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
}
return cfg
}
// Apply applies the given options to the Config object.
func (c *Config) Apply(opts ...Option) error {
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}
// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) Option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}
// BeforeConnect sets the function to be invoked before a connection is established.
func BeforeConnect(fn func(context.Context, *Config) error) Option {
return func(cfg *Config) error {
cfg.beforeConnect = fn
return nil
}
}
// EnableCompress sets the compression mode.
func EnableCompression(yes bool) Option {
return func(cfg *Config) error {
cfg.compress = yes
return nil
}
}
// Charset sets the connection charset and collation.
//
// charset is the connection charset.
// collation is the connection collation. It can be null or empty string.
//
// When collation is not specified, `SET NAMES <charset>` command is sent when the connection is established.
// When collation is specified, `SET NAMES <charset> COLLATE <collation>` command is sent when the connection is established.
func Charset(charset, collation string) Option {
return func(cfg *Config) error {
cfg.charsets = []string{charset}
cfg.Collation = collation
return nil
}
}
func (cfg *Config) Clone() *Config {
@ -97,7 +171,7 @@ func (cfg *Config) Clone() *Config {
}
func (cfg *Config) normalize() error {
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] {
return errInvalidDSNUnsafeCollation
}
@ -153,6 +227,10 @@ func (cfg *Config) normalize() error {
}
}
if cfg.Logger == nil {
cfg.Logger = defaultLogger
}
return nil
}
@ -171,6 +249,8 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
//
// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config].
func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer
@ -196,7 +276,7 @@ func (cfg *Config) FormatDSN() string {
// /dbname
buf.WriteByte('/')
buf.WriteString(cfg.DBName)
buf.WriteString(url.PathEscape(cfg.DBName))
// [?param1=value1&...&paramN=valueN]
hasParam := false
@ -230,7 +310,11 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
}
if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
if charsets := cfg.charsets; len(charsets) > 0 {
writeDSNParam(&buf, &hasParam, "charset", strings.Join(charsets, ","))
}
if col := cfg.Collation; col != "" {
writeDSNParam(&buf, &hasParam, "collation", col)
}
@ -238,6 +322,14 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
}
if cfg.ConnectionAttributes != "" {
writeDSNParam(&buf, &hasParam, "connectionAttributes", url.QueryEscape(cfg.ConnectionAttributes))
}
if cfg.compress {
writeDSNParam(&buf, &hasParam, "compress", "true")
}
if cfg.InterpolateParams {
writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
}
@ -254,6 +346,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "parseTime", "true")
}
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}
if cfg.ReadTimeout > 0 {
writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
}
@ -358,7 +454,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
break
}
}
cfg.DBName = dsn[i+1 : j]
dbname := dsn[i+1 : j]
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
}
break
}
@ -378,13 +478,13 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
key, value, found := strings.Cut(v, "=")
if !found {
continue
}
// cfg params
switch value := param[1]; param[0] {
switch key {
// Disable INFILE allowlist / enable all files
case "allowAllFiles":
var isBool bool
@ -441,6 +541,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value)
}
// charset
case "charset":
cfg.charsets = strings.Split(value, ",")
// Collation
case "collation":
cfg.Collation = value
@ -454,7 +558,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// Compression
case "compress":
return errors.New("compression not implemented yet")
var isBool bool
cfg.compress, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Enable client side placeholder substitution
case "interpolateParams":
@ -490,6 +598,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value)
}
// time.Time truncation
case "timeTruncate":
cfg.timeTruncate, err = time.ParseDuration(value)
if err != nil {
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}
// I/O read Timeout
case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value)
@ -554,13 +669,22 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}
// Connection attributes
case "connectionAttributes":
connectionAttributes, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid connectionAttributes value: %v", err)
}
cfg.ConnectionAttributes = connectionAttributes
default:
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
if cfg.Params[key], err = url.QueryUnescape(value); err != nil {
return
}
}

View file

@ -21,36 +21,42 @@ var (
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
ErrBusyBuffer = errors.New("busy buffer")
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend.
// to trigger a resend. Use mc.markBadConn(err) to do this.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))
// Logger is used to log critical error messages.
type Logger interface {
Print(v ...interface{})
Print(v ...any)
}
// SetLogger is used to set the logger for critical errors.
// NopLogger is a nop implementation of the Logger interface.
type NopLogger struct{}
// Print implements Logger interface.
func (nl *NopLogger) Print(_ ...any) {}
// SetLogger is used to set the default logger for critical errors.
// The initial logger is os.Stderr.
func SetLogger(logger Logger) error {
if logger == nil {
return errors.New("logger is nil")
}
errLog = logger
defaultLogger = logger
return nil
}

View file

@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeBit:
return "BIT"
case fieldTypeBLOB:
if mf.charSet != collations[binaryCollation] {
if mf.charSet != binaryCollationID {
return "TEXT"
}
return "BLOB"
@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeGeometry:
return "GEOMETRY"
case fieldTypeInt24:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED MEDIUMINT"
}
return "MEDIUMINT"
case fieldTypeJSON:
return "JSON"
@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string {
}
return "INT"
case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] {
if mf.charSet != binaryCollationID {
return "LONGTEXT"
}
return "LONGBLOB"
@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string {
}
return "BIGINT"
case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] {
if mf.charSet != binaryCollationID {
return "MEDIUMTEXT"
}
return "MEDIUMBLOB"
@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string {
}
return "SMALLINT"
case fieldTypeString:
if mf.charSet == collations[binaryCollation] {
if mf.flags&flagEnum != 0 {
return "ENUM"
} else if mf.flags&flagSet != 0 {
return "SET"
}
if mf.charSet == binaryCollationID {
return "BINARY"
}
return "CHAR"
@ -88,43 +96,47 @@ func (mf *mysqlField) typeDatabaseName() string {
}
return "TINYINT"
case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] {
if mf.charSet != binaryCollationID {
return "TINYTEXT"
}
return "TINYBLOB"
case fieldTypeVarChar:
if mf.charSet == collations[binaryCollation] {
if mf.charSet == binaryCollationID {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeVarString:
if mf.charSet == collations[binaryCollation] {
if mf.charSet == binaryCollationID {
return "VARBINARY"
}
return "VARCHAR"
case fieldTypeYear:
return "YEAR"
case fieldTypeVector:
return "VECTOR"
default:
return ""
}
}
var (
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
scanTypeUnknown = reflect.TypeOf(new(interface{}))
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(any))
)
type mysqlField struct {
@ -187,12 +199,18 @@ func (mf *mysqlField) scanType() reflect.Type {
}
return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeVector:
if mf.charSet == binaryCollationID {
return scanTypeBytes
}
fallthrough
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeTime:
return scanTypeRawBytes
fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
if mf.flags&flagNotNULL != 0 {
return scanTypeString
}
return scanTypeNullString
case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime:

View file

@ -1,25 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build gofuzz
// +build gofuzz
package mysql
import (
"database/sql"
)
func Fuzz(data []byte) int {
db, err := sql.Open("mysql", string(data))
if err != nil {
return 0
}
db.Close()
return 1
}

View file

@ -17,7 +17,7 @@ import (
)
var (
fileRegister map[string]bool
fileRegister map[string]struct{}
fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex
@ -37,10 +37,10 @@ func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock()
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
fileRegister = make(map[string]struct{})
}
fileRegister[strings.Trim(filePath, `"`)] = true
fileRegister[strings.Trim(filePath, `"`)] = struct{}{}
fileRegisterLock.Unlock()
}
@ -93,9 +93,8 @@ func deferredClose(err *error, closer io.Closer) {
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
@ -116,17 +115,17 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
defer deferredClose(&err, cl)
}
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
err = fmt.Errorf("reader '%s' is <nil>", name)
}
} else {
err = fmt.Errorf("Reader '%s' is not registered", name)
err = fmt.Errorf("reader '%s' is not registered", name)
}
} else { // File
name = strings.Trim(name, `"`)
fileRegisterLock.RLock()
fr := fileRegister[name]
_, exists := fileRegister[name]
fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr {
if mc.cfg.AllowAllFiles || exists {
var file *os.File
var fi os.FileInfo
@ -147,14 +146,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
}
// send content packets
var data []byte
// if packetSize == 0, the Reader contains no data
if err == nil && packetSize > 0 {
data := make([]byte, 4+packetSize)
data = make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
@ -168,15 +169,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr
}
mc.conn().syncSequence()
// read OK packet
if err == nil {
return mc.readResultOK()
}
mc.readPacket()
mc.conn().readPacket()
return err
}

View file

@ -38,7 +38,7 @@ type NullTime sql.NullTime
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
func (nt *NullTime) Scan(value any) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
return fmt.Errorf("can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.

View file

@ -14,75 +14,108 @@ import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"os"
"strconv"
"time"
)
// Packets documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
// MySQL client/server protocol documentations.
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
// https://mariadb.com/kb/en/clientserver-protocol/
// read n bytes from mc.buf
func (mc *mysqlConn) readNext(n int) ([]byte, error) {
if mc.buf.len() < n {
err := mc.buf.fill(n, mc.readWithTimeout)
if err != nil {
return nil, err
}
}
return mc.buf.readNext(n), nil
}
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
invalidSequence := false
readNext := mc.readNext
if mc.compress {
readNext = mc.compIO.readNext
}
for {
// read packet header
data, err := mc.buf.readNext(4)
data, err := readNext(4)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
mc.Close()
mc.log(err)
return nil, ErrInvalidConn
}
// packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
pktLen := getUint24(data[:3])
seq := data[3]
// check packet sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
if seq != mc.sequence {
mc.log(fmt.Sprintf("[warn] unexpected sequence nr: expected %v, got %v", mc.sequence, seq))
// MySQL and MariaDB doesn't check packet nr in compressed packet.
if !mc.compress {
// For large packets, we stop reading as soon as sync error.
if len(prevData) > 0 {
mc.close()
return nil, ErrPktSyncMul
}
invalidSequence = true
}
return nil, ErrPktSync
}
mc.sequence++
mc.sequence = seq + 1
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
errLog.Print(ErrMalformPkt)
mc.Close()
mc.log(ErrMalformPkt)
mc.close()
return nil, ErrInvalidConn
}
return prevData, nil
}
// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
data, err = readNext(pktLen)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
mc.Close()
mc.log(err)
return nil, ErrInvalidConn
}
// return data if this was the last packet
if pktLen < maxPacketSize {
// zero allocations for non-split packets
if prevData == nil {
return data, nil
if prevData != nil {
data = append(prevData, data...)
}
return append(prevData, data...), nil
if invalidSequence {
mc.close()
// return sync error only for regular packet.
// error packets may have wrong sequence number.
if data[0] != iERR {
return nil, ErrPktSync
}
}
return data, nil
}
prevData = append(prevData, data...)
@ -92,88 +125,52 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4
if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge
}
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
writeFunc := mc.writeWithTimeout
if mc.compress {
writeFunc = mc.compIO.writePackets
}
for {
var size int
if pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
size := min(maxPacketSize, pktLen)
putUint24(data[:3], size)
data[3] = mc.sequence
// Write packet
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
return err
}
if debug {
fmt.Fprintf(os.Stderr, "writePacket: size=%v seq=%v\n", size, mc.sequence)
}
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}
// Handle error
if err == nil { // n != len(data)
n, err := writeFunc(data[:4+size])
if err != nil {
mc.cleanup()
errLog.Print(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
mc.log(err)
return errBadConnNoWrite
} else {
return err
}
mc.cleanup()
errLog.Print(err)
}
return ErrInvalidConn
if n != 4+size {
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
mc.cleanup()
return io.ErrShortWrite
}
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
}
}
@ -186,11 +183,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return
}
@ -234,12 +226,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 3
// capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2
// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 11
// second part of the password cipher [mininum 13 bytes],
// second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
//
// The web documentation is ambiguous about the length. However,
@ -285,12 +280,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientConnectAttrs |
mc.flags&clientLongFlag
sendConnectAttrs := mc.flags&clientConnectAttrs != 0
if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows
}
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
clientFlags |= clientCompress
}
// To enable TLS / SSL
if mc.cfg.TLS != nil {
clientFlags |= clientSSL
@ -318,34 +318,38 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}
// encode length of the connection attributes
var connAttrsLEI []byte
if sendConnectAttrs {
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
}
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
mc.cleanup()
return err
}
// ClientFlags [32 bit]
data[4] = byte(clientFlags)
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
// MaxPacketSize [32 bit] (none)
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
binary.LittleEndian.PutUint32(data[8:], 0)
// Charset [1 byte]
var found bool
data[12], found = collations[mc.cfg.Collation]
if !found {
// Note possibility for false negatives:
// could be triggered although the collation is valid if the
// collations map does not contain entries the server supports.
return errors.New("unknown collation")
// Collation ID [1 byte]
data[12] = defaultCollationID
if cname := mc.cfg.Collation; cname != "" {
colID, ok := collations[cname]
if ok {
data[12] = colID
} else if len(mc.cfg.charsets) > 0 {
// When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
return fmt.Errorf("unknown collation: %q", cname)
}
}
// Filler [23 bytes] (all 0x00)
@ -365,11 +369,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
return err
}
mc.rawConn = mc.netConn
mc.netConn = tlsConn
mc.buf.nc = tlsConn
}
// User [null terminated string]
@ -394,6 +399,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00
pos++
// Connection Attributes
if sendConnectAttrs {
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
}
// Send Auth packet
return mc.writePacket(data[:pos])
}
@ -401,11 +412,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeSmallBuffer(pktLen)
data, err := mc.buf.takeBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
mc.cleanup()
return err
}
// Add the auth data [EOF]
@ -419,32 +429,30 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
return err
}
// Add command byte
data[4] = command
// Send CMD packet
return mc.writePacket(data)
err = mc.writePacket(data)
mc.syncSequence()
return err
}
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.sequence = 0
mc.resetSequence()
pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
return err
}
// Add command byte
@ -454,31 +462,30 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
copy(data[5:], arg)
// Send CMD packet
return mc.writePacket(data)
err = mc.writePacket(data)
mc.syncSequence()
return err
}
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
return err
}
// Add command byte
data[4] = command
// Add arg [32 bit]
data[5] = byte(arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
binary.LittleEndian.PutUint32(data[5:], arg)
// Send CMD packet
return mc.writePacket(data)
err = mc.writePacket(data)
mc.syncSequence()
return err
}
/******************************************************************************
@ -495,7 +502,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
switch data[0] {
case iOK:
return nil, "", mc.handleOkPacket(data)
// resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.resultUnchanged().handleOkPacket(data)
case iAuthMoreData:
return data[1:], "", err
@ -511,6 +520,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
}
plugin := string(data[1:pluginEndIndex])
authData := data[pluginEndIndex+1:]
if len(authData) > 0 && authData[len(authData)-1] == 0 {
authData = authData[:len(authData)-1]
}
return authData, plugin, nil
default: // Error otherwise
@ -518,9 +530,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
}
}
// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
data, err := mc.readPacket()
// Returns error if Packet is not a 'Result OK'-Packet
func (mc *okHandler) readResultOK() error {
data, err := mc.conn().readPacket()
if err != nil {
return err
}
@ -528,35 +540,37 @@ func (mc *mysqlConn) readResultOK() error {
if data[0] == iOK {
return mc.handleOkPacket(data)
}
return mc.handleErrorPacket(data)
return mc.conn().handleErrorPacket(data)
}
// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
if err == nil {
switch data[0] {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
}
return 0, ErrMalformPkt
data, err := mc.conn().readPacket()
if err != nil {
return 0, err
}
return 0, err
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
}
// Error Packet
@ -573,7 +587,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
// 1836: ER_READ_ONLY_MODE
if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
// Oops; we are connected to a read-only connection, and won't be able
// to issue any write statements. Since RejectReadOnly is configured,
// we throw away this connection hoping this one would have write
@ -607,18 +622,61 @@ func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8
}
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
// need to be cleared first (e.g. during authentication, or while additional
// resultsets are being fetched.)
func (mc *mysqlConn) resultUnchanged() *okHandler {
return (*okHandler)(mc)
}
// okHandler represents the state of the connection when mysqlConn.result has
// been prepared for processing of OK packets.
//
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
// callpaths must either:
//
// 1. first clear it using clearResult(), or
// 2. confirm that they don't need to (by calling resultUnchanged()).
//
// Both return an instance of type *okHandler.
type okHandler mysqlConn
// Exposes the underlying type's methods.
func (mc *okHandler) conn() *mysqlConn {
return (*mysqlConn)(mc)
}
// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// It returns a handler that can process OK responses.
func (mc *mysqlConn) clearResult() *okHandler {
mc.result = mysqlResult{}
return (*okHandler)(mc)
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var affectedRows, insertId uint64
// 0x00 [1 byte]
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
insertId, _, m = readLengthEncodedInteger(data[1+n:])
// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
@ -769,7 +827,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
for i := range dest {
// Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err != nil {
@ -781,19 +840,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
continue
}
if !mc.parseTime {
continue
}
// Parse time field
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp,
fieldTypeDateTime,
fieldTypeDate,
fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
return err
if mc.parseTime {
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
}
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
}
}
@ -875,32 +955,26 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
pktLen = dataOffset + argLen
}
stmt.mc.sequence = 0
// Add command byte [1 byte]
data[4] = comStmtSendLongData
// Add stmtID [32 bit]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
binary.LittleEndian.PutUint32(data[5:], stmt.id)
// Add paramID [16 bit]
data[9] = byte(paramID)
data[10] = byte(paramID >> 8)
binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
// Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen])
// Every COM_LONG_DATA packet reset Packet Sequence
stmt.mc.resetSequence()
if err == nil {
data = data[pktLen-dataOffset:]
continue
}
return err
}
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil
}
@ -925,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}
// Reset packet-sequence
mc.sequence = 0
mc.resetSequence()
var data []byte
var err error
@ -937,28 +1011,20 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
return errBadConnNoWrite
return err
}
// command [1 byte]
data[4] = comStmtExecute
// statement_id [4 bytes]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
binary.LittleEndian.PutUint32(data[5:], stmt.id)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
binary.LittleEndian.PutUint32(data[10:], 1)
if len(args) > 0 {
pos := minPktLen
@ -1012,50 +1078,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
case int64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
case uint64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x80 // type is unsigned
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
case float64:
paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
case bool:
paramTypes[i+i] = byte(fieldTypeTiny)
@ -1116,7 +1149,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc))
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return err
}
@ -1136,20 +1169,21 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
mc.buf.store(data) // allow this buffer to be reused
}
pos += len(paramValues)
data = data[:pos]
}
return mc.writePacket(data)
err = mc.writePacket(data)
mc.syncSequence()
return err
}
func (mc *mysqlConn) discardResults() error {
// For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
@ -1157,11 +1191,11 @@ func (mc *mysqlConn) discardResults() error {
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.conn().readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
if err := mc.conn().readUntilEOF(); err != nil {
return err
}
}
@ -1268,7 +1302,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeVector:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])

View file

@ -8,15 +8,43 @@
package mysql
import "database/sql/driver"
// Result exposes data not available through *connection.Result.
//
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
// executed statement.
AllRowsAffected() []int64
// AllLastInsertIds returns a slice containing the last inserted ID for each
// executed statement.
AllLastInsertIds() []int64
}
type mysqlResult struct {
affectedRows int64
insertId int64
// One entry in both slices is created for every executed statement result.
affectedRows []int64
insertIds []int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
return res.insertIds[len(res.insertIds)-1], nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
return res.affectedRows[len(res.affectedRows)-1], nil
}
func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
}
func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
}

View file

@ -111,19 +111,13 @@ func (rows *mysqlRows) Close() (err error) {
return err
}
// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
}
if err == nil {
if err = mc.discardResults(); err != nil {
handleOk := mc.clearResult()
if err = handleOk.discardResults(); err != nil {
return err
}
}
@ -160,7 +154,15 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
return 0, io.EOF
}
rows.rs = resultSet{}
return rows.mc.readResultSetHeaderPacket()
// rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
rows.mc.status = rows.mc.status & (^statusMoreResultsExists)
}
return resLen, err
}
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {

View file

@ -24,11 +24,12 @@ type mysqlStmt struct {
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.Load() {
// driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent.
// See also Issue #450 and golang/go#16019.
//errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
// driver.Stmt.Close could be called more than once, thus this function
// had to be idempotent. See also Issue #450 and golang/go#16019.
// This bug has been fixed in Go 1.8.
// https://github.com/golang/go/commit/90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a
// But we keep this function idempotent because it is safer.
return nil
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
@ -51,7 +52,6 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
handleOk := stmt.mc.clearResult()
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
}
}
if err := mc.discardResults(); err != nil {
if err := handleOk.discardResults(); err != nil {
return nil, err
}
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
copied := mc.result
return &copied, nil
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
@ -99,7 +95,6 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
@ -111,7 +106,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
handleOk := stmt.mc.clearResult()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@ -144,7 +140,7 @@ type converter struct{}
// implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference.
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
func (c converter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}

View file

@ -13,18 +13,32 @@ type mysqlTx struct {
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
if tx.mc == nil {
return ErrInvalidConn
}
if tx.mc.closed.Load() {
err = tx.mc.error()
if err == nil {
err = ErrInvalidConn
}
return
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
if tx.mc == nil {
return ErrInvalidConn
}
if tx.mc.closed.Load() {
err = tx.mc.error()
if err == nil {
err = ErrInvalidConn
}
return
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return

View file

@ -36,7 +36,7 @@ var (
// registering it.
//
// rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// pem, err := os.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) {
if timeTruncate > 0 {
t = t.Truncate(timeTruncate)
}
year, month, day := t.Date()
hour, min, sec := t.Clock()
nsec := t.Nanosecond()
@ -486,17 +490,16 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
// 24bit integer: used for packet headers.
func putUint24(data []byte, n int) {
data[2] = byte(n >> 16)
data[1] = byte(n >> 8)
data[0] = byte(n)
}
func getUint24(data []byte) int {
return int(data[2])<<16 | int(data[1])<<8 | int(data[0])
}
func uint64ToString(n uint64) []byte {
@ -521,16 +524,6 @@ func uint64ToString(n uint64) []byte {
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
@ -582,18 +575,15 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
return uint64(binary.LittleEndian.Uint16(b[1:])), false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
return uint64(getUint24(b[1:])), false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9
}
// 0-250: value of first byte
@ -607,13 +597,19 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
b = append(b, 0xfc)
return binary.LittleEndian.AppendUint16(b, uint16(n))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
b = append(b, 0xfe)
return binary.LittleEndian.AppendUint64(b, n)
}
func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
}
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.