Skip to content

Commit cecaf1a

Browse files
abidlabspngwn
andauthored
Sketching + Inpainting Capabilities to Gradio (#2144)
* templates * working on backend * formatting * Sketching fe (#2184) * fix scaling on sketch + bg img * tweaks * ketch updates * cursor style * sketchpad * fixes * ensure background is white for bw sketch * fix everything * re-enable demos * updated demo and changed from dict to str * beta release * fix bugs, tweak webcam source * re-anable demos * fix clear button and tab changing * maybe fix test * maybe fix test again maybe * various fixes * fix img uplaod + color sketch * remove lazy brush but keep smoothing * fix sketch bg Co-authored-by: pngwn <hello@pngwn.io>
1 parent 581fbab commit cecaf1a

20 files changed

Lines changed: 727 additions & 356 deletions

File tree

demo/all_demos/tmp.zip

180 KB
Binary file not shown.

demo/blocks_mask/run.py

Lines changed: 126 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,135 @@
11
import gradio as gr
2-
import os
2+
from gradio.components import Markdown as md
33

4-
def fn(mask):
5-
return [mask["image"], mask["mask"]]
4+
demo = gr.Blocks()
65

6+
io1a = gr.Interface(lambda x: x, gr.Image(), gr.Image())
7+
io1b = gr.Interface(lambda x: x, gr.Image(source="webcam"), gr.Image())
8+
9+
io2a = gr.Interface(lambda x: x, gr.Image(source="canvas"), gr.Image())
10+
io2b = gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())
11+
12+
io3a = gr.Interface(
13+
lambda x: [x["mask"], x["image"]],
14+
gr.Image(source="upload", tool="sketch"),
15+
[gr.Image(), gr.Image()],
16+
)
17+
18+
io3b = gr.Interface(
19+
lambda x: [x["mask"], x["image"]],
20+
gr.ImageMask(),
21+
[gr.Image(), gr.Image()],
22+
)
23+
24+
io3b2 = gr.Interface(
25+
lambda x: [x["mask"], x["image"]],
26+
gr.ImageMask(),
27+
[gr.Image(), gr.Image()],
28+
)
29+
30+
io3b3 = gr.Interface(
31+
lambda x: [x["mask"], x["image"]],
32+
gr.ImageMask(),
33+
[gr.Image(), gr.Image()],
34+
)
35+
36+
io3c = gr.Interface(
37+
lambda x: [x["mask"], x["image"]],
38+
gr.Image(source="webcam", tool="sketch"),
39+
[gr.Image(), gr.Image()],
40+
)
41+
42+
io4a = gr.Interface(
43+
lambda x: x, gr.Image(source="canvas", tool="color-sketch"), gr.Image()
44+
)
45+
io4b = gr.Interface(lambda x: x, gr.Paint(), gr.Image())
46+
47+
io5a = gr.Interface(
48+
lambda x: x, gr.Image(source="upload", tool="color-sketch"), gr.Image()
49+
)
50+
io5b = gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())
51+
io5c = gr.Interface(
52+
lambda x: x, gr.Image(source="webcam", tool="color-sketch"), gr.Image()
53+
)
754

8-
demo = gr.Blocks()
955

1056
with demo:
11-
with gr.Row():
12-
with gr.Column():
13-
img = gr.Image(
14-
tool="sketch", source="upload", label="Mask", value=os.path.join(os.path.dirname(__file__), "lion.jpg")
15-
)
16-
with gr.Row():
17-
btn = gr.Button("Run")
18-
with gr.Column():
19-
img2 = gr.Image()
20-
img3 = gr.Image()
21-
22-
btn.click(fn=fn, inputs=img, outputs=[img2, img3])
57+
md("# Different Ways to Use the Image Input Component")
58+
md(
59+
"**1a. Standalone Image Upload: `gr.Interface(lambda x: x, gr.Image(), gr.Image())`**"
60+
)
61+
io1a.render()
62+
md(
63+
"**1b. Standalone Image from Webcam: `gr.Interface(lambda x: x, gr.Image(source='webcam'), gr.Image())`**"
64+
)
65+
io1b.render()
66+
md(
67+
"**2a. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas'), gr.Image())`**"
68+
)
69+
io2a.render()
70+
md(
71+
"**2b. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())`**"
72+
)
73+
io2b.render()
74+
md("**3a. Binary Mask with image upload:**")
75+
md(
76+
"""```python
77+
gr.Interface(
78+
lambda x: [x['mask'], x['image']],
79+
gr.Image(source='upload', tool='sketch'),
80+
[gr.Image(), gr.Image()],
81+
)
82+
```
83+
"""
84+
)
85+
io3a.render()
86+
md("**3b. Binary Mask with image upload:**")
87+
md(
88+
"""```python
89+
gr.Interface(
90+
lambda x: [x['mask'], x['image']],
91+
gr.ImageMask(),
92+
[gr.Image(), gr.Image()],
93+
)
94+
```
95+
"""
96+
)
97+
io3b.render()
98+
md("**3c. Binary Mask with webcam upload:**")
99+
md(
100+
"""```python
101+
gr.Interface(
102+
lambda x: [x['mask'], x['image']],
103+
gr.Image(source='webcam', tool='sketch'),
104+
[gr.Image(), gr.Image()],
105+
)
106+
```
107+
"""
108+
)
109+
io3c.render()
110+
md(
111+
"**4a. Color Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas', tool='color-sketch'), gr.Image())`**"
112+
)
113+
io4a.render()
114+
md("**4b. Color Sketchpad: `gr.Interface(lambda x: x, gr.Paint(), gr.Image())`**")
115+
io4b.render()
116+
md(
117+
"**5a. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.Image(source='upload', tool='color-sketch'), gr.Image())`**"
118+
)
119+
io5a.render()
120+
md(
121+
"**5b. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())`**"
122+
)
123+
io5b.render()
124+
md(
125+
"**5c. Color Sketchpad with webcam upload: `gr.Interface(lambda x: x, gr.Image(source='webcam', tool='color-sketch'), gr.Image())`**"
126+
)
127+
io5c.render()
128+
md("**Tabs**")
129+
with gr.Tab("One"):
130+
io3b2.render()
131+
with gr.Tab("Two"):
132+
io3b3.render()
23133

24134

25135
if __name__ == "__main__":

demo/filter_records/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def filter_records(records, gender):
1212
headers=["name", "age", "gender"],
1313
datatype=["str", "number", "str"],
1414
row_count=5,
15-
col_count=(3, "fixed")
15+
col_count=(3, "fixed"),
1616
),
1717
gr.Dropdown(["M", "F", "O"]),
1818
],

gradio/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@
6060
from gradio.templates import (
6161
Files,
6262
Highlight,
63+
ImageMask,
64+
ImagePaint,
6365
List,
6466
Matrix,
6567
Mic,
6668
Microphone,
6769
Numpy,
70+
Paint,
6871
Pil,
6972
PlayableVideo,
7073
Sketchpad,

gradio/components.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DataframeData(TypedDict):
3232
import numpy as np
3333
import pandas as pd
3434
import PIL
35+
import PIL.ImageOps
3536
from ffmpy import FFmpeg
3637
from markdown_it import MarkdownIt
3738

@@ -1175,11 +1176,11 @@ def style(
11751176
)
11761177

11771178

1178-
@document("edit", "clear", "change", "stream", "change")
1179+
@document("edit", "clear", "change", "stream", "change", "style")
11791180
class Image(Editable, Clearable, Changeable, Streamable, IOComponent, ImgSerializable):
11801181
"""
11811182
Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
1182-
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch`. In the special case, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
1183+
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
11831184
Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
11841185
Examples-format: a {str} filepath to a local file that contains the image.
11851186
Demos: image_mod, image_mod_default_image
@@ -1194,7 +1195,7 @@ def __init__(
11941195
image_mode: str = "RGB",
11951196
invert_colors: bool = False,
11961197
source: str = "upload",
1197-
tool: str = "editor",
1198+
tool: str = None,
11981199
type: str = "numpy",
11991200
label: Optional[str] = None,
12001201
show_label: bool = True,
@@ -1212,7 +1213,7 @@ def __init__(
12121213
image_mode: "RGB" if color, or "L" if black and white.
12131214
invert_colors: whether to invert the image as a preprocessing step.
12141215
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
1215-
tool: Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool, "sketch" allows you to create a mask over the image and both the image and mask are passed into the function.
1216+
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
12161217
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
12171218
label: component name in interface.
12181219
show_label: if True, will display label.
@@ -1228,7 +1229,10 @@ def __init__(
12281229
self.image_mode = image_mode
12291230
self.source = source
12301231
requires_permissions = source == "webcam"
1231-
self.tool = tool
1232+
if tool is None:
1233+
self.tool = "sketch" if source == "canvas" else "editor"
1234+
else:
1235+
self.tool = tool
12321236
self.invert_colors = invert_colors
12331237
self.test_input = deepcopy(media_data.BASE64_IMAGE)
12341238
self.interpret_by_tokens = True
@@ -1279,9 +1283,10 @@ def update(
12791283
return IOComponent.add_interactive_to_config(updated_config, interactive)
12801284

12811285
def _format_image(
1282-
self, im: Optional[PIL.Image], fmt: str
1286+
self, im: Optional[PIL.Image]
12831287
) -> np.array | PIL.Image | str | None:
12841288
"""Helper method to format an image based on self.type"""
1289+
fmt = im.format
12851290
if im is None:
12861291
return im
12871292
if self.type == "pil":
@@ -1314,36 +1319,37 @@ def generate_sample(self):
13141319
def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
13151320
"""
13161321
Parameters:
1317-
x: base64 url data, or (if tool == "sketch) a dict of image and mask base64 url data
1322+
x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
13181323
Returns:
1319-
image in requested format
1324+
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
13201325
"""
13211326
if x is None:
13221327
return x
1323-
if self.tool == "sketch":
1328+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
13241329
x, mask = x["image"], x["mask"]
1325-
13261330
im = processing_utils.decode_base64_to_image(x)
1327-
fmt = im.format
13281331
with warnings.catch_warnings():
13291332
warnings.simplefilter("ignore")
13301333
im = im.convert(self.image_mode)
13311334
if self.shape is not None:
13321335
im = processing_utils.resize_and_crop(im, self.shape)
13331336
if self.invert_colors:
13341337
im = PIL.ImageOps.invert(im)
1335-
if self.source == "webcam" and self.mirror_webcam is True:
1338+
if (
1339+
self.source == "webcam"
1340+
and self.mirror_webcam is True
1341+
and self.tool != "color-sketch"
1342+
):
13361343
im = PIL.ImageOps.mirror(im)
13371344

1338-
if not (self.tool == "sketch"):
1339-
return self._format_image(im, fmt)
1345+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
1346+
mask_im = processing_utils.decode_base64_to_image(mask)
1347+
return {
1348+
"image": self._format_image(im),
1349+
"mask": self._format_image(mask_im),
1350+
}
13401351

1341-
mask_im = processing_utils.decode_base64_to_image(mask)
1342-
mask_fmt = mask_im.format
1343-
return {
1344-
"image": self._format_image(im, fmt),
1345-
"mask": self._format_image(mask_im, mask_fmt),
1346-
}
1352+
return self._format_image(im)
13471353

13481354
def postprocess(self, y: np.ndarray | PIL.Image | str | Path) -> str:
13491355
"""

gradio/templates.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Webcam(components.Image):
3131
is_template = True
3232

3333
def __init__(self, **kwargs):
34-
super().__init__(source="webcam", **kwargs)
34+
super().__init__(source="webcam", interactive=True, **kwargs)
3535

3636

3737
class Sketchpad(components.Image):
@@ -47,10 +47,48 @@ def __init__(self, **kwargs):
4747
source="canvas",
4848
shape=(28, 28),
4949
invert_colors=True,
50+
interactive=True,
5051
**kwargs
5152
)
5253

5354

55+
class Paint(components.Image):
56+
"""
57+
Sets source="canvas", tool="color-sketch"
58+
"""
59+
60+
is_template = True
61+
62+
def __init__(self, **kwargs):
63+
super().__init__(
64+
source="canvas", tool="color-sketch", interactive=True, **kwargs
65+
)
66+
67+
68+
class ImageMask(components.Image):
69+
"""
70+
Sets source="canvas", tool="sketch"
71+
"""
72+
73+
is_template = True
74+
75+
def __init__(self, **kwargs):
76+
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
77+
78+
79+
class ImagePaint(components.Image):
80+
"""
81+
Sets source="upload", tool="color-sketch"
82+
"""
83+
84+
is_template = True
85+
86+
def __init__(self, **kwargs):
87+
super().__init__(
88+
source="upload", tool="color-sketch", interactive=True, **kwargs
89+
)
90+
91+
5492
class Pil(components.Image):
5593
"""
5694
Sets: type="pil"

gradio/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.3.1
1+
3.4b0

ui/packages/app/test/blocks_xray.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ test("can run an api request and display the data", async ({ page }) => {
6161
await page.check("label:has-text('Covid')");
6262
await page.check("label:has-text('Lung Cancer')");
6363

64-
const run_button = await page.locator("button", { hasText: /Run/ });
64+
const run_button = await page.locator("button", { hasText: /Run/ }).first();
6565

6666
await Promise.all([
6767
run_button.click(),
6868
page.waitForResponse("**/api/predict/")
6969
]);
7070

71-
const json = await page.locator("data-testid=json");
71+
const json = await page.locator("data-testid=json").first();
7272
await expect(json).toContainText(`Covid: 0.75, Lung Cancer: 0.25`);
7373
});

ui/packages/icons/src/Brush.svelte

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
<svg width="100%" height="100%" viewBox="0 0 32 32"
2+
><path
3+
d="M28.828 3.172a4.094 4.094 0 0 0-5.656 0L4.05 22.292A6.954 6.954 0 0 0 2 27.242V30h2.756a6.952 6.952 0 0 0 4.95-2.05L28.828 8.829a3.999 3.999 0 0 0 0-5.657zM10.91 18.26l2.829 2.829l-2.122 2.121l-2.828-2.828zm-2.619 8.276A4.966 4.966 0 0 1 4.756 28H4v-.759a4.967 4.967 0 0 1 1.464-3.535l1.91-1.91l2.829 2.828zM27.415 7.414l-12.261 12.26l-2.829-2.828l12.262-12.26a2.047 2.047 0 0 1 2.828 0a2 2 0 0 1 0 2.828z"
4+
fill="currentColor"
5+
/><path
6+
d="M6.5 15a3.5 3.5 0 0 1-2.475-5.974l3.5-3.5a1.502 1.502 0 0 0 0-2.121a1.537 1.537 0 0 0-2.121 0L3.415 5.394L2 3.98l1.99-1.988a3.585 3.585 0 0 1 4.95 0a3.504 3.504 0 0 1 0 4.949L5.439 10.44a1.502 1.502 0 0 0 0 2.121a1.537 1.537 0 0 0 2.122 0l4.024-4.024L13 9.95l-4.025 4.024A3.475 3.475 0 0 1 6.5 15z"
7+
fill="currentColor"
8+
/></svg
9+
>

0 commit comments

Comments
 (0)