Skip to content

Commit 02032ec

Browse files
committed
wip invenio multipart
1 parent 7455395 commit 02032ec

3 files changed

Lines changed: 327 additions & 0 deletions

File tree

lib/galaxy/files/sources/_rdm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525
class RDMFileSourceTemplateConfiguration(BaseFileSourceTemplateConfiguration):
2626
token: Optional[Union[str, TemplateExpansion]] = None
2727
public_name: Optional[Union[str, TemplateExpansion]] = None
28+
multipart_threshold: Optional[Union[int, TemplateExpansion]] = None # bytes
29+
multipart_chunk_size: Optional[Union[int, TemplateExpansion]] = None # bytes
2830

2931

3032
class RDMFileSourceConfiguration(BaseFileSourceConfiguration):
3133
token: Optional[str] = None
3234
public_name: Optional[str] = None
35+
multipart_threshold: Optional[int] = None # bytes
36+
multipart_chunk_size: Optional[int] = None # bytes
3337

3438

3539
class ContainerAndFileIdentifier(NamedTuple):

lib/galaxy/files/sources/invenio.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import datetime
22
import json
3+
import logging
4+
import math
5+
import os
36
import re
47
import urllib.request
8+
from concurrent.futures import (
9+
as_completed,
10+
ThreadPoolExecutor,
11+
)
512
from typing import (
613
Any,
714
cast,
@@ -11,6 +18,8 @@
1118
from urllib.error import HTTPError
1219
from urllib.parse import quote
1320

21+
log = logging.getLogger(__name__)
22+
1423
from typing_extensions import (
1524
TypedDict,
1625
)
@@ -104,6 +113,53 @@ class RecordLinks(TypedDict):
104113
reserve_doi: str
105114

106115

116+
# AWS S3 multipart limits (used by Invenio RDM)
117+
MIN_UPLOAD_PART_SIZE = 50 * 1024 * 1024 # 50 MiB
118+
MAX_UPLOAD_PART_SIZE = 5 * 1024**3 # 5 GiB
119+
MAX_UPLOAD_PARTS = 10_000
120+
121+
122+
def calculate_multipart_params(file_size: int, preferred_part_size: int | None = None) -> tuple[int, int]:
123+
"""Calculate optimal parts count and part size for multipart upload.
124+
125+
Args:
126+
file_size: Total file size in bytes
127+
preferred_part_size: Preferred part size in bytes (optional)
128+
129+
Returns:
130+
Tuple of (parts_count, part_size)
131+
132+
Note:
133+
Maximum uploadable file size is MAX_UPLOAD_PARTS * MAX_UPLOAD_PART_SIZE (~48.8 TiB).
134+
Files larger than this will still return valid params but would fail server-side.
135+
"""
136+
if file_size == 0:
137+
return 1, 0
138+
139+
# Start with preferred or minimum part size
140+
part_size = preferred_part_size or MIN_UPLOAD_PART_SIZE
141+
142+
# Ensure part_size is within bounds
143+
part_size = max(part_size, MIN_UPLOAD_PART_SIZE)
144+
part_size = min(part_size, MAX_UPLOAD_PART_SIZE)
145+
146+
# Calculate parts needed
147+
parts = math.ceil(file_size / part_size)
148+
149+
# If too many parts, increase part size (up to max)
150+
while parts > MAX_UPLOAD_PARTS and part_size < MAX_UPLOAD_PART_SIZE:
151+
part_size = min(part_size * 2, MAX_UPLOAD_PART_SIZE)
152+
parts = math.ceil(file_size / part_size)
153+
154+
# For extremely large files, cap parts at MAX_UPLOAD_PARTS
155+
# This means part_size may effectively be larger than calculated
156+
# but such files would likely fail server-side anyway
157+
if parts > MAX_UPLOAD_PARTS:
158+
parts = MAX_UPLOAD_PARTS
159+
160+
return parts, part_size
161+
162+
107163
class InvenioRecord(TypedDict):
108164
id: str
109165
title: str
@@ -331,6 +387,25 @@ def upload_file_to_draft_container(
331387
file_path: str,
332388
context: FilesSourceRuntimeContext[RDMFileSourceConfiguration],
333389
):
390+
file_size = os.path.getsize(file_path)
391+
threshold = context.config.multipart_threshold
392+
393+
use_multipart = threshold is not None and threshold > 0 and file_size >= threshold
394+
395+
if use_multipart:
396+
log.info(f"Using multipart upload for file '{filename}' ({file_size} bytes >= threshold {threshold})")
397+
self._upload_file_multipart(record_id, filename, file_path, file_size, context)
398+
else:
399+
self._upload_file_single(record_id, filename, file_path, context)
400+
401+
def _upload_file_single(
402+
self,
403+
record_id: str,
404+
filename: str,
405+
file_path: str,
406+
context: FilesSourceRuntimeContext[RDMFileSourceConfiguration],
407+
):
408+
"""Upload a file using single PUT request."""
334409
record = self._get_draft_record(record_id, context)
335410
upload_file_url = record["links"]["files"]
336411
headers = self._get_request_headers(context, auth_required=True)
@@ -352,6 +427,136 @@ def upload_file_to_draft_container(
352427
response = requests.post(commit_file_upload_url, headers=headers)
353428
self._ensure_response_has_expected_status_code(response, 200)
354429

430+
def _upload_file_multipart(
431+
self,
432+
record_id: str,
433+
filename: str,
434+
file_path: str,
435+
file_size: int,
436+
context: FilesSourceRuntimeContext[RDMFileSourceConfiguration],
437+
):
438+
"""Upload a file using multipart upload.
439+
440+
Flow:
441+
1. Calculate parts/part_size
442+
2. POST with transfer metadata
443+
3. Server returns links.parts[] with URL for each part
444+
4. Upload parts (parallel for > 2 parts)
445+
5. POST to commit URL
446+
"""
447+
preferred_part_size = context.config.multipart_chunk_size
448+
num_parts, part_size = calculate_multipart_params(file_size, preferred_part_size)
449+
450+
log.info(f"Multipart upload: {num_parts} parts of {part_size} bytes each for '{filename}'")
451+
452+
record = self._get_draft_record(record_id, context)
453+
upload_file_url = record["links"]["files"]
454+
headers = self._get_request_headers(context, auth_required=True)
455+
456+
# Initialize multipart upload with transfer metadata
457+
file_metadata = {
458+
"key": filename,
459+
"size": file_size,
460+
"transfer": {
461+
"type": "M",
462+
"parts": num_parts,
463+
"part_size": part_size,
464+
},
465+
}
466+
response = requests.post(upload_file_url, json=[file_metadata], headers=headers)
467+
self._ensure_response_has_expected_status_code(response, 201)
468+
469+
# Get part upload URLs from response
470+
entries = response.json()["entries"]
471+
file_entry = next(entry for entry in entries if entry["key"] == filename)
472+
commit_url = file_entry["links"]["commit"]
473+
part_links = file_entry.get("links", {}).get("parts", [])
474+
475+
if len(part_links) != num_parts:
476+
raise Exception(
477+
f"Server returned {len(part_links)} part URLs but expected {num_parts} for file '{filename}'"
478+
)
479+
480+
# Upload parts
481+
self._upload_parts(file_path, file_size, part_size, part_links, headers)
482+
483+
# Commit multipart upload
484+
response = requests.post(commit_url, json={}, headers=headers)
485+
self._ensure_response_has_expected_status_code(response, 200)
486+
log.info(f"Multipart upload completed for '{filename}'")
487+
488+
def _upload_parts(
489+
self,
490+
file_path: str,
491+
file_size: int,
492+
part_size: int,
493+
part_links: list[dict],
494+
headers: dict,
495+
):
496+
"""Upload all parts, sequentially for <=2 parts, parallel otherwise."""
497+
num_parts = len(part_links)
498+
499+
if num_parts <= 2:
500+
# Sequential upload for small number of parts
501+
for part_index, part_info in enumerate(part_links):
502+
self._upload_single_part(file_path, file_size, part_size, part_index, part_info, headers)
503+
else:
504+
# Parallel upload for larger number of parts
505+
max_workers = min(4, num_parts)
506+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
507+
futures = {}
508+
for part_index, part_info in enumerate(part_links):
509+
future = executor.submit(
510+
self._upload_single_part,
511+
file_path,
512+
file_size,
513+
part_size,
514+
part_index,
515+
part_info,
516+
headers,
517+
)
518+
futures[future] = part_index
519+
520+
for future in as_completed(futures):
521+
part_index = futures[future]
522+
try:
523+
future.result()
524+
except Exception as e:
525+
log.error(f"Failed to upload part {part_index}: {e}")
526+
raise
527+
528+
def _upload_single_part(
529+
self,
530+
file_path: str,
531+
file_size: int,
532+
part_size: int,
533+
part_index: int,
534+
part_info: dict,
535+
headers: dict,
536+
):
537+
"""Upload a single part of a multipart upload."""
538+
part_url = part_info.get("url")
539+
if not part_url:
540+
raise Exception(f"No URL provided for part {part_index}")
541+
542+
# Calculate byte range for this part
543+
start_byte = part_index * part_size
544+
end_byte = min(start_byte + part_size, file_size)
545+
part_content_length = end_byte - start_byte
546+
547+
log.debug(f"Uploading part {part_index}: bytes {start_byte}-{end_byte-1} ({part_content_length} bytes)")
548+
549+
with open(file_path, "rb") as f:
550+
f.seek(start_byte)
551+
part_data = f.read(part_content_length)
552+
553+
part_headers = headers.copy()
554+
part_headers["Content-Length"] = str(part_content_length)
555+
part_headers["Content-Type"] = "application/octet-stream"
556+
557+
response = requests.put(part_url, data=part_data, headers=part_headers)
558+
self._ensure_response_has_expected_status_code(response, 200)
559+
355560
def download_file_from_container(
356561
self,
357562
container_id: str,
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Unit tests for Invenio multipart upload functionality."""
2+
3+
import pytest
4+
5+
from galaxy.files.sources.invenio import (
6+
calculate_multipart_params,
7+
MAX_UPLOAD_PART_SIZE,
8+
MAX_UPLOAD_PARTS,
9+
MIN_UPLOAD_PART_SIZE,
10+
)
11+
12+
13+
class TestCalculateMultipartParams:
14+
"""Tests for calculate_multipart_params function."""
15+
16+
def test_calculate_multipart_params_zero_byte(self):
17+
"""Zero-byte files should return (1, 0)."""
18+
parts, part_size = calculate_multipart_params(0)
19+
assert parts == 1
20+
assert part_size == 0
21+
22+
def test_calculate_multipart_params_small_file(self):
23+
"""Files under 50 MiB should use minimum part size."""
24+
# 10 MiB file
25+
file_size = 10 * 1024 * 1024
26+
parts, part_size = calculate_multipart_params(file_size)
27+
assert parts == 1
28+
assert part_size == MIN_UPLOAD_PART_SIZE
29+
30+
def test_calculate_multipart_params_medium_file(self):
31+
"""Files between 50 MiB and 100 MiB."""
32+
# 75 MiB file
33+
file_size = 75 * 1024 * 1024
34+
parts, part_size = calculate_multipart_params(file_size)
35+
assert parts == 2
36+
assert part_size == MIN_UPLOAD_PART_SIZE
37+
38+
def test_calculate_multipart_params_large_file(self):
39+
"""Large files requiring multiple parts."""
40+
# 250 MiB file
41+
file_size = 250 * 1024 * 1024
42+
parts, part_size = calculate_multipart_params(file_size)
43+
assert parts == 5
44+
assert part_size == MIN_UPLOAD_PART_SIZE
45+
46+
def test_calculate_multipart_params_respects_max_parts(self):
47+
"""Very large files should not exceed MAX_UPLOAD_PARTS."""
48+
# File larger than MAX_UPLOAD_PARTS * MIN_UPLOAD_PART_SIZE
49+
file_size = MAX_UPLOAD_PARTS * MIN_UPLOAD_PART_SIZE + MIN_UPLOAD_PART_SIZE
50+
parts, part_size = calculate_multipart_params(file_size)
51+
assert parts <= MAX_UPLOAD_PARTS
52+
assert part_size >= MIN_UPLOAD_PART_SIZE
53+
54+
def test_calculate_multipart_params_extremely_large_file(self):
55+
"""Extremely large files should hit both MAX limits.
56+
57+
Note: Files larger than MAX_UPLOAD_PARTS * MAX_UPLOAD_PART_SIZE (~48.8 TiB)
58+
cannot be uploaded via multipart, but we cap params rather than fail here.
59+
The upload would fail server-side anyway.
60+
"""
61+
# 100 TiB file - exceeds theoretical maximum (~48.8 TiB)
62+
file_size = 100 * 1024**4 # 100 TiB
63+
parts, part_size = calculate_multipart_params(file_size)
64+
# Parts should be capped at MAX_UPLOAD_PARTS
65+
assert parts == MAX_UPLOAD_PARTS
66+
# Part size should hit max
67+
assert part_size == MAX_UPLOAD_PART_SIZE
68+
69+
def test_calculate_multipart_params_respects_preferred_part_size(self):
70+
"""Should use preferred part size when provided and valid."""
71+
# 150 MiB file with 100 MiB preferred part size
72+
file_size = 150 * 1024 * 1024
73+
preferred_part_size = 100 * 1024 * 1024 # 100 MiB
74+
parts, part_size = calculate_multipart_params(file_size, preferred_part_size)
75+
assert parts == 2
76+
assert part_size == preferred_part_size
77+
78+
def test_calculate_multipart_params_preferred_too_small(self):
79+
"""Should use minimum part size if preferred is too small."""
80+
# 100 MiB file with 1 MiB preferred part size (too small)
81+
file_size = 100 * 1024 * 1024
82+
preferred_part_size = 1 * 1024 * 1024 # 1 MiB - too small
83+
parts, part_size = calculate_multipart_params(file_size, preferred_part_size)
84+
assert part_size == MIN_UPLOAD_PART_SIZE # Should be bumped to minimum
85+
86+
def test_calculate_multipart_params_preferred_exceeds_max(self):
87+
"""Should cap at MAX_UPLOAD_PART_SIZE if preferred exceeds it."""
88+
# Small file with huge preferred part size
89+
file_size = 100 * 1024 * 1024
90+
preferred_part_size = MAX_UPLOAD_PART_SIZE * 2 # Exceeds max
91+
parts, part_size = calculate_multipart_params(file_size, preferred_part_size)
92+
assert parts == 1
93+
assert part_size == MAX_UPLOAD_PART_SIZE
94+
95+
def test_calculate_multipart_params_exact_multiple(self):
96+
"""File size that's an exact multiple of part size."""
97+
# Exactly 3 * MIN_UPLOAD_PART_SIZE
98+
file_size = 3 * MIN_UPLOAD_PART_SIZE
99+
parts, part_size = calculate_multipart_params(file_size)
100+
assert parts == 3
101+
assert part_size == MIN_UPLOAD_PART_SIZE
102+
103+
def test_calculate_multipart_params_one_byte_over(self):
104+
"""File size one byte over an exact multiple."""
105+
# 3 * MIN_UPLOAD_PART_SIZE + 1 byte
106+
file_size = 3 * MIN_UPLOAD_PART_SIZE + 1
107+
parts, part_size = calculate_multipart_params(file_size)
108+
assert parts == 4 # Need 4 parts for 3 full + 1 byte
109+
assert part_size == MIN_UPLOAD_PART_SIZE
110+
111+
def test_calculate_multipart_params_at_boundary(self):
112+
"""Test file at MAX_UPLOAD_PARTS boundary."""
113+
# Exactly at the boundary where we need to increase part size
114+
file_size = (MAX_UPLOAD_PARTS + 1) * MIN_UPLOAD_PART_SIZE
115+
parts, part_size = calculate_multipart_params(file_size)
116+
assert parts <= MAX_UPLOAD_PARTS
117+
# Part size should have increased
118+
assert part_size > MIN_UPLOAD_PART_SIZE

0 commit comments

Comments
 (0)