Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 111 additions & 13 deletions nexum_ai/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import logging
import numpy as np
import time
from typing import Optional, List, Dict, Any
import json
import os
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, similarity_threshold: float = 0.95, cache_file: str = "semant
self.cache: List[Dict] = []
self.similarity_threshold = similarity_threshold
self.model = None
self.max_age_seconds: Optional[float] = None # None = no TTL

# Support environment variable for cache file path
cache_file_env = os.environ.get('NEXUMDB_CACHE_FILE', cache_file)
Expand Down Expand Up @@ -115,11 +117,58 @@ def cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:

return float(dot_product / (norm1 * norm2))

def _is_entry_expired(self, entry: Dict, now: Optional[float] = None) -> bool:
"""Check if a cache entry has exceeded its TTL.

Args:
entry: Cache entry dict, expected to contain a 'timestamp' key.
now: Current time as a Unix timestamp. If *None*, ``time.time()``
is called. Callers iterating over many entries should snapshot
the current time once and pass it in to avoid redundant
syscalls and subtle inconsistencies.

Returns:
True if the entry is expired, False otherwise.
Entries without a timestamp are never considered expired.
"""
if self.max_age_seconds is None:
return False
timestamp = entry.get('timestamp')
if timestamp is None:
# Legacy entries without a timestamp are kept (not expired)
return False
if now is None:
now = time.time()
return (now - timestamp) > self.max_age_seconds

def _evict_expired(self) -> int:
"""Remove all expired cache entries.

Returns:
Number of entries removed.
"""
if self.max_age_seconds is None:
return 0
now = time.time()
before = len(self.cache)
self.cache = [e for e in self.cache if not self._is_entry_expired(e, now=now)]
removed = before - len(self.cache)
if removed > 0:
logger.info(f"Evicted {removed} expired cache entries")
return removed

def get(self, query: str) -> Optional[str]:
"""Retrieve cached result if similar query exists"""
"""Retrieve cached result if similar query exists.

Expired entries (based on TTL) are skipped during lookup.
"""
query_vec = self.vectorize(query)
now = time.time()

for entry in self.cache:
# Skip expired entries
if self._is_entry_expired(entry, now=now):
continue
similarity = self.cosine_similarity(query_vec, entry['vector'])
if similarity >= self.similarity_threshold:
logger.info(f"Cache hit! Similarity: {similarity:.4f}")
Expand All @@ -128,12 +177,13 @@ def get(self, query: str) -> Optional[str]:
return None

def put(self, query: str, result: str) -> None:
"""Store query and result in cache"""
"""Store query and result in cache with a creation timestamp."""
query_vec = self.vectorize(query)
self.cache.append({
'query': query,
'vector': query_vec,
'result': result
'result': result,
'timestamp': time.time()
})
logger.info(f"Cached query: {query[:50]}...")

Expand Down Expand Up @@ -248,7 +298,8 @@ def save_cache_json(self, filepath: Optional[str] = None) -> None:
'cache': self.cache,
'similarity_threshold': self.similarity_threshold,
'cache_size': len(self.cache),
'format_version': '1.0'
'format_version': '1.1',
'max_age_seconds': self.max_age_seconds,
}

with open(filepath, 'w') as f:
Expand Down Expand Up @@ -278,8 +329,16 @@ def load_cache_json(self, filepath: Optional[str] = None) -> None:

self.cache = data.get('cache', [])
self.similarity_threshold = data.get('similarity_threshold', self.similarity_threshold)

# Restore persisted TTL setting (if any)
saved_max_age = data.get('max_age_seconds')
if saved_max_age is not None:
self.max_age_seconds = float(saved_max_age)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

logger.info(f"Semantic cache loaded from JSON: {filepath} ({len(self.cache)} entries)")

# Evict entries that became stale while the process was down
self._evict_expired()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

except Exception:
logger.exception("Error loading cache from JSON")
Expand All @@ -288,14 +347,26 @@ def load_cache_json(self, filepath: Optional[str] = None) -> None:
logger.debug(f"No JSON cache file found at {filepath}")

def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
return {
"""Get cache statistics including TTL information."""
try:
cache_size_bytes = self.cache_path.stat().st_size
except OSError:
cache_size_bytes = 0

stats: Dict[str, Any] = {
'total_entries': len(self.cache),
'similarity_threshold': self.similarity_threshold,
'cache_file': str(self.cache_path),
'cache_exists': self.cache_path.exists(),
'cache_size_bytes': self.cache_path.stat().st_size if self.cache_path.exists() else 0
'cache_size_bytes': cache_size_bytes,
}
if self.max_age_seconds is not None:
now = time.time()
stats['max_age_hours'] = self.max_age_seconds / 3600.0
# Count how many entries are currently expired
expired = sum(1 for e in self.cache if self._is_entry_expired(e, now=now))
stats['expired_entries'] = expired
return stats
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def explain_query(self, query: str) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -345,8 +416,11 @@ def explain_query(self, query: str) -> Dict[str, Any]:
best_match = None
best_similarity = 0.0

# Analyze cache entries safely
# Analyze cache entries safely (skip expired)
now = time.time()
for i, entry in enumerate(self.cache):
if self._is_entry_expired(entry, now=now):
continue
try:
similarity = self.cosine_similarity(query_vec, entry.get('vector', []))
except Exception as e:
Expand Down Expand Up @@ -389,11 +463,35 @@ def explain_query(self, query: str) -> Dict[str, Any]:
'top_matches': cache_analysis[:5] # Top 5 similar cached queries
}

def set_cache_expiration(self, max_age_hours: int = 24) -> None:
"""Remove cache entries older than specified hours (future enhancement)"""
# This would require adding timestamps to cache entries
# For now, just a placeholder for TTL functionality
logger.info(f"Cache expiration set to {max_age_hours} hours (not yet implemented)")
def set_cache_expiration(self, max_age_hours: float = 24) -> int:
"""Set TTL and immediately evict cache entries older than *max_age_hours*.

After calling this method every subsequent :meth:`get` call will
transparently skip entries that have exceeded the TTL, and every
:meth:`save_cache` / :meth:`save_cache_json` call will persist the
TTL setting so it survives restarts.

Args:
max_age_hours: Maximum age of a cache entry in hours.
Must be a positive number.

Returns:
Number of expired entries that were evicted.

Raises:
ValueError: If *max_age_hours* is not positive.
"""
if max_age_hours <= 0:
raise ValueError("max_age_hours must be a positive number")

self.max_age_seconds = max_age_hours * 3600.0
Comment thread
coderabbitai[bot] marked this conversation as resolved.
evicted = self._evict_expired()
logger.info(
"Cache expiration set to %.2f hours – evicted %d stale entries",
max_age_hours,
evicted,
)
return evicted
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def optimize_cache(self, max_entries: int = 1000) -> None:
"""Remove oldest entries if cache exceeds max size"""
Expand Down
153 changes: 153 additions & 0 deletions nexum_ai/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Unit tests for optimizer.py - Query optimization logic
"""

import time

import pytest

from nexum_ai.optimizer import SemanticCache, QueryOptimizer


Expand Down Expand Up @@ -114,6 +118,155 @@ def test_multiple_cache_entries(self):
assert len(cache.cache) == 3


class TestSemanticCacheTTL:
"""Test suite for SemanticCache TTL / expiration feature"""

def test_put_stores_timestamp(self):
"""Entries created via put() must carry a timestamp."""
cache = SemanticCache()
cache.put("SELECT 1", "one")
entry = cache.cache[0]
assert 'timestamp' in entry
assert isinstance(entry['timestamp'], float)
# Timestamp should be very recent (within last 5 seconds)
assert time.time() - entry['timestamp'] < 5

def test_set_cache_expiration_rejects_non_positive(self):
"""set_cache_expiration must reject zero or negative hours."""
cache = SemanticCache()
with pytest.raises(ValueError):
cache.set_cache_expiration(0)
with pytest.raises(ValueError):
cache.set_cache_expiration(-1)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_set_cache_expiration_sets_max_age(self):
"""set_cache_expiration stores the TTL internally."""
cache = SemanticCache()
cache.set_cache_expiration(2)
assert cache.max_age_seconds == 2 * 3600.0

def test_expired_entries_are_evicted(self):
"""Entries older than the TTL are removed by set_cache_expiration."""
cache = SemanticCache()
# Insert an entry with a timestamp 2 hours in the past
cache.cache.append({
'query': 'old query',
'vector': [0.0] * 384,
'result': 'old',
'timestamp': time.time() - 7200, # 2 hours ago
})
cache.put("new query", "new") # fresh entry
assert len(cache.cache) == 2

evicted = cache.set_cache_expiration(1) # 1 hour TTL
assert evicted == 1
assert len(cache.cache) == 1
assert cache.cache[0]['result'] == 'new'

def test_get_skips_expired_entries(self):
"""get() must not return results from expired entries."""
cache = SemanticCache(similarity_threshold=0.95)
cache.put("SELECT * FROM users", "result_users")
# Artificially expire the entry
cache.cache[0]['timestamp'] = time.time() - 7200
cache.max_age_seconds = 3600.0 # 1 hour TTL

result = cache.get("SELECT * FROM users")
assert result is None # expired, should miss

def test_get_returns_valid_entries(self):
"""get() still returns non-expired entries."""
cache = SemanticCache(similarity_threshold=0.95)
cache.put("SELECT * FROM users", "result_users")
cache.max_age_seconds = 3600.0

result = cache.get("SELECT * FROM users")
assert result == "result_users"

def test_no_ttl_means_no_eviction(self):
"""When max_age_seconds is None nothing is evicted."""
cache = SemanticCache()
cache.cache.append({
'query': 'ancient',
'vector': [0.0] * 384,
'result': 'data',
'timestamp': 0, # epoch – very old
})
assert cache.max_age_seconds is None
assert cache._evict_expired() == 0
assert len(cache.cache) == 1

def test_legacy_entries_without_timestamp_survive(self):
"""Entries loaded from old caches (no timestamp) are not evicted."""
cache = SemanticCache()
cache.cache.append({
'query': 'legacy',
'vector': [0.0] * 384,
'result': 'legacy_data',
# no 'timestamp' key
})
cache.set_cache_expiration(1)
assert len(cache.cache) == 1 # kept, not evicted

def test_explain_query_skips_expired(self):
"""explain_query should ignore expired entries."""
cache = SemanticCache(similarity_threshold=0.5)
cache.put("SELECT * FROM users", "result_users")
# Expire the entry
cache.cache[0]['timestamp'] = time.time() - 7200
cache.max_age_seconds = 3600.0

explanation = cache.explain_query("SELECT * FROM users")
# cache_entries_checked reports len(self.cache) which includes expired
# entries; top_matches only contains entries that were actually analysed.
assert explanation['cache_entries_checked'] == 1
assert len(explanation['top_matches']) == 0

def test_get_cache_stats_includes_ttl_info(self):
"""Stats dict must include TTL fields when TTL is active."""
cache = SemanticCache()
# Without TTL
stats = cache.get_cache_stats()
assert 'max_age_hours' not in stats

cache.set_cache_expiration(12)
stats = cache.get_cache_stats()
assert stats['max_age_hours'] == 12.0
assert 'expired_entries' in stats

def test_ttl_persists_across_save_load(self):
"""max_age_seconds should survive a save/load cycle."""
import tempfile
import os
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
with tempfile.TemporaryDirectory() as td:
path = os.path.join(td, "ttl_test.json")
cache1 = SemanticCache()
cache1.set_cache_expiration(6)
cache1.put("q1", "r1")
cache1.save_cache_json(path)

cache2 = SemanticCache()
cache2.load_cache_json(path)
assert cache2.max_age_seconds == 6 * 3600.0
assert len(cache2.cache) == 1
Comment thread
AvaneeshKesavan marked this conversation as resolved.

def test_evict_expired_returns_count(self):
"""_evict_expired returns the number of removed entries."""
cache = SemanticCache()
now = time.time()
for i in range(5):
cache.cache.append({
'query': f'q{i}',
'vector': [0.0] * 384,
'result': f'r{i}',
'timestamp': now - (i * 3600), # 0h, 1h, 2h, 3h, 4h ago
})
cache.max_age_seconds = 2.5 * 3600 # 2.5 hour TTL
removed = cache._evict_expired()
assert removed == 2 # entries at 3h and 4h ago
assert len(cache.cache) == 3


class TestQueryOptimizer:
"""Test suite for QueryOptimizer class"""

Expand Down
Loading