|
15 | 15 | from pydantic_ai.agent import AgentRunResult |
16 | 16 | from pydantic_ai.messages import ModelMessage |
17 | 17 |
|
18 | | -from slackbot._internal.templates import WELCOME_MESSAGE |
| 18 | +from slackbot._internal.constants import WORKSPACE_TO_CHANNEL_ID |
| 19 | +from slackbot._internal.templates import CHANNEL_REDIRECT_MESSAGE, WELCOME_MESSAGE |
19 | 20 | from slackbot.assets import summarize_thread |
20 | 21 | from slackbot.core import ( |
21 | 22 | Database, |
|
40 | 41 | logger = get_logger(__name__) |
41 | 42 |
|
42 | 43 |
|
| 44 | +def get_designated_channel_for_workspace(team_id: str) -> str | None: |
| 45 | + """Get the designated channel ID for a given workspace team ID.""" |
| 46 | + return WORKSPACE_TO_CHANNEL_ID.get(team_id) |
| 47 | + |
| 48 | + |
| 49 | +def check_if_designated_channel(channel_id: str, team_id: str) -> bool: |
| 50 | + """Check if the given channel is the designated channel for the workspace.""" |
| 51 | + designated_channel = get_designated_channel_for_workspace(team_id) |
| 52 | + if not designated_channel: |
| 53 | + # If no designated channel is configured, allow all channels |
| 54 | + return True |
| 55 | + return channel_id == designated_channel |
| 56 | + |
| 57 | + |
43 | 58 | @task(name="run agent loop") |
44 | 59 | async def run_agent( |
45 | 60 | cleaned_message: str, |
@@ -122,6 +137,33 @@ async def handle_message(payload: SlackPayload, db: Database): |
122 | 137 | return Completed(message="Message too long", name="SKIPPED") |
123 | 138 |
|
124 | 139 | if re.search(BOT_MENTION, user_message) and payload.authorizations: |
| 140 | + # Check if this is the designated channel |
| 141 | + team_id = payload.team_id or "" |
| 142 | + is_designated = check_if_designated_channel(event.channel, team_id) |
| 143 | + |
| 144 | + if not is_designated: |
| 145 | + # Send redirect message to the designated channel |
| 146 | + designated_channel_id = get_designated_channel_for_workspace(team_id) |
| 147 | + if designated_channel_id: |
| 148 | + logger.info( |
| 149 | + f"Redirecting user from {event.channel} to {designated_channel_id}" |
| 150 | + ) |
| 151 | + await post_slack_message( |
| 152 | + message=CHANNEL_REDIRECT_MESSAGE.format( |
| 153 | + channel_id=designated_channel_id |
| 154 | + ), |
| 155 | + channel_id=event.channel, |
| 156 | + thread_ts=thread_ts, |
| 157 | + ) |
| 158 | + return Completed( |
| 159 | + message="Redirected to designated channel", |
| 160 | + name="REDIRECTED", |
| 161 | + data=dict( |
| 162 | + from_channel=event.channel, |
| 163 | + to_channel=designated_channel_id, |
| 164 | + ), |
| 165 | + ) |
| 166 | + |
125 | 167 | logger.info( |
126 | 168 | f"Processing message in thread {thread_ts}\nUser message: {cleaned_message}" |
127 | 169 | ) |
@@ -182,6 +224,11 @@ async def handle_message(payload: SlackPayload, db: Database): |
182 | 224 | @handle_message.on_completion |
183 | 225 | async def summarize_thread_so_far(flow: Flow, flow_run: FlowRun, state: State[Any]): |
184 | 226 | result = await state.result() |
| 227 | + |
| 228 | + # Skip summarization for redirects and other non-conversation states |
| 229 | + if not isinstance(result, dict) or "conversation" not in result: |
| 230 | + return |
| 231 | + |
185 | 232 | conversation = result["conversation"] |
186 | 233 |
|
187 | 234 | if len(conversation) % 4 != 0: # only summarize thread every 4 messages |
|
0 commit comments