Skip to content

Commit 2a444a0

Browse files
committed
add select random mysql host
1 parent 83984f4 commit 2a444a0

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

pkg/loader/util.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ import (
1818
gosql "database/sql"
1919
"fmt"
2020
"hash/crc32"
21+
"math/rand"
2122
"net/url"
2223
"strconv"
2324
"strings"
2425
"sync/atomic"
26+
"time"
2527

2628
"github.com/go-sql-driver/mysql"
2729
"github.com/pingcap/errors"
@@ -146,7 +148,18 @@ func createDBWitSessions(dsn string, params map[string]string) (db *gosql.DB, er
146148

147149
// CreateDBWithSQLMode return sql.DB
148150
func CreateDBWithSQLMode(user string, password string, host string, port int, tlsConfig *tls.Config, sqlMode *string, params map[string]string) (db *gosql.DB, err error) {
149-
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port)
151+
hosts := strings.Split(host, ",")
152+
153+
if len(hosts) < 1 {
154+
return nil, errors.Annotate(err, "You must provide at least one mysql address")
155+
}
156+
157+
random := rand.New(rand.NewSource(time.Now().UnixNano()))
158+
159+
index := random.Intn(len(hosts))
160+
h := hosts[index]
161+
162+
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, h, port)
150163
if sqlMode != nil {
151164
// same as "set sql_mode = '<sqlMode>'"
152165
dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'"

tests/util/db.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import (
1818
"database/sql"
1919
"fmt"
2020
"log"
21+
"math/rand"
2122
"net/url"
23+
"strings"
2224
"time"
2325

2426
"github.com/pingcap/errors"
@@ -47,14 +49,25 @@ func (c *DBConfig) String() string {
4749
}
4850

4951
// CreateDB create a mysql fd
50-
func CreateDB(cfg DBConfig) (*sql.DB, error) {
52+
func CreateDB(cfg DBConfig) (db *sql.DB, err error) {
5153
// just set to the same timezone so the timestamp field of mysql will return the same value
5254
// timestamp field will be display as the time zone of the Local time of drainer when write to kafka, so we set it to local time to pass CI now
5355
_, offset := time.Now().Zone()
5456
zone := fmt.Sprintf("'+%02d:00'", offset/3600)
5557

56-
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&interpolateParams=true&multiStatements=true&time_zone=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, url.QueryEscape(zone))
57-
db, err := sql.Open("mysql", dbDSN)
58+
hosts := strings.Split(cfg.Host, ",")
59+
60+
if len(hosts) < 1 {
61+
return nil, errors.Annotate(err, "You must provide at least one mysql address")
62+
}
63+
64+
random := rand.New(rand.NewSource(time.Now().UnixNano()))
65+
66+
index := random.Intn(len(hosts))
67+
h := hosts[index]
68+
69+
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&interpolateParams=true&multiStatements=true&time_zone=%s", cfg.User, cfg.Password, h, cfg.Port, cfg.Name, url.QueryEscape(zone))
70+
db, err = sql.Open("mysql", dbDSN)
5871
if err != nil {
5972
return nil, errors.Trace(err)
6073
}

0 commit comments

Comments
 (0)