Skip to content

Commit 3144a46

Browse files
authored
executor: optimize load data assignment expressions (#46082)
close #46081
1 parent 4e6af00 commit 3144a46

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

executor/builder.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,11 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor {
944944
b.err = err
945945
return nil
946946
}
947+
err = loadDataInfo.initColAssignExprs()
948+
if err != nil {
949+
b.err = err
950+
return nil
951+
}
947952
loadDataExec := &LoadDataExec{
948953
baseExecutor: newBaseExecutor(b.ctx, nil, v.ID()),
949954
IsLocal: v.IsLocal,

executor/load_data.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/pingcap/tidb/parser/model"
3030
"github.com/pingcap/tidb/parser/mysql"
3131
"github.com/pingcap/tidb/sessionctx"
32+
"github.com/pingcap/tidb/sessionctx/stmtctx"
3233
"github.com/pingcap/tidb/table"
3334
"github.com/pingcap/tidb/types"
3435
"github.com/pingcap/tidb/util/chunk"
@@ -116,7 +117,11 @@ type LoadDataInfo struct {
116117
rows [][]types.Datum
117118
Drained bool
118119

119-
ColumnAssignments []*ast.Assignment
120+
ColumnAssignments []*ast.Assignment
121+
ColumnAssignmentExprs []expression.Expression
122+
// sessionCtx generate warnings when rewrite AST node into expression.
123+
// we should generate such warnings for each row encoded.
124+
exprWarnings []stmtctx.SQLWarn
120125
ColumnsAndUserVars []*ast.ColumnNameOrUserVar
121126
FieldMappings []*FieldMapping
122127

@@ -211,6 +216,23 @@ func (e *LoadDataInfo) initLoadColumns(columnNames []string) error {
211216
return nil
212217
}
213218

219+
// initColAssignExprs creates the column assignment expressions using session context.
220+
// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content,
221+
func (e *LoadDataInfo) initColAssignExprs() error {
222+
for _, assign := range e.ColumnAssignments {
223+
newExpr, err := expression.RewriteAstExpr(e.Ctx, assign.Expr, nil, nil)
224+
if err != nil {
225+
return err
226+
}
227+
// col assign expr warnings is static, we should generate it for each row processed.
228+
// so we save it and clear it here.
229+
e.exprWarnings = append(e.exprWarnings, e.Ctx.GetSessionVars().StmtCtx.GetWarnings()...)
230+
e.Ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
231+
e.ColumnAssignmentExprs = append(e.ColumnAssignmentExprs, newExpr)
232+
}
233+
return nil
234+
}
235+
214236
// initFieldMappings make a field mapping slice to implicitly map input field to table column or user defined variable
215237
// the slice's order is the same as the order of the input fields.
216238
// Returns a slice of same ordered column names without user defined variable names.
@@ -664,15 +686,19 @@ func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []field) []types.Datu
664686

665687
row = append(row, types.NewDatum(string(cols[i].str)))
666688
}
667-
for i := 0; i < len(e.ColumnAssignments); i++ {
689+
690+
for i := 0; i < len(e.ColumnAssignmentExprs); i++ {
668691
// eval expression of `SET` clause
669-
d, err := expression.EvalAstExpr(e.Ctx, e.ColumnAssignments[i].Expr)
692+
d, err := e.ColumnAssignmentExprs[i].Eval(chunk.Row{})
670693
if err != nil {
671694
e.handleWarning(err)
672695
return nil
673696
}
674697
row = append(row, d)
675698
}
699+
if len(e.exprWarnings) > 0 {
700+
e.Ctx.GetSessionVars().StmtCtx.AppendWarnings(e.exprWarnings)
701+
}
676702

677703
// a new row buffer will be allocated in getRow
678704
newRow, err := e.getRow(ctx, row)

server/server_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,73 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) {
15761576
require.NoError(t, rows.Close())
15771577
dbt.MustExec("drop table if exists pn")
15781578
})
1579+
1580+
err = fp.Close()
1581+
require.NoError(t, err)
1582+
err = os.Remove(path)
1583+
require.NoError(t, err)
1584+
1585+
fp, err = os.Create(path)
1586+
require.NoError(t, err)
1587+
require.NotNil(t, fp)
1588+
1589+
_, err = fp.WriteString(
1590+
`1,2` + "\n" +
1591+
`1,2,,4` + "\n" +
1592+
`1,2,3` + "\n" +
1593+
`,,,` + "\n" +
1594+
`,,3` + "\n" +
1595+
`1,,,4` + "\n")
1596+
require.NoError(t, err)
1597+
1598+
nullInt32 := func(val int32, valid bool) sql.NullInt32 {
1599+
return sql.NullInt32{Int32: val, Valid: valid}
1600+
}
1601+
expects := []struct {
1602+
col1 sql.NullInt32
1603+
col2 sql.NullInt32
1604+
col3 sql.NullInt32
1605+
col4 sql.NullInt32
1606+
}{
1607+
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(0, false)},
1608+
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(4, true)},
1609+
{nullInt32(1, true), nullInt32(2, true), nullInt32(3, true), nullInt32(0, false)},
1610+
{nullInt32(0, true), nullInt32(0, false), nullInt32(0, false), nullInt32(0, false)},
1611+
{nullInt32(0, true), nullInt32(0, false), nullInt32(3, true), nullInt32(0, false)},
1612+
{nullInt32(1, true), nullInt32(0, false), nullInt32(0, false), nullInt32(4, true)},
1613+
}
1614+
1615+
cli.runTestsOnNewDB(t, func(config *mysql.Config) {
1616+
config.AllowAllFiles = true
1617+
config.Params["sql_mode"] = "''"
1618+
}, "LoadData", func(dbt *testkit.DBTestKit) {
1619+
dbt.MustExec("drop table if exists pn")
1620+
dbt.MustExec("create table pn (c1 int, c2 int, c3 int, c4 int)")
1621+
dbt.MustExec("set @@tidb_dml_batch_size = 1")
1622+
_, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @val2, @val3, @val4)
1623+
SET c2 = NULLIF(@val2, ''), c3 = NULLIF(@val3, ''), c4 = NULLIF(@val4, '')`, path))
1624+
require.NoError(t, err1)
1625+
var (
1626+
a sql.NullInt32
1627+
b sql.NullInt32
1628+
c sql.NullInt32
1629+
d sql.NullInt32
1630+
)
1631+
rows := dbt.MustQuery("select * from pn")
1632+
for _, expect := range expects {
1633+
require.Truef(t, rows.Next(), "unexpected data")
1634+
err = rows.Scan(&a, &b, &c, &d)
1635+
require.NoError(t, err)
1636+
require.Equal(t, expect.col1, a)
1637+
require.Equal(t, expect.col2, b)
1638+
require.Equal(t, expect.col3, c)
1639+
require.Equal(t, expect.col4, d)
1640+
}
1641+
1642+
require.Falsef(t, rows.Next(), "unexpected data")
1643+
require.NoError(t, rows.Close())
1644+
dbt.MustExec("drop table if exists pn")
1645+
})
15791646
}
15801647

15811648
func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {

0 commit comments

Comments
 (0)