6060"""
6161
6262from abc import ABC , abstractmethod
63+ import functools
6364import re
6465builtin_any = any
6566
@@ -137,15 +138,46 @@ def matches(self, arg):
137138 return True
138139
139140 def __repr__ (self ):
140- return "<Any: %s>" % self .wanted_type
141+ return "<Any: %s>" % _any_wanted_type_label (self .wanted_type )
142+
143+
144+ def _any_wanted_type_label (wanted_type ):
145+ if isinstance (wanted_type , type ):
146+ return _type_label (wanted_type )
147+
148+ if (
149+ isinstance (wanted_type , tuple )
150+ and all (isinstance (t , type ) for t in wanted_type )
151+ ):
152+ items = [_type_label (t ) for t in wanted_type ]
153+ if len (items ) == 1 :
154+ return '(%s,)' % items [0 ]
155+ return '(%s)' % ', ' .join (items )
156+
157+ return _safe_repr (wanted_type )
158+
159+
160+ def _type_label (type_ ):
161+ module = _safe_getattr (type_ , '__module__' )
162+ qualname = _safe_getattr (type_ , '__qualname__' ) or _safe_getattr (type_ , '__name__' )
163+ if qualname is None :
164+ return _safe_repr (type_ )
165+
166+ if module is None or module == 'builtins' :
167+ return qualname
168+
169+ return '%s.%s' % (module , qualname )
141170
142171
143172class ValueMatcher (Matcher ):
144173 def __init__ (self , value ):
145174 self .value = value
146175
147176 def __repr__ (self ):
148- return "<%s: %s>" % (self .__class__ .__name__ , self .value )
177+ return "<%s: %s>" % (
178+ self .__class__ .__name__ ,
179+ _safe_repr (self .value ),
180+ )
149181
150182
151183class Eq (ValueMatcher ):
@@ -223,7 +255,93 @@ def matches(self, arg):
223255 return self .predicate (arg )
224256
225257 def __repr__ (self ):
226- return "<ArgThat>"
258+ return "<ArgThat: %s>" % _arg_that_predicate_label (self .predicate )
259+
260+
261+ def _arg_that_predicate_label (predicate ):
262+ try :
263+ return _arg_that_predicate_label_unchecked (predicate )
264+ except Exception :
265+ predicate_class = _safe_getattr (
266+ _safe_getattr (predicate , '__class__' ),
267+ '__name__' ,
268+ )
269+ if predicate_class is None :
270+ return 'callable'
271+
272+ return 'callable %s' % predicate_class
273+
274+
275+ def _arg_that_predicate_label_unchecked (predicate ):
276+ if isinstance (predicate , functools .partial ):
277+ return _arg_that_partial_label (predicate )
278+
279+ function_line = _line_of_callable (predicate )
280+ function_name = _safe_getattr (predicate , '__name__' )
281+ if function_name is not None :
282+ if function_name == '<lambda>' :
283+ return _label_with_line ('lambda' , function_line )
284+ return _label_with_line ('def %s' % function_name , function_line )
285+
286+ predicate_class = _safe_getattr (
287+ _safe_getattr (predicate , '__class__' ),
288+ '__name__' ,
289+ )
290+ if predicate_class is None :
291+ predicate_class = 'object'
292+
293+ call = _safe_getattr (predicate , '__call__' )
294+ call_line = _line_of_callable (call )
295+ return _label_with_line (
296+ 'callable %s.__call__' % predicate_class ,
297+ call_line ,
298+ )
299+
300+
301+ def _arg_that_partial_label (predicate ):
302+ partial_func = _safe_getattr (predicate , 'func' )
303+ partial_name = _safe_getattr (partial_func , '__name__' )
304+
305+ if partial_name is not None :
306+ return 'partial %s' % partial_name
307+
308+ return 'partial'
309+
310+
311+ def _line_of_callable (value ):
312+ if value is None :
313+ return None
314+
315+ func = _safe_getattr (value , '__func__' , value )
316+ code = _safe_getattr (func , '__code__' )
317+ if code is None :
318+ return None
319+
320+ return _safe_getattr (code , 'co_firstlineno' )
321+
322+
323+ def _safe_getattr (value , name , default = None ):
324+ try :
325+ return getattr (value , name )
326+ except Exception :
327+ return default
328+
329+
330+ def _safe_repr (value ):
331+ try :
332+ return repr (value )
333+ except Exception :
334+ try :
335+ return object .__repr__ (value )
336+ except Exception :
337+ return '<unrepresentable>'
338+
339+
340+ def _label_with_line (label , line_number ):
341+ if line_number is None :
342+ return label
343+
344+ return '%s at line %s' % (label , line_number )
227345
228346
229347class Contains (Matcher ):
@@ -236,24 +354,41 @@ def matches(self, arg):
236354 return self .sub and len (self .sub ) > 0 and arg .find (self .sub ) > - 1
237355
238356 def __repr__ (self ):
239- return "<Contains: '%s' >" % self .sub
357+ return "<Contains: %s >" % _safe_repr ( self .sub )
240358
241359
242360class Matches (Matcher ):
243361 def __init__ (self , regex , flags = 0 ):
244362 self .regex = re .compile (regex , flags )
363+ self .flags = _explicit_regex_flags (regex , flags )
245364
246365 def matches (self , arg ):
247366 if not isinstance (arg , str ):
248367 return
249368 return self .regex .match (arg ) is not None
250369
251370 def __repr__ (self ):
252- if self .regex .flags :
253- return "<Matches: %s flags=%d>" % (self .regex .pattern ,
254- self .regex .flags )
371+ if self .flags :
372+ return "<Matches: %r flags=%d>" % (self .regex .pattern , self .flags )
255373 else :
256- return "<Matches: %s>" % self .regex .pattern
374+ return "<Matches: %r>" % self .regex .pattern
375+
376+
377+ def _explicit_regex_flags (regex , flags ):
378+ if flags :
379+ return flags
380+
381+ compiled_flags = _safe_getattr (regex , 'flags' )
382+ pattern = _safe_getattr (regex , 'pattern' )
383+ if compiled_flags is None or pattern is None :
384+ return 0
385+
386+ try :
387+ baseline_flags = re .compile (pattern ).flags
388+ except Exception :
389+ return compiled_flags
390+
391+ return compiled_flags & ~ baseline_flags
257392
258393
259394class ArgumentCaptor (Matcher , Capturing ):
0 commit comments