Skip to content

Commit b90b274

Browse files
committed
Fix staging and resume for typed record inputs with nested Path fields
Signed-off-by: Stephen Kazakoff <sh.kazakoff@gmail.com>
1 parent 656ff4e commit b90b274

13 files changed

Lines changed: 314 additions & 24 deletions

File tree

modules/nextflow/src/main/groovy/nextflow/processor/TaskContext.groovy

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package nextflow.processor
1919
import nextflow.NF
2020
import nextflow.script.ScriptMeta
2121

22+
import java.lang.reflect.Modifier
2223
import java.nio.file.Path
2324
import java.nio.file.Paths
2425
import java.nio.file.Files
@@ -33,7 +34,9 @@ import groovyx.gpars.dataflow.DataflowWriteChannel
3334
import nextflow.Global
3435
import nextflow.exception.ProcessException
3536
import nextflow.script.ScriptBinding
37+
import nextflow.script.types.Record
3638
import nextflow.util.KryoHelper
39+
import nextflow.util.RecordMap
3740
import nextflow.util.TestOnly
3841
/**
3942
* Map used to delegate variable resolution to script scope
@@ -196,7 +199,7 @@ class TaskContext implements Map<String,Object>, Cloneable {
196199
map.remove(TaskProcessor.TASK_CONTEXT_PROPERTY_NAME)
197200
}
198201

199-
return KryoHelper.serialize(map)
202+
return KryoHelper.serialize(sanitizeValue(map))
200203
}
201204
catch( Exception e ) {
202205
log.warn "Cannot serialize context map. Cause: ${e.cause} -- Resume will not work on this process"
@@ -210,6 +213,25 @@ class TaskContext implements Map<String,Object>, Cloneable {
210213
new TaskContext(processor, map)
211214
}
212215

216+
private static Object sanitizeValue(Object value) {
217+
if( value instanceof Map ) {
218+
final normalized = value.collectEntries { k, v -> [k, sanitizeValue(v)] }
219+
return value instanceof RecordMap ? new RecordMap(normalized) : normalized
220+
}
221+
if( value instanceof Collection ) {
222+
return value.collect { item -> sanitizeValue(item) }
223+
}
224+
if( value instanceof Record ) {
225+
final fields = value.getClass().getFields()
226+
.findAll { field -> !Modifier.isStatic(field.modifiers) && !field.synthetic }
227+
.sort { it.name }
228+
return new RecordMap(fields.collectEntries { field ->
229+
[field.name, sanitizeValue(field.get(value))]
230+
})
231+
}
232+
return value
233+
}
234+
213235

214236
@PackageScope
215237
static String dumpMap( Map map ) {

modules/nextflow/src/main/groovy/nextflow/processor/TaskInputResolver.groovy

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package nextflow.processor
1818

19+
import java.lang.reflect.Modifier
1920
import java.nio.file.Path
2021
import java.util.regex.Matcher
2122
import java.util.regex.Pattern
@@ -32,6 +33,7 @@ import nextflow.file.LogicalDataPath
3233
import nextflow.script.ScriptType
3334
import nextflow.script.params.FileInParam
3435
import nextflow.script.params.v2.ProcessFileInput
36+
import nextflow.script.types.Record
3537
import nextflow.util.ArrayBag
3638
import nextflow.util.BlankSeparatedList
3739
import nextflow.util.RecordMap
@@ -111,14 +113,47 @@ class TaskInputResolver {
111113
return value.collect { el -> normalizeValue(el, holders) }
112114
}
113115

116+
if( value instanceof RecordMap ) {
117+
final normalized = new LinkedHashMap<String,Object>()
118+
for( final entry : value.entrySet() )
119+
normalized.put(entry.key, normalizeValue(entry.value, holders))
120+
return new RecordMap(normalized)
121+
}
122+
114123
if( value instanceof Map ) {
115124
final normalized = value.collectEntries { k, v -> [k, normalizeValue(v, holders)] }
116-
return value instanceof RecordMap ? new RecordMap(normalized as Map<String,?>) : normalized
125+
return normalized
126+
}
127+
128+
if( value instanceof Record ) {
129+
return normalizeRecord(value, holders)
117130
}
118131

119132
return value
120133
}
121134

135+
private Object normalizeRecord(Record value, Map<Path,FileHolder> holders) {
136+
final normalized = new LinkedHashMap<String,Object>()
137+
final fields = value.getClass().getFields()
138+
.findAll { field -> !Modifier.isStatic(field.modifiers) && !field.synthetic }
139+
.sort { it.name }
140+
141+
for( final field : fields )
142+
normalized.put(field.name, normalizeValue(field.get(value), holders))
143+
144+
try {
145+
final result = value.getClass().getDeclaredConstructor().newInstance()
146+
for( final field : fields ) {
147+
field.set(result, normalized[field.name])
148+
}
149+
return result
150+
}
151+
catch( Exception e ) {
152+
log.debug "Unable to normalize record as ${value.getClass().name}; using RecordMap instead", e
153+
return new RecordMap(normalized)
154+
}
155+
}
156+
122157
private Path normalizePath(Path value, Map<Path,FileHolder> holders) {
123158
return holders.containsKey(value)
124159
? new TaskPath(holders[value])

modules/nextflow/src/main/groovy/nextflow/util/SerializationHelper.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class KryoHelper {
182182
}
183183
finally {
184184
if( prev ) {
185-
kryo.setClassLoader(loader)
185+
kryo.setClassLoader(prev)
186186
}
187187
}
188188
}

modules/nextflow/src/test/groovy/nextflow/processor/TaskContextTest.groovy

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package nextflow.processor
1818

1919
import java.nio.file.Files
20+
import java.nio.file.Path
2021
import java.nio.file.Paths
2122

2223
import groovy.runtime.metaclass.ExtensionProvider
@@ -32,13 +33,19 @@ import nextflow.script.ScriptMeta
3233
import nextflow.util.BlankSeparatedList
3334
import nextflow.util.Duration
3435
import nextflow.util.MemoryUnit
36+
import nextflow.util.RecordMap
3537
import spock.lang.Specification
3638
/**
3739
*
3840
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
3941
*/
4042
class TaskContextTest extends Specification {
4143

44+
static class SampleRecord implements nextflow.script.types.Record {
45+
public String id
46+
public Path path
47+
}
48+
4249
def setupSpec() {
4350
NF.init()
4451
}
@@ -80,6 +87,27 @@ class TaskContextTest extends Specification {
8087

8188
}
8289

90+
def 'should serialize typed records as record maps' () {
91+
setup:
92+
def processor = Mock(TaskProcessor) {
93+
getTaskBody() >> new BodyDef(null,'source')
94+
}
95+
def path = Paths.get('/some/input.txt')
96+
def map = new TaskContext(processor, [:])
97+
map.sample = new SampleRecord(id: 'alpha', path: path)
98+
map.samples = [new SampleRecord(id: 'beta', path: path)]
99+
100+
when:
101+
def buffer = map.serialize()
102+
def result = TaskContext.deserialize(processor, buffer)
103+
104+
then:
105+
result.sample instanceof RecordMap
106+
result.sample == new RecordMap(id: 'alpha', path: path)
107+
result.samples[0] instanceof RecordMap
108+
result.samples[0] == new RecordMap(id: 'beta', path: path)
109+
}
110+
83111
def 'should dehydrate rehydrate'() {
84112

85113
setup:

modules/nextflow/src/test/groovy/nextflow/processor/TaskInputResolverTest.groovy

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ import spock.lang.Unroll
4040
*/
4141
class TaskInputResolverTest extends Specification {
4242

43+
static class SampleRecord implements nextflow.script.types.Record {
44+
public String id
45+
public Path path
46+
public List<Path> paths
47+
}
48+
4349
def holdersMap(List<FileHolder> holders) {
4450
final result = [:]
4551
for( final holder : holders )
@@ -172,6 +178,26 @@ class TaskInputResolverTest extends Specification {
172178
task.context.input.bam.toString() == 'input.bam'
173179
}
174180

181+
def 'should normalize typed record fields' () {
182+
given:
183+
def resolver = new TaskInputResolver(Mock(TaskRun), Mock(FilePorter.Batch), Mock(Executor))
184+
def source = Path.of('/some/input.txt')
185+
def holder = FileHolder.get(source, 'input.txt')
186+
def record = new SampleRecord(id: 'alpha', path: source, paths: [source])
187+
188+
when:
189+
def result = resolver.normalizeValue(record, holdersMap([holder]))
190+
191+
then:
192+
result instanceof SampleRecord
193+
!result.is(record)
194+
result.id == 'alpha'
195+
result.path instanceof TaskPath
196+
result.path.toString() == 'input.txt'
197+
result.paths[0] instanceof TaskPath
198+
result.paths[0].toString() == 'input.txt'
199+
}
200+
175201
def 'should return single item or collection'() {
176202

177203
setup:

modules/nf-commons/src/main/nextflow/util/HashBuilder.java

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.io.File;
2020
import java.io.IOException;
2121
import java.io.OutputStream;
22+
import java.lang.reflect.Field;
23+
import java.lang.reflect.Modifier;
2224
import java.nio.file.FileVisitResult;
2325
import java.nio.file.Files;
2426
import java.nio.file.Path;
@@ -29,6 +31,7 @@
2931
import java.util.HashMap;
3032
import java.util.Map;
3133
import java.util.Set;
34+
import java.util.TreeMap;
3235
import java.util.UUID;
3336
import java.util.concurrent.ExecutionException;
3437

@@ -47,6 +50,7 @@
4750
import nextflow.extension.FilesEx;
4851
import nextflow.io.SerializableMarker;
4952
import nextflow.script.types.Bag;
53+
import nextflow.script.types.Record;
5054
import org.slf4j.Logger;
5155
import org.slf4j.LoggerFactory;
5256
import static nextflow.Const.DEFAULT_ROOT;
@@ -149,20 +153,19 @@ else if( value instanceof Object[]) {
149153
else if( value instanceof CacheFunnel )
150154
((CacheFunnel)value).funnel(hasher, mode);
151155

152-
else if( value instanceof Map )
153-
hashUnorderedCollection(hasher, ((Map) value).entrySet(), mode);
156+
else if( value instanceof Map<?,?> map )
157+
hashUnorderedCollection(hasher, map.entrySet(), mode, basePath);
154158

155-
else if( value instanceof Map.Entry ) {
156-
Map.Entry entry = (Map.Entry)value;
159+
else if( value instanceof Map.Entry<?,?> entry ) {
157160
with(entry.getKey());
158161
with(entry.getValue());
159162
}
160163

161-
else if( value instanceof Bag || value instanceof Set )
162-
hashUnorderedCollection(hasher, (Collection) value, mode);
164+
else if( value instanceof Bag<?> || value instanceof Set<?> )
165+
hashUnorderedCollection(hasher, (Collection<?>) value, mode, basePath);
163166

164-
else if( value instanceof Collection)
165-
for( Object item : ((Collection)value) )
167+
else if( value instanceof Collection<?> collection)
168+
for( Object item : collection )
166169
with(item);
167170

168171
else if( value instanceof Path )
@@ -179,6 +182,9 @@ else if( value instanceof UUID ) {
179182
else if( value instanceof VersionNumber )
180183
hasher.putInt( value.hashCode() );
181184

185+
else if( value instanceof Record )
186+
hashUnorderedCollection(hasher, recordFields((Record)value).entrySet(), mode, basePath);
187+
182188
else if( value instanceof SerializableMarker)
183189
hasher.putInt( value.hashCode() );
184190

@@ -461,11 +467,11 @@ static HashCode hashContent( Path file, HashFunction function ) {
461467
return hashFileContent(hasher, file).hash();
462468
}
463469

464-
static private Hasher hashUnorderedCollection(Hasher hasher, Collection collection, HashMode mode) {
470+
static private Hasher hashUnorderedCollection(Hasher hasher, Collection<?> collection, HashMode mode, Path basePath) {
465471
byte[] resultBytes = new byte[HASH_BYTES];
466472
for (Object item : collection) {
467473
// hash ghe collection item
468-
byte[] nextBytes = hashBytes(item, mode);
474+
byte[] nextBytes = hashBytes(item, mode, basePath);
469475
// sum the hash bytes to the "resultBytes" accumulator
470476
// since the sum is a commutative operation the order does not matter
471477
sumBytes(resultBytes, nextBytes);
@@ -475,7 +481,33 @@ static private Hasher hashUnorderedCollection(Hasher hasher, Collection collecti
475481
}
476482

477483
static private byte[] hashBytes(Object item, HashMode mode) {
478-
return hasher(defaultHasher(), item, mode).hash().asBytes();
484+
return hashBytes(item, mode, null);
485+
}
486+
487+
static private byte[] hashBytes(Object item, HashMode mode, Path basePath) {
488+
return new HashBuilder()
489+
.withHasher(defaultHasher())
490+
.withMode(mode)
491+
.withBasePath(basePath)
492+
.with(item)
493+
.getHasher()
494+
.hash()
495+
.asBytes();
496+
}
497+
498+
static private Map<String,Object> recordFields(Record record) {
499+
final Map<String,Object> result = new TreeMap<>();
500+
for( Field field : record.getClass().getFields() ) {
501+
if( Modifier.isStatic(field.getModifiers()) || field.isSynthetic() )
502+
continue;
503+
try {
504+
result.put(field.getName(), field.get(record));
505+
}
506+
catch( IllegalAccessException e ) {
507+
throw new IllegalStateException("Unable to access record field: " + field.getName(), e);
508+
}
509+
}
510+
return result;
479511
}
480512

481513
/**

modules/nf-commons/src/test/nextflow/util/HashBuilderTest.groovy

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
package nextflow.util
1818

1919
import java.nio.file.Files
20+
import java.nio.file.Path
2021
import java.nio.file.Paths
2122

2223
import com.google.common.hash.Hashing
2324
import nextflow.Global
2425
import nextflow.Session
26+
import nextflow.script.types.Record
2527
import org.apache.commons.codec.digest.DigestUtils
2628
import spock.lang.Specification
2729
import test.TestHelper
@@ -31,6 +33,15 @@ import test.TestHelper
3133
*/
3234
class HashBuilderTest extends Specification {
3335

36+
static class SampleRecord implements Record {
37+
public String id
38+
public Path path
39+
}
40+
41+
static class OtherSampleRecord implements Record {
42+
public String id
43+
public Path path
44+
}
3445

3546
def testHashContent() {
3647
setup:
@@ -122,6 +133,19 @@ class HashBuilderTest extends Specification {
122133
HashBuilder.hashFileSha256Impl0(file) == DigestUtils.sha256Hex(file.bytes)
123134
}
124135

136+
def 'should hash records structurally'() {
137+
given:
138+
def file = TestHelper.createInMemTempFile('foo', 'Hello world')
139+
def record = new SampleRecord(id: 'alpha', path: file)
140+
def sameRecord = new OtherSampleRecord(id: 'alpha', path: file)
141+
def differentRecord = new SampleRecord(id: 'beta', path: file)
142+
143+
expect:
144+
new HashBuilder().with(record).build() == new HashBuilder().with(sameRecord).build()
145+
new HashBuilder().with(record).build() == new HashBuilder().with(new RecordMap(id: 'alpha', path: file)).build()
146+
new HashBuilder().with(record).build() != new HashBuilder().with(differentRecord).build()
147+
}
148+
125149
def 'should hash dir content with sha256'() {
126150
given:
127151
def folder = TestHelper.createInMemTempDir()

0 commit comments

Comments
 (0)