Skip to content

Commit d4ab33c

Browse files
committed
Allow IPv6 outbound and prioritizes IPv6 if prefer_ipv6 is set
1 parent 1a4329c commit d4ab33c

2 files changed

Lines changed: 115 additions & 35 deletions

File tree

test/test_ipv6.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
from unittest.mock import patch
3+
import socket
4+
5+
from whois.whois import NICClient
6+
7+
8+
class TestNICClientIPv6(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.ipv4_info = (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('1.2.3.4', 43))
12+
self.ipv6_info = (socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('2001:db8::1', 43, 0, 0))
13+
self.mock_addr_info = [self.ipv4_info, self.ipv6_info]
14+
15+
@patch('socket.getaddrinfo')
16+
@patch('socket.socket')
17+
def test_connect_prioritizes_ipv6(self, mock_socket, mock_getaddrinfo):
18+
mock_getaddrinfo.return_value = self.mock_addr_info
19+
20+
client = NICClient(prefer_ipv6=True)
21+
try:
22+
client._connect("example.com", timeout=10)
23+
except Exception:
24+
pass
25+
26+
first_call_args = mock_socket.call_args_list[0][0]
27+
# Make sure we used IPv6 when creating socket
28+
self.assertEqual(first_call_args[0], socket.AF_INET6)
29+
30+
@patch('socket.getaddrinfo')
31+
@patch('socket.socket')
32+
def test_connect_keeps_default_order(self, mock_socket, mock_getaddrinfo):
33+
mock_getaddrinfo.return_value = self.mock_addr_info
34+
35+
client = NICClient(prefer_ipv6=False)
36+
try:
37+
client._connect("example.com", timeout=10)
38+
except Exception:
39+
pass
40+
41+
first_call_args = mock_socket.call_args_list[0][0]
42+
# Make sure we used IPv4 when creating socket, which is the first appearing in our mock.
43+
self.assertEqual(first_call_args[0], socket.AF_INET)

whois/whois.py

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ class NICClient:
123123

124124
ip_whois: list[str] = [LNICHOST, RNICHOST, PNICHOST, BNICHOST, PANDIHOST]
125125

126-
def __init__(self):
126+
def __init__(self, prefer_ipv6: bool = False):
127127
self.use_qnichost: bool = False
128+
self.prefer_ipv6 = prefer_ipv6
128129

129130
@staticmethod
130131
def findwhois_server(buf: str, hostname: str, query: str) -> Optional[str]:
@@ -150,40 +151,69 @@ def findwhois_server(buf: str, hostname: str, query: str) -> Optional[str]:
150151
return nhost
151152

152153
@staticmethod
153-
def get_socket():
154-
if "SOCKS" in os.environ:
155-
try:
156-
import socks
157-
except ImportError as e:
158-
logger.error(
159-
"You need to install the Python socks module. Install PIP "
160-
"(https://bootstrap.pypa.io/get-pip.py) and then 'pip install PySocks'"
161-
)
162-
raise e
163-
socks_user, socks_password = None, None
164-
if "@" in os.environ["SOCKS"]:
165-
creds, proxy = os.environ["SOCKS"].split("@")
166-
socks_user, socks_password = creds.split(":")
167-
else:
168-
proxy = os.environ["SOCKS"]
169-
socksproxy, port = proxy.split(":")
170-
socks_proto = socket.AF_INET
171-
if socket.AF_INET6 in [
172-
sock[0] for sock in socket.getaddrinfo(socksproxy, port)
173-
]:
174-
socks_proto = socket.AF_INET6
175-
s = socks.socksocket(socks_proto)
176-
s.set_proxy(
177-
socks.SOCKS5, socksproxy, int(port), True, socks_user, socks_password
154+
def get_socks_socket():
155+
try:
156+
import socks
157+
except ImportError as e:
158+
logger.error(
159+
"You need to install the Python socks module. Install PIP "
160+
"(https://bootstrap.pypa.io/get-pip.py) and then 'pip install PySocks'"
178161
)
162+
raise e
163+
socks_user, socks_password = None, None
164+
if "@" in os.environ["SOCKS"]:
165+
creds, proxy = os.environ["SOCKS"].split("@")
166+
socks_user, socks_password = creds.split(":")
179167
else:
180-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
168+
proxy = os.environ["SOCKS"]
169+
socksproxy, port = proxy.split(":")
170+
socks_proto = socket.AF_INET
171+
if socket.AF_INET6 in [
172+
sock[0] for sock in socket.getaddrinfo(socksproxy, port)
173+
]:
174+
socks_proto = socket.AF_INET6
175+
s = socks.socksocket(socks_proto)
176+
s.set_proxy(
177+
socks.SOCKS5, socksproxy, int(port), True, socks_user, socks_password
178+
)
181179
return s
182180

181+
def _connect(self, hostname: str, timeout: int) -> socket.socket:
182+
"""Resolve WHOIS IP address and connect to its TCP 43 port."""
183+
port = 43
184+
185+
if "SOCKS" in os.environ:
186+
s = NICClient.get_socks_socket()
187+
s.settimeout(timeout)
188+
s.connect((hostname, port))
189+
return s
190+
191+
# Resolve all IP addresses for the WHOIS server
192+
addr_infos = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
193+
194+
if self.prefer_ipv6:
195+
# Sort by family to prioritize AF_INET6 (10) over AF_INET (2)
196+
addr_infos.sort(key=lambda x: x[0], reverse=True)
197+
198+
last_err = None
199+
# Attempt to connect to each related IP address until one works
200+
for family, sock_type, proto, __, sockaddr in addr_infos:
201+
s = None
202+
try:
203+
s = socket.socket(family, sock_type, proto)
204+
s.settimeout(timeout)
205+
s.connect(sockaddr)
206+
return s
207+
except socket.error as e:
208+
last_err = e
209+
if s:
210+
s.close()
211+
continue
212+
213+
raise last_err or socket.error(f"Could not connect to {hostname}")
214+
183215
def findwhois_iana(self, tld: str, timeout: int = 10) -> Optional[str]:
184-
s = self.get_socket()
185-
s.settimeout(timeout)
186-
s.connect(("whois.iana.org", 43))
216+
s = self._connect("whois.iana.org", timeout)
187217
s.send(bytes(tld, "utf-8") + b"\r\n")
188218
response = b""
189219
while True:
@@ -219,11 +249,10 @@ def whois(
219249
a string containing the error.
220250
"""
221251
response = b""
222-
s = self.get_socket()
223-
s.settimeout(timeout)
252+
s = None
224253
try: # socket.connect in a try, in order to allow things like looping whois on different domains without
225254
# stopping on timeouts: https://stackoverflow.com/questions/25447803/python-socket-connection-exception
226-
s.connect((hostname, 43))
255+
s = self._connect(hostname, timeout)
227256
if hostname == NICClient.DENICHOST:
228257
query_bytes = "-T dn,ace -C UTF-8 " + query
229258
elif hostname == NICClient.DK_HOST:
@@ -261,7 +290,8 @@ def whois(
261290
else:
262291
raise e
263292
finally:
264-
s.close()
293+
if s:
294+
s.close()
265295
return response_str
266296

267297
def choose_server(self, domain: str, timeout: int = 10) -> Optional[str]:
@@ -567,6 +597,13 @@ def parse_command_line(argv: list[str]) -> tuple[optparse.Values, list[str]]:
567597
dest="port",
568598
help="Lookup using specified tcp port",
569599
)
600+
parser.add_option(
601+
"--prefer-ipv6",
602+
action="store_true",
603+
dest="prefer_ipv6",
604+
default=False,
605+
help="Prioritize IPv6 resolution for WHOIS servers",
606+
)
570607
parser.add_option(
571608
"-Q",
572609
"--quick",
@@ -621,8 +658,8 @@ def parse_command_line(argv: list[str]) -> tuple[optparse.Values, list[str]]:
621658

622659
if __name__ == "__main__":
623660
flags = 0
624-
nic_client = NICClient()
625661
options, args = parse_command_line(sys.argv)
662+
nic_client = NICClient(prefer_ipv6=options.prefer_ipv6)
626663
if options.b_quicklookup:
627664
flags = flags | NICClient.WHOIS_QUICK
628665
logger.debug(nic_client.whois_lookup(options.__dict__, args[1], flags))

0 commit comments

Comments
 (0)