Skip to content

Commit c126e62

Browse files
authored
Add support for numpy array and other types to gr.Dataframe() initial value (#2804)
* bigquery draft * updated guide * orjson fix * formatting * changelog * rm guide
1 parent 39ffe9d commit c126e62

6 files changed

Lines changed: 56 additions & 25 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
No changes to highlight.
55

66
## Bug Fixes:
7-
No changes to highlight.
7+
* Allows `gr.Dataframe()` to take a `pandas.DataFrame` that includes numpy array and other types as its initial value, by [@abidlabs](https://github.com/abidlabs) in [PR 2804](https://github.com/gradio-app/gradio/pull/2804)
88

99
## Documentation Changes:
1010
No changes to highlight.

gradio/components.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2449,7 +2449,7 @@ def postprocess(
24492449
"""
24502450
if y is None:
24512451
return self.postprocess(self.test_input)
2452-
if isinstance(y, Dict):
2452+
if isinstance(y, dict):
24532453
return y
24542454
if isinstance(y, str):
24552455
y = pd.read_csv(y)

gradio/routes.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from fastapi.security import OAuth2PasswordRequestForm
3232
from fastapi.templating import Jinja2Templates
3333
from jinja2.exceptions import TemplateNotFound
34+
from jinja2.utils import htmlsafe_json_dumps
3435
from starlette.responses import RedirectResponse
3536
from starlette.websockets import WebSocketState
3637

@@ -55,11 +56,28 @@
5556
class ORJSONResponse(JSONResponse):
5657
media_type = "application/json"
5758

59+
@staticmethod
60+
def _render(content: Any) -> bytes:
61+
return orjson.dumps(
62+
content,
63+
option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
64+
default=str,
65+
)
66+
5867
def render(self, content: Any) -> bytes:
59-
return orjson.dumps(content, option=orjson.OPT_SERIALIZE_NUMPY)
68+
return ORJSONResponse._render(content)
69+
70+
@staticmethod
71+
def _render_str(content: Any) -> str:
72+
return ORJSONResponse._render(content).decode("utf-8")
73+
74+
75+
def toorjson(value):
76+
return htmlsafe_json_dumps(value, dumps=ORJSONResponse._render_str)
6077

6178

6279
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
80+
templates.env.filters["toorjson"] = toorjson
6381

6482

6583
###########

test/test_components.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,7 @@ async def test_in_interface(self):
551551

552552

553553
class TestImage:
554-
@pytest.mark.asyncio
555-
async def test_component_functions(self):
554+
def test_component_functions(self):
556555
"""
557556
Preprocess, postprocess, serialize, generate_sample, get_config, _segment_by_slic
558557
type: pil, file, filepath, numpy
@@ -618,8 +617,7 @@ async def test_component_functions(self):
618617
image_output = gr.Image(type="numpy")
619618
assert image_output.postprocess(y_img).startswith("data:image/png;base64,")
620619

621-
@pytest.mark.asyncio
622-
async def test_in_interface_as_input(self):
620+
def test_in_interface_as_input(self):
623621
"""
624622
Interface, process, interpret
625623
type: file
@@ -638,8 +636,7 @@ async def test_in_interface_as_input(self):
638636
lambda x: np.sum(x), image_input, "number", interpretation="default"
639637
)
640638

641-
@pytest.mark.asyncio
642-
async def test_in_interface_as_output(self):
639+
def test_in_interface_as_output(self):
643640
"""
644641
Interface, process
645642
"""
@@ -789,8 +786,7 @@ def test_tokenize(self):
789786
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
790787
assert similarity > 0.9
791788

792-
@pytest.mark.asyncio
793-
async def test_in_interface(self):
789+
def test_in_interface(self):
794790
def reverse_audio(audio):
795791
sr, data = audio
796792
return (sr, np.flipud(data))
@@ -806,8 +802,7 @@ def reverse_audio(audio):
806802
).ratio()
807803
assert similarity > 0.99
808804

809-
@pytest.mark.asyncio
810-
async def test_in_interface_as_output(self):
805+
def test_in_interface_as_output(self):
811806
"""
812807
Interface, process
813808
"""
@@ -1007,7 +1002,7 @@ def test_dataframe_postprocess_all_types(self):
10071002
"%B %d, %Y, %r"
10081003
),
10091004
"number": np.array([0.2233, 0.57281]),
1010-
"number_2": np.array([84, 23]).astype(np.int),
1005+
"number_2": np.array([84, 23]).astype(np.int64),
10111006
"bool": [True, False],
10121007
"markdown": ["# Hello", "# Goodbye"],
10131008
}
@@ -1167,8 +1162,7 @@ def test_component_functions(self):
11671162
}
11681163
).endswith(".mp4")
11691164

1170-
@pytest.mark.asyncio
1171-
async def test_in_interface(self):
1165+
def test_in_interface(self):
11721166
"""
11731167
Interface, process
11741168
"""
@@ -1396,8 +1390,7 @@ def test_color_argument(self):
13961390
)
13971391
assert update_5["color"] == "transparent"
13981392

1399-
@pytest.mark.asyncio
1400-
async def test_in_interface(self):
1393+
def test_in_interface(self):
14011394
"""
14021395
Interface, process
14031396
"""
@@ -1640,8 +1633,7 @@ def test_component_functions(self):
16401633
"root_url": None,
16411634
} == html_component.get_config()
16421635

1643-
@pytest.mark.asyncio
1644-
async def test_in_interface(self):
1636+
def test_in_interface(self):
16451637
"""
16461638
Interface, process
16471639
"""
@@ -1660,8 +1652,7 @@ def test_component_functions(self):
16601652
"""<h1>Let\'s learn about <span class="math inline"><span style=\'font-size: 0px\'>x</span><svg xmlns:xlink="http://www.w3.org/1999/xlink" width="11.6pt" height="19.35625pt" viewBox="0 0 11.6 19.35625" xmlns="http://www.w3.org/2000/svg" version="1.1">\n \n <defs>\n <style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id="figure_1">\n <g id="patch_1">\n <path d="M 0 19.35625"""
16611653
)
16621654

1663-
@pytest.mark.asyncio
1664-
async def test_in_interface(self):
1655+
def test_in_interface(self):
16651656
"""
16661657
Interface, process
16671658
"""
@@ -1693,8 +1684,7 @@ def test_component_functions(self):
16931684
"style": {},
16941685
} == component.get_config()
16951686

1696-
@pytest.mark.asyncio
1697-
async def test_in_interface(self):
1687+
def test_in_interface(self):
16981688
"""
16991689
Interface, process
17001690
"""

test/test_routes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import sys
55
from unittest.mock import patch
66

7+
import numpy as np
8+
import pandas as pd
79
import pytest
810
import starlette.routing
911
import websockets
@@ -398,3 +400,24 @@ def test_show_api_queue_not_enabled():
398400
io.close()
399401
io.launch(prevent_thread_lock=True, show_api=False)
400402
assert not io.show_api
403+
404+
405+
def test_orjson_serialization():
406+
df = pd.DataFrame(
407+
{
408+
"date_1": pd.date_range("2021-01-01", periods=2),
409+
"date_2": pd.date_range("2022-02-15", periods=2).strftime("%B %d, %Y, %r"),
410+
"number": np.array([0.2233, 0.57281]),
411+
"number_2": np.array([84, 23]).astype(np.int64),
412+
"bool": [True, False],
413+
"markdown": ["# Hello", "# Goodbye"],
414+
}
415+
)
416+
417+
with gr.Blocks() as demo:
418+
gr.DataFrame(df)
419+
app, _, _ = demo.launch(prevent_thread_lock=True)
420+
test_client = TestClient(app)
421+
response = test_client.get("/")
422+
assert response.status_code == 200
423+
demo.close()

ui/packages/app/build_plugins.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export function inject_ejs(): Plugin {
1111
transformIndexHtml: (html) => {
1212
return html.replace(
1313
/%gradio_config%/,
14-
`<script>window.gradio_config = {{ config | tojson }};</script>`
14+
`<script>window.gradio_config = {{ config | toorjson }};</script>`
1515
);
1616
}
1717
};

0 commit comments

Comments
 (0)