Skip to content

Commit 0a242b8

Browse files
Add Gradio-compatible /gradio_api routes on Spaces (#515)
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent d54d290 commit 0a242b8

3 files changed

Lines changed: 355 additions & 4 deletions

File tree

.changeset/silly-coats-go.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": minor
3+
---
4+
5+
feat:Add Gradio-compatible /gradio_api routes on Spaces
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from starlette.testclient import TestClient
2+
3+
from trackio.asgi_app import create_trackio_starlette_app
4+
5+
6+
def _echo(msg: str = "") -> str:
7+
return msg
8+
9+
10+
def _hf_echo(hf_token: str | None = None) -> str | None:
11+
return hf_token
12+
13+
14+
def test_gradio_api_info_call_poll_and_headers(monkeypatch, temp_dir):
15+
monkeypatch.setenv("SYSTEM", "spaces")
16+
app = create_trackio_starlette_app([], {"echo": _echo})
17+
client = TestClient(app)
18+
19+
info = client.get("/gradio_api/info")
20+
assert info.status_code == 200
21+
body = info.json()
22+
assert "/echo" in body["named_endpoints"]
23+
assert body["unnamed_endpoints"] == {}
24+
25+
post = client.post("/gradio_api/call/echo", json={"data": ["hi"]})
26+
assert post.status_code == 200
27+
event_id = post.json()["event_id"]
28+
29+
poll = client.get(f"/gradio_api/call/echo/{event_id}")
30+
assert poll.status_code == 200
31+
assert poll.headers.get("cache-control") == "no-store"
32+
assert poll.headers.get("x-accel-buffering") == "no"
33+
assert "event: complete" in poll.text
34+
assert '"hi"' in poll.text
35+
36+
37+
def test_gradio_poll_wrong_api_name_not_consumed(monkeypatch, temp_dir):
38+
monkeypatch.setenv("SYSTEM", "spaces")
39+
app = create_trackio_starlette_app([], {"echo": _echo})
40+
client = TestClient(app)
41+
42+
post = client.post("/gradio_api/call/echo", json={"data": ["x"]})
43+
event_id = post.json()["event_id"]
44+
45+
bad = client.get(f"/gradio_api/call/other/{event_id}")
46+
assert bad.status_code == 404
47+
48+
ok = client.get(f"/gradio_api/call/echo/{event_id}")
49+
assert ok.status_code == 200
50+
51+
52+
def test_hf_token_from_authorization_on_spaces(monkeypatch, temp_dir):
53+
monkeypatch.setenv("SYSTEM", "spaces")
54+
app = create_trackio_starlette_app([], {"hf_echo": _hf_echo})
55+
client = TestClient(app)
56+
57+
r = client.post(
58+
"/api/hf_echo",
59+
json={},
60+
headers={"Authorization": "Bearer space-token"},
61+
)
62+
assert r.status_code == 200
63+
assert r.json()["data"] == "space-token"
64+
65+
66+
def test_hf_token_empty_or_whitespace_body_uses_bearer(monkeypatch, temp_dir):
67+
monkeypatch.setenv("SYSTEM", "spaces")
68+
app = create_trackio_starlette_app([], {"hf_echo": _hf_echo})
69+
client = TestClient(app)
70+
71+
r = client.post(
72+
"/api/hf_echo",
73+
json={"hf_token": ""},
74+
headers={"Authorization": "Bearer from-header"},
75+
)
76+
assert r.status_code == 200
77+
assert r.json()["data"] == "from-header"
78+
79+
r2 = client.post(
80+
"/api/hf_echo",
81+
json={"hf_token": " "},
82+
headers={"Authorization": "Bearer from-header"},
83+
)
84+
assert r2.status_code == 200
85+
assert r2.json()["data"] == "from-header"
86+
87+
88+
def test_gradio_upload_aliases_api_upload(monkeypatch, temp_dir):
89+
monkeypatch.setenv("SYSTEM", "spaces")
90+
app = create_trackio_starlette_app([], {"echo": _echo})
91+
client = TestClient(app)
92+
93+
g = client.post("/gradio_api/upload", files={"files": ("a.txt", b"x")})
94+
assert g.status_code == 200
95+
assert "paths" in g.json()

0 commit comments

Comments
 (0)