Skip to content

Commit 7d21d69

Browse files
committed
Add the vulnerability report
closes: #6773
1 parent b24f501 commit 7d21d69

7 files changed

Lines changed: 347 additions & 0 deletions

File tree

CHANGES/6773.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added the vulnerability report data model.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from django.db import models
2+
3+
from pulpcore.plugin.models import BaseModel
4+
from pulpcore.plugin.util import get_domain_pk
5+
6+
7+
class VulnerabilityReport(BaseModel):
8+
"""
9+
Model used in vulnerability report.
10+
"""
11+
12+
vulns = models.JSONField()
13+
pulp_domain = models.ForeignKey("core.Domain", default=get_domain_pk, on_delete=models.CASCADE)
14+
15+
class Meta:
16+
default_related_name = "%(app_label)s_%(model_name)s"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from rest_framework import serializers
2+
3+
from pulpcore.app.models.vulnerability_report import VulnerabilityReport
4+
from pulpcore.plugin.serializers import IdentityField, ModelSerializer
5+
6+
7+
class VulnerabilityReportSerializer(ModelSerializer):
8+
"""
9+
A serializer for the VulnerabilityReport Model.
10+
"""
11+
12+
vulns = serializers.JSONField()
13+
pulp_href = IdentityField(view_name="vuln_report-detail")
14+
15+
class Meta:
16+
model = VulnerabilityReport
17+
fields = ModelSerializer.Meta.fields + ("vulns",)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import aiohttp
2+
import json
3+
4+
from asgiref.sync import sync_to_async
5+
from queue import Empty, Queue
6+
from threading import Thread
7+
from typing import Callable, Optional, Tuple, Dict, Any
8+
9+
10+
from pulpcore.plugin.util import get_domain
11+
from pulpcore.plugin.models import CreatedResource
12+
from pulpcore.constants import (
13+
OSV_QUERY_URL,
14+
VULNERABILITY_TASK_THREAD_TIMEOUT,
15+
)
16+
from pulpcore.app.models.vulnerability_report import VulnerabilityReport
17+
18+
# Create a thread-safe queue to share Content units between threads
19+
content_queue = Queue()
20+
21+
22+
async def check_content(func: Callable, args: Optional[Tuple] = None) -> None:
23+
"""
24+
Start the background_thread and make the API requests (scan) to osv.dev with the packages
25+
from Queue.
26+
27+
Args:
28+
func (callable | str): The function to populate content_queue Queue with the content_units.
29+
Each item in the queue must follow the osv.dev request data format.
30+
args (tuple): The positional arguments to pass on to the func.
31+
"""
32+
if not args:
33+
args = ()
34+
background_thread = Thread(target=func, args=args)
35+
await _scan_packages(background_thread)
36+
background_thread.join()
37+
38+
39+
async def _scan_packages(background_thread: Thread) -> None:
40+
"""
41+
Makes a request to the osv.dev API and store the results in VulnerabilityReport model.
42+
43+
Args:
44+
background_thread (Thread): We need to pass the thread object used to populate the queue to
45+
prevent deadlock issues.
46+
"""
47+
scanned_packages = await _scan_packages_from_queue(
48+
background_thread=background_thread,
49+
)
50+
await _save_vulnerability_report(scanned_packages)
51+
52+
53+
async def _scan_packages_from_queue(
54+
background_thread: Thread, http_client: Optional[aiohttp.ClientSession] = None
55+
) -> Dict[str, Any]:
56+
"""
57+
Scans packages from a queue by making HTTP requests to OSV API.
58+
59+
Args:
60+
background_thread (Thread): Thread populating the queue
61+
http_client (Optional[aiohttp.ClientSession]): HTTP client for making requests
62+
63+
Returns:
64+
Dict[str, Any]: Dictionary mapping package names to vulnerability data
65+
"""
66+
67+
# Use provided client or create a new one
68+
if http_client:
69+
return await _process_queue_with_client(background_thread, http_client)
70+
71+
async with aiohttp.ClientSession() as session:
72+
return await _process_queue_with_client(background_thread, session)
73+
74+
75+
async def _process_queue_with_client(
76+
background_thread: Thread,
77+
session: aiohttp.ClientSession,
78+
) -> Dict[str, Any]:
79+
"""
80+
Process queue items using the provided HTTP client session.
81+
82+
Args:
83+
background_thread (Thread): Thread populating the queue
84+
session (aiohttp.ClientSession): HTTP client session
85+
86+
Returns:
87+
Dict[str, Any]: Dictionary mapping package names to vulnerability data
88+
"""
89+
scanned_packages = {}
90+
try:
91+
for osv_data in iter(
92+
lambda: content_queue.get(timeout=VULNERABILITY_TASK_THREAD_TIMEOUT), None
93+
):
94+
if isinstance(osv_data, Exception):
95+
raise RuntimeError(f"Background vuln report task failed to execute: {osv_data}")
96+
97+
vulnerability_data = await _query_osv_api(session, osv_data)
98+
package_name = _get_package_name(osv_data)
99+
100+
if vulnerability_data.get("vulns"):
101+
scanned_packages[package_name] = vulnerability_data["vulns"]
102+
103+
# Handle pagination
104+
if next_page_token := vulnerability_data.get("next_page_token"):
105+
next_page_request = build_osv_data(
106+
osv_data["package"]["name"],
107+
osv_data["package"]["ecosystem"],
108+
osv_data.get("version"),
109+
next_page_token,
110+
)
111+
content_queue.put(next_page_request)
112+
113+
except Empty:
114+
if not background_thread.is_alive():
115+
raise RuntimeError("Vuln report task thread died unexpectedly.")
116+
else:
117+
raise RuntimeError("Background vuln report thread took too long.")
118+
119+
return scanned_packages
120+
121+
122+
async def _query_osv_api(
123+
session: aiohttp.ClientSession, osv_data: Dict[str, Any]
124+
) -> Dict[str, Any]:
125+
"""
126+
Make a single request to the OSV API.
127+
128+
Args:
129+
session (aiohttp.ClientSession): HTTP client session
130+
osv_data (Dict[str, Any]): OSV query data
131+
132+
Returns:
133+
Dict[str, Any]: JSON response from OSV API
134+
"""
135+
data = json.dumps(osv_data)
136+
async with session.post(url=OSV_QUERY_URL, data=data) as response:
137+
response_body = await response.text()
138+
return json.loads(response_body)
139+
140+
141+
def _get_package_name(osv_data: Dict[str, Any]) -> str:
142+
"""
143+
Extract package name from OSV data.
144+
145+
Args:
146+
osv_data (Dict[str, Any]): OSV query data
147+
148+
Returns:
149+
str: Formatted package name
150+
"""
151+
osv_package_name = osv_data["package"]["name"]
152+
osv_package_version = osv_data.get("version", "")
153+
return "{package}-{version}".format(
154+
package=osv_package_name,
155+
version=osv_package_version,
156+
)
157+
158+
159+
async def _save_vulnerability_report(scanned_packages: Dict[str, Any]) -> None:
160+
"""
161+
Save vulnerability report to the database.
162+
163+
Args:
164+
scanned_packages (Dict[str, Any]): Dictionary mapping package names to vulnerability data
165+
"""
166+
vuln_report, created = await sync_to_async(VulnerabilityReport.objects.get_or_create)(
167+
vulns=scanned_packages, pulp_domain=get_domain()
168+
)
169+
if created:
170+
await CreatedResource.objects.acreate(content_object=vuln_report)
171+
172+
173+
def build_osv_data(
174+
name: str, ecosystem: str, version: Optional[str] = None, next_page_token: Optional[str] = None
175+
) -> Dict[str, Any]:
176+
"""
177+
Helper function to build the osv.dev request data based on content object
178+
179+
Args:
180+
name (str): Name of the package. Should match the name used in the package ecosystem.
181+
ecosystem (str): The ecosystem for this package. Check osv.dev documentation for
182+
the complete list of valid ecosystem names.
183+
version (str): The version string to query for.
184+
next_page_token (Optional[str]): If the previous query fetched a large number of results,
185+
the response will be paginated.
186+
187+
Returns (Dict[str, Any]): The package which vulnerability will be checked. A dicitonary following the
188+
expected payload format from osv.dev.
189+
"""
190+
osv_data: Dict[str, Any] = {"package": {"name": name, "ecosystem": ecosystem}}
191+
if version:
192+
osv_data["version"] = version
193+
if next_page_token:
194+
osv_data["page_token"] = next_page_token
195+
return osv_data
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from rest_framework.mixins import DestroyModelMixin, ListModelMixin, RetrieveModelMixin
2+
3+
from pulpcore.app.models.vulnerability_report import VulnerabilityReport as VulnReport
4+
from pulpcore.app.serializers.vulnerability_report import VulnerabilityReportSerializer
5+
from pulpcore.plugin.viewsets import NamedModelViewSet
6+
7+
8+
class VulnerabilityReport(NamedModelViewSet, ListModelMixin, RetrieveModelMixin, DestroyModelMixin):
9+
10+
endpoint_name = "vuln_report"
11+
queryset = VulnReport.objects.all()
12+
serializer_class = VulnerabilityReportSerializer

pulpcore/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,9 @@
127127
# The upper boundary represents an unsigned 32-bit integer and prevents overflow
128128
ORPHAN_PROTECTION_TIME_LOWER_BOUND = 0
129129
ORPHAN_PROTECTION_TIME_UPPER_BOUND = 4294967295 # (2^32)-1
130+
131+
# OSV API URL
132+
OSV_QUERY_URL = "https://api.osv.dev/v1/query"
133+
134+
# Timeout when waiting on tasks scan thread queue to avoid indefinite blocking.
135+
VULNERABILITY_TASK_THREAD_TIMEOUT = 60
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
import json
3+
4+
from queue import Queue
5+
from threading import Thread
6+
from unittest.mock import Mock, AsyncMock, patch
7+
8+
from pulpcore.app.tasks.vulnerability_report import (
9+
build_osv_data,
10+
_query_osv_api,
11+
_process_queue_with_client,
12+
)
13+
from pulpcore.constants import OSV_QUERY_URL
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_process_queue_with_client():
18+
"""Test successful processing of queue items with vulnerabilities."""
19+
20+
osv_data = {"package": {"name": "django", "ecosystem": "PyPI"}, "version": "5.1"}
21+
vulnerability_data = {"vulns": [{"id": "GHSA-test-1234", "summary": "Test vulnerability"}]}
22+
23+
# Mock background thread as alive
24+
mock_thread = Mock(spec=Thread)
25+
mock_thread.is_alive.return_value = True
26+
27+
# Mock session
28+
mock_session = Mock()
29+
30+
# Create a test queue and patch the global content_queue
31+
test_queue = Queue()
32+
test_queue.put(osv_data)
33+
test_queue.put(None) # Sentinel value to notify no more items in queue
34+
35+
with patch("pulpcore.app.tasks.vulnerability_report.content_queue", test_queue), patch(
36+
"pulpcore.app.tasks.vulnerability_report._query_osv_api", return_value=vulnerability_data
37+
) as mock_query, patch(
38+
"pulpcore.app.tasks.vulnerability_report._get_package_name", return_value="django-5.1"
39+
) as mock_get_name:
40+
41+
result = await _process_queue_with_client(mock_thread, mock_session)
42+
assert result == {"django-5.1": vulnerability_data["vulns"]}
43+
mock_query.assert_called_once_with(mock_session, osv_data)
44+
mock_get_name.assert_called_once_with(osv_data)
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_query_osv_api():
49+
"""Test OSV API query with mocked HTTP response."""
50+
# Arrange
51+
osv_data = {"package": {"name": "django", "ecosystem": "PyPI"}, "version": "5.1"}
52+
expected_response = {
53+
"vulns": [
54+
{
55+
"id": "GHSA-abcd-1234",
56+
"summary": "Test vulnerability",
57+
"details": "This is a test vulnerability",
58+
}
59+
]
60+
}
61+
62+
# Mock the HTTP response
63+
mock_response = AsyncMock()
64+
mock_response.text.return_value = json.dumps(expected_response)
65+
66+
# Mock the session and its context manager
67+
mock_session = Mock()
68+
mock_session.post.return_value = AsyncMock()
69+
mock_session.post.return_value.__aenter__.return_value = mock_response
70+
mock_session.post.return_value.__aexit__.return_value = None
71+
72+
result = await _query_osv_api(mock_session, osv_data)
73+
74+
assert result == expected_response
75+
mock_session.post.assert_called_once_with(url=OSV_QUERY_URL, data=json.dumps(osv_data))
76+
mock_response.text.assert_called_once()
77+
78+
79+
def test_build_osv_data():
80+
"""Test if the osv_data is correctly built with the required fields."""
81+
pkg_name = "test"
82+
pkg_ecosystem = "npm"
83+
osv_data = build_osv_data(pkg_name, pkg_ecosystem)
84+
assert osv_data == {"package": {"name": pkg_name, "ecosystem": pkg_ecosystem}}
85+
86+
# verify if the osv_data is correctly built with the optional field version
87+
version = "beta"
88+
osv_data = build_osv_data(pkg_name, pkg_ecosystem, version)
89+
assert osv_data == {
90+
"package": {"name": pkg_name, "ecosystem": pkg_ecosystem},
91+
"version": version,
92+
}
93+
94+
# verify if the osv_data is correctly built with the optional field next_page_token
95+
token = "1234"
96+
osv_data = build_osv_data(pkg_name, pkg_ecosystem, next_page_token=token)
97+
assert osv_data == {
98+
"package": {"name": pkg_name, "ecosystem": pkg_ecosystem},
99+
"page_token": token,
100+
}

0 commit comments

Comments
 (0)