-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclients.py
More file actions
58 lines (38 loc) · 1.46 KB
/
clients.py
File metadata and controls
58 lines (38 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from typing import Optional, Protocol, Tuple
from .exceptions import UnsupportedDatabaseError
TupleRow = Tuple
class Cursor(Protocol):
def fetchone(self) -> Optional[TupleRow]: ...
class Client(Protocol):
def execute(self, sql: str) -> Cursor: ...
def get_client() -> Client:
if _is_sqlite():
return SQLiteClient()
if _is_postgres():
return PostgresClient()
raise UnsupportedDatabaseError("Unsupported database")
def _is_postgres() -> bool:
return os.getenv("DB_URL", "").startswith("postgresql")
def _is_sqlite() -> bool:
return os.getenv("DB_URL", "sqlite:///db.sqlite").startswith("sqlite")
class SQLiteClient:
def execute(self, sql) -> Cursor:
import sqlite3
with sqlite3.connect(self._get_database_name()) as connection:
cursor = connection.cursor()
for statement in sql.split(";"):
if statement.strip():
cursor.execute(f"{statement};")
return cursor
def _get_database_name(self) -> str:
db_url = os.getenv("DB_URL", "sqlite:///db.sqlite")
return db_url.split("sqlite:///")[-1]
class PostgresClient:
def execute(self, sql) -> Cursor:
import psycopg
with psycopg.connect(self._get_connection_string()) as connection:
cursor = connection.cursor()
return cursor.execute(sql)
def _get_connection_string(self) -> str:
return os.getenv("DB_URL", "")