Skip to content

Commit 5637f59

Browse files
committed
Move back to models with protobuf
Signed-off-by: SimJeg <sjegou@nvidia.com>
1 parent 2060dfb commit 5637f59

File tree

5 files changed

+30
-25
lines changed

5 files changed

+30
-25
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@ jobs:
1010
test:
1111
runs-on: linux-amd64-gpu-l4-latest-1
1212
container:
13-
image: nvcr.io/nvidia/pytorch:25.10-py3
13+
image: nvidia/cuda:13.0.0-devel-ubuntu24.04
1414
steps:
1515
- uses: actions/checkout@v3
1616

17+
- name: Setup Python
18+
uses: actions/setup-python@v4
19+
with:
20+
python-version: 3.12
21+
1722
- name: Verify environment
1823
run: |
1924
nvidia-smi
2025
nvcc --version
2126
python3 --version
27+
echo "CUDA_HOME=/usr/local/cuda" >> $GITHUB_ENV
2228
- name: Install uv
2329
uses: astral-sh/setup-uv@v6
2430
with:
@@ -28,9 +34,6 @@ jobs:
2834
run: |
2935
uv sync --all-groups
3036
uv pip install torch==2.10
31-
env:
32-
UV_HTTP_TIMEOUT: 300
33-
3437
- run: make test
3538
env:
3639
HF_TOKEN: ${{ secrets.HF_TOKEN }}

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ dependencies = [
1616
# transformers<4.54 is not supported due to refactoring of the transformers library.
1717
# transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers
1818
"transformers>=4.56,<5.0.0",
19+
"sentencepiece>=0.2.0,<0.3",
20+
"protobuf>=5.27.2,<6",
1921
"datasets>=2.21.0,<3",
2022
"pandas>=2.2.2,<3",
2123
"accelerate>=1.0.0,<2",

tests/fixtures.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,38 @@ def get_device():
1414

1515
@pytest.fixture(scope="session")
1616
def unit_test_model():
17-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B").eval()
17+
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
1818
return model.to(get_device())
1919

2020

2121
@pytest.fixture(scope="session")
2222
def unit_test_model_output_attention():
2323
model = AutoModelForCausalLM.from_pretrained(
24-
"Qwen/Qwen3-0.6B", attn_implementation="eager"
24+
"MaxJeblick/llama2-0b-unit-test", attn_implementation="eager"
2525
).eval()
2626
return model.to(get_device())
2727

2828

2929
@pytest.fixture(scope="session")
30-
def qwen3_600m_model():
31-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B").eval()
30+
def danube_500m_model():
31+
model = AutoModelForCausalLM.from_pretrained("h2oai/h2o-danube3-500m-chat").eval()
3232
return model.to(get_device())
3333

3434

3535
@pytest.fixture(scope="session")
3636
def kv_press_unit_test_pipeline():
3737
return pipeline(
3838
"kv-press-text-generation",
39-
model="Qwen/Qwen3-0.6B",
39+
model="maxjeblick/llama2-0b-unit-test",
4040
device=get_device(),
4141
)
4242

4343

4444
@pytest.fixture(scope="session")
45-
def kv_press_qwen3_600m_pipeline():
45+
def kv_press_danube_pipeline():
4646
return pipeline(
4747
"kv-press-text-generation",
48-
model="Qwen/Qwen3-0.6B",
48+
model="h2oai/h2o-danube3-500m-chat",
4949
device=get_device(),
5050
)
5151

tests/test_decoding_compression.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_decoding_compression(token_buffer_size):
3131
"""Test that DecodingPress compresses the cache during decoding."""
3232

3333
# Initialize pipeline with a small model
34-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
34+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
3535

3636
# Create a DecodingPress with KnormPress
3737
press = DecodingPress(
@@ -65,7 +65,7 @@ def test_prefill_decoding_press_calls_both_phases():
6565
"""Test that PrefillDecodingPress calls both prefilling and decoding presses."""
6666

6767
# Initialize pipeline
68-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
68+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
6969

7070
# Create PrefillDecodingPress with both presses
7171
combined_press = PrefillDecodingPress(
@@ -99,7 +99,7 @@ def test_decoding_press_without_prefill():
9999
"""Test that DecodingPress works correctly when used standalone (no prefill compression)."""
100100

101101
# Initialize pipeline
102-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
102+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
103103

104104
# Create DecodingPress only
105105
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64)
@@ -129,7 +129,7 @@ def test_prefill_decoding_press_decoding_only():
129129
"""Test PrefillDecodingPress with only decoding press (no prefill compression)."""
130130

131131
# Initialize pipeline
132-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
132+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
133133

134134
# Create PrefillDecodingPress with only decoding press
135135
combined_press = PrefillDecodingPress(
@@ -167,7 +167,7 @@ def test_decoding_press_equivalence():
167167
torch.manual_seed(42)
168168

169169
# Initialize pipeline
170-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
170+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
171171

172172
# Create standalone decoding press
173173
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52)
@@ -222,7 +222,7 @@ def test_all_presses_work_with_decoding_press(press_config):
222222
"""Test that all default presses work as base presses for DecodingPress."""
223223

224224
# Initialize pipeline
225-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
225+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
226226

227227
# Get press class and use the first (easier) configuration
228228
press_cls = press_config["cls"]
@@ -274,7 +274,7 @@ def test_all_presses_work_with_decoding_press(press_config):
274274
def test_compression_actually_reduces_memory():
275275
"""Test that compression actually reduces memory usage compared to no compression."""
276276

277-
pipe = pipeline("kv-press-text-generation", model="Qwen/Qwen3-0.6B", device_map="auto")
277+
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
278278

279279
context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context
280280
question = "What animal jumps over the dog?"

tests/test_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from kvpress import ExpectedAttentionPress
1313
from kvpress.pipeline import KVPressTextGenerationPipeline
14-
from tests.fixtures import qwen3_600m_model # noqa: F401
15-
from tests.fixtures import kv_press_qwen3_600m_pipeline # noqa: F401
14+
from tests.fixtures import danube_500m_model # noqa: F401
15+
from tests.fixtures import kv_press_danube_pipeline # noqa: F401
1616
from tests.fixtures import unit_test_model # noqa: F401
1717
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_unit_test_pipeline # noqa: F401
1818

@@ -94,9 +94,9 @@ def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa:
9494
kv_press_unit_test_pipeline(context, question=question)
9595

9696

97-
def test_pipeline_answer_is_correct(qwen3_600m_model, caplog): # noqa: F811
97+
def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811
9898
with caplog.at_level(logging.DEBUG):
99-
answers = generate_answer(qwen3_600m_model)
99+
answers = generate_answer(danube_500m_model)
100100

101101
for answer in answers:
102102
assert answer == "This article was written on January 1, 2022."
@@ -107,13 +107,13 @@ def test_pipeline_answer_is_correct(qwen3_600m_model, caplog): # noqa: F811
107107

108108

109109
@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available")
110-
def test_pipeline_with_quantized_cache(kv_press_qwen3_600m_pipeline, caplog): # noqa: F811
110+
def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811
111111
with caplog.at_level(logging.DEBUG):
112112
context = "This is a test article. It was written on 2022-01-01."
113113
questions = ["When was this article written?"]
114114
press = ExpectedAttentionPress(compression_ratio=0.4)
115-
cache = QuantoQuantizedCache(config=kv_press_qwen3_600m_pipeline.model.config, nbits=4)
116-
answers = kv_press_qwen3_600m_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
115+
cache = QuantoQuantizedCache(config=kv_press_danube_pipeline.model.config, nbits=4)
116+
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
117117

118118
assert len(answers) == 1
119119
assert isinstance(answers[0], str)

0 commit comments

Comments
 (0)