Skip to content
Merged
21 changes: 20 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher, AfterCommandExecutionEvent, OnErrorEvent,
EventDispatcher, AfterCommandExecutionEvent, OnErrorEvent, OnPubSubMessageEvent,
)
from redis.exceptions import (
ConnectionError,
Expand All @@ -59,6 +59,7 @@
from redis.maint_notifications import (
MaintNotificationsConfig,
)
from redis.observability.attributes import PubSubDirection
from redis.retry import Retry
from redis.utils import (
_set_info_logger,
Expand Down Expand Up @@ -1382,6 +1383,24 @@ def handle_message(self, response, ignore_subscribe_messages=False):
"data": response[2],
}

if message_type in ["message", "pmessage"]:
channel = str_if_bytes(message["channel"])
self._event_dispatcher.dispatch(
OnPubSubMessageEvent(
direction=PubSubDirection.RECEIVE,
channel=channel,
)
)
elif message_type == "smessage":
channel = str_if_bytes(message["channel"])
self._event_dispatcher.dispatch(
OnPubSubMessageEvent(
direction=PubSubDirection.RECEIVE,
channel=channel,
sharded=True,
)
)

# if this is an unsubscribe message, remove it from memory
if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
if message_type == "punsubscribe":
Expand Down
45 changes: 40 additions & 5 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@
deprecated_function,
experimental_args,
experimental_method,
extract_expire_flags,
extract_expire_flags, str_if_bytes,
)

from .helpers import at_most_one_value_set, list_or_args
from ..event import OnPubSubMessageEvent, OnStreamMessageReceivedEvent
from ..observability.attributes import PubSubDirection

if TYPE_CHECKING:
import redis.asyncio.client
Expand Down Expand Up @@ -4211,7 +4213,15 @@ def xread(
keys, values = zip(*streams.items())
pieces.extend(keys)
pieces.extend(values)
return self.execute_command("XREAD", *pieces, keys=keys)
response = self.execute_command("XREAD", *pieces, keys=keys)

self._event_dispatcher.dispatch(
OnStreamMessageReceivedEvent(
response=response
)
)

return response

def xreadgroup(
self,
Expand Down Expand Up @@ -4271,7 +4281,17 @@ def xreadgroup(
pieces.append(b"STREAMS")
pieces.extend(streams.keys())
pieces.extend(streams.values())
return self.execute_command("XREADGROUP", *pieces, **options)
response = self.execute_command("XREADGROUP", *pieces, **options)

self._event_dispatcher.dispatch(
OnStreamMessageReceivedEvent(
response=response,
consumer_group=groupname,
consumer_name=consumername,
)
)
Comment thread
vladvildanov marked this conversation as resolved.

return response

def xrevrange(
self,
Expand Down Expand Up @@ -6038,7 +6058,14 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT

For more information, see https://redis.io/commands/publish
"""
return self.execute_command("PUBLISH", channel, message, **kwargs)
response = self.execute_command("PUBLISH", channel, message, **kwargs)
self._event_dispatcher.dispatch(
OnPubSubMessageEvent(
direction=PubSubDirection.PUBLISH,
channel=str_if_bytes(channel),
)
)
return response
Comment thread
petyaslavova marked this conversation as resolved.

def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
"""
Expand All @@ -6047,7 +6074,15 @@ def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:

For more information, see https://redis.io/commands/spublish
"""
return self.execute_command("SPUBLISH", shard_channel, message)
response = self.execute_command("SPUBLISH", shard_channel, message)
self._event_dispatcher.dispatch(
OnPubSubMessageEvent(
direction=PubSubDirection.PUBLISH,
channel=str_if_bytes(shard_channel),
sharded=True,
)
)
return response
Comment thread
vladvildanov marked this conversation as resolved.

def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
"""
Expand Down
87 changes: 85 additions & 2 deletions redis/event.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import asyncio
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Union, Callable
from typing import Dict, List, Optional, Type, Union, Callable, Iterable
Comment thread
vladvildanov marked this conversation as resolved.

from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
from redis.observability.attributes import PubSubDirection
from redis.observability.recorder import record_operation_duration, record_error_count, record_maint_notification_count, \
record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff
record_connection_create_time, init_connection_count, record_connection_relaxed_timeout, record_connection_handoff, \
record_pubsub_message, record_streaming_lag
from redis.utils import str_if_bytes


class EventListenerInterface(ABC):
Expand Down Expand Up @@ -109,6 +113,12 @@ def __init__(
AfterConnectionHandoffEvent: [
ExportConnectionHandoffMetric(),
],
OnPubSubMessageEvent: [
ExportPubSubMessageMetric(),
],
OnStreamMessageReceivedEvent: [
ExportStreamingLagMetric(),
],
}

self._lock = threading.Lock()
Expand Down Expand Up @@ -337,6 +347,24 @@ class OnErrorEvent:
is_internal: bool = True
retry_attempts: Optional[int] = None

@dataclass
class OnPubSubMessageEvent:
"""
Event fired whenever a pub/sub message is published/received.
"""
direction: PubSubDirection
channel: str
sharded: bool = False

@dataclass
class OnStreamMessageReceivedEvent:
"""
Event fired whenever a stream message is received.
"""
response: Union[list, dict]
consumer_group: Optional[str] = None
consumer_name: Optional[str] = None

@dataclass
class OnMaintenanceNotificationEvent:
"""
Expand Down Expand Up @@ -621,3 +649,58 @@ def listen(self, event: AfterConnectionHandoffEvent):
record_connection_handoff(
pool_name=repr(event.connection_pool),
)

class ExportPubSubMessageMetric(EventListenerInterface):
"""
Listener that exports pubsub message metric.
"""
def listen(self, event: OnPubSubMessageEvent):
record_pubsub_message(
direction=event.direction,
channel=event.channel,
sharded=event.sharded,
)

class ExportStreamingLagMetric(EventListenerInterface):
"""
Listener that exports streaming lag metric per stream message.
"""
def listen(self, event: OnStreamMessageReceivedEvent):
now = datetime.now().timestamp()

if not event.response:
return

# RESP3
if isinstance(event.response, dict):
for stream_name, stream_messages in event.response.items():
for messages in stream_messages:
for message in messages:
message_id, message = message
message_id = str_if_bytes(message_id)
timestamp, _ = message_id.split("-")
lag_seconds = now - int(timestamp) / 1000

record_streaming_lag(
lag_seconds=lag_seconds,
stream_name=str_if_bytes(stream_name),
consumer_group=event.consumer_group,
consumer_name=event.consumer_name,
)
Comment thread
vladvildanov marked this conversation as resolved.
else:
# RESP 2
for stream_entry in event.response:
stream_name = str_if_bytes(stream_entry[0])

for message in stream_entry[1]:
message_id, message = message
message_id = str_if_bytes(message_id)
timestamp, _ = message_id.split("-")
lag_seconds = now - int(timestamp) / 1000

record_streaming_lag(
lag_seconds=lag_seconds,
stream_name=stream_name,
consumer_group=event.consumer_group,
consumer_name=event.consumer_name,
)
Comment thread
vladvildanov marked this conversation as resolved.
12 changes: 6 additions & 6 deletions redis/observability/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,12 @@ def record_streaming_lag(
return

# try:
_metrics_collector.record_streaming_lag(
lag_seconds=lag_seconds,
stream_name=stream_name,
consumer_group=consumer_group,
consumer_name=consumer_name,
)
_metrics_collector.record_streaming_lag(
lag_seconds=lag_seconds,
stream_name=stream_name,
consumer_group=consumer_group,
consumer_name=consumer_name,
)
# except Exception:
# pass

Expand Down
1 change: 1 addition & 0 deletions redis/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@


class CommandsProtocol(Protocol):
_event_dispatcher: "EventDispatcherInterface"
def execute_command(self, *args, **options) -> ResponseT: ...


Expand Down