Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ No changes to highlight.

## Full Changelog:
* Allow `gr.Templates` to accept parameters to override the defaults by [@abidlabs](https://github.com/abidlabs) in [PR 2600](https://github.com/gradio-app/gradio/pull/2600)
* Update queue with using deque & update requirements by [@GLGDLY](https://github.com/GLGDLY) in [PR 2428](https://github.com/gradio-app/gradio/pull/2428)
* Allow auth with using queue by [@GLGDLY](https://github.com/GLGDLY) in [PR 2611](https://github.com/gradio-app/gradio/pull/2611)

## Contributors Shoutout:
No changes to highlight.
Expand Down Expand Up @@ -57,7 +59,6 @@ No changes to highlight.

## Full Changelog:
* Add `api_name` to `Blocks.__call__` by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2593](https://github.com/gradio-app/gradio/pull/2593)
* Update queue with using deque & update requirements by [@GLGDLY](https://github.com/GLGDLY) in [PR 2428](https://github.com/gradio-app/gradio/pull/2428)


## Contributors Shoutout:
Expand Down
6 changes: 2 additions & 4 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,10 +1324,8 @@ def reverse(text):
requests.get(f"{self.local_url}startup-events")

if self.enable_queue:
if self.auth is not None or self.encrypt:
raise ValueError(
"Cannot queue with encryption or authentication enabled."
)
if self.encrypt:
raise ValueError("Cannot queue with encryption enabled.")
utils.launch_counter()

self.share = (
Expand Down
3 changes: 3 additions & 0 deletions gradio/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, websocket: fastapi.WebSocket, fn_index: int | None = None):
self.lost_connection_time: float | None = None
self.fn_index: int | None = fn_index
self.session_hash: str = "foo"
self.token: str | None = None

async def disconnect(self, code=1000):
await self.websocket.close(code=code)
Expand Down Expand Up @@ -254,6 +255,7 @@ def get_estimation(self) -> Estimation:

async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
token = events[0].token
if batch:
data.data = list(zip(*[event.data.data for event in events if event.data]))
data.batched = True
Expand All @@ -262,6 +264,7 @@ async def call_prediction(self, events: List[Event], batch: bool):
url=f"{self.server_path}api/predict",
json=dict(data),
headers={"Authorization": f"Bearer {self.access_token}"},
cookies={"access-token": token} if token is not None else None,
)
return response

Expand Down
26 changes: 19 additions & 7 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, **kwargs):
self.iterators = defaultdict(dict)
self.lock = asyncio.Lock()
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
super().__init__(**kwargs)

def configure_app(self, blocks: gradio.Blocks) -> None:
Expand Down Expand Up @@ -120,6 +121,10 @@ def login_check(user: str = Depends(get_current_user)):
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)

async def ws_login_check(websocket: WebSocket) -> str:
token = websocket.cookies.get("access-token")
return token # token is returned to allow request in queue

@app.get("/token")
@app.get("/token/")
def get_token(request: Request) -> dict:
Expand Down Expand Up @@ -344,12 +349,19 @@ async def predict(
return result

@app.websocket("/queue/join")
async def join_queue(websocket: WebSocket):
async def join_queue(
websocket: WebSocket, token: str = Depends(ws_login_check)
):
if app.auth is not None and token is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
if app.blocks._queue.server_path is None:
app_url = get_server_url_from_ws_url(str(websocket.url))
app.blocks._queue.set_url(app_url)
await websocket.accept()
event = Event(websocket)
# set the token into Event to allow using the same token for call_prediction
event.token = token

# In order to cancel jobs, we need the session_hash and fn_index
# to create a unique id for each job
Expand Down Expand Up @@ -391,13 +403,13 @@ async def join_queue(websocket: WebSocket):
async def get_queue_status():
return app.blocks._queue.get_estimation()

@app.get(
"/startup-events",
dependencies=[Depends(login_check)],
)
@app.get("/startup-events")
async def startup_events():
app.blocks.startup_events()
return True
if not app.startup_events_triggered:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come we need this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comments are removed here, it is changed to allow the initialization to run the queue background coros

app.blocks.startup_events()
app.startup_events_triggered = True
return True
return False

return app

Expand Down
71 changes: 71 additions & 0 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,5 +911,76 @@ def test_queue_enabled_for_fn():
assert demo.queue_enabled_for_fn(1)


@pytest.mark.asyncio
async def test_queue_when_using_auth():
sleep_time = 5

async def say_hello(name):
await asyncio.sleep(sleep_time)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to sleep for so long. The test takes quite a while to run right now.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is 1 sec and 3 times now

return f"Hello {name}!"

with gr.Blocks() as demo:
_input = gr.Textbox()
_output = gr.Textbox()
button = gr.Button()
button.click(say_hello, _input, _output)
demo.queue()
app, _, _ = demo.launch(auth=("abc", "123"), prevent_thread_lock=True)
client = TestClient(app)

resp = client.post(
f"{demo.local_url}login", data={"username": "abc", "password": "123"}
)
assert resp.ok
token = resp.cookies.get("access-token")
assert token

with pytest.raises(Exception) as e:
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join",
) as ws:
await ws.recv()
assert e.type == websockets.InvalidStatusCode

async def run_ws(_loop, _time):
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join",
extra_headers={"Cookie": f"access-token={token}"},
) as ws:
while True:
try:
msg = json.loads(await ws.recv())
except websockets.ConnectionClosedOK:
break
if msg["msg"] == "send_hash":
await ws.send(
json.dumps({"fn_index": 0, "session_hash": "enwpitpex2q"})
)
if msg["msg"] == "send_data":
await ws.send(
json.dumps(
{
"data": ["123"],
"fn_index": 0,
"session_hash": "enwpitpex2q",
}
)
)
msg = json.loads(await ws.recv())
assert msg["msg"] == "process_starts"
if msg["msg"] == "process_completed":
assert msg["success"]
assert msg["output"]["data"] == ["Hello 123!"]
assert _loop.time() > _time
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not passing on our CI. I don't think we need to check the time. I think it would be sufficient to test that the output is correct. Might make sense to rewrite this loop so that the "name" is different for each event in the queue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"name" is different for each event in the queue is added now, time is checked as I think we need to ensure the queue is actually working

break

loop = asyncio.get_event_loop()
tm = loop.time()
group = asyncio.gather(
*[run_ws(loop, tm + sleep_time * (i + 1) - 1) for i in range(5)]
)
await group


if __name__ == "__main__":
unittest.main()