Skip to content

Commit ad30df4

Browse files
committed
fix imports
1 parent e4b038d commit ad30df4

3 files changed

Lines changed: 348 additions & 48 deletions

File tree

examples/slackbot/api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from pydantic_ai.agent import AgentRunResult
2020
from pydantic_ai.messages import ModelMessage
2121
from settings import settings
22+
from slack import SlackPayload, get_channel_name, post_slack_message
23+
from strings import count_tokens, slice_tokens
2224
from wrap import WatchToolCalls
2325

24-
from marvin.utilities.slack import SlackPayload, get_channel_name, post_slack_message
25-
from marvin.utilities.strings import count_tokens
26-
2726
BOT_MENTION = r"<@(\w+)>"
2827

2928

@@ -72,10 +71,12 @@ async def handle_message(payload: SlackPayload, db: Database):
7271
logger.warning(
7372
f"Message too long by {msg_len - USER_MESSAGE_MAX_TOKENS} tokens"
7473
)
75-
exceeded = msg_len - USER_MESSAGE_MAX_TOKENS
7674
assert event.channel is not None, "No channel found"
7775
await post_slack_message(
78-
message=f"Your message was too long by {exceeded} tokens...",
76+
message=(
77+
"Your message was too long, here's your message at the allowed limit:"
78+
f"\n{slice_tokens(cleaned_message, USER_MESSAGE_MAX_TOKENS)}"
79+
),
7980
channel_id=event.channel,
8081
thread_ts=thread_ts,
8182
)

examples/slackbot/slack.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""Module for Slack-related utilities."""
2+
3+
import os
4+
import re
5+
from typing import Any, List, Union
6+
7+
import httpx
8+
from pydantic import BaseModel, ValidationInfo, field_validator, model_validator
9+
10+
import marvin
11+
12+
13+
class EventBlockElement(BaseModel):
14+
type: str
15+
text: str | None = None
16+
user_id: str | None = None
17+
18+
19+
class EventBlockElementGroup(BaseModel):
20+
type: str
21+
elements: List[EventBlockElement]
22+
23+
24+
class EventBlock(BaseModel):
25+
type: str
26+
block_id: str
27+
elements: List[Union[EventBlockElement, EventBlockElementGroup]]
28+
29+
30+
class SlackEvent(BaseModel):
31+
client_msg_id: str | None = None
32+
type: str
33+
text: str | None = None
34+
user: str | dict[str, Any] | None = None
35+
ts: str | None = None
36+
team: str | None = None
37+
channel: str | None = None
38+
event_ts: str
39+
thread_ts: str | None = None
40+
parent_user_id: str | None = None
41+
blocks: list[EventBlock] | None = None
42+
43+
@model_validator(mode="before")
44+
@classmethod
45+
def extract_user_id(cls, data: dict[str, Any]) -> dict[str, Any]:
46+
if isinstance(data.get("user"), dict):
47+
data["user"] = data["user"].get("id")
48+
return data
49+
50+
51+
class EventAuthorization(BaseModel):
52+
enterprise_id: str | None = None
53+
team_id: str
54+
user_id: str
55+
is_bot: bool
56+
is_enterprise_install: bool
57+
58+
59+
class SlackPayload(BaseModel):
60+
token: str
61+
type: str
62+
team_id: str | None = None
63+
api_app_id: str | None = None
64+
event: SlackEvent | None = None
65+
event_id: str | None = None
66+
event_time: int | None = None
67+
authorizations: list[EventAuthorization] | None = None
68+
is_ext_shared_channel: bool | None = None
69+
event_context: str | None = None
70+
challenge: str | None = None
71+
72+
@field_validator("event")
73+
def validate_event(
74+
cls, v: SlackEvent | None, info: ValidationInfo
75+
) -> SlackEvent | None:
76+
if v is None and info.data.get("type") != "url_verification":
77+
raise ValueError("event is required")
78+
return v
79+
80+
81+
async def get_token() -> str:
82+
"""Get the Slack bot token from the environment."""
83+
try:
84+
token = (
85+
marvin.settings.slack_api_token
86+
) # set `MARVIN_SLACK_API_TOKEN` in `~/.marvin/.env
87+
except AttributeError:
88+
if token := os.getenv("MARVIN_SLACK_API_TOKEN"):
89+
return token
90+
try: # TODO: clean this up
91+
from prefect.blocks.system import Secret
92+
93+
return (await Secret.load("slack-api-token")).get()
94+
except ImportError:
95+
pass
96+
raise ValueError(
97+
"`MARVIN_SLACK_API_TOKEN` not found in environment."
98+
" Please set it in `~/.marvin/.env` or as an environment variable."
99+
)
100+
return token
101+
102+
103+
def convert_md_links_to_slack(text) -> str:
104+
md_link_pattern = r"\[(?P<text>[^\]]+)]\((?P<url>[^\)]+)\)"
105+
106+
# converting Markdown links to Slack-style links
107+
def to_slack_link(match):
108+
return f"<{match.group('url')}|{match.group('text')}>"
109+
110+
# Replace Markdown links with Slack-style links
111+
slack_text = re.sub(md_link_pattern, to_slack_link, text)
112+
113+
slack_text = re.sub(r"\*\*(.*?)\*\*", r"*\1*", slack_text)
114+
115+
return slack_text
116+
117+
118+
async def post_slack_message(
119+
message: str,
120+
channel_id: str,
121+
attachments: Union[list[dict[str, Any]], None] = None,
122+
thread_ts: Union[str, None] = None,
123+
auth_token: Union[str, None] = None,
124+
) -> httpx.Response:
125+
if not auth_token:
126+
auth_token = await get_token()
127+
128+
post_data = {
129+
"channel": channel_id,
130+
"text": convert_md_links_to_slack(message),
131+
"attachments": attachments if attachments else [],
132+
}
133+
134+
if thread_ts:
135+
post_data["thread_ts"] = thread_ts
136+
137+
async with httpx.AsyncClient() as client:
138+
response = await client.post(
139+
"https://slack.com/api/chat.postMessage",
140+
headers={"Authorization": f"Bearer {auth_token}"},
141+
json=post_data,
142+
)
143+
response_data = response.json()
144+
145+
if response_data.get("ok") is not True:
146+
raise ValueError(f"Error posting Slack message: {response_data.get('error')}")
147+
return response
148+
149+
150+
async def get_thread_messages(channel: str, thread_ts: str) -> list:
151+
"""Get all messages from a slack thread."""
152+
async with httpx.AsyncClient() as client:
153+
response = await client.get(
154+
"https://slack.com/api/conversations.replies",
155+
headers={"Authorization": f"Bearer {await get_token()}"},
156+
params={"channel": channel, "ts": thread_ts},
157+
)
158+
response.raise_for_status()
159+
return response.json().get("messages", [])
160+
161+
162+
async def get_user_name(user_id: str) -> str:
163+
async with httpx.AsyncClient() as client:
164+
response = await client.get(
165+
"https://slack.com/api/users.info",
166+
params={"user": user_id},
167+
headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
168+
)
169+
return (
170+
response.json().get("user", {}).get("name", user_id)
171+
if response.status_code == 200
172+
else user_id
173+
)
174+
175+
176+
async def get_channel_name(channel_id: str) -> str:
177+
async with httpx.AsyncClient() as client:
178+
response = await client.get(
179+
"https://slack.com/api/conversations.info",
180+
params={"channel": channel_id},
181+
headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
182+
)
183+
return (
184+
response.json().get("channel", {}).get("name", channel_id)
185+
if response.status_code == 200
186+
else channel_id
187+
)
188+
189+
190+
async def fetch_current_message_text(channel: str, ts: str) -> str:
191+
"""Fetch the current text of a specific Slack message using its timestamp."""
192+
async with httpx.AsyncClient() as client:
193+
response = await client.get(
194+
"https://slack.com/api/conversations.replies",
195+
params={"channel": channel, "ts": ts},
196+
headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
197+
)
198+
response.raise_for_status()
199+
messages = response.json().get("messages", [])
200+
if not messages:
201+
raise ValueError("Message not found")
202+
203+
return messages[0]["text"]
204+
205+
206+
async def edit_slack_message(
207+
new_text: str,
208+
channel_id: str,
209+
thread_ts: str,
210+
mode: str = "append",
211+
delimiter: Union[str, None] = None,
212+
) -> httpx.Response:
213+
"""Edit an existing Slack message by appending new text or replacing it.
214+
215+
Args:
216+
channel (str): The Slack channel ID.
217+
ts (str): The timestamp of the message to edit.
218+
new_text (str): The new text to append or replace in the message.
219+
mode (str): The mode of text editing, 'append' (default) or 'replace'.
220+
221+
Returns:
222+
httpx.Response: The response from the Slack API.
223+
"""
224+
if mode == "append":
225+
current_text = await fetch_current_message_text(channel_id, thread_ts)
226+
delimiter = "\n\n" if delimiter is None else delimiter
227+
updated_text = f"{current_text}{delimiter}{convert_md_links_to_slack(new_text)}"
228+
elif mode == "replace":
229+
updated_text = convert_md_links_to_slack(new_text)
230+
else:
231+
raise ValueError("Invalid mode. Use 'append' or 'replace'.")
232+
233+
async with httpx.AsyncClient() as client:
234+
response = await client.post(
235+
"https://slack.com/api/chat.update",
236+
headers={"Authorization": f"Bearer {await get_token()}"},
237+
json={"channel": channel_id, "ts": thread_ts, "text": updated_text},
238+
)
239+
240+
response.raise_for_status()
241+
return response
242+
243+
244+
async def search_slack_messages(
245+
query: str,
246+
max_messages: int = 3,
247+
channel: Union[str, None] = None,
248+
user_auth_token: Union[str, None] = None,
249+
) -> list:
250+
"""
251+
Search for messages in Slack workspace based on a query.
252+
253+
Args:
254+
query (str): The search query.
255+
max_messages (int): The maximum number of messages to retrieve.
256+
channel (str, optional): The specific channel to search in. Defaults to None,
257+
which searches all channels.
258+
259+
Returns:
260+
list: A list of message contents and permalinks matching the query.
261+
"""
262+
all_messages = []
263+
next_cursor = None
264+
265+
if not user_auth_token:
266+
user_auth_token = await get_token()
267+
268+
async with httpx.AsyncClient() as client:
269+
while len(all_messages) < max_messages:
270+
params = {
271+
"query": query,
272+
"limit": min(max_messages - len(all_messages), 10),
273+
}
274+
if channel:
275+
params["channel"] = channel
276+
if next_cursor:
277+
params["cursor"] = next_cursor
278+
279+
response = await client.get(
280+
"https://slack.com/api/search.messages",
281+
headers={"Authorization": f"Bearer {user_auth_token}"},
282+
params=params,
283+
)
284+
285+
response.raise_for_status()
286+
data = response.json().get("messages", {}).get("matches", [])
287+
for message in data:
288+
all_messages.append(
289+
{
290+
"content": message.get("text", ""),
291+
"permalink": message.get("permalink", ""),
292+
}
293+
)
294+
295+
next_cursor = (
296+
response.json().get("response_metadata", {}).get("next_cursor")
297+
)
298+
299+
if not next_cursor:
300+
break
301+
302+
return all_messages[:max_messages]
303+
304+
305+
async def get_workspace_info(slack_bot_token: Union[str, None] = None) -> dict:
306+
if not slack_bot_token:
307+
slack_bot_token = await get_token()
308+
309+
async with httpx.AsyncClient() as client:
310+
response = await client.get(
311+
"https://slack.com/api/team.info",
312+
headers={"Authorization": f"Bearer {slack_bot_token}"},
313+
)
314+
response.raise_for_status()
315+
return response.json().get("team", {})

0 commit comments

Comments
 (0)