Skip to content

[BUG] SizePartitioner is not handling test_range properly #2430

@Fabio-Trindade

Description

@Fabio-Trindade

test_range can be interpreted as described in:

test_range (int or float or str, optional): The size of the partial
test dataset to load.
If None, the entire test dataset will be loaded.
If int or float, the random partial dataset will be loaded with the
specified size.
If str, the partial dataset will be loaded with the
specified index list (e.g. "[:100]" for the first 100 examples,
"[100:200]" for the second 100 examples, etc.). Defaults to None.

but, the actual_size computation assumes:

actual_size = eval('len(range(self.dataset_size[dataset_abbr])'
f'{test_range})')

based on the dataset logic

total_size = len(dataset)
index_list = list(range(total_size))
if isinstance(size, (int, float)):
if size >= total_size or size <= 0:
return dataset
if size > 0 and size < 1:
size = int(size * total_size)
rand = random.Random(x=size)
rand.shuffle(index_list)
dataset = dataset.select(index_list[:size])
elif isinstance(size, str):
dataset = dataset.select(eval(f'index_list{size}'))

if my understanding is correct, actual_size should be computed similarly to:

        def get_size(test_range, total_size):
            if test_range is None:
                return total_size

            if isinstance(test_range, (int, float)):
                if test_range <= 0:
                    return total_size

                if isinstance(test_range, float) and test_range > 0 and test_range  < 1:
                    return int(test_range * total_size)

                return min(int(test_range), total_size)

            if isinstance(test_range, str):
                rng = test_range.strip().strip("[]")
                
                if ":" in rng:
                    start, end = rng.split(":")

                    start = int(start) if start.strip() else 0
                    end = int(end) if end.strip() else total_size

                    if start < 0:
                        start += total_size
                    if end < 0:
                        end += total_size

                    start = max(0, min(start, total_size))
                    end = max(0, min(end, total_size))

                    if start > end:
                        assert(False)

                    return end - start

                raise ValueError(f"Invalid range format: {test_range}")

            raise TypeError(f"Unsupported type: {type(test_range)}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions