Skip to content

Commit a42bae7

Browse files
committed
Batch spell-check calls to reduce subprocess overhead
1 parent a6f91a9 commit a42bae7

File tree

1 file changed

+140
-5
lines changed

1 file changed

+140
-5
lines changed

pyspelling/__init__.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Spell check with Aspell or Hunspell."""
22
import os
3+
import re
34
import importlib
45
from . import util
56
from .__meta__ import __version__, __version_info__ # noqa: F401
@@ -143,6 +144,11 @@ def _spelling_pipeline(self, sources, options, personal_dict):
143144
err = self.get_error(e)
144145
yield Results([], source.context, source.category, err)
145146

147+
def _collect_pipeline_sources(self, sources, options, personal_dict):
148+
"""Run pipeline steps and yield extracted sources without spell checking."""
149+
150+
yield from self._pipeline_step(sources, options, personal_dict)
151+
146152
def spell_check_no_pipeline(self, sources, options, personal_dict):
147153
"""Spell check without the pipeline."""
148154

@@ -357,6 +363,57 @@ def spell_check_no_pipeline(self, sources, options, personal_dict):
357363
err = self.get_error(e)
358364
yield Results([], source.context, source.category, err)
359365

366+
def _batch_spellcheck(self, extracted_sources, options, personal_dict):
367+
"""Run a single aspell call for all sources; map misspelled words back per source."""
368+
369+
error_sources = []
370+
valid = [] # list of (source, utf8_bytes)
371+
372+
for source in extracted_sources:
373+
if source._has_error():
374+
error_sources.append(source)
375+
continue
376+
if not source.text or source.text.isspace():
377+
continue
378+
encoding = source.encoding
379+
if source._is_bytes():
380+
text = source.text
381+
if encoding and not encoding.startswith('utf-8'):
382+
text = text.decode(encoding, errors='replace').encode('utf-8')
383+
else:
384+
# Python str — encode directly as UTF-8
385+
text = source.text.encode('utf-8')
386+
valid.append((source, text))
387+
388+
for source in error_sources:
389+
yield Results([], source.context, source.category, source.error)
390+
391+
if not valid:
392+
return
393+
394+
# Blank-line separator acts as a paragraph boundary in aspell list mode,
395+
# preventing words from different sources from being joined across boundaries.
396+
combined = b'\n\n'.join(text for _, text in valid)
397+
cmd = self.setup_command('utf-8', options, personal_dict)
398+
self.log("Batch command: " + str(cmd), 4)
399+
400+
try:
401+
wordlist = util.call_spellchecker(cmd, input_text=combined, encoding='utf-8')
402+
misspelled_global = {w for w in wordlist.replace('\r', '').split('\n') if w}
403+
404+
for source, text in valid:
405+
source_str = text.decode('utf-8', errors='replace')
406+
per_source = sorted(
407+
w for w in misspelled_global
408+
if re.search(r'\b' + re.escape(w) + r'\b', source_str)
409+
)
410+
yield Results(per_source, source.context, source.category)
411+
412+
except Exception as e: # pragma: no cover
413+
err = self.get_error(e)
414+
for source, _ in valid:
415+
yield Results([], source.context, source.category, err)
416+
360417
def setup_command(self, encoding, options, personal_dict, file_name=None):
361418
"""Setup the command."""
362419

@@ -484,6 +541,54 @@ def spell_check_no_pipeline(self, sources, options, personal_dict):
484541
err = self.get_error(e)
485542
yield Results([], source.context, source.category, err)
486543

544+
def _batch_spellcheck(self, extracted_sources, options, personal_dict): # pragma: no cover
545+
"""Run a single hunspell call for all sources; map misspelled words back per source."""
546+
547+
error_sources = []
548+
valid = [] # list of (source, utf8_bytes)
549+
550+
for source in extracted_sources:
551+
if source._has_error():
552+
error_sources.append(source)
553+
continue
554+
if not source.text or source.text.isspace():
555+
continue
556+
encoding = source.encoding
557+
if source._is_bytes():
558+
text = source.text
559+
if encoding and not encoding.startswith('utf-8'):
560+
text = text.decode(encoding, errors='replace').encode('utf-8')
561+
else:
562+
text = source.text.encode('utf-8')
563+
valid.append((source, text))
564+
565+
for source in error_sources:
566+
yield Results([], source.context, source.category, source.error)
567+
568+
if not valid:
569+
return
570+
571+
combined = b'\n\n'.join(text for _, text in valid)
572+
cmd = self.setup_command('utf-8', options, personal_dict)
573+
self.log("Batch command: " + str(cmd), 4)
574+
575+
try:
576+
wordlist = util.call_spellchecker(cmd, input_text=combined, encoding='utf-8')
577+
misspelled_global = {w for w in wordlist.replace('\r', '').split('\n') if w}
578+
579+
for source, text in valid:
580+
source_str = text.decode('utf-8', errors='replace')
581+
per_source = sorted(
582+
w for w in misspelled_global
583+
if re.search(r'\b' + re.escape(w) + r'\b', source_str)
584+
)
585+
yield Results(per_source, source.context, source.category)
586+
587+
except Exception as e:
588+
err = self.get_error(e)
589+
for source, _ in valid:
590+
yield Results([], source.context, source.category, err)
591+
487592
def setup_command(self, encoding, options, personal_dict, file_name=None):
488593
"""Setup command."""
489594

@@ -651,6 +756,18 @@ def multi_check(self, f):
651756

652757
return list(self.process_file(f, self.get_checker()))
653758

759+
def extract_only(self, f):
760+
"""Extract pipeline sources for a file without spell checking (for worker processes)."""
761+
762+
self.log('', 2)
763+
self.log('> Processing: %s' % f, 1)
764+
checker = self.get_checker()
765+
source = checker.get_source(f)
766+
if checker.pipeline_steps is not None:
767+
return list(checker._collect_pipeline_sources(source, self.options, self.personal_dict))
768+
else:
769+
return list(source)
770+
654771
def run_task(self, task, source_patterns=None):
655772
"""Walk source and initiate spell check."""
656773

@@ -681,18 +798,36 @@ def run_task(self, task, source_patterns=None):
681798
jobs = self.config.get('jobs', 1) if self.jobs is None else self.jobs
682799

683800
expect_match = self.task.get('expect_match', True)
801+
checker = self.get_checker()
802+
all_extracted = []
803+
684804
if jobs != 1 and jobs > 0:
685-
# Use multi-processing to process files concurrently
805+
# Use multi-processing to extract sources concurrently, then spell-check once
686806
with ProcessPoolExecutor(max_workers=jobs if jobs else None) as pool:
687-
for results in pool.map(self.multi_check, self.walk_src(source_patterns, glob_flags, glob_limit)):
807+
for extracted_list in pool.map(
808+
self.extract_only, self.walk_src(source_patterns, glob_flags, glob_limit)
809+
):
688810
self.found_match = True
689-
yield from results
811+
all_extracted.extend(extracted_list)
690812
else:
691813
# Avoid overhead of multiprocessing if we are single threaded
692-
checker = self.get_checker()
693814
for f in self.walk_src(source_patterns, glob_flags, glob_limit):
694815
self.found_match = True
695-
yield from self.process_file(f, checker)
816+
self.log('', 2)
817+
self.log('> Processing: %s' % f, 1)
818+
source = checker.get_source(f)
819+
if checker.pipeline_steps is not None:
820+
all_extracted.extend(
821+
checker._collect_pipeline_sources(source, self.options, self.personal_dict)
822+
)
823+
else:
824+
all_extracted.extend(source)
825+
826+
if all_extracted:
827+
if checker.pipeline_steps is not None:
828+
yield from checker._batch_spellcheck(all_extracted, self.options, self.personal_dict)
829+
else:
830+
yield from checker.spell_check_no_pipeline(all_extracted, self.options, self.personal_dict)
696831

697832
if not self.found_match and expect_match:
698833
raise RuntimeError(

0 commit comments

Comments
 (0)