Skip to content

Commit 267a950

Browse files
authored
DM: Improve MariaDB permission handling (pingcap#12404)
close pingcap#12207
1 parent 88ac061 commit 267a950

File tree

6 files changed

+40
-25
lines changed

6 files changed

+40
-25
lines changed

dm/checker/checker.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ func (c *Checker) Init(ctx context.Context) (err error) {
304304
c.checkList = append(c.checkList, checker.NewTargetPrivilegeChecker(
305305
c.instances[0].targetDB.DB,
306306
c.instances[0].targetDBInfo,
307+
c.instances[0].targetDB.Version,
307308
))
308309
}
309310
// sourceID -> DB
@@ -334,6 +335,7 @@ func (c *Checker) Init(ctx context.Context) (err error) {
334335
c.checkList = append(c.checkList, checker.NewSourceDumpPrivilegeChecker(
335336
instance.sourceDB.DB,
336337
instance.sourceDBinfo,
338+
instance.sourceDB.Version,
337339
info.sourceID2SourceTables[sourceID],
338340
exportCfg.Consistency,
339341
c.dumpWholeInstance,
@@ -364,7 +366,7 @@ func (c *Checker) Init(ctx context.Context) (err error) {
364366
c.checkList = append(c.checkList, checker.NewMySQLBinlogRowImageChecker(instance.sourceDB.DB, instance.sourceDBinfo))
365367
}
366368
if _, ok := c.checkingItems[config.ReplicationPrivilegeChecking]; ok {
367-
c.checkList = append(c.checkList, checker.NewSourceReplicationPrivilegeChecker(instance.sourceDB.DB, instance.sourceDBinfo))
369+
c.checkList = append(c.checkList, checker.NewSourceReplicationPrivilegeChecker(instance.sourceDB.DB, instance.sourceDBinfo, instance.sourceDB.Version))
368370
}
369371
if _, ok := c.checkingItems[config.OnlineDDLChecking]; c.onlineDDL != nil && ok {
370372
c.checkList = append(c.checkList, checker.NewOnlineDDLChecker(instance.sourceDB.DB, info.sourceID2InterestedDB[i], c.onlineDDL, instance.baList))

dm/pkg/checker/conn_checker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (c *connNumberChecker) check(ctx context.Context, checkerName string, neede
9494
markCheckError(result, err)
9595
return result
9696
}
97-
err2 := verifyPrivilegesWithResult(result, grants, neededPriv)
97+
err2 := verifyPrivilegesWithResult(result, grants, neededPriv, c.toCheckDB.Version)
9898
if err2 != nil {
9999
// no enough privilege to check the user's connection number
100100
result.State = StateWarning

dm/pkg/checker/privilege.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ type SourceDumpPrivilegeChecker struct {
5454
checkTables []filter.Table
5555
consistency string
5656
dumpWholeInstance bool
57+
version string
5758
}
5859

5960
// NewSourceDumpPrivilegeChecker returns a RealChecker.
6061
func NewSourceDumpPrivilegeChecker(
6162
db *sql.DB,
6263
dbinfo *dbutil.DBConfig,
64+
version string,
6365
checkTables []filter.Table,
6466
consistency string,
6567
dumpWholeInstance bool,
@@ -70,6 +72,7 @@ func NewSourceDumpPrivilegeChecker(
7072
checkTables: checkTables,
7173
consistency: consistency,
7274
dumpWholeInstance: dumpWholeInstance,
75+
version: version,
7376
}
7477
}
7578

@@ -107,7 +110,7 @@ func (pc *SourceDumpPrivilegeChecker) Check(ctx context.Context) *Result {
107110
dumpRequiredPrivs[mysql.LockTablesPriv] = priv{needGlobal: true}
108111
}
109112

110-
err2 := verifyPrivilegesWithResult(result, grants, dumpRequiredPrivs)
113+
err2 := verifyPrivilegesWithResult(result, grants, dumpRequiredPrivs, pc.version)
111114
if err2 != nil {
112115
result.Errors = append(result.Errors, err2)
113116
result.Instruction = "Please grant the required privileges to the account."
@@ -126,13 +129,14 @@ func (pc *SourceDumpPrivilegeChecker) Name() string {
126129

127130
// SourceReplicatePrivilegeChecker checks replication privileges of source DB.
128131
type SourceReplicatePrivilegeChecker struct {
129-
db *sql.DB
130-
dbinfo *dbutil.DBConfig
132+
db *sql.DB
133+
dbinfo *dbutil.DBConfig
134+
version string
131135
}
132136

133137
// NewSourceReplicationPrivilegeChecker returns a RealChecker.
134-
func NewSourceReplicationPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig) RealChecker {
135-
return &SourceReplicatePrivilegeChecker{db: db, dbinfo: dbinfo}
138+
func NewSourceReplicationPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig, version string) RealChecker {
139+
return &SourceReplicatePrivilegeChecker{db: db, dbinfo: dbinfo, version: version}
136140
}
137141

138142
// Check implements the RealChecker interface.
@@ -154,7 +158,7 @@ func (pc *SourceReplicatePrivilegeChecker) Check(ctx context.Context) *Result {
154158
mysql.ReplicationSlavePriv: {needGlobal: true},
155159
mysql.ReplicationClientPriv: {needGlobal: true},
156160
}
157-
err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs)
161+
err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs, pc.version)
158162
if err2 != nil {
159163
result.Errors = append(result.Errors, err2)
160164
result.State = StateFailure
@@ -169,12 +173,13 @@ func (pc *SourceReplicatePrivilegeChecker) Name() string {
169173
}
170174

171175
type TargetPrivilegeChecker struct {
172-
db *sql.DB
173-
dbinfo *dbutil.DBConfig
176+
db *sql.DB
177+
dbinfo *dbutil.DBConfig
178+
version string
174179
}
175180

176-
func NewTargetPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig) RealChecker {
177-
return &TargetPrivilegeChecker{db: db, dbinfo: dbinfo}
181+
func NewTargetPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig, version string) RealChecker {
182+
return &TargetPrivilegeChecker{db: db, dbinfo: dbinfo, version: version}
178183
}
179184

180185
func (t *TargetPrivilegeChecker) Name() string {
@@ -203,7 +208,7 @@ func (t *TargetPrivilegeChecker) Check(ctx context.Context) *Result {
203208
mysql.DropPriv: {needGlobal: true},
204209
mysql.IndexPriv: {needGlobal: true},
205210
}
206-
err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs)
211+
err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs, t.version)
207212
if err2 != nil {
208213
result.Errors = append(result.Errors, err2)
209214
// because we cannot be very precisely sure about which table
@@ -217,8 +222,9 @@ func verifyPrivilegesWithResult(
217222
result *Result,
218223
grants []string,
219224
requiredPriv map[mysql.PrivilegeType]priv,
225+
version string,
220226
) *Error {
221-
lackedPriv, err := VerifyPrivileges(grants, requiredPriv)
227+
lackedPriv, err := VerifyPrivileges(grants, requiredPriv, version)
222228
if err != nil {
223229
// nolint
224230
return NewError("%s", err.Error())
@@ -284,12 +290,19 @@ func LackedPrivilegesAsStr(lackPriv map[mysql.PrivilegeType]priv) string {
284290
func VerifyPrivileges(
285291
grants []string,
286292
lackPrivs map[mysql.PrivilegeType]priv,
293+
version string,
287294
) (map[mysql.PrivilegeType]priv, error) {
288295
if len(grants) == 0 {
289296
return nil, errors.New("there is no such grant defined for current user on host '%%'")
290297
}
291298

292299
p := parser.New()
300+
301+
// Support for BINLOG MONITOR and other MariaDB things
302+
if strings.Contains(version, "MariaDB") {
303+
p.SetMariaDB(true)
304+
}
305+
293306
for _, grant := range grants {
294307
if len(lackPrivs) == 0 {
295308
break

dm/pkg/checker/privilege_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func TestVerifyDumpPrivileges(t *testing.T) {
213213
if cs.dumpWholeInstance {
214214
dumpRequiredPrivs[mysql.SelectPriv] = priv{needGlobal: true}
215215
}
216-
err := verifyPrivilegesWithResult(result, cs.grants, dumpRequiredPrivs)
216+
err := verifyPrivilegesWithResult(result, cs.grants, dumpRequiredPrivs, "8.0.11")
217217
if cs.dumpState == StateSuccess {
218218
require.Nil(t, err, "grants: %v", cs.grants)
219219
} else {
@@ -325,7 +325,7 @@ func TestVerifyReplicationPrivileges(t *testing.T) {
325325
mysql.ReplicationSlavePriv: {needGlobal: true},
326326
mysql.ReplicationClientPriv: {needGlobal: true},
327327
}
328-
err := verifyPrivilegesWithResult(result, cs.grants, replRequiredPrivs)
328+
err := verifyPrivilegesWithResult(result, cs.grants, replRequiredPrivs, "8.0.11")
329329
if cs.replicationState == StateSuccess {
330330
require.Nil(t, err, "grants: %v", cs.grants)
331331
} else {
@@ -405,7 +405,7 @@ func TestVerifyPrivilegesWildcard(t *testing.T) {
405405
dbs: genTableLevelPrivs(cs.checkTables),
406406
},
407407
}
408-
err := verifyPrivilegesWithResult(result, cs.grants, requiredPrivs)
408+
err := verifyPrivilegesWithResult(result, cs.grants, requiredPrivs, "8.0.11")
409409
if cs.replicationState == StateSuccess {
410410
require.Nil(t, err, "grants: %v", cs.grants)
411411
} else {
@@ -469,7 +469,7 @@ func TestVerifyTargetPrivilege(t *testing.T) {
469469
mysql.AlterPriv: {needGlobal: true},
470470
mysql.DropPriv: {needGlobal: true},
471471
}
472-
err := verifyPrivilegesWithResult(result, cs.grants, replRequiredPrivs)
472+
err := verifyPrivilegesWithResult(result, cs.grants, replRequiredPrivs, "8.0.11")
473473
if cs.checkState == StateSuccess {
474474
require.Nil(t, err, "grants: %v", cs.grants)
475475
} else {

dm/pkg/conn/basedb.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ type BaseDB struct {
201201
doNotClose bool
202202

203203
// SELECT VERSION()
204-
version string
204+
Version string
205205
}
206206

207207
// NewBaseDB returns *BaseDB object for test.
@@ -213,7 +213,7 @@ func NewBaseDB(db *sql.DB, scope terror.ErrScope, version string, doFuncInClose
213213
Retry: &retry.FiniteRetryStrategy{},
214214
Scope: scope,
215215
doFuncInClose: doFuncInClose,
216-
version: version,
216+
Version: version,
217217
}
218218
}
219219

@@ -238,7 +238,7 @@ func NewBaseDBForTestWithVersion(db *sql.DB, version string, doFuncInClose ...fu
238238
Retry: &retry.FiniteRetryStrategy{},
239239
Scope: terror.ScopeNotSet,
240240
doFuncInClose: doFuncInClose,
241-
version: version,
241+
Version: version,
242242
}
243243
}
244244

@@ -398,15 +398,15 @@ func (d *BaseDB) needsModernTerminology() bool {
398398
// - https://mariadb.com/docs/server/reference/sql-statements/administrative-sql-statements/show/show-replica-hosts
399399
//
400400
// Old syntax is still accepted.
401-
if strings.Contains(d.version, "MariaDB") {
401+
if strings.Contains(d.Version, "MariaDB") {
402402
return false
403403
}
404404

405405
// https://dev.mysql.com/doc/relnotes/mysql/8.4/en/news-8-4-0.html#mysqld-8-4-0-deprecation-removal
406406
// MySQL 8.4 removed `SHOW MASTER STATUS`.
407407
minVer := semver.New("8.4.0")
408408

409-
v, err := semver.NewVersion(d.version)
409+
v, err := semver.NewVersion(d.Version)
410410
if err != nil {
411411
return false
412412
}

dm/pkg/conn/basedb_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestNeedsModernTerminology(t *testing.T) {
121121
}
122122

123123
for _, tc := range cases {
124-
b.version = tc.version
125-
require.Equal(t, tc.modern, b.needsModernTerminology(), b.version)
124+
b.Version = tc.version
125+
require.Equal(t, tc.modern, b.needsModernTerminology(), b.Version)
126126
}
127127
}

0 commit comments

Comments
 (0)