Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions samcli/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class CfnParameterOverridesType(click.ParamType):

name = "list,object,string"

def convert(self, values, param, ctx, seen_files=None):
def convert(self, values, param, ctx, seen_files=None, current_file: Optional[Path] = None):
"""
Takes parameter overrides loaded from various supported config file formats and
flattens and normalizes them into a dictionary where all keys and values are strings.
Expand All @@ -143,6 +143,8 @@ def convert(self, values, param, ctx, seen_files=None):
Click context for error reporting.
seen_files : set
List of files processed in the current execution branch, used to detect infinite recursion
current_file : Path
Path to the file currently being processed, used to resolve relative paths in nested includes

Returns
-------
Expand All @@ -169,6 +171,11 @@ def convert(self, values, param, ctx, seen_files=None):
# If the string is a file reference (e.g., 'file://params.yaml')
if value.startswith("file://"):
file_path = Path(value[7:])
# Resolve relative paths against current file's directory
if not file_path.is_absolute() and current_file:
file_path = current_file.parent / file_path
file_path = file_path.resolve()

if not file_path.is_file():
self.fail(f"{value} was not found or is a directory", param, ctx)
file_manager = FILE_MANAGER_MAPPER.get(file_path.suffix, None)
Expand All @@ -181,7 +188,7 @@ def convert(self, values, param, ctx, seen_files=None):
seen_files.add(file_path)
try:
nested_values = file_manager.read(file_path)
parameters.update(self.convert(nested_values, param, ctx, seen_files))
parameters.update(self.convert(nested_values, param, ctx, seen_files, file_path))
finally:
seen_files.remove(file_path)
else:
Expand All @@ -205,6 +212,15 @@ def convert(self, values, param, ctx, seen_files=None):
if v is None:
# Unset value if previously set
parameters[str(k)] = ""
elif k == "$include":
# Process includes (can be string or list)
if not isinstance(v, (str, list, CommentedSeq)):
self.fail(
f"$include must be a string or list of strings, got {type(v).__name__}",
param,
ctx,
)
parameters.update(self.convert(v, param, ctx, seen_files, current_file))
elif isinstance(v, (list, CommentedSeq)):
# Join list elements into comma-separated string, strip whitespace and ignore empty entries
parameters[str(k)] = ",".join(str(x).strip() for x in v if x not in (None, ""))
Expand Down
95 changes: 94 additions & 1 deletion tests/unit/cli/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def test_successful_parsing(self, input, expected):
("file://nested.yaml",),
{"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml"},
),
(
("file://nested.toml",),
{"A": "toml", "B": "toml", "Toml": "toml", "Env": "production"},
),
(
("file://nested-list.toml",),
{"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml", "Env": "prod"},
),
(
("file://nested-list.yaml",),
{"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml", "Env": "prod"},
),
]
)
def test_merge_file_parsing(self, inputs, expected):
Expand All @@ -182,6 +194,9 @@ def test_merge_file_parsing(self, inputs, expected):
"params.yaml": "A: yaml\nB: yaml\nYaml: yaml",
"list.yaml": "- - - - - - A: a\n- List:\n - 1\n - 2\n - 3\n",
"nested.yaml": "- file://params.toml\n- file://params.yaml",
"nested.toml": "'$include' = 'file://params.toml'\nEnv = 'production'",
"nested-list.toml": "'$include' = ['file://params.toml', 'file://params.yaml']\nEnv = 'prod'",
"nested-list.yaml": "$include:\n - file://params.toml\n - file://params.yaml\nEnv: prod",
}

def mock_read_text(file_path):
Expand All @@ -196,6 +211,27 @@ def mock_is_file(file_path):
print(result)
self.assertEqual(result, expected, msg="Failed with Input = " + str(inputs))

def test_include_infinite_recursion_protection(self):
mock_files = {
"A.yaml": "$include: file://B.yaml",
"B.yaml": "$include: file://C.yaml",
"C.yaml": "$include: file://A.yaml",
}

def mock_read_text(file_path):
file_name = file_path.name
return mock_files.get(file_name, "")

def mock_is_file(file_path):
return file_path.name in mock_files

with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch(
"pathlib.Path.read_text", new=mock_read_text
):
self.param_type.convert("file://A.yaml", None, MagicMock())

self.assertIn("Infinite recursion detected in file references", str(exception.exception))

def test_infinite_recursion_protection(self):
mock_files = {
"A.yaml": "- file://B.yaml",
Expand All @@ -213,10 +249,67 @@ def mock_is_file(file_path):
with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch(
"pathlib.Path.read_text", new=mock_read_text
):
self.param_type.convert(f"file://A.yaml", None, MagicMock())
self.param_type.convert("file://A.yaml", None, MagicMock())

self.assertIn("Infinite recursion detected in file references", str(exception.exception))

@parameterized.expand(
[
(
{
"config/default.yaml": "Base: default\nEnv: dev",
"config/prod.yaml": "- file://default.yaml\n- Env: production\n- Region: us-east-1",
},
"config/prod.yaml",
{"Base": "default", "Env": "production", "Region": "us-east-1"},
),
(
{
"config/shared/common.yaml": "Common: shared",
"config/prod.yaml": "- file://shared/common.yaml\n- Env: prod",
},
"config/prod.yaml",
{"Common": "shared", "Env": "prod"},
),
]
)
def test_relative_path_nested_includes(self, file_contents, entry_file, expected):
"""Test that nested files can reference other files using relative paths"""
from pathlib import Path
import tempfile

with tempfile.TemporaryDirectory() as temp:
temp_path = Path(temp)

# Create all files
for file_path, content in file_contents.items():
full_path = temp_path / file_path
full_path.parent.mkdir(parents=True, exist_ok=True)
full_path.write_text(content)

entry_path = temp_path / entry_file
result = self.param_type.convert(f"file://{entry_path}", None, MagicMock())
self.assertEqual(result, expected)

def test_include_key_invalid_type(self):
"""Test that $include with invalid type fails with clear error"""
mock_files = {
"invalid.yaml": "$include: 123\nEnv: prod",
}

def mock_read_text(file_path):
return mock_files.get(file_path.name, "")

def mock_is_file(file_path):
return file_path.name in mock_files

with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch(
"pathlib.Path.read_text", new=mock_read_text
):
self.param_type.convert("file://invalid.yaml", None, MagicMock())

self.assertIn("$include must be a string or list of strings", str(exception.exception))


class TestCfnMetadataType(TestCase):
def setUp(self):
Expand Down
Loading