Skip to content

Commit 22f8f03

Browse files
committed
Fix kwargs in plot_graph
1 parent d3b4ed8 commit 22f8f03

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

city2graph/utils.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4177,7 +4177,7 @@ def _resolve_plot_parameter(
41774177
param_value: str | float | pd.Series | None,
41784178
_param_name: str,
41794179
default_value: Any, # noqa: ANN401
4180-
) -> str | float | pd.Series:
4180+
) -> str | float | np.ndarray | pd.Series:
41814181
"""
41824182
Resolve a plot parameter to a value usable by GeoPandas plot().
41834183
@@ -4196,15 +4196,15 @@ def _resolve_plot_parameter(
41964196
41974197
Returns
41984198
-------
4199-
str, float, or pd.Series
4199+
str, float, pd.Series, or np.ndarray
42004200
Resolved parameter value.
42014201
"""
42024202
if param_value is None:
42034203
return default_value # type: ignore[no-any-return]
42044204
if isinstance(param_value, pd.Series):
4205-
return param_value
4205+
return param_value.to_numpy()
42064206
if isinstance(param_value, str) and param_value in gdf.columns:
4207-
return gdf[param_value] # type: ignore[no-any-return]
4207+
return gdf[param_value].to_numpy() # type: ignore[no-any-return]
42084208
return param_value
42094209

42104210

@@ -4336,6 +4336,9 @@ def _or_default(val: Any, default: Any) -> Any: # noqa: ANN401
43364336
"legend_position",
43374337
"labelcolor",
43384338
"title_color",
4339+
"legend",
4340+
"legend_kwargs",
4341+
"title",
43394342
}
43404343

43414344
# Start with all global kwargs, resolving potential type-specific dictionaries
@@ -4410,7 +4413,7 @@ def _plot_gdf(
44104413
for param_name, default_val in param_defaults.items():
44114414
val = _resolve_plot_parameter(gdf, kwargs.get(param_name), param_name, default_val)
44124415
if val is not None:
4413-
if param_name == "color" and isinstance(val, pd.Series):
4416+
if param_name == "color" and isinstance(val, (pd.Series, np.ndarray)):
44144417
plot_kwargs["column"] = val
44154418
plot_kwargs.pop("color", None)
44164419
else:
@@ -4749,16 +4752,32 @@ def _plot_homo_graph(
47494752
**kwargs : Any
47504753
Additional styling arguments.
47514754
"""
4755+
title = kwargs.get("title")
4756+
legend = kwargs.get("legend")
4757+
legend_kwargs = kwargs.get("legend_kwargs")
4758+
plot_kwargs = {k: v for k, v in kwargs.items() if k not in {"title", "legend", "legend_kwargs"}}
4759+
47524760
# Plot edges first (in background)
47534761
if edges is not None and isinstance(edges, gpd.GeoDataFrame):
4754-
style_kwargs = _resolve_style_kwargs(kwargs, None, is_edge=True)
4762+
style_kwargs = _resolve_style_kwargs(plot_kwargs, None, is_edge=True)
47554763
_plot_gdf(edges, ax, **style_kwargs)
47564764

47574765
# Plot nodes on top
47584766
if nodes is not None and isinstance(nodes, gpd.GeoDataFrame):
4759-
style_kwargs = _resolve_style_kwargs(kwargs, None, is_edge=False)
4767+
style_kwargs = _resolve_style_kwargs(plot_kwargs, None, is_edge=False)
4768+
if legend is not None:
4769+
style_kwargs["legend"] = legend
4770+
if legend_kwargs is not None:
4771+
style_kwargs["legend_kwds"] = legend_kwargs
47604772
_plot_gdf(nodes, ax, **style_kwargs)
47614773

4774+
if title is not None:
4775+
title_color = kwargs.get("title_color")
4776+
ax.set_title(
4777+
title,
4778+
color=title_color if title_color is not None else PLOT_DEFAULTS["title_color"],
4779+
)
4780+
47624781

47634782
def _validate_graph_input(
47644783
graph: gpd.GeoDataFrame | tuple[gpd.GeoDataFrame, gpd.GeoDataFrame] | nx.Graph | nx.MultiGraph,

tests/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,6 +2151,38 @@ def test_plot_graph_with_ax(
21512151
finally:
21522152
plt.close(fig)
21532153

2154+
def test_plot_graph_homo_with_legend_kwargs_and_title(
2155+
self,
2156+
sample_nodes_gdf: gpd.GeoDataFrame,
2157+
sample_edges_gdf: gpd.GeoDataFrame,
2158+
) -> None:
2159+
"""Homogeneous plotting should support legend kwargs and title."""
2160+
if not utils.MATPLOTLIB_AVAILABLE:
2161+
pytest.skip("matplotlib not available")
2162+
2163+
nodes = sample_nodes_gdf.copy()
2164+
nodes["centrality_quantile"] = range(len(nodes))
2165+
2166+
edge_linewidth = pd.Series([1.0] * len(sample_edges_gdf), index=sample_edges_gdf.index)
2167+
fig, ax = plt.subplots(figsize=(5, 5))
2168+
try:
2169+
utils.plot_graph(
2170+
nodes=nodes,
2171+
edges=sample_edges_gdf,
2172+
node_color="centrality_quantile",
2173+
edge_color="#bbbbbb",
2174+
edge_linewidth=edge_linewidth,
2175+
legend=True,
2176+
legend_kwargs={"label": "Betweenness Centrality", "orientation": "horizontal"},
2177+
title="Central London Transit Network Betweenness Centrality of Stops",
2178+
ax=ax,
2179+
)
2180+
assert (
2181+
ax.get_title() == "Central London Transit Network Betweenness Centrality of Stops"
2182+
)
2183+
finally:
2184+
plt.close(fig)
2185+
21542186
@pytest.mark.skipif(not MATPLOTLIB_AVAILABLE, reason="Matplotlib not available")
21552187
def test_plot_empty_gdf(self, empty_gdf: gpd.GeoDataFrame) -> None:
21562188
"""Test _plot_gdf with empty GeoDataFrame (line 3246)."""

0 commit comments

Comments
 (0)