Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
621acee
work on saving flags in JSON format
chrisemezue Jul 18, 2022
1dcc0ad
explained what I did more clearly
chrisemezue Jul 18, 2022
abf5bdc
final updates + added test case
chrisemezue Jul 25, 2022
ea19435
reviews to flagging.py for HuggingFaceDatasetJSONSaver
chrisemezue Aug 8, 2022
80d8f02
formatted imports
chrisemezue Aug 8, 2022
5b78490
Merge branch 'main' into main
abidlabs Aug 11, 2022
40eb226
used uuid for random ids
chrisemezue Aug 11, 2022
31faab4
Merge branch 'main' of https://github.com/chrisemezue/gradio
chrisemezue Aug 11, 2022
6e93fde
used uuid for random + function to get dataset infos
chrisemezue Aug 11, 2022
1b3e8d6
reformmated flagging.py
chrisemezue Aug 11, 2022
af9a3eb
fix examples test
abidlabs Aug 11, 2022
1ee063f
formatting
abidlabs Aug 11, 2022
7941131
async examples
abidlabs Aug 11, 2022
9c65aea
working on mix
abidlabs Aug 11, 2022
df154e2
comment out failing test
abidlabs Aug 11, 2022
3f447ea
fixed interface problem
abidlabs Aug 11, 2022
9e6a692
Merge branch 'fix-async-tests' into chrisemezue/main
abidlabs Aug 11, 2022
d7961f8
Merge branch 'main' into main
abidlabs Aug 11, 2022
c8ccb90
final updates to HuggingFaceDatasetJSONSaver flagging.py
chrisemezue Aug 11, 2022
9f43268
Merge branch 'main' of https://github.com/chrisemezue/gradio
chrisemezue Aug 11, 2022
3b036f2
final updates to HuggingFaceDatasetJSONSaver flagging.py
chrisemezue Aug 11, 2022
12c5f2f
formatting
abidlabs Aug 12, 2022
eca9ae2
some tweaks
abidlabs Aug 12, 2022
fc76fbd
tweaks
abidlabs Aug 12, 2022
1423525
tweaks
abidlabs Aug 12, 2022
c076a30
merge
abidlabs Aug 23, 2022
fb76673
omar's fixes
abidlabs Aug 23, 2022
f9e24cc
added back test.init
abidlabs Aug 23, 2022
5a1388b
restored test init
abidlabs Aug 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
FlaggingCallback,
HuggingFaceDatasetSaver,
SimpleCSVLogger,
HuggingFaceDatasetJSONSaver,
)
from gradio.interface import Interface, TabbedInterface, close_all
from gradio.ipython_ext import load_ipython_extension
Expand Down
169 changes: 169 additions & 0 deletions gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import io
import json
import os
import random
Comment thread
chrisemezue marked this conversation as resolved.
Outdated
import string
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional

Expand Down Expand Up @@ -344,3 +346,170 @@ def flag(
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))

return line_count


class HuggingFaceDatasetJSONSaver(FlaggingCallback):
"""
A FlaggingCallback that saves flagged data to a HuggingFace dataset in JSONL format.
Comment thread
chrisemezue marked this conversation as resolved.
Outdated
"""

def __init__(
self,
hf_foken: str,
dataset_name: str,
organization: Optional[str] = None,
private: bool = False,
verbose: bool = True,
):
"""
Params:
hf_token (str): The token to use to access the huggingface API.
dataset_name (str): The name of the dataset to save the data to, e.g.
"image-classifier-1"
organization (str): The name of the organization to which to attach
the datasets. If None, the dataset attaches to the user only.
private (bool): If the dataset does not already exist, whether it
should be created as a private dataset or public. Private datasets
may require paid huggingface.co accounts
verbose (bool): Whether to print out the status of the dataset
creation.
"""
self.hf_foken = hf_foken
self.dataset_name = dataset_name
self.organization_name = organization
self.dataset_private = private
self.verbose = verbose

def setup(self, components: List[Component], flagging_dir: str):
"""
Params:
flagging_dir (str): local directory where the dataset is cloned,
Comment thread
chrisemezue marked this conversation as resolved.
updated, and pushed from.
"""
try:
import huggingface_hub
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
Comment thread
chrisemezue marked this conversation as resolved.
Outdated
)
path_to_dataset_repo = huggingface_hub.create_repo(
name=self.dataset_name,
token=self.hf_foken,
private=self.dataset_private,
repo_type="dataset",
exist_ok=True,
)
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
self.components = components
self.flagging_dir = flagging_dir
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
self.repo = huggingface_hub.Repository(
local_dir=self.dataset_dir,
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_foken,
)
self.repo.git_pull()
Comment thread
abidlabs marked this conversation as resolved.
Outdated

self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")

def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
self.repo.git_pull(lfs=True)
Comment thread
chrisemezue marked this conversation as resolved.
Comment thread
chrisemezue marked this conversation as resolved.

# Generate unique folder for the flagged sample
unique_name = self.get_unique_name() # unique name for folder
folder_name = os.path.join(
self.dataset_dir, unique_name
) # unique folder for specific example
os.makedirs(folder_name, exist_ok=True)
Comment thread
chrisemezue marked this conversation as resolved.
Outdated

# Now uses the existence of `dataset_infos.json` to determine if new
is_new = not os.path.exists(self.infos_file)

infos = {"flagged": {"features": {}}}
Comment thread
chrisemezue marked this conversation as resolved.
Outdated

# File previews for certain input and output types
file_preview_types = {
Comment thread
chrisemezue marked this conversation as resolved.
Outdated
gr.inputs.Audio: "Audio",
gr.outputs.Audio: "Audio",
gr.inputs.Image: "Image",
gr.outputs.Image: "Image",
}
Comment thread
chrisemezue marked this conversation as resolved.
Outdated

# Generate the headers and dataset_infos
if is_new:

for component, sample in zip(self.components, flag_data):
infos["flagged"]["features"][component.label] = {
"dtype": "string",
"_type": "Value",
}
if isinstance(component, tuple(file_preview_types)):
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
infos["flagged"]["features"][component.label + " file"] = {
"_type": _type
}
break

infos["flagged"]["features"]["flag"] = {
"dtype": "string",
"_type": "Value",
}

# Generate the row and header corresponding to the flagged sample
csv_data = []
headers = []

for component, sample in zip(self.components, flag_data):
headers.append(component.label)

try:
filepath = component.save_flagged(
folder_name, component.label, sample, None
)
except Exception:
# Could not parse 'sample' (mostly) because it was None and `component.save_flagged` does not handle None cases. for example: Label (line 3109 of components.py raises an error if data is None)
filepath = None

if isinstance(component, tuple(file_preview_types)):
headers.append(component.label + " file")

csv_data.append(
"{}/resolve/main/{}/{}".format(
self.path_to_dataset_repo, unique_name, filepath
)
if filepath is not None
else None
)

csv_data.append(filepath)
headers.append("flag")
csv_data.append(flag_option if flag_option is not None else "")

# Creates metadata dict from row data and dumps it
metadata_dict = {
header: _csv_data for header, _csv_data in zip(headers, csv_data)
}
self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl"))
Comment thread
chrisemezue marked this conversation as resolved.

if is_new:
json.dump(infos, open(self.infos_file, "w"))

self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
return unique_name

def get_unique_name(self):
return "".join(
[random.choice(string.ascii_letters + string.digits) for n in range(32)]
)
Comment thread
chrisemezue marked this conversation as resolved.
Outdated

def dump_json(self, thing: dict, file_path: str) -> None:
with open(file_path, "w+", encoding="utf8") as f:
json.dump(thing, f)
35 changes: 35 additions & 0 deletions test/test_flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,41 @@ def test_saver_flag(self):
self.assertEqual(row_count, 2) # 3 rows written including header


class TestHuggingFaceDatasetJSONSaver(unittest.TestCase):
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetJSONSaver("test", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
huggingface_hub.create_repo.assert_called_once()

def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.HuggingFaceDatasetJSONSaver("test", "test"),
)
test_dir = os.path.join(tmpdirname, "test")
os.mkdir(test_dir)
io.launch(prevent_thread_lock=True)
row_unique_name = io.flagging_callback.flag(["test", "test"])
# Test existence of metadata.jsonl file for that example
self.assertEqual(
os.path.isfile(
os.path.join(
os.path.join(test_dir, row_unique_name), "metadata.jsonl"
)
),
True,
)


class TestDisableFlagging(unittest.TestCase):
def test_flagging_no_permission_error_with_flagging_disabled(self):
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down