Skip to content

Commit 8bec3c4

Browse files
committed
[fix] Fixes by @coderabbitai
1 parent 619b153 commit 8bec3c4

2 files changed

Lines changed: 19 additions & 12 deletions

File tree

openwisp_users/management/commands/export_users.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.contrib.auth import get_user_model
44
from django.core.exceptions import ObjectDoesNotExist
55
from django.core.management.base import BaseCommand, CommandError
6+
from django.db.models.manager import BaseManager
67
from django.db.models.query import QuerySet
78
from django.utils.translation import gettext_lazy as _
89

@@ -21,6 +22,10 @@ def normalize_field(field):
2122
class Command(BaseCommand):
2223
help = _("Exports user data to a CSV file")
2324

25+
def _normalize_value(self, value):
26+
"""Convert None to empty string, otherwise stringify the value."""
27+
return "" if value is None else str(value)
28+
2429
def add_arguments(self, parser):
2530
parser.add_argument(
2631
"--exclude-fields",
@@ -73,7 +78,7 @@ def handle(self, *args, **options):
7378
row = []
7479
for field in fields:
7580
val = self._get_field_value(user, field)
76-
row.append(val if val is not None else "")
81+
row.append(val)
7782
csv_writer.writerow(row)
7883
self.stdout.write(
7984
self.style.SUCCESS(
@@ -98,7 +103,7 @@ def serialize_related(self, manager, subfields):
98103
row = []
99104
for f in subfields:
100105
val = self._get_nested_attr(obj, f)
101-
row.append("" if val is None else str(val))
106+
row.append(self._normalize_value(val))
102107
rows.append(row)
103108
if not rows:
104109
return ""
@@ -126,17 +131,15 @@ def _get_nested_attr(self, obj, attr_path):
126131
# missing attribute or intermediate raises -> None sentinel
127132
return None
128133
# Detect querysets/related managers robustly.
129-
if (isinstance(current, QuerySet) or hasattr(current, "all")) and i < len(
130-
parts
131-
) - 1:
134+
if (isinstance(current, (QuerySet, BaseManager))) and (i < len(parts) - 1):
132135
remaining_path = ".".join(parts[i + 1 :])
133136
# We use current.all() instead of current.iterator() to utilize
134137
# the prefetch_related queryset cache. The iterator() method
135138
# would bypass the cache and cause additional queries.
136139
values = []
137140
for item in current.all():
138141
v = self._get_nested_attr(item, remaining_path)
139-
values.append("" if v is None else str(v))
142+
values.append(self._normalize_value(v))
140143
return ",".join(values)
141144
return current
142145

@@ -148,19 +151,23 @@ def _get_field_value(self, user, field):
148151
# Priority: callable > fields > name
149152
if callable_fn is not None:
150153
try:
151-
return callable_fn(user)
154+
val = callable_fn(user)
152155
except Exception as e:
153156
func_name = getattr(callable_fn, "__name__", repr(callable_fn))
154157
raise CommandError(
155158
_(
156159
"Error calling function {func_name!r} for field '{name}': {e}"
157160
).format(func_name=func_name, name=name, e=e)
158161
)
162+
return self._normalize_value(val)
159163
if subfields is not None:
160164
attr = self._get_nested_attr(user, name)
161165
if attr is None:
162166
return ""
163-
if isinstance(attr, QuerySet) or hasattr(attr, "all"):
167+
if isinstance(attr, (QuerySet, BaseManager)):
164168
return self.serialize_related(attr, subfields)
165-
return ",".join(str(self._get_nested_attr(attr, f)) for f in subfields)
166-
return self._get_nested_attr(user, name)
169+
return ",".join(
170+
self._normalize_value(self._get_nested_attr(attr, f)) for f in subfields
171+
)
172+
val = self._get_nested_attr(user, name)
173+
return self._normalize_value(val)

openwisp_users/tests/test_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_related_fields(self):
139139
csv_reader = csv.reader(temp_file)
140140
csv_data = list(csv_reader)
141141

142-
# 3 user and 1 header
142+
# 1 user and 1 header
143143
self.assertEqual(len(csv_data), 2)
144144
# When fields are ["id", "auth_token.key"] the expected headers
145145
# are the literal keys used to identify columns in the CSV.
@@ -264,4 +264,4 @@ class FakeUser:
264264
intermediate = FakeIntermediate()
265265

266266
result = Command()._get_field_value(FakeUser(), "intermediate.sub_field")
267-
self.assertEqual(result, None)
267+
self.assertEqual(result, "")

0 commit comments

Comments
 (0)