diff --git a/rewrite-core/src/main/java/org/openrewrite/CsvDataTableStore.java b/rewrite-core/src/main/java/org/openrewrite/CsvDataTableStore.java index 0f5942fc565..e641eece7a5 100644 --- a/rewrite-core/src/main/java/org/openrewrite/CsvDataTableStore.java +++ b/rewrite-core/src/main/java/org/openrewrite/CsvDataTableStore.java @@ -28,6 +28,7 @@ import java.io.*; import java.lang.reflect.Field; +import java.lang.reflect.ParameterizedType; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -90,6 +91,10 @@ public class CsvDataTableStore implements DataTableStore, AutoCloseable { private final String fileExtension; private final Map prefixColumns; private final Map suffixColumns; + private static final ObjectMapper ROW_MAPPER = new ObjectMapper() + .registerModule(new ParameterNamesModule()) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + private final ConcurrentHashMap writers = new ConcurrentHashMap<>(); private final ConcurrentHashMap rowMetadata = new ConcurrentHashMap<>(); private final ConcurrentHashMap> knownTables = new ConcurrentHashMap<>(); @@ -180,15 +185,30 @@ private static InputStream defaultInputStream(Path path) { @Override public void insertRow(DataTable dataTable, ExecutionContext ctx, Row row) { String metaKey = metaKey(dataTable.getName(), dataTable.getGroup()); - rowMetadata.computeIfAbsent(metaKey, k -> RowMetadata.from(dataTable)); + rowMetadata.computeIfAbsent(metaKey, k -> RowMetadata.of(dataTable.getType())); knownTables.putIfAbsent(fileKey(dataTable), dataTable); String fileKey = fileKey(dataTable); BucketWriter writer = writers.computeIfAbsent(fileKey, k -> createBucketWriter(dataTable)); writer.writeRow(row); } + @Deprecated @Override public Stream getRows(String dataTableName, @Nullable String group) { + RowMetadata meta = rowMetadata.get(metaKey(dataTableName, group)); + return readRows(dataTableName, group, meta); + } + + @SuppressWarnings("unchecked") + @Override + public Stream getRows(Class> dataTableClass, @Nullable String group) { + Class rowType = (Class) ((ParameterizedType) dataTableClass.getGenericSuperclass()) + .getActualTypeArguments()[0]; + return readRows(dataTableClass.getName(), group, RowMetadata.of(rowType)); + } + + @SuppressWarnings("unchecked") + private Stream readRows(String dataTableName, @Nullable String group, @Nullable RowMetadata meta) { // Close (not just flush) matching writers so that compression trailers // (e.g., GZIP footer) are written, making the files fully readable. // Removed writers will be lazily re-created in append mode on the next insertRow(). @@ -203,8 +223,6 @@ public Stream getRows(String dataTableName, @Nullable String group) { } } - RowMetadata meta = rowMetadata.get(metaKey(dataTableName, group)); - List allRows = new ArrayList<>(); //noinspection DataFlowIssue File[] files = outputDir.toFile().listFiles((dir, name) -> name.endsWith(fileExtension)); @@ -241,6 +259,7 @@ public Stream getRows(String dataTableName, @Nullable String group) { try (InputStream is = inputStreamFactory.apply(file.toPath())) { CsvParserSettings settings = new CsvParserSettings(); + settings.setMaxCharsPerColumn(-1); settings.setHeaderExtractionEnabled(true); settings.getFormat().setComment('#'); CsvParser parser = new CsvParser(settings); @@ -265,7 +284,7 @@ public Stream getRows(String dataTableName, @Nullable String group) { } } - return allRows.stream(); + return (Stream) allRows.stream(); } @Override @@ -385,32 +404,27 @@ private static String metaKey(String dataTableName, @Nullable String group) { } /** - * Holds the row class and its @Column field names so that - * String[] rows read from CSV can be deserialized back to typed objects - * via Jackson's {@link ObjectMapper#convertValue}. + * Caches the {@link Column @Column} field names for a row class so they + * are only computed once, and converts CSV {@code String[]} rows back to + * typed objects via Jackson. */ private static class RowMetadata { - private static final ObjectMapper MAPPER = new ObjectMapper() - .registerModule(new ParameterNamesModule()) - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - - final String rowClassName; + final Class rowClass; final List fieldNames; - RowMetadata(String rowClassName, List fieldNames) { - this.rowClassName = rowClassName; + private RowMetadata(Class rowClass, List fieldNames) { + this.rowClass = rowClass; this.fieldNames = fieldNames; } - static RowMetadata from(DataTable dataTable) { - Class rowClass = dataTable.getType(); + static RowMetadata of(Class rowClass) { List names = new ArrayList<>(); for (Field f : rowClass.getDeclaredFields()) { if (f.isAnnotationPresent(Column.class)) { names.add(f.getName()); } } - return new RowMetadata(rowClass.getName(), names); + return new RowMetadata(rowClass, names); } Object toRow(String[] values) { @@ -418,11 +432,7 @@ Object toRow(String[] values) { for (int i = 0; i < fieldNames.size(); i++) { map.put(fieldNames.get(i), i < values.length ? values[i] : ""); } - try { - return MAPPER.convertValue(map, Class.forName(rowClassName)); - } catch (ClassNotFoundException e) { - throw new IllegalStateException("Row class not found: " + rowClassName, e); - } + return ROW_MAPPER.convertValue(map, rowClass); } } @@ -486,7 +496,7 @@ public static String sanitize(String value) { prefix = prefix.substring(0, lastDash); } } - String hash = sha256Prefix(value, 4); + String hash = sha256Prefix(value); return prefix + "-" + hash; } @@ -528,18 +538,18 @@ public static String sanitize(String value) { return new DataTableDescriptor(name, name, instanceName, "", group, Collections.emptyList()); } - private static String sha256Prefix(String input, int hexChars) { + private static String sha256Prefix(String input) { try { MessageDigest digest = MessageDigest.getInstance("SHA-256"); byte[] hash = digest.digest(input.getBytes(StandardCharsets.UTF_8)); StringBuilder hex = new StringBuilder(); for (byte b : hash) { hex.append(String.format("%02x", b)); - if (hex.length() >= hexChars) { + if (hex.length() >= 4) { break; } } - return hex.substring(0, Math.min(hexChars, hex.length())); + return hex.substring(0, Math.min(4, hex.length())); } catch (NoSuchAlgorithmException e) { throw new IllegalStateException(e); } diff --git a/rewrite-core/src/main/java/org/openrewrite/DataTableStore.java b/rewrite-core/src/main/java/org/openrewrite/DataTableStore.java index b337fe202b5..c936ef19b29 100644 --- a/rewrite-core/src/main/java/org/openrewrite/DataTableStore.java +++ b/rewrite-core/src/main/java/org/openrewrite/DataTableStore.java @@ -60,9 +60,36 @@ static DataTableStore noop() { * @param dataTableName the fully qualified class name of the data table * @param group the group identifying the bucket, or null for ungrouped * @return a stream of rows, or an empty stream if no rows exist + * @deprecated Use {@link #getRows(Class)} or {@link #getRows(Class, String)} for type-safe deserialization. */ + @Deprecated Stream getRows(String dataTableName, @Nullable String group); + /** + * Stream typed rows for a specific data table class and group. + * The row type is inferred from the data table's generic parameter. + * + * @param dataTableClass the data table class (e.g., {@code ServiceEndpoints.class}) + * @param group the group identifying the bucket, or null for ungrouped + * @param the row type + * @return a stream of typed rows, or an empty stream if no rows exist + */ + @SuppressWarnings("unchecked") + default Stream getRows(Class> dataTableClass, @Nullable String group) { + return (Stream) getRows(dataTableClass.getName(), group); + } + + /** + * Stream typed rows for a specific data table class (ungrouped). + * + * @param dataTableClass the data table class (e.g., {@code ServiceEndpoints.class}) + * @param the row type + * @return a stream of typed rows, or an empty stream if no rows exist + */ + default Stream getRows(Class> dataTableClass) { + return getRows(dataTableClass, null); + } + /** * Get the set of {@link DataTable} instances that have received rows. * diff --git a/rewrite-core/src/main/java/org/openrewrite/RecipeRun.java b/rewrite-core/src/main/java/org/openrewrite/RecipeRun.java index a350766d8e9..8c55a3276ba 100644 --- a/rewrite-core/src/main/java/org/openrewrite/RecipeRun.java +++ b/rewrite-core/src/main/java/org/openrewrite/RecipeRun.java @@ -44,13 +44,30 @@ public class RecipeRun { return null; } + /** + * @deprecated Use {@link #getDataTableRows(Class)} for type-safe deserialization. + */ + @Deprecated public List getDataTableRows(String name) { return getDataTableRows(name, null); } + /** + * @deprecated Use {@link #getDataTableRows(Class, String)} for type-safe deserialization. + */ @SuppressWarnings("unchecked") + @Deprecated public List getDataTableRows(String name, @Nullable String group) { return (List) dataTableStore.getRows(name, group) .collect(Collectors.toList()); } + + public List getDataTableRows(Class> dataTableClass) { + return getDataTableRows(dataTableClass, null); + } + + public List getDataTableRows(Class> dataTableClass, @Nullable String group) { + return dataTableStore.getRows(dataTableClass, group) + .collect(Collectors.toList()); + } } diff --git a/rewrite-core/src/test/java/org/openrewrite/DataTableStoreTest.java b/rewrite-core/src/test/java/org/openrewrite/DataTableStoreTest.java index 88b6763e629..b0eb73132a1 100644 --- a/rewrite-core/src/test/java/org/openrewrite/DataTableStoreTest.java +++ b/rewrite-core/src/test/java/org/openrewrite/DataTableStoreTest.java @@ -486,14 +486,8 @@ public List getInitialValue(ExecutionContext ctx) { // On cycle 2+, the store already contains rows written in cycle 1 List readBack = new ArrayList<>(); DataTableStore store = DataTableExecutionContextView.view(ctx).getDataTableStore(); - try (Stream rows = store.getRows(table.getName(), null)) { - rows.forEach(row -> { - if (row instanceof TestTable.Row) { - readBack.add(((TestTable.Row) row).getName()); - } else { - readBack.add(((String[]) row)[0]); - } - }); + try (Stream rows = store.getRows(TestTable.class)) { + rows.forEach(row -> readBack.add(row.getName())); } return readBack; } diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/search/FindFieldsOfTypeTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/search/FindFieldsOfTypeTest.java index bd0737d4813..1c9cc076885 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/search/FindFieldsOfTypeTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/search/FindFieldsOfTypeTest.java @@ -136,7 +136,7 @@ void verifyDataTable() { rewriteRun( spec -> spec.recipe(new FindFieldsOfType("java.util.List", true)) .afterRecipe(recipeRun -> { - List fields = recipeRun.getDataTableRows(FieldsOfTypeUses.class.getName()); + List fields = recipeRun.getDataTableRows(FieldsOfTypeUses.class); assertThat(fields).containsExactlyInAnyOrderElementsOf(expectedFields); }), java( diff --git a/rewrite-test/src/main/java/org/openrewrite/test/RecipeSpec.java b/rewrite-test/src/main/java/org/openrewrite/test/RecipeSpec.java index 7a8273cade7..bcbde761bf5 100644 --- a/rewrite-test/src/main/java/org/openrewrite/test/RecipeSpec.java +++ b/rewrite-test/src/main/java/org/openrewrite/test/RecipeSpec.java @@ -250,8 +250,7 @@ public RecipeSpec dataTableAsCsv(String name, String expect) { } } assertThat(dataTable).isNotNull(); - @SuppressWarnings("unchecked") - List rows = (List) store.getRows(dataTable.getName(), dataTable.getGroup()) + List rows = store.getRows(dataTable.getName(), dataTable.getGroup()) .collect(java.util.stream.Collectors.toList()); StringWriter writer = new StringWriter(); CsvMapper mapper = CsvMapper.builder()