Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/deploy-spaces.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ jobs:
uses: pnpm/action-setup@v2.2.2
with:
version: 7
- uses: actions/setup-node@v3
with:
node-version: 16
cache: pnpm
cache-dependency-path: ui/pnpm-lock.yaml
- name: Install pip
run: python -m pip install build requests
- name: Get PR Number
Expand All @@ -33,11 +38,13 @@ jobs:
export AWS_DEFAULT_REGION=us-east-1
echo ${{ env.GRADIO_VERSION }} > gradio/version.txt
cd ui
pnpm i
pnpm i --frozen-lockfile
pnpm build
cd ..
python3 -m build -w
aws s3 cp dist/gradio-${{ env.GRADIO_VERSION }}-py3-none-any.whl s3://gradio-builds/${{ github.sha }}/
env:
NODE_OPTIONS: --max_old_space_size=8192
- name: Install Hub Client Library
run: pip install huggingface-hub
- name: Set up Demos
Expand Down
31 changes: 31 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,37 @@

## New Features:

### Support for altair plots

The `Plot` component can now accept altair plots as values!
Simply return an altair plot from your event listener and gradio will display it in the front-end.
See the example below:

```python
import gradio as gr
import altair as alt
from vega_datasets import data

cars = data.cars()
chart = (
alt.Chart(cars)
.mark_point()
.encode(
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
)
)

with gr.Blocks() as demo:
gr.Plot(value=chart)
demo.launch()
```

<img width="1366" alt="image" src="https://user-images.githubusercontent.com/41651716/204660697-f994316f-5ca7-4e8a-93bc-eb5e0d556c91.png">

By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741)

### Set the color of a Label component with a function

The `Label` component now accepts a `color` argument by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2736](https://github.com/gradio-app/gradio/pull/2736).
Expand Down
2 changes: 2 additions & 0 deletions demo/altair_plot/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
altair
vega_datasets
1 change: 1 addition & 0 deletions demo/altair_plot/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: altair_plot"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio altair vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair as alt\n", "import gradio as gr\n", "import numpy as np\n", "import pandas as pd\n", "from vega_datasets import data\n", "\n", "\n", "def make_plot(plot_type):\n", " if plot_type == \"scatter_plot\":\n", " cars = data.cars()\n", " return alt.Chart(cars).mark_point().encode(\n", " x='Horsepower',\n", " y='Miles_per_Gallon',\n", " color='Origin',\n", " )\n", " elif plot_type == \"heatmap\":\n", " # Compute x^2 + y^2 across a 2D grid\n", " x, y = np.meshgrid(range(-5, 5), range(-5, 5))\n", " z = x ** 2 + y ** 2\n", "\n", " # Convert this grid to columnar data expected by Altair\n", " source = pd.DataFrame({'x': x.ravel(),\n", " 'y': y.ravel(),\n", " 'z': z.ravel()})\n", " return alt.Chart(source).mark_rect().encode(\n", " x='x:O',\n", " y='y:O',\n", " color='z:Q'\n", " )\n", " elif plot_type == \"us_map\":\n", " states = alt.topo_feature(data.us_10m.url, 'states')\n", " source = data.income.url\n", "\n", " return alt.Chart(source).mark_geoshape().encode(\n", " shape='geo:G',\n", " color='pct:Q',\n", " tooltip=['name:N', 'pct:Q'],\n", " facet=alt.Facet('group:N', columns=2),\n", " ).transform_lookup(\n", " lookup='id',\n", " from_=alt.LookupData(data=states, key='id'),\n", " as_='geo'\n", " ).properties(\n", " width=300,\n", " height=175,\n", " ).project(\n", " type='albersUsa'\n", " )\n", " elif plot_type == \"interactive_barplot\":\n", " source = data.movies.url\n", "\n", " pts = alt.selection(type=\"single\", encodings=['x'])\n", "\n", " rect = alt.Chart(data.movies.url).mark_rect().encode(\n", " alt.X('IMDB_Rating:Q', bin=True),\n", " alt.Y('Rotten_Tomatoes_Rating:Q', bin=True),\n", " alt.Color('count()',\n", " scale=alt.Scale(scheme='greenblue'),\n", " legend=alt.Legend(title='Total Records')\n", " )\n", " )\n", "\n", " circ = rect.mark_point().encode(\n", " alt.ColorValue('grey'),\n", " alt.Size('count()',\n", " legend=alt.Legend(title='Records in Selection')\n", " )\n", " ).transform_filter(\n", " pts\n", " )\n", "\n", " bar = alt.Chart(source).mark_bar().encode(\n", " x='Major_Genre:N',\n", " y='count()',\n", " color=alt.condition(pts, alt.ColorValue(\"steelblue\"), alt.ColorValue(\"grey\"))\n", " ).properties(\n", " width=550,\n", " height=200\n", " ).add_selection(pts)\n", "\n", " plot = alt.vconcat(\n", " rect + circ,\n", " bar\n", " ).resolve_legend(\n", " color=\"independent\",\n", " size=\"independent\"\n", " )\n", " return plot\n", " elif plot_type == \"radial\":\n", " source = pd.DataFrame({\"values\": [12, 23, 47, 6, 52, 19]})\n", "\n", " base = alt.Chart(source).encode(\n", " theta=alt.Theta(\"values:Q\", stack=True),\n", " radius=alt.Radius(\"values\", scale=alt.Scale(type=\"sqrt\", zero=True, rangeMin=20)),\n", " color=\"values:N\",\n", " )\n", "\n", " c1 = base.mark_arc(innerRadius=20, stroke=\"#fff\")\n", "\n", " c2 = base.mark_text(radiusOffset=10).encode(text=\"values:Q\")\n", "\n", " return c1 + c2\n", " elif plot_type == \"multiline\":\n", " source = data.stocks()\n", "\n", " highlight = alt.selection(type='single', on='mouseover',\n", " fields=['symbol'], nearest=True)\n", "\n", " base = alt.Chart(source).encode(\n", " x='date:T',\n", " y='price:Q',\n", " color='symbol:N'\n", " )\n", "\n", " points = base.mark_circle().encode(\n", " opacity=alt.value(0)\n", " ).add_selection(\n", " highlight\n", " ).properties(\n", " width=600\n", " )\n", "\n", " lines = base.mark_line().encode(\n", " size=alt.condition(~highlight, alt.value(1), alt.value(3))\n", " )\n", "\n", " return points + lines\n", "\n", "\n", "with gr.Blocks() as demo:\n", " button = gr.Radio(label=\"Plot type\",\n", " choices=['scatter_plot', 'heatmap', 'us_map',\n", " 'interactive_barplot', \"radial\", \"multiline\"], value='scatter_plot')\n", " plot = gr.Plot(label=\"Plot\")\n", " button.change(make_plot, inputs=button, outputs=[plot])\n", " demo.load(make_plot, inputs=[button], outputs=[plot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
140 changes: 140 additions & 0 deletions demo/altair_plot/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import altair as alt
import gradio as gr
import numpy as np
import pandas as pd
from vega_datasets import data


def make_plot(plot_type):
if plot_type == "scatter_plot":
cars = data.cars()
return alt.Chart(cars).mark_point().encode(
x='Horsepower',
y='Miles_per_Gallon',
color='Origin',
)
elif plot_type == "heatmap":
# Compute x^2 + y^2 across a 2D grid
x, y = np.meshgrid(range(-5, 5), range(-5, 5))
z = x ** 2 + y ** 2

# Convert this grid to columnar data expected by Altair
source = pd.DataFrame({'x': x.ravel(),
'y': y.ravel(),
'z': z.ravel()})
return alt.Chart(source).mark_rect().encode(
x='x:O',
y='y:O',
color='z:Q'
)
elif plot_type == "us_map":
states = alt.topo_feature(data.us_10m.url, 'states')
source = data.income.url

return alt.Chart(source).mark_geoshape().encode(
shape='geo:G',
color='pct:Q',
tooltip=['name:N', 'pct:Q'],
facet=alt.Facet('group:N', columns=2),
).transform_lookup(
lookup='id',
from_=alt.LookupData(data=states, key='id'),
as_='geo'
).properties(
width=300,
height=175,
).project(
type='albersUsa'
)
elif plot_type == "interactive_barplot":
source = data.movies.url

pts = alt.selection(type="single", encodings=['x'])

rect = alt.Chart(data.movies.url).mark_rect().encode(
alt.X('IMDB_Rating:Q', bin=True),
alt.Y('Rotten_Tomatoes_Rating:Q', bin=True),
alt.Color('count()',
scale=alt.Scale(scheme='greenblue'),
legend=alt.Legend(title='Total Records')
)
)

circ = rect.mark_point().encode(
alt.ColorValue('grey'),
alt.Size('count()',
legend=alt.Legend(title='Records in Selection')
)
).transform_filter(
pts
)

bar = alt.Chart(source).mark_bar().encode(
x='Major_Genre:N',
y='count()',
color=alt.condition(pts, alt.ColorValue("steelblue"), alt.ColorValue("grey"))
).properties(
width=550,
height=200
).add_selection(pts)

plot = alt.vconcat(
rect + circ,
bar
).resolve_legend(
color="independent",
size="independent"
)
return plot
elif plot_type == "radial":
source = pd.DataFrame({"values": [12, 23, 47, 6, 52, 19]})

base = alt.Chart(source).encode(
theta=alt.Theta("values:Q", stack=True),
radius=alt.Radius("values", scale=alt.Scale(type="sqrt", zero=True, rangeMin=20)),
color="values:N",
)

c1 = base.mark_arc(innerRadius=20, stroke="#fff")

c2 = base.mark_text(radiusOffset=10).encode(text="values:Q")

return c1 + c2
elif plot_type == "multiline":
source = data.stocks()

highlight = alt.selection(type='single', on='mouseover',
fields=['symbol'], nearest=True)

base = alt.Chart(source).encode(
x='date:T',
y='price:Q',
color='symbol:N'
)

points = base.mark_circle().encode(
opacity=alt.value(0)
).add_selection(
highlight
).properties(
width=600
)

lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)

return points + lines


with gr.Blocks() as demo:
button = gr.Radio(label="Plot type",
choices=['scatter_plot', 'heatmap', 'us_map',
'interactive_barplot', "radial", "multiline"], value='scatter_plot')
plot = gr.Plot(label="Plot")
button.change(make_plot, inputs=button, outputs=[plot])
demo.load(make_plot, inputs=[button], outputs=[plot])


if __name__ == "__main__":
demo.launch()
3 changes: 2 additions & 1 deletion demo/outbreak_forecast/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy
matplotlib
bokeh
plotly
plotly
altair
2 changes: 1 addition & 1 deletion demo/outbreak_forecast/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
9 changes: 8 additions & 1 deletion demo/outbreak_forecast/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import altair

import gradio as gr
from math import sqrt
import matplotlib
Expand Down Expand Up @@ -39,12 +41,16 @@ def outbreak(plot_type, r, month, countries, social_distancing):
yaxis_title="Days Since Day 0",
)
return fig
elif plot_type == "Altair":
df = df.melt(id_vars="day").rename(columns={"variable": "country"})
fig = altair.Chart(df).mark_line().encode(x="day", y='value', color='country')
return fig
else:
raise ValueError("A plot type must be selected")


inputs = [
gr.Dropdown(["Matplotlib", "Plotly"], label="Plot Type"),
gr.Dropdown(["Matplotlib", "Plotly", "Altair"], label="Plot Type"),
gr.Slider(1, 4, 3.2, label="R"),
gr.Dropdown(["January", "February", "March", "April", "May"], label="Month"),
gr.CheckboxGroup(
Expand All @@ -60,6 +66,7 @@ def outbreak(plot_type, r, month, countries, social_distancing):
outputs=outputs,
examples=[
["Matplotlib", 2, "March", ["Mexico", "UK"], True],
["Altair", 2, "March", ["Mexico", "Canada"], True],
["Plotly", 3.6, "February", ["Canada", "Mexico", "UK"], False],
],
cache_examples=True,
Expand Down
Loading