From feed1c79ae453da5e782a74f160ae95599268d66 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Wed, 1 Apr 2026 16:09:03 +0800 Subject: [PATCH 1/3] importsdk, importer: fix sampled source size in import estimate --- pkg/executor/importer/sampler.go | 48 ++++++++++++++++--------- pkg/executor/importer/sampler_test.go | 51 +++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/pkg/executor/importer/sampler.go b/pkg/executor/importer/sampler.go index 10f74352cb944..39657240e7c47 100644 --- a/pkg/executor/importer/sampler.go +++ b/pkg/executor/importer/sampler.go @@ -39,7 +39,6 @@ import ( plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" "go.uber.org/zap" @@ -352,27 +351,30 @@ func (s *kvSizeSampler) sampleOneFile( }() var ( - count int - readRowCache []types.Datum - readFn = parserEncodeReader(parser, chunk.Chunk.EndOffset, chunk.GetKey()) - kvBatch = newEncodedKVGroupBatch(ksCodec, maxRowCount) + count int + kvBatch = newEncodedKVGroupBatch(ksCodec, maxRowCount) ) for count < maxRowCount { - row, closed, readErr := readFn(ctx, readRowCache) - if readErr != nil { - return 0, 0, 0, readErr - } - if closed { + startPos, _ := parser.Pos() + if s.cfg.Format != DataFormatParquet && startPos >= chunk.Chunk.EndOffset { break } - readRowCache = row.row - if rowDelta := row.endOffset - row.startPos; rowDelta > 0 { - sourceSize += rowDelta + + readErr := parser.ReadRow() + if readErr != nil { + if errors.Cause(readErr) == io.EOF { + break + } + return 0, 0, 0, common.ErrEncodeKV.Wrap(readErr).GenWithStackByArgs(chunk.GetKey(), startPos) } - kvs, encodeErr := encoder.Encode(row.row, row.rowID) - row.resetFn() + + lastRow := parser.LastRow() + sourceSize += s.sampledRowSourceSize(parser, startPos, lastRow) + + kvs, encodeErr := encoder.Encode(lastRow.Row, lastRow.RowID) + parser.RecycleRow(lastRow) if encodeErr != nil { - return 0, 0, 0, common.ErrEncodeKV.Wrap(encodeErr).GenWithStackByArgs(chunk.GetKey(), row.startPos) + return 0, 0, 0, common.ErrEncodeKV.Wrap(encodeErr).GenWithStackByArgs(chunk.GetKey(), startPos) } if _, err = kvBatch.add(kvs); err != nil { return 0, 0, 0, err @@ -382,3 +384,17 @@ func (s *kvSizeSampler) sampleOneFile( dataKVSize, indexKVSize = kvBatch.groupChecksum.DataAndIndexSumSize() return sourceSize, dataKVSize, indexKVSize, nil } + +func (s *kvSizeSampler) sampledRowSourceSize(parser mydump.Parser, startPos int64, row mydump.Row) int64 { + // Sampling needs per-row source bytes, not buffered reader progress. + // SQL/CSV parsers expose byte offsets through Pos(), while parquet Pos() + // is row-count based and must fall back to the row-size estimate. + if s.cfg.Format == DataFormatParquet { + return int64(row.Length) + } + endPos, _ := parser.Pos() + if rowDelta := endPos - startPos; rowDelta > 0 { + return rowDelta + } + return int64(row.Length) +} diff --git a/pkg/executor/importer/sampler_test.go b/pkg/executor/importer/sampler_test.go index 3b0dda715440d..260c65397ac42 100644 --- a/pkg/executor/importer/sampler_test.go +++ b/pkg/executor/importer/sampler_test.go @@ -238,6 +238,57 @@ func TestSampleIndexSizeRatio(t *testing.T) { require.Error(t, err) require.True(t, reader.closed) }) + + t.Run("sql_source_size_uses_consumed_bytes_not_buffered_progress", func(t *testing.T) { + dir := t.TempDir() + var fileSB strings.Builder + fileSB.WriteString("INSERT INTO t VALUES\n") + for i := 0; i < 20; i++ { + _, err := fmt.Fprintf(&fileSB, "(%d,'v%d','w%d','x%d')", i, i, i, i) + require.NoError(t, err) + if i < 19 { + fileSB.WriteString(",\n") + continue + } + fileSB.WriteString(";\n") + } + content := fileSB.String() + require.NoError(t, os.WriteFile(filepath.Join(dir, "001.sql"), []byte(content), 0o644)) + + p := parser.New() + node, err := p.ParseOneStmt(`create table t (a int, b text, c text, d text, index idx(a));`, "", "") + require.NoError(t, err) + sctx := utilmock.NewContext() + tblInfo, err := ddl.MockTableInfo(sctx, node.(*ast.CreateTableStmt), 1) + require.NoError(t, err) + tblInfo.State = model.StatePublic + table := tables.MockTableFromMeta(tblInfo) + + ctrl, err := NewLoadDataController(&Plan{ + Path: filepath.Join(dir, "*.sql"), + Format: DataFormatSQL, + InImportInto: true, + }, table, &ASTArgs{}) + require.NoError(t, err) + ctrl.logger = zap.Must(zap.NewDevelopment()) + ctx := context.Background() + require.NoError(t, ctrl.InitDataFiles(ctx)) + + sampled, err := SampleFileImportKVSize( + ctx, + ctrl.buildKVSizeSampleConfig(), + table, + ctrl.dataStore, + ctrl.dataFiles, + nil, + ctrl.logger, + ) + require.NoError(t, err) + require.Positive(t, sampled.SourceSize) + require.Positive(t, sampled.TotalKVSize()) + require.Greater(t, sampled.SourceSize, int64(len(content)/2)) + require.Less(t, sampled.SourceSize, int64(len(content)*2)) + }) } func TestSampleIndexSizeRatioVeryLongRows(t *testing.T) { simpleTbl := `create table t (a int, b text, c text, d text, index idx(a));` From ce12082dceaf1804513c90ea41f21f2e9828ddbe Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Wed, 1 Apr 2026 18:52:08 +0800 Subject: [PATCH 2/3] importer: clarify sampled source size comment --- pkg/executor/importer/sampler.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/executor/importer/sampler.go b/pkg/executor/importer/sampler.go index 39657240e7c47..e69a3a222e238 100644 --- a/pkg/executor/importer/sampler.go +++ b/pkg/executor/importer/sampler.go @@ -387,8 +387,10 @@ func (s *kvSizeSampler) sampleOneFile( func (s *kvSizeSampler) sampledRowSourceSize(parser mydump.Parser, startPos int64, row mydump.Row) int64 { // Sampling needs per-row source bytes, not buffered reader progress. - // SQL/CSV parsers expose byte offsets through Pos(), while parquet Pos() - // is row-count based and must fall back to the row-size estimate. + // SQL/CSV parsers expose byte offsets through Pos(), including compressed + // input where Pos() tracks uncompressed bytes and stays aligned with the + // RealSize-based source totals. Parquet Pos() is row-count based and must + // fall back to the row-size estimate. if s.cfg.Format == DataFormatParquet { return int64(row.Length) } From b62a461179c611c30aa80940ae8c78f985187fee Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Thu, 2 Apr 2026 16:35:53 +0800 Subject: [PATCH 3/3] importsdk: support multi-statement schema in size estimator --- pkg/importsdk/file_scanner.go | 51 +++++++++++++++++++++++++++--- pkg/importsdk/file_scanner_test.go | 35 ++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/pkg/importsdk/file_scanner.go b/pkg/importsdk/file_scanner.go index 2e4bcf6ceaf20..7af7d0f60d3db 100644 --- a/pkg/importsdk/file_scanner.go +++ b/pkg/importsdk/file_scanner.go @@ -393,13 +393,13 @@ func (s *fileScanner) buildEstimateTableInfo(ctx context.Context, tblMeta *mydum } p := parser.New() p.SetSQLMode(s.config.sqlMode) - stmt, err := p.ParseOneStmt(schemaSQL, "", "") + stmts, _, err := p.ParseSQL(schemaSQL) if err != nil { return nil, errors.Trace(err) } - createStmt, ok := stmt.(*ast.CreateTableStmt) - if !ok { - return nil, errors.Errorf("schema file %s does not contain a CREATE TABLE statement", tblMeta.SchemaFile.FileMeta.Path) + createStmt, err := buildEstimateCreateTableStmt(stmts, tblMeta) + if err != nil { + return nil, err } tableInfo, err := ddl.BuildTableInfoFromAST(metabuild.NewContext(), createStmt) if err != nil { @@ -408,6 +408,49 @@ func (s *fileScanner) buildEstimateTableInfo(ctx context.Context, tblMeta *mydum return tableInfo, nil } +func buildEstimateCreateTableStmt(stmts []ast.StmtNode, tblMeta *mydump.MDTableMeta) (*ast.CreateTableStmt, error) { + var ( + firstCreateStmt *ast.CreateTableStmt + createStmtCount int + ) + for _, stmt := range stmts { + createStmt, ok := stmt.(*ast.CreateTableStmt) + if !ok { + continue + } + if firstCreateStmt == nil { + firstCreateStmt = createStmt + } + createStmtCount++ + if estimateCreateTableStmtMatchesMeta(createStmt, tblMeta) { + return createStmt, nil + } + } + if createStmtCount == 1 { + return firstCreateStmt, nil + } + if createStmtCount == 0 { + return nil, errors.Errorf("schema file %s does not contain a CREATE TABLE statement", tblMeta.SchemaFile.FileMeta.Path) + } + return nil, errors.Errorf( + "schema file %s contains %d CREATE TABLE statements but none match table %s.%s", + tblMeta.SchemaFile.FileMeta.Path, + createStmtCount, + tblMeta.DB, + tblMeta.Name, + ) +} + +func estimateCreateTableStmtMatchesMeta(createStmt *ast.CreateTableStmt, tblMeta *mydump.MDTableMeta) bool { + if !strings.EqualFold(createStmt.Table.Name.String(), tblMeta.Name) { + return false + } + if createStmt.Table.Schema.String() == "" { + return true + } + return strings.EqualFold(createStmt.Table.Schema.String(), tblMeta.DB) +} + func sourceTypeToImportFormat(tp mydump.SourceType) (string, error) { switch tp { case mydump.SourceTypeCSV: diff --git a/pkg/importsdk/file_scanner_test.go b/pkg/importsdk/file_scanner_test.go index 3b632e9bad654..ed69868de673f 100644 --- a/pkg/importsdk/file_scanner_test.go +++ b/pkg/importsdk/file_scanner_test.go @@ -266,6 +266,41 @@ func TestFileScanner(t *testing.T) { require.Equal(t, estimate.Tables[0].SourceSize, estimate.TotalSourceSize) require.Equal(t, estimate.Tables[0].TiKVSize, estimate.TotalTiKVSize) }) + + t.Run("EstimateImportDataSizeMultiStatementSchema", func(t *testing.T) { + estimateDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(estimateDir, "test_db-schema-create.sql"), []byte("CREATE DATABASE test_db;"), 0o644)) + require.NoError(t, os.WriteFile( + filepath.Join(estimateDir, "test_db.users-schema.sql"), + []byte(strings.Join([]string{ + "CREATE DATABASE IF NOT EXISTS test_db;", + "USE test_db;", + "DROP TABLE IF EXISTS users;", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255), KEY idx_name (name));", + }, "\n")), + 0o644, + )) + require.NoError(t, os.WriteFile( + filepath.Join(estimateDir, "test_db.users.001.csv"), + []byte("1,alice\n2,bob\n"), + 0o644, + )) + + cfg := defaultSDKConfig() + cfg.skipInvalidFiles = true + estimateScanner, err := NewFileScanner(ctx, "file://"+estimateDir, db, cfg) + require.NoError(t, err) + defer estimateScanner.Close() + + estimate, err := estimateScanner.EstimateImportDataSize(ctx) + require.NoError(t, err) + require.Len(t, estimate.Tables, 1) + require.Equal(t, "users", estimate.Tables[0].Table) + require.Positive(t, estimate.Tables[0].SourceSize) + require.Positive(t, estimate.Tables[0].TiKVSize) + require.Equal(t, estimate.Tables[0].SourceSize, estimate.TotalSourceSize) + require.Equal(t, estimate.Tables[0].TiKVSize, estimate.TotalTiKVSize) + }) } func TestFileScannerWithEstimateRealSize(t *testing.T) {