Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/389.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated behaviour for recursive lookups for the conversion of nested relationships. Note that this change could cause issues in transforms or generators that use the convert_query_response feature if "id" or "__typename" isn't requested for nested related objects.
53 changes: 48 additions & 5 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ async def _process_relationships(
branch: str,
related_nodes: list[InfrahubNode],
timeout: int | None = None,
recursive: bool = False,
) -> None:
"""Processes the Relationships of a InfrahubNode and add Related Nodes to a list.

Expand All @@ -903,24 +904,45 @@ async def _process_relationships(
branch (str): The branch name.
related_nodes (list[InfrahubNode]): The list to which related nodes will be appended.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
recursive:(bool): Whether to recursively process relationships of related nodes.
"""
for rel_name in self._relationships:
rel = getattr(self, rel_name)
if rel and isinstance(rel, RelatedNode):
relation = node_data["node"].get(rel_name, None)
if relation and relation.get("node", None):
related_node = await InfrahubNode.from_graphql(
client=self._client, branch=branch, data=relation, timeout=timeout
client=self._client,
branch=branch,
data=relation,
timeout=timeout,
)
related_nodes.append(related_node)
if recursive:
await related_node._process_relationships(
node_data=relation,
branch=branch,
related_nodes=related_nodes,
recursive=recursive,
)
elif rel and isinstance(rel, RelationshipManager):
peers = node_data["node"].get(rel_name, None)
if peers and peers["edges"]:
for peer in peers["edges"]:
related_node = await InfrahubNode.from_graphql(
client=self._client, branch=branch, data=peer, timeout=timeout
client=self._client,
branch=branch,
data=peer,
timeout=timeout,
)
related_nodes.append(related_node)
if recursive:
await related_node._process_relationships(
node_data=peer,
branch=branch,
related_nodes=related_nodes,
recursive=recursive,
)

async def get_pool_allocated_resources(self, resource: InfrahubNode) -> list[InfrahubNode]:
"""Fetch all nodes that were allocated for the pool and a given resource.
Expand Down Expand Up @@ -1520,6 +1542,7 @@ def _process_relationships(
branch: str,
related_nodes: list[InfrahubNodeSync],
timeout: int | None = None,
recursive: bool = False,
) -> None:
"""Processes the Relationships of a InfrahubNodeSync and add Related Nodes to a list.

Expand All @@ -1528,25 +1551,45 @@ def _process_relationships(
branch (str): The branch name.
related_nodes (list[InfrahubNodeSync]): The list to which related nodes will be appended.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.

recursive:(bool): Whether to recursively process relationships of related nodes.
"""
for rel_name in self._relationships:
rel = getattr(self, rel_name)
if rel and isinstance(rel, RelatedNodeSync):
relation = node_data["node"].get(rel_name, None)
if relation and relation.get("node", None):
related_node = InfrahubNodeSync.from_graphql(
client=self._client, branch=branch, data=relation, timeout=timeout
client=self._client,
branch=branch,
data=relation,
timeout=timeout,
)
related_nodes.append(related_node)
if recursive:
related_node._process_relationships(
node_data=relation,
branch=branch,
related_nodes=related_nodes,
recursive=recursive,
)
elif rel and isinstance(rel, RelationshipManagerSync):
peers = node_data["node"].get(rel_name, None)
if peers and peers["edges"]:
for peer in peers["edges"]:
related_node = InfrahubNodeSync.from_graphql(
client=self._client, branch=branch, data=peer, timeout=timeout
client=self._client,
branch=branch,
data=peer,
timeout=timeout,
)
related_nodes.append(related_node)
if recursive:
related_node._process_relationships(
node_data=peer,
branch=branch,
related_nodes=related_nodes,
recursive=recursive,
)

def get_pool_allocated_resources(self, resource: InfrahubNodeSync) -> list[InfrahubNodeSync]:
"""Fetch all nodes that were allocated for the pool and a given resource.
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def process_nodes(self, data: dict) -> None:
)
self._nodes.append(node)
await node._process_relationships(
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes, recursive=True
)

for node in self._nodes + self._related_nodes:
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,3 +2698,31 @@ async def mock_query_tasks_05(httpx_mock: HTTPXMock) -> HTTPXMock:
is_reusable=True,
)
return httpx_mock


@pytest.fixture
async def nested_device_with_interfaces_schema() -> NodeSchemaAPI:
"""Schema for Device with interfaces relationship for deep nesting tests."""
data = {
"name": "Device",
"namespace": "Infra",
"label": "Device",
"default_filter": "name__value",
"order_by": ["name__value"],
"display_labels": ["name__value"],
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "description", "kind": "Text", "optional": True},
],
"relationships": [
{
"name": "interfaces",
"peer": "InfraInterfaceL3",
"identifier": "device__interface",
"optional": True,
"cardinality": "many",
"kind": "Component",
},
],
}
return NodeSchema(**data).convert_api()
213 changes: 213 additions & 0 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from infrahub_sdk.node.constants import SAFE_VALUE
from infrahub_sdk.node.related_node import RelatedNode, RelatedNodeSync
from infrahub_sdk.schema import GenericSchema, NodeSchemaAPI
from tests.unit.sdk.conftest import BothClients

if TYPE_CHECKING:
from infrahub_sdk.client import InfrahubClient, InfrahubClientSync
Expand Down Expand Up @@ -2394,3 +2395,215 @@ async def test_from_graphql(clients, mock_schema_query_01, location_data01, clie
node = InfrahubNodeSync.from_graphql(client=clients.sync, schema=schema, branch="main", data=location_data01)

assert node.id == "llllllll-llll-llll-llll-llllllllllll"


@pytest.mark.parametrize("client_type", client_types)
async def test_process_relationships_recursive_deep_nesting(
clients: BothClients,
nested_device_with_interfaces_schema: NodeSchemaAPI,
client_type: str,
) -> None:
"""Test that _process_relationships with recursive=True processes deeply nested relationships.

This test validates 3-level deep nesting:
Device -> Interfaces (many) -> IP Addresses (many)

With recursive=False, only level 1 (interfaces) should be processed.
With recursive=True, all 3 levels (device, interfaces, ip_addresses) should be processed.
"""
nested_device_data = {
"node": {
"id": "device-1",
"__typename": "InfraDevice",
"display_label": "atl1-edge1",
"name": {"value": "atl1-edge1"},
"description": {"value": "Edge device in Atlanta"},
"interfaces": {
"edges": [
{
"node": {
"id": "interface-1",
"__typename": "InfraInterfaceL3",
"display_label": "Ethernet1",
"name": {"value": "Ethernet1"},
"description": {"value": "Primary interface"},
"ip_addresses": {
"edges": [
{
"node": {
"id": "ip-1",
"__typename": "InfraIPAddress",
"display_label": "10.0.0.1/24",
"address": {"value": "10.0.0.1/24"},
}
},
{
"node": {
"id": "ip-2",
"__typename": "InfraIPAddress",
"display_label": "10.0.0.2/24",
"address": {"value": "10.0.0.2/24"},
}
},
]
},
}
},
{
"node": {
"id": "interface-2",
"__typename": "InfraInterfaceL3",
"display_label": "Ethernet2",
"name": {"value": "Ethernet2"},
"description": {"value": "Secondary interface"},
"ip_addresses": {
"edges": [
{
"node": {
"id": "ip-3",
"__typename": "InfraIPAddress",
"display_label": "10.0.1.1/24",
"address": {"value": "10.0.1.1/24"},
}
}
]
},
}
},
]
},
}
}
schema_data = {
"version": "1.0",
"nodes": [
# Convert the schema objects back to dictionaries
{
"name": "Device",
"namespace": "Infra",
"attributes": [{"name": "name", "kind": "Text"}],
"relationships": [
{
"name": "interfaces",
"peer": "InfraInterfaceL3",
"cardinality": "many",
"optional": True,
}
],
},
{
"name": "InterfaceL3",
"namespace": "Infra",
"attributes": [{"name": "name", "kind": "Text"}],
"relationships": [
{
"name": "ip_addresses",
"peer": "InfraIPAddress",
"cardinality": "many",
"optional": True,
}
],
},
{
"name": "IPAddress",
"namespace": "Infra",
"attributes": [{"name": "address", "kind": "IPHost"}],
"relationships": [],
},
],
}

# Set up schemas in the client cache to enable schema lookups
if client_type == "standard":
# Create a properly structured schema response with all three schemas

clients.standard.schema.set_cache(schema_data, branch="main")

# Test with recursive=False - should only process interfaces (level 1)
device_node = await InfrahubNode.from_graphql(
client=clients.standard,
schema=nested_device_with_interfaces_schema,
branch="main",
data=nested_device_data,
)
related_nodes_non_recursive_async: list[InfrahubNode] = []
await device_node._process_relationships(
node_data=nested_device_data,
branch="main",
related_nodes=related_nodes_non_recursive_async,
recursive=False,
)
related_nodes_non_recursive = related_nodes_non_recursive_async

# Test with recursive=True - should process all levels
device_node_recursive = await InfrahubNode.from_graphql(
client=clients.standard,
schema=nested_device_with_interfaces_schema,
branch="main",
data=nested_device_data,
)
related_nodes_recursive_async: list[InfrahubNode] = []
await device_node_recursive._process_relationships(
node_data=nested_device_data,
branch="main",
related_nodes=related_nodes_recursive_async,
recursive=True,
)
related_nodes_recursive = related_nodes_recursive_async

else:
# Sync client test
clients.sync.schema.set_cache(schema_data, branch="main")

# Test with recursive=False
device_node = InfrahubNodeSync.from_graphql(
client=clients.sync,
schema=nested_device_with_interfaces_schema,
branch="main",
data=nested_device_data,
)
related_nodes_non_recursive_sync: list[InfrahubNodeSync] = []
device_node._process_relationships(
node_data=nested_device_data,
branch="main",
related_nodes=related_nodes_non_recursive_sync,
recursive=False,
)
related_nodes_non_recursive = related_nodes_non_recursive_sync

# Test with recursive=True
device_node_recursive = InfrahubNodeSync.from_graphql(
client=clients.sync,
schema=nested_device_with_interfaces_schema,
branch="main",
data=nested_device_data,
)
related_nodes_recursive_sync: list[InfrahubNodeSync] = []
device_node_recursive._process_relationships(
node_data=nested_device_data,
branch="main",
related_nodes=related_nodes_recursive_sync,
recursive=True,
)

related_nodes_recursive = related_nodes_recursive_sync

# With recursive=False, should only process the 2 interfaces (level 1)
# IP addresses (level 2) should NOT be processed
non_recursive_ids = {rn.id for rn in related_nodes_non_recursive}
assert "interface-1" in non_recursive_ids
assert "interface-2" in non_recursive_ids
# IP addresses should NOT be in the list when recursive=False
assert "ip-1" not in non_recursive_ids
assert "ip-2" not in non_recursive_ids
assert "ip-3" not in non_recursive_ids
assert len(related_nodes_non_recursive) == 2 # Only 2 interfaces

# With recursive=True, should process interfaces AND their IP addresses
recursive_ids = {rn.id for rn in related_nodes_recursive}
assert "interface-1" in recursive_ids
assert "interface-2" in recursive_ids
assert "ip-1" in recursive_ids # From interface-1
assert "ip-2" in recursive_ids # From interface-1
assert "ip-3" in recursive_ids # From interface-2
assert len(related_nodes_recursive) == 5 # 2 interfaces + 3 IP addresses