Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
# Upcoming Release

## New Features:
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)

### Accessing the Requests Object Directly

You can now access the Request object directly in your Python function by [@abidlabs](https://github.com/abidlabs) in [PR 2641](https://github.com/gradio-app/gradio/pull/2641). This means that you can access request headers, the client IP address, and so on. In order to use it, add a parameter to your function and set its type hint to be `gr.Request`. Here's a simple example:

```py
import gradio as gr

def echo(name, request: gr.Request):
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
return name

io = gr.Interface(echo, "textbox", "textbox").launch()
```


## Bug Fixes:
* Updated the minimum FastApi used in tests to version 0.87 [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2647](https://github.com/gradio-app/gradio/pull/2647)
Expand All @@ -18,7 +33,7 @@ No changes to highlight.
No changes to highlight.

## Full Changelog:
No changes to highlight.
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)

## Contributors Shoutout:
No changes to highlight.
Expand Down
8 changes: 8 additions & 0 deletions demo/audio_debugger/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gradio as gr
import subprocess
import os
import fastapi
Comment thread
abidlabs marked this conversation as resolved.
Outdated

audio_file = os.path.join(os.path.dirname(__file__), "cantina.wav")

Expand All @@ -11,7 +12,14 @@
with gr.Tab("Interface"):
gr.Interface(lambda x:x, "audio", "audio", examples=[audio_file])
with gr.Tab("console"):
ip = gr.Textbox(label="User IP Address")
gr.Interface(lambda cmd:subprocess.run([cmd], capture_output=True, shell=True).stdout.decode('utf-8').strip(), "text", "text")

def get_ip(request: gr.Request):
return request.client.host

demo.load(get_ip, None, ip)

if __name__ == "__main__":
demo.queue()
demo.launch()
2 changes: 1 addition & 1 deletion gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from gradio.ipython_ext import load_ipython_extension
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
from gradio.mix import Parallel, Series
from gradio.routes import mount_gradio_app
from gradio.routes import Request, mount_gradio_app
from gradio.templates import (
Files,
ImageMask,
Expand Down
39 changes: 35 additions & 4 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
import sys
import time
import typing
import warnings
import webbrowser
from types import ModuleType
Expand Down Expand Up @@ -57,6 +58,7 @@

set_documentation_group("blocks")


if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import comet_ml
import mlflow
Expand Down Expand Up @@ -456,6 +458,23 @@ def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) ->
return predictions


def add_request_to_inputs(
Comment thread
abidlabs marked this conversation as resolved.
fn: Callable, inputs: List[Any], request: routes.Request | List[routes.Request]
):
"""
Adds the FastAPI Request object to the inputs of a function if the type of the parameter is FastAPI.Request.
"""
param_names = inspect.getfullargspec(fn)[0]
try:
parameter_types = typing.get_type_hints(fn)
for idx, param_name in enumerate(param_names):
if parameter_types.get(param_name, "") == routes.Request:
inputs.insert(idx, request)
except TypeError: # A TypeError is raised if the function is a partial or other rare cases.
Comment thread
abidlabs marked this conversation as resolved.
pass
return inputs


@document("load")
class Blocks(BlockContext):
"""
Expand Down Expand Up @@ -795,7 +814,12 @@ def __call__(self, *inputs, fn_index: int = 0, api_name: str = None):
if batch:
processed_inputs = [[inp] for inp in processed_inputs]

outputs = utils.synchronize_async(self.process_api, fn_index, processed_inputs)
outputs = utils.synchronize_async(
self.process_api,
fn_index=fn_index,
inputs=processed_inputs,
request=None,
)
outputs = outputs["data"]

if batch:
Expand All @@ -811,11 +835,11 @@ async def call_function(
fn_index: int,
processed_input: List[Any],
iterator: Iterator[Any] | None = None,
request: routes.Request | List[routes.Request] | None = None,
):
"""Calls and times function with given index and preprocessed input."""
block_fn = self.fns[fn_index]
is_generating = False
start = time.time()

if block_fn.inputs_as_dict:
processed_input = [
Expand All @@ -825,6 +849,12 @@ async def call_function(
}
]

processed_input = add_request_to_inputs(
block_fn.fn, list(processed_input), request
)

start = time.time()

if iterator is None: # If not a generator function that has already run
if inspect.iscoroutinefunction(block_fn.fn):
prediction = await block_fn.fn(*processed_input)
Expand Down Expand Up @@ -943,6 +973,7 @@ async def process_api(
self,
fn_index: int,
inputs: List[Any],
request: routes.Request | List[routes.Request] | None = None,
username: str = None,
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
iterators: Dict[int, Any] | None = None,
Expand Down Expand Up @@ -979,15 +1010,15 @@ async def process_api(
)

inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
result = await self.call_function(fn_index, zip(*inputs), None)
result = await self.call_function(fn_index, zip(*inputs), None, request)
preds = result["prediction"]
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
data = list(zip(*data))
is_generating, iterator = None, None
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
result = await self.call_function(fn_index, inputs, iterator)
result = await self.call_function(fn_index, inputs, iterator, request)
data = self.postprocess_data(fn_index, result["prediction"], state)
is_generating, iterator = result["is_generating"], result["iterator"]

Expand Down
5 changes: 4 additions & 1 deletion gradio/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel

Expand All @@ -10,6 +10,9 @@ class PredictBody(BaseModel):
batched: Optional[
bool
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
request: Optional[
Union[Dict, List[Dict]]
] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)


class ResetBody(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion gradio/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async def cache(self) -> None:
if self.batch:
processed_input = [[value] for value in processed_input]
prediction = await Context.root_block.process_api(
fn_index, processed_input
fn_index=fn_index, inputs=processed_input, request=None
Comment thread
abidlabs marked this conversation as resolved.
)
output = prediction["data"]
if self.batch:
Expand Down
51 changes: 31 additions & 20 deletions gradio/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import sys
import time
from collections import deque
from itertools import islice
from typing import Deque, Dict, List, Optional, Tuple
from typing import Any, Deque, Dict, List, Optional, Tuple

import fastapi
from pydantic import BaseModel

from gradio.dataclasses import PredictBody
from gradio.utils import Request, run_coro_in_background, set_task_name
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name


class Estimation(BaseModel):
Expand All @@ -26,7 +25,11 @@ class Estimation(BaseModel):


class Event:
def __init__(self, websocket: fastapi.WebSocket, fn_index: int | None = None):
def __init__(
self,
websocket: fastapi.WebSocket,
fn_index: int | None = None,
):
self.websocket = websocket
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
Expand Down Expand Up @@ -157,18 +160,6 @@ async def broadcast_live_estimations(self) -> None:
if self.live_updates:
await self.broadcast_estimations()

async def gather_data_for_first_ranks(self) -> None:
Comment thread
abidlabs marked this conversation as resolved.
"""
Gather data for the first x events.
"""
# Send all messages concurrently
await asyncio.gather(
*[
self.gather_event_data(event)
for event in islice(self.event_queue, self.data_gathering_start)
]
)

async def gather_event_data(self, event: Event) -> bool:
"""
Gather data for the event
Expand Down Expand Up @@ -253,14 +244,34 @@ def get_estimation(self) -> Estimation:
queue_eta=self.queue_duration,
)

def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
return {
"url": str(websocket.url),
"headers": dict(websocket.headers),
"query_params": dict(websocket.query_params),
"path_params": dict(websocket.path_params),
"client": dict(host=websocket.client.host, port=websocket.client.port),
}

async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
token = events[0].token
try:
data.request = self.get_request_params(events[0].websocket)
except ValueError:
pass

if batch:
data.data = list(zip(*[event.data.data for event in events if event.data]))
data.request = [
self.get_request_params(event.websocket)
for event in events
if event.data
]
data.batched = True
response = await Request(
method=Request.Method.POST,

response = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}api/predict",
json=dict(data),
headers={"Authorization": f"Bearer {self.access_token}"},
Expand Down Expand Up @@ -370,8 +381,8 @@ async def get_message(self, event) -> Optional[PredictBody]:
return None

async def reset_iterators(self, session_hash: str, fn_index: int):
await Request(
method=Request.Method.POST,
await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}reset",
json={
"session_hash": session_hash,
Expand Down
Loading