diff --git a/src/scout_apm/core/config.py b/src/scout_apm/core/config.py index 81f8eb97..2a7cb710 100644 --- a/src/scout_apm/core/config.py +++ b/src/scout_apm/core/config.py @@ -87,8 +87,10 @@ def log(self) -> None: "name", "revision_sha", "sample_rate", + "endpoint_sample_rate", "sample_endpoints", "sample_jobs", + "job_sample_rate", "scm_subdirectory", "shutdown_message_enabled", "shutdown_timeout_seconds", @@ -250,7 +252,9 @@ def __init__(self): "revision_sha": self._git_revision_sha(), "sample_rate": 100, "sample_endpoints": [], + "endpoint_sample_rate": None, "sample_jobs": [], + "job_sample_rate": None, "scm_subdirectory": "", "shutdown_message_enabled": True, "shutdown_timeout_seconds": 2.0, @@ -299,10 +303,13 @@ def convert_to_float(value: Any) -> float: return 0.0 -def convert_sample_rate(value: Any) -> int: +def convert_sample_rate(value: Any) -> Optional[int]: """ - Converts sample rate to integer, ensuring it's between 0 and 100 + Converts sample rate to integer, ensuring it's between 0 and 100. + Allows None as a valid value. """ + if value is None: + return None try: rate = int(value) if not (0 <= rate <= 100): @@ -374,7 +381,9 @@ def convert_endpoint_sampling(value: Union[str, Dict[str, Any]]) -> Dict[str, in "monitor": convert_to_bool, "sample_rate": convert_sample_rate, "sample_endpoints": convert_endpoint_sampling, + "endpoint_sample_rate": convert_sample_rate, "sample_jobs": convert_endpoint_sampling, + "job_sample_rate": convert_sample_rate, "shutdown_message_enabled": convert_to_bool, "shutdown_timeout_seconds": convert_to_float, } diff --git a/src/scout_apm/core/sampler.py b/src/scout_apm/core/sampler.py new file mode 100644 index 00000000..21d29b8a --- /dev/null +++ b/src/scout_apm/core/sampler.py @@ -0,0 +1,149 @@ +# coding=utf-8 + +import random +from typing import Dict, Optional, Tuple + + +class Sampler: + """ + Handles sampling decision logic for Scout APM. + + This class encapsulates all sampling-related functionality including: + - Loading and managing sampling configuration + - Pattern matching for operations (endpoints and jobs) + - Making sampling decisions based on operation type and patterns + """ + + # Constants for operation type detection + CONTROLLER_PREFIX = "Controller/" + JOB_PREFIX = "Job/" + + def __init__(self, config): + """ + Initialize sampler with Scout configuration. + + Args: + config: ScoutConfig instance containing sampling configuration + """ + self.config = config + self.sample_rate = config.value("sample_rate") + self.sample_endpoints = config.value("sample_endpoints") + self.sample_jobs = config.value("sample_jobs") + self.ignore_endpoints = set( + config.value("ignore_endpoints") + config.value("ignore") + ) + self.ignore_jobs = set(config.value("ignore_jobs")) + self.endpoint_sample_rate = config.value("endpoint_sample_rate") + self.job_sample_rate = config.value("job_sample_rate") + + def _any_sampling(self): + """ + Check if any sampling is enabled. + + Returns: + Boolean indicating if any sampling is enabled + """ + return ( + self.sample_rate < 100 + or self.sample_endpoints + or self.sample_jobs + or self.ignore_endpoints + or self.ignore_jobs + or self.endpoint_sample_rate is not None + or self.job_sample_rate is not None + ) + + def _find_matching_rate( + self, name: str, patterns: Dict[str, float] + ) -> Optional[str]: + """ + Finds the matching sample rate for a given operation name. + + Args: + name: The operation name to match + patterns: Dictionary of pattern to sample rate mappings + + Returns: + The sample rate for the matching pattern or None if no match found + """ + + for pattern, rate in patterns.items(): + if name.startswith(pattern): + return rate + return None + + def _get_operation_type_and_name( + self, operation: str + ) -> Tuple[Optional[str], Optional[str]]: + """ + Determines if an operation is an endpoint or job and extracts its name. + + Args: + operation: The full operation string (e.g. "Controller/users/show") + + Returns: + Tuple of (type, name) where type is either 'endpoint' or 'job', + and name is the operation name without the prefix + """ + if operation.startswith(self.CONTROLLER_PREFIX): + return "endpoint", operation[len(self.CONTROLLER_PREFIX) :] + elif operation.startswith(self.JOB_PREFIX): + return "job", operation[len(self.JOB_PREFIX) :] + else: + return None, None + + def get_effective_sample_rate(self, operation: str, is_ignored: bool) -> int: + """ + Determines the effective sample rate for a given operation. + + Prioritization: + 1. Sampling rate for specific endpoint or job + 2. Specified ignore pattern or flag for operation + 3. Global endpoint or job sample rate + 4. Global sample rate + + Args: + operation: The operation string (e.g. "Controller/users/show") + is_ignored: boolean for if the specific transaction is ignored + + Returns: + Integer between 0 and 100 representing sample rate + """ + op_type, name = self._get_operation_type_and_name(operation) + patterns = self.sample_endpoints if op_type == "endpoint" else self.sample_jobs + ignores = self.ignore_endpoints if op_type == "endpoint" else self.ignore_jobs + default_operation_rate = ( + self.endpoint_sample_rate if op_type == "endpoint" else self.job_sample_rate + ) + + if not op_type or not name: + return self.sample_rate + matching_rate = self._find_matching_rate(name, patterns) + if matching_rate is not None: + return matching_rate + for prefix in ignores: + if name.startswith(prefix) or is_ignored: + return 0 + if default_operation_rate is not None: + return default_operation_rate + + # Fall back to global sample rate + return self.sample_rate + + def should_sample(self, operation: str, is_ignored: bool) -> bool: + """ + Determines if an operation should be sampled. + If no sampling is enabled, always return True. + + Args: + operation: The operation string (e.g. "Controller/users/show" + or "Job/mailer") + + Returns: + Boolean indicating whether to sample this operation + """ + if not self._any_sampling(): + return True + return random.randint(1, 100) <= self.get_effective_sample_rate( + operation, is_ignored + ) diff --git a/tests/unit/core/test_sampler.py b/tests/unit/core/test_sampler.py new file mode 100644 index 00000000..330f6321 --- /dev/null +++ b/tests/unit/core/test_sampler.py @@ -0,0 +1,170 @@ +# coding=utf-8 + +from unittest import mock + +import pytest + +from scout_apm.core.config import ScoutConfig +from scout_apm.core.sampler import Sampler + + +@pytest.fixture +def config(): + config = ScoutConfig() + ScoutConfig.set( + sample_rate=50, # 50% global sampling + sample_endpoints={ + "users/test": 0, # Never sample specific endpoint + "users": 100, # Always sample + "test": 20, # 20% sampling for test endpoints + "health": 0, # Never sample health checks + }, + sample_jobs={ + "critical-job": 100, # Always sample + "batch": 30, # 30% sampling for batch jobs + }, + ignore_endpoints=["metrics", "ping"], + ignore_jobs=["test-job"], + endpoint_sample_rate=70, # 70% sampling for unspecified endpoints + job_sample_rate=40, # 40% sampling for unspecified jobs + ) + yield config + ScoutConfig.reset_all() + + +@pytest.fixture +def sampler(config): + return Sampler(config) + + +def test_should_sample_endpoint_always(sampler): + assert sampler.should_sample("Controller/users", False) is True + + +def test_should_sample_endpoint_never(sampler): + assert sampler.should_sample("Controller/health/check", False) is False + assert sampler.should_sample("Controller/users/test", False) is False + + +def test_should_sample_endpoint_ignored(sampler): + assert sampler.should_sample("Controller/metrics/some/more", False) is False + + +def test_should_sample_endpoint_partial(sampler): + with mock.patch("random.randint", return_value=10): + assert sampler.should_sample("Controller/test/endpoint", False) is True + with mock.patch("random.randint", return_value=30): + assert sampler.should_sample("Controller/test/endpoint", False) is False + + +def test_should_sample_job_always(sampler): + assert sampler.should_sample("Job/critical-job", False) is True + + +def test_should_sample_job_never(sampler): + assert sampler.should_sample("Job/test-job", False) is False + + +def test_should_sample_job_partial(sampler): + with mock.patch("random.randint", return_value=10): + assert sampler.should_sample("Job/batch-process", False) is True + with mock.patch("random.randint", return_value=40): + assert sampler.should_sample("Job/batch-process", False) is False + + +def test_should_sample_unknown_operation(sampler): + with mock.patch("random.randint", return_value=10): + assert sampler.should_sample("Unknown/operation", False) is True + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Unknown/operation", False) is False + + +def test_should_sample_no_sampling_enabled(config): + config.set( + sample_rate=100, # Return config to defaults + sample_endpoints={}, + sample_jobs={}, + ignore_endpoints=[], + ignore_jobs=[], + endpoint_sample_rate=None, + job_sample_rate=None, + ) + sampler = Sampler(config) + assert sampler.should_sample("Controller/any_endpoint", False) is True + assert sampler.should_sample("Job/any_job", False) is True + + +def test_should_sample_endpoint_default_rate(sampler): + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Controller/unspecified", False) is True + with mock.patch("random.randint", return_value=80): + assert sampler.should_sample("Controller/unspecified", False) is False + + +def test_should_sample_job_default_rate(sampler): + with mock.patch("random.randint", return_value=30): + assert sampler.should_sample("Job/unspecified-job", False) is True + with mock.patch("random.randint", return_value=50): + assert sampler.should_sample("Job/unspecified-job", False) is False + + +def test_should_sample_endpoint_fallback_to_global_rate(config): + config.set(endpoint_sample_rate=None) + sampler = Sampler(config) + with mock.patch("random.randint", return_value=40): + assert sampler.should_sample("Controller/unspecified", False) is True + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Controller/unspecified", False) is False + + +def test_should_sample_job_fallback_to_global_rate(config): + config.set(job_sample_rate=None) + sampler = Sampler(config) + with mock.patch("random.randint", return_value=40): + assert sampler.should_sample("Job/unspecified-job", False) is True + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Job/unspecified-job", False) is False + + +def test_should_handle_legacy_ignore_with_specific_sampling(config): + """Test that specific sampling rates override legacy ignore patterns.""" + config.set( + sample_endpoints={ + "foo/bar": 50, # Should override the ignore pattern for specific endpoint + "foo": 0, # Ignore all other foo endpoints + }, + ) + sampler = Sampler(config) + + # foo/bar should be sampled at 50% + with mock.patch("random.randint", return_value=40): + assert sampler.should_sample("Controller/foo/bar", False) is True + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Controller/foo/bar", False) is False + + # foo/other should be ignored (0% sampling) + assert sampler.should_sample("Controller/foo/other", False) is False + + +def test_prefix_matching_precedence(config): + """Test that longer prefix matches take precedence.""" + config.set( + sample_endpoints={ + "api/users/vip": 100, # Sample all VIP user endpoints + "api/users": 50, # Sample 50% of user endpoints + "api": 0, # Ignore all API endpoints by default + } + ) + sampler = Sampler(config) + + # Regular API endpoint should be ignored + assert sampler.should_sample("Controller/api/status", False) is False + + # Users API should be sampled at 50% + with mock.patch("random.randint", return_value=40): + assert sampler.should_sample("Controller/api/users/list", False) is True + with mock.patch("random.randint", return_value=60): + assert sampler.should_sample("Controller/api/users/list", False) is False + + # VIP users API should always be sampled + assert sampler.should_sample("Controller/api/users/vip/list", False) is True