Skip to content

Commit 28a9be4

Browse files
deacon-mpuruwhyFiona McCrae
authored
fix: sanitize object IDs to prevent path traversal in BaseApiManager (#3299)
* fix(security): sanitize id field to prevent path traversal in objectives API * fix sanitization * Added warning log for when ID is sanitized, and corresponding test * style fix --------- Co-authored-by: Daniel Matthews <58484522+uruwhy@users.noreply.github.com> Co-authored-by: Fiona McCrae <fmccrae@mitre.org>
1 parent c928eb2 commit 28a9be4

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

app/api/v2/managers/base_api_manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import logging
22
import os
3+
import re
34
import uuid
45
import yaml
56

67
from marshmallow.schema import SchemaMeta
78
from typing import Any, List
89
from base64 import b64encode, b64decode
910

11+
from app.api.v2.errors import DataValidationError
1012
from app.utility.base_world import BaseWorld
1113

1214

@@ -64,6 +66,7 @@ def create_object_from_schema(self, schema: SchemaMeta, data: dict, access: Base
6466

6567
async def create_on_disk_object(self, data: dict, access: dict, ram_key: str, id_property: str, obj_class: type):
6668
obj_id = data.get(id_property) or str(uuid.uuid4())
69+
obj_id = self._sanitize_id(obj_id)
6770
data[id_property] = obj_id
6871

6972
file_path = await self._get_new_object_file_path(data[id_property], ram_key)
@@ -121,18 +124,34 @@ async def remove_object_from_memory_by_id(self, identifier: str, ram_key: str, i
121124
await self._data_svc.remove(ram_key, {id_property: identifier})
122125

123126
async def remove_object_from_disk_by_id(self, identifier: str, ram_key: str):
127+
identifier = self._sanitize_id(identifier)
124128
file_path = await self._get_existing_object_file_path(identifier, ram_key)
125129

126130
if os.path.exists(file_path):
127131
os.remove(file_path)
128132

133+
@staticmethod
134+
def _sanitize_id(obj_id) -> str:
135+
'''Removes any non-alphanumeric characters and non-hyphen/underscore.'''
136+
if not isinstance(obj_id, str):
137+
raise DataValidationError(message=f'Invalid id type: expected str, got {type(obj_id).__name__}', name='id', value=obj_id)
138+
original_id = obj_id
139+
obj_id = re.sub(r'[^a-zA-Z0-9_-]', '', obj_id)
140+
if not obj_id:
141+
raise DataValidationError(message=f"Invalid id: {obj_id!r}", name='id', value=obj_id)
142+
if original_id != obj_id:
143+
logging.getLogger(DEFAULT_LOGGER_NAME).warning(f"Sanitized ID: {obj_id}")
144+
return obj_id
145+
129146
@staticmethod
130147
async def _get_new_object_file_path(identifier: str, ram_key: str) -> str:
131148
"""Create file path for new object"""
149+
identifier = BaseApiManager._sanitize_id(identifier)
132150
return os.path.join('data', ram_key, f'{identifier}.yml')
133151

134152
async def _get_existing_object_file_path(self, identifier: str, ram_key: str) -> str:
135153
"""Find file path for existing object (by id)"""
154+
identifier = self._sanitize_id(identifier)
136155
_, file_path = await self._file_svc.find_file_path(f'{identifier}.yml', location=ram_key)
137156
if not file_path:
138157
file_path = await self._get_new_object_file_path(identifier, ram_key)

tests/api/v2/managers/test_base_api_manager.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import pytest
12
import marshmallow as ma
23

4+
from app.api.v2.errors import DataValidationError
35
from app.api.v2.managers.base_api_manager import BaseApiManager
46
from app.objects.interfaces.i_object import FirstClassObjectInterface
57
from app.utility.base_object import BaseObject
@@ -248,3 +250,28 @@ def test_replace_object(agent):
248250

249251
assert len(stub_data_svc.ram['tests']) == 1
250252
assert not stub_data_svc.ram['tests'][0].value
253+
254+
255+
def test_sanitize_id():
256+
valid = '766be199-7316-4b26-b3db-e272aaf7e0d4'
257+
assert valid == BaseApiManager._sanitize_id(valid)
258+
assert valid.upper() == BaseApiManager._sanitize_id(valid.upper())
259+
assert valid == BaseApiManager._sanitize_id('../.././&%$!"#766be19:9-73[16-]4b}26-b{3d!b-e272..\\//aaf/*7e0d4')
260+
assert 'testid123TEST' == BaseApiManager._sanitize_id('testid123TEST')
261+
with pytest.raises(DataValidationError):
262+
BaseApiManager._sanitize_id('../../.')
263+
with pytest.raises(DataValidationError):
264+
BaseApiManager._sanitize_id('')
265+
# Non-string IDs should raise a DataValidationError
266+
with pytest.raises(DataValidationError):
267+
BaseApiManager._sanitize_id(12345)
268+
269+
270+
def test_sanitize_id_logs_warning_when_changed(caplog):
271+
# Capture warnings when an ID is mutated by sanitization
272+
caplog.set_level('WARNING')
273+
original = 'abc/def?ghi'
274+
sanitized = BaseApiManager._sanitize_id(original)
275+
assert sanitized == 'abcdefghi'
276+
# Ensure a warning was emitted that includes the sanitized ID
277+
assert any('Sanitized ID' in rec.getMessage() and sanitized in rec.getMessage() for rec in caplog.records)

0 commit comments

Comments
 (0)