Skip to content

Commit c7b6326

Browse files
committed
[fix] Fixes by @coderabbitai
1 parent 852e204 commit c7b6326

2 files changed

Lines changed: 84 additions & 49 deletions

File tree

openwisp_users/management/commands/export_users.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,6 @@ def normalize_field(field):
1616
return {"name": field}
1717

1818

19-
def serialize_related(manager, subfields):
20-
"""Serialize a RelatedManager queryset using the given subfields.
21-
22-
Single subfield → comma-separated values: val1,val2,...
23-
Multiple subfields → tuple-per-row format: ((v1,v2),(v3,v4))
24-
"""
25-
rows = [[str(getattr(obj, f, "")) for f in subfields] for obj in manager.all()]
26-
if not rows:
27-
return ""
28-
if len(subfields) == 1:
29-
return ",".join(row[0] for row in rows)
30-
return "(" + ",".join("(" + ",".join(row) + ")" for row in rows) + ")"
31-
32-
3319
class Command(BaseCommand):
3420
help = "Exports user data to a CSV file"
3521

@@ -51,15 +37,15 @@ def add_arguments(self, parser):
5137
)
5238

5339
def handle(self, *args, **options):
54-
fields = app_settings.EXPORT_USERS_COMMAND_CONFIG.get("fields", []).copy()
40+
raw_fields = app_settings.EXPORT_USERS_COMMAND_CONFIG.get("fields", []).copy()
5541
# Get the fields to be excluded from the command-line argument
5642
exclude_fields = [
5743
t.strip() for t in options.get("exclude_fields").split(",") if t.strip()
5844
]
5945
# Remove excluded fields from the export fields (match on the field name)
6046
fields = [
6147
field
62-
for field in fields
48+
for field in raw_fields
6349
if normalize_field(field)["name"] not in exclude_fields
6450
]
6551
# Fetch all user data using select_related and prefetch_related
@@ -89,6 +75,40 @@ def handle(self, *args, **options):
8975
self.style.SUCCESS(f"User data exported successfully to {filename}!")
9076
)
9177

78+
def serialize_related(self, manager, subfields):
79+
"""Serialize a RelatedManager queryset using the given subfields.
80+
81+
Single subfield → comma-separated values: val1,val2,...
82+
Multiple subfields → tuple-per-row format: ((v1,v2),(v3,v4))
83+
"""
84+
rows = [
85+
[str(self._get_nested_attr(obj, f)) for f in subfields]
86+
for obj in manager.all()
87+
]
88+
if not rows:
89+
return ""
90+
if len(subfields) == 1:
91+
return ",".join(row[0] for row in rows)
92+
return "(" + ",".join("(" + ",".join(row) + ")" for row in rows) + ")"
93+
94+
def _get_nested_attr(self, obj, attr_path):
95+
if not attr_path:
96+
return obj
97+
parts = attr_path.split(".")
98+
current = obj
99+
for i, part in enumerate(parts):
100+
try:
101+
current = getattr(current, part)
102+
except (ObjectDoesNotExist, AttributeError):
103+
return ""
104+
if hasattr(current, "iterator") and i < len(parts) - 1:
105+
remaining_path = ".".join(parts[i + 1 :])
106+
return ",".join(
107+
str(self._get_nested_attr(item, remaining_path))
108+
for item in current.iterator()
109+
)
110+
return current
111+
92112
def _get_field_value(self, user, field):
93113
normalized = normalize_field(field)
94114
name = normalized["name"]
@@ -101,30 +121,14 @@ def _get_field_value(self, user, field):
101121
except Exception as e:
102122
raise CommandError(f"Error calling function for field '{name}': {e}")
103123
if subfields is not None:
104-
try:
105-
attr = getattr(user, name)
106-
except ObjectDoesNotExist:
107-
return ""
124+
attr = self._get_nested_attr(user, name)
108125
if attr is None:
109126
return ""
110127
if hasattr(attr, "iterator"):
111-
return serialize_related(attr, subfields)
112-
return ",".join(str(getattr(attr, f, "")) for f in subfields)
128+
return self.serialize_related(attr, subfields)
129+
return ",".join(str(self._get_nested_attr(attr, f)) for f in subfields)
113130

114131
# Dot-notation: e.g. "auth_token.key" or "profile.phone_number"
115132
if "." in name:
116-
model_attr, sub_attr = name.split(".", 1)
117-
try:
118-
intermediate = getattr(user, model_attr)
119-
except ObjectDoesNotExist:
120-
return ""
121-
if hasattr(intermediate, "iterator"):
122-
# Related manager accessed via dot notation → comma-separated values
123-
return ",".join(
124-
str(getattr(obj, sub_attr, "")) for obj in intermediate.iterator()
125-
)
126-
try:
127-
return getattr(intermediate, sub_attr)
128-
except ObjectDoesNotExist:
129-
return ""
130-
return getattr(user, name)
133+
return self._get_nested_attr(user, name)
134+
return self._get_nested_attr(user, name)

openwisp_users/tests/test_commands.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from openwisp_utils.tests import capture_stdout
1111

1212
from .. import settings as app_settings
13-
from ..management.commands.export_users import Command, normalize_field
13+
from ..management.commands.export_users import Command
1414
from .utils import TestOrganizationMixin
1515

1616

@@ -47,9 +47,25 @@ def test_export_users(self):
4747

4848
# 3 user and 1 header
4949
self.assertEqual(len(csv_data), 4)
50+
# Expected headers are the keys produced by normalize_field for the
51+
# default EXPORT_USERS_COMMAND_CONFIG. These are stable values and
52+
# asserted explicitly to avoid mirroring production code in the test.
5053
expected_headers = [
51-
normalize_field(f)["name"]
52-
for f in app_settings.EXPORT_USERS_COMMAND_CONFIG["fields"]
54+
"id",
55+
"username",
56+
"email",
57+
"password",
58+
"first_name",
59+
"last_name",
60+
"is_staff",
61+
"is_active",
62+
"date_joined",
63+
"phone_number",
64+
"birth_date",
65+
"location",
66+
"notes",
67+
"language",
68+
"organizations",
5369
]
5470
self.assertEqual(csv_data[0], expected_headers)
5571
# Ensuring ordering
@@ -67,9 +83,24 @@ def test_exclude_fields(self):
6783
call_command(
6884
"export_users",
6985
filename=self.temp_file.name,
86+
# Exclude all fields except "id"
7087
exclude_fields=",".join(
71-
normalize_field(f)["name"]
72-
for f in app_settings.EXPORT_USERS_COMMAND_CONFIG["fields"][1:]
88+
[
89+
"username",
90+
"email",
91+
"password",
92+
"first_name",
93+
"last_name",
94+
"is_staff",
95+
"is_active",
96+
"date_joined",
97+
"phone_number",
98+
"birth_date",
99+
"location",
100+
"notes",
101+
"language",
102+
"organizations",
103+
]
73104
),
74105
)
75106
with open(self.temp_file.name, "r") as temp_file:
@@ -107,10 +138,9 @@ def test_related_fields(self):
107138

108139
# 3 user and 1 header
109140
self.assertEqual(len(csv_data), 2)
110-
expected_headers = [
111-
normalize_field(f)["name"]
112-
for f in app_settings.EXPORT_USERS_COMMAND_CONFIG["fields"]
113-
]
141+
# When fields are ["id", "auth_token.key"] the expected headers
142+
# are the literal keys used to identify columns in the CSV.
143+
expected_headers = ["id", "auth_token.key"]
114144
self.assertEqual(csv_data[0], expected_headers)
115145
self.assertEqual(csv_data[1][0], str(user.id))
116146
self.assertEqual(csv_data[1][1], str(token.key))
@@ -152,9 +182,10 @@ def _broken_callable(user):
152182
}
153183
self._create_user()
154184
stderr = StringIO()
155-
with patch.object(
156-
app_settings, "EXPORT_USERS_COMMAND_CONFIG", config
157-
), self.assertRaises(Exception) as context:
185+
with (
186+
patch.object(app_settings, "EXPORT_USERS_COMMAND_CONFIG", config),
187+
self.assertRaises(Exception) as context,
188+
):
158189
call_command("export_users", filename=self.temp_file.name, stderr=stderr)
159190
self.assertIn(
160191
"Error calling function for field 'broken'", str(context.exception)

0 commit comments

Comments
 (0)