diff --git a/samcli/cli/types.py b/samcli/cli/types.py index 261c3ee70e..802a1b3799 100644 --- a/samcli/cli/types.py +++ b/samcli/cli/types.py @@ -124,7 +124,7 @@ class CfnParameterOverridesType(click.ParamType): name = "list,object,string" - def convert(self, values, param, ctx, seen_files=None): + def convert(self, values, param, ctx, seen_files=None, current_file: Optional[Path] = None): """ Takes parameter overrides loaded from various supported config file formats and flattens and normalizes them into a dictionary where all keys and values are strings. @@ -143,6 +143,8 @@ def convert(self, values, param, ctx, seen_files=None): Click context for error reporting. seen_files : set List of files processed in the current execution branch, used to detect infinite recursion + current_file : Path + Path to the file currently being processed, used to resolve relative paths in nested includes Returns ------- @@ -169,6 +171,11 @@ def convert(self, values, param, ctx, seen_files=None): # If the string is a file reference (e.g., 'file://params.yaml') if value.startswith("file://"): file_path = Path(value[7:]) + # Resolve relative paths against current file's directory + if not file_path.is_absolute() and current_file: + file_path = current_file.parent / file_path + file_path = file_path.resolve() + if not file_path.is_file(): self.fail(f"{value} was not found or is a directory", param, ctx) file_manager = FILE_MANAGER_MAPPER.get(file_path.suffix, None) @@ -181,7 +188,7 @@ def convert(self, values, param, ctx, seen_files=None): seen_files.add(file_path) try: nested_values = file_manager.read(file_path) - parameters.update(self.convert(nested_values, param, ctx, seen_files)) + parameters.update(self.convert(nested_values, param, ctx, seen_files, file_path)) finally: seen_files.remove(file_path) else: @@ -205,6 +212,15 @@ def convert(self, values, param, ctx, seen_files=None): if v is None: # Unset value if previously set parameters[str(k)] = "" + elif k == "$include": + # Process includes (can be string or list) + if not isinstance(v, (str, list, CommentedSeq)): + self.fail( + f"$include must be a string or list of strings, got {type(v).__name__}", + param, + ctx, + ) + parameters.update(self.convert(v, param, ctx, seen_files, current_file)) elif isinstance(v, (list, CommentedSeq)): # Join list elements into comma-separated string, strip whitespace and ignore empty entries parameters[str(k)] = ",".join(str(x).strip() for x in v if x not in (None, "")) diff --git a/tests/unit/cli/test_types.py b/tests/unit/cli/test_types.py index b65ab3348a..b937f5e8b3 100644 --- a/tests/unit/cli/test_types.py +++ b/tests/unit/cli/test_types.py @@ -174,6 +174,18 @@ def test_successful_parsing(self, input, expected): ("file://nested.yaml",), {"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml"}, ), + ( + ("file://nested.toml",), + {"A": "toml", "B": "toml", "Toml": "toml", "Env": "production"}, + ), + ( + ("file://nested-list.toml",), + {"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml", "Env": "prod"}, + ), + ( + ("file://nested-list.yaml",), + {"A": "yaml", "B": "yaml", "Toml": "toml", "Yaml": "yaml", "Env": "prod"}, + ), ] ) def test_merge_file_parsing(self, inputs, expected): @@ -182,6 +194,9 @@ def test_merge_file_parsing(self, inputs, expected): "params.yaml": "A: yaml\nB: yaml\nYaml: yaml", "list.yaml": "- - - - - - A: a\n- List:\n - 1\n - 2\n - 3\n", "nested.yaml": "- file://params.toml\n- file://params.yaml", + "nested.toml": "'$include' = 'file://params.toml'\nEnv = 'production'", + "nested-list.toml": "'$include' = ['file://params.toml', 'file://params.yaml']\nEnv = 'prod'", + "nested-list.yaml": "$include:\n - file://params.toml\n - file://params.yaml\nEnv: prod", } def mock_read_text(file_path): @@ -196,6 +211,27 @@ def mock_is_file(file_path): print(result) self.assertEqual(result, expected, msg="Failed with Input = " + str(inputs)) + def test_include_infinite_recursion_protection(self): + mock_files = { + "A.yaml": "$include: file://B.yaml", + "B.yaml": "$include: file://C.yaml", + "C.yaml": "$include: file://A.yaml", + } + + def mock_read_text(file_path): + file_name = file_path.name + return mock_files.get(file_name, "") + + def mock_is_file(file_path): + return file_path.name in mock_files + + with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch( + "pathlib.Path.read_text", new=mock_read_text + ): + self.param_type.convert("file://A.yaml", None, MagicMock()) + + self.assertIn("Infinite recursion detected in file references", str(exception.exception)) + def test_infinite_recursion_protection(self): mock_files = { "A.yaml": "- file://B.yaml", @@ -213,10 +249,67 @@ def mock_is_file(file_path): with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch( "pathlib.Path.read_text", new=mock_read_text ): - self.param_type.convert(f"file://A.yaml", None, MagicMock()) + self.param_type.convert("file://A.yaml", None, MagicMock()) self.assertIn("Infinite recursion detected in file references", str(exception.exception)) + @parameterized.expand( + [ + ( + { + "config/default.yaml": "Base: default\nEnv: dev", + "config/prod.yaml": "- file://default.yaml\n- Env: production\n- Region: us-east-1", + }, + "config/prod.yaml", + {"Base": "default", "Env": "production", "Region": "us-east-1"}, + ), + ( + { + "config/shared/common.yaml": "Common: shared", + "config/prod.yaml": "- file://shared/common.yaml\n- Env: prod", + }, + "config/prod.yaml", + {"Common": "shared", "Env": "prod"}, + ), + ] + ) + def test_relative_path_nested_includes(self, file_contents, entry_file, expected): + """Test that nested files can reference other files using relative paths""" + from pathlib import Path + import tempfile + + with tempfile.TemporaryDirectory() as temp: + temp_path = Path(temp) + + # Create all files + for file_path, content in file_contents.items(): + full_path = temp_path / file_path + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.write_text(content) + + entry_path = temp_path / entry_file + result = self.param_type.convert(f"file://{entry_path}", None, MagicMock()) + self.assertEqual(result, expected) + + def test_include_key_invalid_type(self): + """Test that $include with invalid type fails with clear error""" + mock_files = { + "invalid.yaml": "$include: 123\nEnv: prod", + } + + def mock_read_text(file_path): + return mock_files.get(file_path.name, "") + + def mock_is_file(file_path): + return file_path.name in mock_files + + with self.assertRaises(BadParameter) as exception, patch("pathlib.Path.is_file", new=mock_is_file), patch( + "pathlib.Path.read_text", new=mock_read_text + ): + self.param_type.convert("file://invalid.yaml", None, MagicMock()) + + self.assertIn("$include must be a string or list of strings", str(exception.exception)) + class TestCfnMetadataType(TestCase): def setUp(self):