Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ No changes to highlight.
* Fixed small typos in the docs [@julien-c](https://github.com/julien-c) in [PR 2373](https://github.com/gradio-app/gradio/pull/2373)
* Adds ability to disable pre/post-processing for examples [@abidlabs](https://github.com/abidlabs) in [PR 2383](https://github.com/gradio-app/gradio/pull/2383)
* Copy changelog file in website docker by [@aliabd](https://github.com/aliabd) in [PR 2384](https://github.com/gradio-app/gradio/pull/2384)
* Lets users provide a `gr.update()` dictionary even if post-processing is diabled [@abidlabs](https://github.com/abidlabs) in [PR 2385](https://github.com/gradio-app/gradio/pull/2385)

## Contributors Shoutout:
No changes to highlight.
Expand Down
78 changes: 48 additions & 30 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,33 @@ def skip() -> dict:
return update()


def postprocess_update_dict(block: Block, update_dict: Dict):
def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
"""
Converts a dictionary of updates into a format that can be sent to the frontend.
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
Into -> {"__type__": "update", "value": 2.0, "mode": "static"}

Parameters:
block: The Block that is being updated with this update dictionary.
update_dict: The original update dictionary
postprocess: Whether to postprocess the "value" key of the update dictionary.
"""
prediction_value = block.get_specific_update(update_dict)
if prediction_value.get("value") is components._Keywords.NO_VALUE:
prediction_value.pop("value")
prediction_value = delete_none(prediction_value, skip_value=True)
if "value" in prediction_value:
if "value" in prediction_value and postprocess:
prediction_value["value"] = block.postprocess(prediction_value["value"])
return prediction_value


def convert_update_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
"""
Converts a dictionary of component updates into a list of updates in the order of
the outputs_ids and including every output component.
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
"""
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
if all(keys_are_blocks):
reordered_predictions = [skip() for _ in outputs_ids]
Expand Down Expand Up @@ -705,57 +721,59 @@ def postprocess_data(self, fn_index, predictions, state):
dependency = self.dependencies[fn_index]

if type(predictions) is dict and len(predictions) > 0:
predictions = convert_update_dict_to_list(
predictions = convert_component_dict_to_list(
dependency["outputs"], predictions
)

if len(dependency["outputs"]) == 1:
predictions = (predictions,)

if block_fn.postprocess:
output = []
for i, output_id in enumerate(dependency["outputs"]):
if predictions[i] is components._Keywords.FINISHED_ITERATING:
output.append(None)
break
block = self.blocks[output_id]
if getattr(block, "stateful", False):
if not utils.is_update(predictions[i]):
state[output_id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if utils.is_update(prediction_value):
output_value = postprocess_update_dict(block, prediction_value)
else:
output_value = block.postprocess(prediction_value)
output.append(output_value)

else:
output = predictions
output = []
for i, output_id in enumerate(dependency["outputs"]):
if predictions[i] is components._Keywords.FINISHED_ITERATING:
output.append(None)
continue
block = self.blocks[output_id]
if getattr(block, "stateful", False):
if not utils.is_update(predictions[i]):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to your PR but why do we not set the value of a state to be an update?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point. I suppose that if the update includes a value key, we should assign the value to the state.

I think this check was introduced because if a function returns values in a component-dictionary format (e.g. {textbox: "hi"}), we create placeholder updates for the rest of the components, and those placeholder components are a singleton dictionary: {"__type__"="generic_update"}. And we wouldn't want to assign this dictionary itself to the state.

Let me make a fix and add a test!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine if we keep it for now too - don't mean to slow down this PR. I was just curious hehe

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok actually gr.State() doesn't have an update() method at the moment, so passing in a dictionary of updates for gr.State() is not supported. There's no strong reason to add update() when a function can just return the value itself, so I'll leave this out for now.

state[output_id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if utils.is_update(prediction_value):
prediction_value = postprocess_update_dict(
block=block,
update_dict=prediction_value,
postprocess=block_fn.postprocess,
)
elif block_fn.postprocess:
prediction_value = block.postprocess(prediction_value)
output.append(prediction_value)
return output

async def process_api(
self,
fn_index: int,
inputs: List[Any],
username: str = None,
state: Optional[Dict[int, Any]] = None,
iterators: Dict[int, Any] = None,
state: Dict[int, Any] | None = None,
iterators: Dict[int, Any] | None = None,
) -> Dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
then runs the relevant function, then postprocesses the output.
Parameters:
data: data recieved from the frontend
username: name of user if authentication is set up
state: data stored from stateful components for session
inputs: the list of raw inputs to pass to the function
username: name of user if authentication is set up (not used)
state: data stored from stateful components for session (key is input block id)
iterators: the in-progress iterators for each generator function (key is function index)
Returns: None
"""
block_fn = self.fns[fn_index]

inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None)
iterator = iterators.get(fn_index, None) if iterators else None

result = await self.call_function(fn_index, inputs, iterator)
block_fn.total_runtime += result["duration"]
Expand Down
8 changes: 5 additions & 3 deletions gradio/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import anyio

from gradio import utils
from gradio.blocks import convert_update_dict_to_list, postprocess_update_dict
from gradio.blocks import convert_component_dict_to_list, postprocess_update_dict
from gradio.components import Dataset
from gradio.context import Context
from gradio.documentation import document, set_documentation_group
Expand Down Expand Up @@ -287,7 +287,7 @@ async def predict_example(self, example_id: int) -> List[Any]:

output_ids = [output._id for output in self.outputs]
if type(predictions) is dict and len(predictions) > 0:
predictions = convert_update_dict_to_list(output_ids, predictions)
predictions = convert_component_dict_to_list(output_ids, predictions)
Comment thread
abidlabs marked this conversation as resolved.

if len(self.outputs) == 1:
predictions = [predictions]
Expand All @@ -296,7 +296,9 @@ async def predict_example(self, example_id: int) -> List[Any]:
for i, output_component in enumerate(self.outputs):
output = predictions[i]
if utils.is_update(predictions[i]):
output = postprocess_update_dict(output_component, output)
output = postprocess_update_dict(
output_component, output, self.postprocess
)
elif self.postprocess:
output = output_component.postprocess(output)
predictions_.append(output)
Expand Down
61 changes: 48 additions & 13 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def captured_output():
sys.stdout, sys.stderr = old_out, old_err


class TestBlocks(unittest.TestCase):
class TestBlocksMethods(unittest.TestCase):
maxDiff = None

def test_set_share(self):
Expand Down Expand Up @@ -304,24 +304,59 @@ def test_blocks_does_not_replace_keyword_literal(self):
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None)
assert output[0]["value"] == "NO_VALUE"

def test_blocks_returns_correct_output_dict_single_key(self):
with gr.Blocks() as demo:
num = gr.Number()
num2 = gr.Number()
update = gr.Button(value="update")

def update_values():
return {num2: gr.Number.update(value=42)}

def test_blocks_returns_correct_output_dict_single_key():
update.click(update_values, inputs=[num], outputs=[num2])

with gr.Blocks() as demo:
num = gr.Number()
num2 = gr.Number()
update = gr.Button(value="update")
output = demo.postprocess_data(
0, {num2: gr.Number.update(value=42)}, state=None
)
assert output[0]["value"] == 42

def update_values():
return {num2: gr.Number.update(value=42)}
output = demo.postprocess_data(0, {num2: 23}, state=None)
assert output[0] == 23

update.click(update_values, inputs=[num], outputs=[num2])
@pytest.mark.asyncio
async def test_blocks_update_dict_without_postprocessing(self):
def infer(x):
return gr.media_data.BASE64_IMAGE, gr.update(visible=True)

with gr.Blocks() as demo:
prompt = gr.Textbox()
image = gr.Image()
run_button = gr.Button()
share_button = gr.Button("share", visible=False)
run_button.click(infer, prompt, [image, share_button], postprocess=False)
Comment thread
abidlabs marked this conversation as resolved.

output = demo.postprocess_data(0, {num2: gr.Number.update(value=42)}, state=None)
assert output[0]["value"] == 42
output = await demo.process_api(0, ["test"])
assert output["data"][0] == gr.media_data.BASE64_IMAGE
assert output["data"][1] == {"__type__": "update", "visible": True}

output = demo.postprocess_data(0, {num2: 23}, state=None)
assert output[0] == 23
@pytest.mark.asyncio
async def test_blocks_update_dict_does_not_postprocess_value_if_postprocessing_false(
self,
):
def infer(x):
return gr.Image.update(value=gr.media_data.BASE64_IMAGE)

with gr.Blocks() as demo:
prompt = gr.Textbox()
image = gr.Image()
run_button = gr.Button()
run_button.click(infer, [prompt], [image], postprocess=False)

output = await demo.process_api(0, ["test"])
assert output["data"][0] == {
"__type__": "update",
"value": gr.media_data.BASE64_IMAGE,
}


Comment thread
abidlabs marked this conversation as resolved.
class TestCallFunction:
Expand Down