1616# See the License for the specific language governing permissions and
1717# limitations under the License.
1818
19+ from calendar import c
1920import logging
2021import os
2122import time
4849logger .setLevel (logging .DEBUG ) # Capture all levels in the logger itself
4950
5051
51- MEM_REQUIREMENT_1B_GB = 18 # add 0.6 GB to max mem reserved, and round up
52- MEM_REQUIREMENT_7B_GB = 48
52+ def determine_memory_requirement_and_skip_if_not_met (
53+ ckpt_name : str , flash_decode : bool | None = None
54+ ) -> int :
55+ """Determine the memory requirement for a given checkpoint and flash decode condition.
56+ ckpt_name : str
57+ the name of the checkpoint to test
58+ flash_decode: bool | None
59+ whether to test with flash decode
60+ Returns:
61+ The input sequence length cap, for the model sin the checkpoint, given certain memory requirements.
62+ If the memory requirement is not met, the test is skipped.
63+ """
64+
65+ if "1b" in ckpt_name :
66+ model_size = "1b"
67+ seq_len_cap = 6000
68+ memory_needed_by_test = 17 # max reserved rounded up, for stand-alone test
69+
70+ elif "7b" in ckpt_name :
71+ model_size = "7b"
72+ seq_len_cap = 4000
73+ memory_needed_by_test = 32 # max reserved rounded up, for stand-alone test
74+ else :
75+ raise ValueError (f"{ ckpt_name = } is not supported for testing" )
76+
77+ skip_condition_flash = (flash_decode is None or flash_decode )
78+ gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
79+ skip_condition = gb_available < memory_needed_by_test and skip_condition_flash
5380
81+ if skip_condition :
82+ pytest .skip (
83+ ", " .join ([
84+ f"Inference API requires at least { memory_needed_by_test } GB of available memory for { model_size } models" ,
85+ f"{ gb_available = } "
86+ ])
87+ )
88+
89+ return seq_len_cap
5490
5591
5692def load_weights_sharded_inplace_nemo2_to_mcore (
@@ -369,11 +405,8 @@ def check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True):
369405)
370406def test_forward (sequences : list [str ], ckpt_name : str , expected_matchpercents : list [float ]):
371407 assert len (sequences ) > 0
372- gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
373- if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name ) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name ):
374- pytest .skip (
375- f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. { gb_available = } "
376- )
408+ seq_len_cap = determine_memory_requirement_and_skip_if_not_met (ckpt_name )
409+
377410 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
378411 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
379412 if skip :
@@ -385,7 +418,7 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
385418 )
386419 matchrates = []
387420 for seq in sequences :
388- seq = seq [:6000 ] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
421+ seq = seq [:seq_len_cap ] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
389422 with torch .no_grad ():
390423 device = torch .cuda .current_device ()
391424 tokens = torch .tensor ([mcore_tokenizer .tokenize (seq )], device = device )
@@ -431,13 +464,12 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
431464)
432465def test_forward_manual (sequences : list [str ], ckpt_name : str , expected_matchpercents : list [float ], flash_decode : bool ):
433466 assert len (sequences ) > 0
467+ seq_len_cap = determine_memory_requirement_and_skip_if_not_met (ckpt_name , flash_decode )
468+
434469 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
435470 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
436- gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
437- if (gb_available < MEM_REQUIREMENT_1B_GB and flash_decode ) or (gb_available < MEM_REQUIREMENT_7B_GB and flash_decode and "7b" in ckpt_name ):
438- pytest .skip (
439- f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. { gb_available = } "
440- )
471+
472+
441473 vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
442474 if skip :
443475 # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
@@ -484,7 +516,7 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc
484516 forward_kwargs = {}
485517 matchrates = []
486518 for seq in sequences :
487- seq = seq [:6000 ] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
519+ seq = seq [:seq_len_cap ] # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences
488520 with torch .no_grad ():
489521 device = torch .cuda .current_device ()
490522 # tokens = torch.tensor([tokenizer.tokenize(seq)], device=device)
@@ -547,12 +579,9 @@ def test_batch_generate(
547579 sequences : list [str ], ckpt_name : str , model_tokenizer_provider : Callable , expected_matchpercents : list [float ]
548580):
549581 assert len (sequences ) > 0
582+ determine_memory_requirement_and_skip_if_not_met (ckpt_name )
583+
550584 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
551- gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
552- if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name ) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name ):
553- pytest .skip (
554- f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. { gb_available = } "
555- )
556585 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
557586 if skip :
558587 # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
@@ -619,11 +648,8 @@ def test_batch_generate_coding_sequences(
619648 expected_matchpercents : list [float ],
620649):
621650 assert len (coding_sequences ) > 0
622- gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
623- if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name ) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name ):
624- pytest .skip (
625- f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. { gb_available = } "
626- )
651+ determine_memory_requirement_and_skip_if_not_met (ckpt_name )
652+
627653 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
628654 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
629655 if skip :
@@ -728,11 +754,8 @@ def test_generate_speed(
728754 expected_tokens_sec : float ,
729755):
730756 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
731- gb_available = torch .cuda .mem_get_info ()[0 ] / 1024 ** 3
732- if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name ) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name ):
733- pytest .skip (
734- f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. { gb_available = } "
735- )
757+ determine_memory_requirement_and_skip_if_not_met (ckpt_name )
758+
736759 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
737760 if skip :
738761 # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
0 commit comments