Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 140 additions & 5 deletions pyspelling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Spell check with Aspell or Hunspell."""
import os
import re
import importlib
from . import util
from .__meta__ import __version__, __version_info__ # noqa: F401
Expand Down Expand Up @@ -143,6 +144,11 @@ def _spelling_pipeline(self, sources, options, personal_dict):
err = self.get_error(e)
yield Results([], source.context, source.category, err)

def _collect_pipeline_sources(self, sources, options, personal_dict):
"""Run pipeline steps and yield extracted sources without spell checking."""

yield from self._pipeline_step(sources, options, personal_dict)

def spell_check_no_pipeline(self, sources, options, personal_dict):
"""Spell check without the pipeline."""

Expand Down Expand Up @@ -357,6 +363,57 @@ def spell_check_no_pipeline(self, sources, options, personal_dict):
err = self.get_error(e)
yield Results([], source.context, source.category, err)

def _batch_spellcheck(self, extracted_sources, options, personal_dict):
"""Run a single aspell call for all sources; map misspelled words back per source."""

error_sources = []
valid = [] # list of (source, utf8_bytes)

for source in extracted_sources:
if source._has_error():
error_sources.append(source)
continue
if not source.text or source.text.isspace():
continue
encoding = source.encoding
if source._is_bytes():
text = source.text
if encoding and not encoding.startswith('utf-8'):
text = text.decode(encoding, errors='replace').encode('utf-8')
else:
# Python str — encode directly as UTF-8
text = source.text.encode('utf-8')
valid.append((source, text))

for source in error_sources:
yield Results([], source.context, source.category, source.error)

if not valid:
return

# Blank-line separator acts as a paragraph boundary in aspell list mode,
# preventing words from different sources from being joined across boundaries.
combined = b'\n\n'.join(text for _, text in valid)
cmd = self.setup_command('utf-8', options, personal_dict)
self.log("Batch command: " + str(cmd), 4)

try:
wordlist = util.call_spellchecker(cmd, input_text=combined, encoding='utf-8')
misspelled_global = {w for w in wordlist.replace('\r', '').split('\n') if w}

for source, text in valid:
source_str = text.decode('utf-8', errors='replace')
per_source = sorted(
w for w in misspelled_global
if re.search(r'\b' + re.escape(w) + r'\b', source_str)
)
yield Results(per_source, source.context, source.category)

except Exception as e: # pragma: no cover
err = self.get_error(e)
for source, _ in valid:
yield Results([], source.context, source.category, err)

def setup_command(self, encoding, options, personal_dict, file_name=None):
"""Setup the command."""

Expand Down Expand Up @@ -484,6 +541,54 @@ def spell_check_no_pipeline(self, sources, options, personal_dict):
err = self.get_error(e)
yield Results([], source.context, source.category, err)

def _batch_spellcheck(self, extracted_sources, options, personal_dict): # pragma: no cover
"""Run a single hunspell call for all sources; map misspelled words back per source."""

error_sources = []
valid = [] # list of (source, utf8_bytes)

for source in extracted_sources:
if source._has_error():
error_sources.append(source)
continue
if not source.text or source.text.isspace():
continue
encoding = source.encoding
if source._is_bytes():
text = source.text
if encoding and not encoding.startswith('utf-8'):
text = text.decode(encoding, errors='replace').encode('utf-8')
else:
text = source.text.encode('utf-8')
valid.append((source, text))

for source in error_sources:
yield Results([], source.context, source.category, source.error)

if not valid:
return

combined = b'\n\n'.join(text for _, text in valid)
cmd = self.setup_command('utf-8', options, personal_dict)
self.log("Batch command: " + str(cmd), 4)

try:
wordlist = util.call_spellchecker(cmd, input_text=combined, encoding='utf-8')
misspelled_global = {w for w in wordlist.replace('\r', '').split('\n') if w}

for source, text in valid:
source_str = text.decode('utf-8', errors='replace')
per_source = sorted(
w for w in misspelled_global
if re.search(r'\b' + re.escape(w) + r'\b', source_str)
)
yield Results(per_source, source.context, source.category)

except Exception as e:
err = self.get_error(e)
for source, _ in valid:
yield Results([], source.context, source.category, err)

def setup_command(self, encoding, options, personal_dict, file_name=None):
"""Setup command."""

Expand Down Expand Up @@ -651,6 +756,18 @@ def multi_check(self, f):

return list(self.process_file(f, self.get_checker()))

def extract_only(self, f):
"""Extract pipeline sources for a file without spell checking (for worker processes)."""

self.log('', 2)
self.log('> Processing: %s' % f, 1)
checker = self.get_checker()
source = checker.get_source(f)
if checker.pipeline_steps is not None:
return list(checker._collect_pipeline_sources(source, self.options, self.personal_dict))
else:
return list(source)

def run_task(self, task, source_patterns=None):
"""Walk source and initiate spell check."""

Expand Down Expand Up @@ -681,18 +798,36 @@ def run_task(self, task, source_patterns=None):
jobs = self.config.get('jobs', 1) if self.jobs is None else self.jobs

expect_match = self.task.get('expect_match', True)
checker = self.get_checker()
all_extracted = []

if jobs != 1 and jobs > 0:
# Use multi-processing to process files concurrently
# Use multi-processing to extract sources concurrently, then spell-check once
with ProcessPoolExecutor(max_workers=jobs if jobs else None) as pool:
for results in pool.map(self.multi_check, self.walk_src(source_patterns, glob_flags, glob_limit)):
for extracted_list in pool.map(
self.extract_only, self.walk_src(source_patterns, glob_flags, glob_limit)
):
self.found_match = True
yield from results
all_extracted.extend(extracted_list)
else:
# Avoid overhead of multiprocessing if we are single threaded
checker = self.get_checker()
for f in self.walk_src(source_patterns, glob_flags, glob_limit):
self.found_match = True
yield from self.process_file(f, checker)
self.log('', 2)
self.log('> Processing: %s' % f, 1)
source = checker.get_source(f)
if checker.pipeline_steps is not None:
all_extracted.extend(
checker._collect_pipeline_sources(source, self.options, self.personal_dict)
)
else:
all_extracted.extend(source)

if all_extracted:
if checker.pipeline_steps is not None:
yield from checker._batch_spellcheck(all_extracted, self.options, self.personal_dict)
else:
yield from checker.spell_check_no_pipeline(all_extracted, self.options, self.personal_dict)

if not self.found_match and expect_match:
raise RuntimeError(
Expand Down
Loading