Skip to content

Commit 6e2a005

Browse files
committed
br: adjust sequence length cap for fixed memory requirements"
Signed-off-by: Brian Roland <broland@nvidia.com>
1 parent 03ef6e6 commit 6e2a005

1 file changed

Lines changed: 52 additions & 29 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_evo2.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
from calendar import c
1920
import logging
2021
import os
2122
import time
@@ -48,9 +49,44 @@
4849
logger.setLevel(logging.DEBUG) # Capture all levels in the logger itself
4950

5051

51-
MEM_REQUIREMENT_1B_GB = 18 # add 0.6 GB to max mem reserved, and round up
52-
MEM_REQUIREMENT_7B_GB = 48
52+
def determine_memory_requirement_and_skip_if_not_met(
53+
ckpt_name: str, flash_decode: bool | None = None
54+
) -> int:
55+
"""Determine the memory requirement for a given checkpoint and flash decode condition.
56+
ckpt_name : str
57+
the name of the checkpoint to test
58+
flash_decode: bool | None
59+
whether to test with flash decode
60+
Returns:
61+
The input sequence length cap, for the model sin the checkpoint, given certain memory requirements.
62+
If the memory requirement is not met, the test is skipped.
63+
"""
64+
65+
if "1b" in ckpt_name:
66+
model_size = "1b"
67+
seq_len_cap = 6000
68+
memory_needed_by_test = 17 # max reserved rounded up, for stand-alone test
69+
70+
elif "7b" in ckpt_name:
71+
model_size = "7b"
72+
seq_len_cap = 4000
73+
memory_needed_by_test = 32 # max reserved rounded up, for stand-alone test
74+
else:
75+
raise ValueError(f"{ckpt_name=} is not supported for testing")
76+
77+
skip_condition_flash = (flash_decode is None or flash_decode)
78+
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
79+
skip_condition = gb_available < memory_needed_by_test and skip_condition_flash
5380

81+
if skip_condition:
82+
pytest.skip(
83+
", ".join([
84+
f"Inference API requires at least {memory_needed_by_test}GB of available memory for {model_size} models",
85+
f"{gb_available=}"
86+
])
87+
)
88+
89+
return seq_len_cap
5490

5591

5692
def load_weights_sharded_inplace_nemo2_to_mcore(
@@ -369,11 +405,8 @@ def check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True):
369405
)
370406
def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float]):
371407
assert len(sequences) > 0
372-
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
373-
if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name):
374-
pytest.skip(
375-
f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}"
376-
)
408+
seq_len_cap = determine_memory_requirement_and_skip_if_not_met(ckpt_name)
409+
377410
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
378411
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
379412
if skip:
@@ -385,7 +418,7 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
385418
)
386419
matchrates = []
387420
for seq in sequences:
388-
seq = seq[:6000] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
421+
seq = seq[:seq_len_cap] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
389422
with torch.no_grad():
390423
device = torch.cuda.current_device()
391424
tokens = torch.tensor([mcore_tokenizer.tokenize(seq)], device=device)
@@ -431,13 +464,12 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
431464
)
432465
def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float], flash_decode: bool):
433466
assert len(sequences) > 0
467+
seq_len_cap = determine_memory_requirement_and_skip_if_not_met(ckpt_name, flash_decode)
468+
434469
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
435470
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
436-
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
437-
if (gb_available < MEM_REQUIREMENT_1B_GB and flash_decode) or (gb_available < MEM_REQUIREMENT_7B_GB and flash_decode and "7b" in ckpt_name):
438-
pytest.skip(
439-
f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}"
440-
)
471+
472+
441473
vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
442474
if skip:
443475
# This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
@@ -484,7 +516,7 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc
484516
forward_kwargs = {}
485517
matchrates = []
486518
for seq in sequences:
487-
seq = seq[:6000] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
519+
seq = seq[:seq_len_cap] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
488520
with torch.no_grad():
489521
device = torch.cuda.current_device()
490522
# tokens = torch.tensor([tokenizer.tokenize(seq)], device=device)
@@ -547,12 +579,9 @@ def test_batch_generate(
547579
sequences: list[str], ckpt_name: str, model_tokenizer_provider: Callable, expected_matchpercents: list[float]
548580
):
549581
assert len(sequences) > 0
582+
determine_memory_requirement_and_skip_if_not_met(ckpt_name)
583+
550584
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
551-
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
552-
if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name):
553-
pytest.skip(
554-
f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}"
555-
)
556585
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
557586
if skip:
558587
# This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
@@ -619,11 +648,8 @@ def test_batch_generate_coding_sequences(
619648
expected_matchpercents: list[float],
620649
):
621650
assert len(coding_sequences) > 0
622-
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
623-
if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name):
624-
pytest.skip(
625-
f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}"
626-
)
651+
determine_memory_requirement_and_skip_if_not_met(ckpt_name)
652+
627653
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
628654
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
629655
if skip:
@@ -728,11 +754,8 @@ def test_generate_speed(
728754
expected_tokens_sec: float,
729755
):
730756
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
731-
gb_available = torch.cuda.mem_get_info()[0] / 1024**3
732-
if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name):
733-
pytest.skip(
734-
f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}"
735-
)
757+
determine_memory_requirement_and_skip_if_not_met(ckpt_name)
758+
736759
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
737760
if skip:
738761
# This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.

0 commit comments

Comments
 (0)