Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions rewrite-core/src/main/java/org/openrewrite/CsvDataTableStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.io.*;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -90,6 +91,10 @@ public class CsvDataTableStore implements DataTableStore, AutoCloseable {
private final String fileExtension;
private final Map<String, String> prefixColumns;
private final Map<String, String> suffixColumns;
private static final ObjectMapper ROW_MAPPER = new ObjectMapper()
.registerModule(new ParameterNamesModule())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

private final ConcurrentHashMap<String, BucketWriter> writers = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, RowMetadata> rowMetadata = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, DataTable<?>> knownTables = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -180,15 +185,30 @@ private static InputStream defaultInputStream(Path path) {
@Override
public <Row> void insertRow(DataTable<Row> dataTable, ExecutionContext ctx, Row row) {
String metaKey = metaKey(dataTable.getName(), dataTable.getGroup());
rowMetadata.computeIfAbsent(metaKey, k -> RowMetadata.from(dataTable));
rowMetadata.computeIfAbsent(metaKey, k -> RowMetadata.of(dataTable.getType()));
knownTables.putIfAbsent(fileKey(dataTable), dataTable);
String fileKey = fileKey(dataTable);
BucketWriter writer = writers.computeIfAbsent(fileKey, k -> createBucketWriter(dataTable));
writer.writeRow(row);
}

@Deprecated
@Override
public Stream<?> getRows(String dataTableName, @Nullable String group) {
RowMetadata meta = rowMetadata.get(metaKey(dataTableName, group));
return readRows(dataTableName, group, meta);
}

@SuppressWarnings("unchecked")
@Override
public <Row> Stream<Row> getRows(Class<? extends DataTable<Row>> dataTableClass, @Nullable String group) {
Class<Row> rowType = (Class<Row>) ((ParameterizedType) dataTableClass.getGenericSuperclass())
.getActualTypeArguments()[0];
return readRows(dataTableClass.getName(), group, RowMetadata.of(rowType));
}

@SuppressWarnings("unchecked")
private <T> Stream<T> readRows(String dataTableName, @Nullable String group, @Nullable RowMetadata meta) {
// Close (not just flush) matching writers so that compression trailers
// (e.g., GZIP footer) are written, making the files fully readable.
// Removed writers will be lazily re-created in append mode on the next insertRow().
Expand All @@ -203,8 +223,6 @@ public Stream<?> getRows(String dataTableName, @Nullable String group) {
}
}

RowMetadata meta = rowMetadata.get(metaKey(dataTableName, group));

List<Object> allRows = new ArrayList<>();
//noinspection DataFlowIssue
File[] files = outputDir.toFile().listFiles((dir, name) -> name.endsWith(fileExtension));
Expand Down Expand Up @@ -241,6 +259,7 @@ public Stream<?> getRows(String dataTableName, @Nullable String group) {

try (InputStream is = inputStreamFactory.apply(file.toPath())) {
CsvParserSettings settings = new CsvParserSettings();
settings.setMaxCharsPerColumn(-1);
settings.setHeaderExtractionEnabled(true);
settings.getFormat().setComment('#');
CsvParser parser = new CsvParser(settings);
Expand All @@ -265,7 +284,7 @@ public Stream<?> getRows(String dataTableName, @Nullable String group) {
}
}

return allRows.stream();
return (Stream<T>) allRows.stream();
}

@Override
Expand Down Expand Up @@ -385,44 +404,35 @@ private static String metaKey(String dataTableName, @Nullable String group) {
}

/**
* Holds the row class and its @Column field names so that
* String[] rows read from CSV can be deserialized back to typed objects
* via Jackson's {@link ObjectMapper#convertValue}.
* Caches the {@link Column @Column} field names for a row class so they
* are only computed once, and converts CSV {@code String[]} rows back to
* typed objects via Jackson.
*/
private static class RowMetadata {
private static final ObjectMapper MAPPER = new ObjectMapper()
.registerModule(new ParameterNamesModule())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

final String rowClassName;
final Class<?> rowClass;
final List<String> fieldNames;

RowMetadata(String rowClassName, List<String> fieldNames) {
this.rowClassName = rowClassName;
private RowMetadata(Class<?> rowClass, List<String> fieldNames) {
this.rowClass = rowClass;
this.fieldNames = fieldNames;
}

static RowMetadata from(DataTable<?> dataTable) {
Class<?> rowClass = dataTable.getType();
static RowMetadata of(Class<?> rowClass) {
List<String> names = new ArrayList<>();
for (Field f : rowClass.getDeclaredFields()) {
if (f.isAnnotationPresent(Column.class)) {
names.add(f.getName());
}
}
return new RowMetadata(rowClass.getName(), names);
return new RowMetadata(rowClass, names);
}

Object toRow(String[] values) {
Map<String, String> map = new LinkedHashMap<>();
for (int i = 0; i < fieldNames.size(); i++) {
map.put(fieldNames.get(i), i < values.length ? values[i] : "");
}
try {
return MAPPER.convertValue(map, Class.forName(rowClassName));
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Row class not found: " + rowClassName, e);
}
return ROW_MAPPER.convertValue(map, rowClass);
}
}

Expand Down Expand Up @@ -486,7 +496,7 @@ public static String sanitize(String value) {
prefix = prefix.substring(0, lastDash);
}
}
String hash = sha256Prefix(value, 4);
String hash = sha256Prefix(value);
return prefix + "-" + hash;
}

Expand Down Expand Up @@ -528,18 +538,18 @@ public static String sanitize(String value) {
return new DataTableDescriptor(name, name, instanceName, "", group, Collections.emptyList());
}

private static String sha256Prefix(String input, int hexChars) {
private static String sha256Prefix(String input) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(input.getBytes(StandardCharsets.UTF_8));
StringBuilder hex = new StringBuilder();
for (byte b : hash) {
hex.append(String.format("%02x", b));
if (hex.length() >= hexChars) {
if (hex.length() >= 4) {
break;
}
}
return hex.substring(0, Math.min(hexChars, hex.length()));
return hex.substring(0, Math.min(4, hex.length()));
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
Expand Down
27 changes: 27 additions & 0 deletions rewrite-core/src/main/java/org/openrewrite/DataTableStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,36 @@ static DataTableStore noop() {
* @param dataTableName the fully qualified class name of the data table
* @param group the group identifying the bucket, or null for ungrouped
* @return a stream of rows, or an empty stream if no rows exist
* @deprecated Use {@link #getRows(Class)} or {@link #getRows(Class, String)} for type-safe deserialization.
*/
@Deprecated
Stream<?> getRows(String dataTableName, @Nullable String group);

/**
* Stream typed rows for a specific data table class and group.
* The row type is inferred from the data table's generic parameter.
*
* @param dataTableClass the data table class (e.g., {@code ServiceEndpoints.class})
* @param group the group identifying the bucket, or null for ungrouped
* @param <Row> the row type
* @return a stream of typed rows, or an empty stream if no rows exist
*/
@SuppressWarnings("unchecked")
default <Row> Stream<Row> getRows(Class<? extends DataTable<Row>> dataTableClass, @Nullable String group) {
return (Stream<Row>) getRows(dataTableClass.getName(), group);
}

/**
* Stream typed rows for a specific data table class (ungrouped).
*
* @param dataTableClass the data table class (e.g., {@code ServiceEndpoints.class})
* @param <Row> the row type
* @return a stream of typed rows, or an empty stream if no rows exist
*/
default <Row> Stream<Row> getRows(Class<? extends DataTable<Row>> dataTableClass) {
return getRows(dataTableClass, null);
}

/**
* Get the set of {@link DataTable} instances that have received rows.
*
Expand Down
17 changes: 17 additions & 0 deletions rewrite-core/src/main/java/org/openrewrite/RecipeRun.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,30 @@ public class RecipeRun {
return null;
}

/**
* @deprecated Use {@link #getDataTableRows(Class)} for type-safe deserialization.
*/
@Deprecated
public <E> List<E> getDataTableRows(String name) {
return getDataTableRows(name, null);
}

/**
* @deprecated Use {@link #getDataTableRows(Class, String)} for type-safe deserialization.
*/
@SuppressWarnings("unchecked")
@Deprecated
public <E> List<E> getDataTableRows(String name, @Nullable String group) {
return (List<E>) dataTableStore.getRows(name, group)
.collect(Collectors.toList());
}

public <E> List<E> getDataTableRows(Class<? extends DataTable<E>> dataTableClass) {
return getDataTableRows(dataTableClass, null);
}

public <E> List<E> getDataTableRows(Class<? extends DataTable<E>> dataTableClass, @Nullable String group) {
return dataTableStore.getRows(dataTableClass, group)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,8 @@ public List<String> getInitialValue(ExecutionContext ctx) {
// On cycle 2+, the store already contains rows written in cycle 1
List<String> readBack = new ArrayList<>();
DataTableStore store = DataTableExecutionContextView.view(ctx).getDataTableStore();
try (Stream<?> rows = store.getRows(table.getName(), null)) {
rows.forEach(row -> {
if (row instanceof TestTable.Row) {
readBack.add(((TestTable.Row) row).getName());
} else {
readBack.add(((String[]) row)[0]);
}
});
try (Stream<TestTable.Row> rows = store.getRows(TestTable.class)) {
rows.forEach(row -> readBack.add(row.getName()));
}
return readBack;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void verifyDataTable() {
rewriteRun(
spec -> spec.recipe(new FindFieldsOfType("java.util.List", true))
.afterRecipe(recipeRun -> {
List<FieldsOfTypeUses.Row> fields = recipeRun.getDataTableRows(FieldsOfTypeUses.class.getName());
List<FieldsOfTypeUses.Row> fields = recipeRun.getDataTableRows(FieldsOfTypeUses.class);
assertThat(fields).containsExactlyInAnyOrderElementsOf(expectedFields);
}),
java(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ public <E, V> RecipeSpec dataTableAsCsv(String name, String expect) {
}
}
assertThat(dataTable).isNotNull();
@SuppressWarnings("unchecked")
List<E> rows = (List<E>) store.getRows(dataTable.getName(), dataTable.getGroup())
List<?> rows = store.getRows(dataTable.getName(), dataTable.getGroup())
.collect(java.util.stream.Collectors.toList());
StringWriter writer = new StringWriter();
CsvMapper mapper = CsvMapper.builder()
Expand Down