Skip to content

Commit 1f5b86e

Browse files
committed
Add support for password authentication (password and md5)
Other authentication protocols that require frontend intervention are not supported yet.
1 parent 309e671 commit 1f5b86e

File tree

7 files changed

+314
-27
lines changed

7 files changed

+314
-27
lines changed

asyncpg/_testbase.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def assertRunUnder(self, delta):
9696
_default_cluster = None
9797

9898

99-
def _start_cluster():
99+
def _start_cluster(server_settings={}):
100100
global _default_cluster
101101

102102
if _default_cluster is None:
@@ -107,7 +107,8 @@ def _start_cluster():
107107
else:
108108
_default_cluster = pg_cluster.TempCluster()
109109
_default_cluster.init()
110-
_default_cluster.start(port=12345)
110+
_default_cluster.trust_local_connections()
111+
_default_cluster.start(port=12345, server_settings=server_settings)
111112
atexit.register(_shutdown_cluster, _default_cluster)
112113

113114
return _default_cluster
@@ -121,7 +122,9 @@ def _shutdown_cluster(cluster):
121122
class ClusterTestCase(TestCase):
122123
def setUp(self):
123124
super().setUp()
124-
self.cluster = _start_cluster()
125+
self.cluster = _start_cluster({
126+
'log_connections': 'on'
127+
})
125128

126129

127130
class ConnectedTestCase(ClusterTestCase):

asyncpg/cluster.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ def start(self, wait=60, *, server_settings={}, **opts):
131131

132132
self._test_connection(timeout=wait)
133133

134+
def reload(self):
135+
"""Reload server configuration."""
136+
status = self.get_status()
137+
if status != 'running':
138+
raise ClusterError('cannot reload: cluster is not running')
139+
140+
process = subprocess.run(
141+
[self._pg_ctl, 'reload', '-D', self._data_dir],
142+
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
143+
144+
stderr = process.stderr
145+
146+
if process.returncode != 0:
147+
raise ClusterError(
148+
'pg_ctl stop exited with status {:d}: {}'.format(
149+
process.returncode, stderr.decode()))
150+
134151
def stop(self, wait=60):
135152
process = subprocess.run(
136153
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
@@ -165,6 +182,68 @@ def get_connection_addr(self):
165182

166183
return self._connection_addr['host'], self._connection_addr['port']
167184

185+
def reset_hba(self):
186+
"""Remove all records from pg_hba.conf."""
187+
status = self.get_status()
188+
if status == 'not-initialized':
189+
raise ClusterError(
190+
'cannot modify HBA records: cluster is not initialized')
191+
192+
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
193+
194+
try:
195+
with open(pg_hba, 'w'):
196+
pass
197+
except IOError as e:
198+
raise ClusterError(
199+
'cannot modify HBA records: {}'.format(e)) from e
200+
201+
def add_hba_entry(self, *, type='host', database, user, address=None,
202+
auth_method, auth_options=None):
203+
"""Add a record to pg_hba.conf."""
204+
status = self.get_status()
205+
if status == 'not-initialized':
206+
raise ClusterError(
207+
'cannot modify HBA records: cluster is not initialized')
208+
209+
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
210+
raise ValueError('invalid HBA record type: {!r}'.format(type))
211+
212+
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
213+
214+
record = '{} {} {}'.format(type, database, user)
215+
216+
if type != 'local':
217+
if address is None:
218+
raise ValueError(
219+
'{!r} entry requires a valid address'.format(type))
220+
else:
221+
record += ' {}'.format(address)
222+
223+
record += ' {}'.format(auth_method)
224+
225+
if auth_options is not None:
226+
record += ' ' + ' '.join(
227+
'{}={}'.format(k, v) for k, v in auth_options)
228+
229+
try:
230+
with open(pg_hba, 'a') as f:
231+
print(record, file=f)
232+
except IOError as e:
233+
raise ClusterError(
234+
'cannot modify HBA records: {}'.format(e)) from e
235+
236+
def trust_local_connections(self):
237+
self.reset_hba()
238+
self.add_hba_entry(type='local', database='all',
239+
user='all', auth_method='trust')
240+
self.add_hba_entry(type='host', address='127.0.0.1/32',
241+
database='all', user='all',
242+
auth_method='trust')
243+
status = self.get_status()
244+
if status == 'running':
245+
self.reload()
246+
168247
def _init_env(self):
169248
self._pg_config = self._find_pg_config(self._pg_config_path)
170249
self._pg_config_data = self._run_pg_config(self._pg_config)

asyncpg/protocol/buffer.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ cdef class ReadBuffer:
100100
cdef inline read_byte(self)
101101
cdef inline char* _try_read_bytes(self, int nbytes)
102102
cdef inline read(self, int nbytes)
103+
cdef inline read_bytes(self, ssize_t n)
103104
cdef inline read_int32(self)
104105
cdef inline read_int16(self)
105106
cdef inline read_cstr(self)

asyncpg/protocol/buffer.pyx

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,6 @@ cdef class ReadBuffer:
293293
raise RuntimeError(
294294
'debug: second buffer of ReadBuffer is empty')
295295

296-
cdef inline read_byte(self):
297-
cdef char* first_byte
298-
299-
IF DEBUG:
300-
if not self._buf0:
301-
raise RuntimeError(
302-
'debug: first buffer of ReadBuffer is empty')
303-
304-
self._ensure_first_buf()
305-
first_byte = self._try_read_bytes(1)
306-
if first_byte is NULL:
307-
raise BufferError('not enough data to read one byte')
308-
309-
return first_byte[0]
310-
311296
cdef inline char* _try_read_bytes(self, int nbytes):
312297
# Important: caller must call _ensure_first_buf() prior
313298
# to calling try_read_bytes, and must not overread
@@ -373,6 +358,34 @@ cdef class ReadBuffer:
373358
result,
374359
len(result))
375360

361+
cdef inline read_byte(self):
362+
cdef char* first_byte
363+
364+
IF DEBUG:
365+
if not self._buf0:
366+
raise RuntimeError(
367+
'debug: first buffer of ReadBuffer is empty')
368+
369+
self._ensure_first_buf()
370+
first_byte = self._try_read_bytes(1)
371+
if first_byte is NULL:
372+
raise BufferError('not enough data to read one byte')
373+
374+
return first_byte[0]
375+
376+
cdef inline read_bytes(self, ssize_t n):
377+
cdef:
378+
Memory mem
379+
char *cbuf
380+
381+
self._ensure_first_buf()
382+
cbuf = self._try_read_bytes(n)
383+
if cbuf != NULL:
384+
return cbuf
385+
else:
386+
mem = <Memory>(self.read(n))
387+
return mem.buf
388+
376389
cdef inline read_int32(self):
377390
cdef:
378391
Memory mem
@@ -508,7 +521,10 @@ cdef class ReadBuffer:
508521
cdef Memory consume_message(self):
509522
if not self._current_message_ready:
510523
raise BufferError('no message to consume')
511-
mem = self.read(self._current_message_len_unread)
524+
if self._current_message_len_unread > 0:
525+
mem = self.read(self._current_message_len_unread)
526+
else:
527+
mem = None
512528
self._discard_message()
513529
return mem
514530

asyncpg/protocol/coreproto.pxd

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@ cdef enum ProtocolState:
2727
PROTOCOL_BIND = 16
2828

2929

30+
cdef enum AuthenticationMessage:
31+
AUTH_SUCCESSFUL = 0
32+
AUTH_REQUIRED_KERBEROS = 2
33+
AUTH_REQUIRED_PASSWORD = 3
34+
AUTH_REQUIRED_PASSWORDMD5 = 5
35+
AUTH_REQUIRED_SCMCRED = 6
36+
AUTH_REQUIRED_GSS = 7
37+
AUTH_REQUIRED_GSS_CONTINUE = 8
38+
AUTH_REQUIRED_SSPI = 9
39+
40+
41+
AUTH_METHOD_NAME = {
42+
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
43+
AUTH_REQUIRED_PASSWORD: 'password',
44+
AUTH_REQUIRED_PASSWORDMD5: 'md5',
45+
AUTH_REQUIRED_GSS: 'gss',
46+
AUTH_REQUIRED_SSPI: 'sspi',
47+
}
48+
49+
3050
cdef enum ResultType:
3151
RESULT_OK = 1
3252
RESULT_FAILED = 2
@@ -87,6 +107,9 @@ cdef class CoreProtocol:
87107
cdef _parse_msg_error_response(self, is_error)
88108
cdef _parse_msg_command_complete(self)
89109

110+
cdef _auth_password_message_cleartext(self)
111+
cdef _auth_password_message_md5(self, bytes salt)
112+
90113
cdef _write(self, buf)
91114
cdef inline _write_sync_message(self)
92115

asyncpg/protocol/coreproto.pyx

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
from hashlib import md5 as hashlib_md5 # for MD5 authentication
9+
10+
811
cdef class CoreProtocol:
912

1013
def __init__(self, con_args):
1114
self.buffer = ReadBuffer()
15+
self.user = con_args.get('user')
16+
self.password = con_args.pop('password', None)
17+
self.auth_msg = None
1218
self.con_args = con_args
1319
self.transport = None
1420
self.con_status = CONNECTION_BAD
@@ -106,6 +112,11 @@ cdef class CoreProtocol:
106112
self._push_result()
107113
self.transport.close()
108114

115+
elif self.auth_msg is not None:
116+
# Server wants us to send auth data, so do that.
117+
self._write(self.auth_msg)
118+
self.auth_msg = None
119+
109120
elif mtype == b'K':
110121
# BackendKeyData
111122
self._parse_msg_backend_key_data()
@@ -294,18 +305,69 @@ cdef class CoreProtocol:
294305
self._set_server_parameter(name, val)
295306

296307
cdef _parse_msg_authentication(self):
297-
cdef int status
308+
cdef:
309+
int32_t status
310+
bytes md5_salt
311+
298312
status = self.buffer.read_int32()
299-
if status != 0:
313+
314+
if status == AUTH_SUCCESSFUL:
315+
# AuthenticationOk
316+
self.result_type = RESULT_OK
317+
318+
elif status == AUTH_REQUIRED_PASSWORD:
319+
# AuthenticationCleartextPassword
320+
self.result_type = RESULT_OK
321+
self.auth_msg = self._auth_password_message_cleartext()
322+
323+
elif status == AUTH_REQUIRED_PASSWORDMD5:
324+
# AuthenticationMD5Password
325+
# Note: MD5 salt is passed as a four-byte sequence
326+
md5_salt = cpython.PyBytes_FromStringAndSize(
327+
self.buffer.read_bytes(4), 4)
328+
self.auth_msg = self._auth_password_message_md5(md5_salt)
329+
330+
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
331+
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
332+
AUTH_REQUIRED_SSPI):
300333
self.result_type = RESULT_FAILED
301334
self.result = apg_exc.InterfaceError(
302-
'unsupported status {} for Authentication (R) '
303-
'message'.format(status))
335+
'unsupported authentication method requested by the '
336+
'server: {!r}'.format(AUTH_METHOD_NAME[status]))
337+
304338
else:
305-
# 0 == AuthenticationOk
306-
self.result_type = RESULT_OK
339+
self.result_type = RESULT_FAILED
340+
self.result = apg_exc.InterfaceError(
341+
'unsupported authentication method requested by the '
342+
'server: {}'.format(status))
343+
307344
self.buffer.consume_message()
308345

346+
cdef _auth_password_message_cleartext(self):
347+
cdef:
348+
WriteBuffer msg
349+
350+
msg = WriteBuffer.new_message(b'p')
351+
msg.write_bytestring(self.password.encode('ascii'))
352+
msg.end_message()
353+
354+
return msg
355+
356+
cdef _auth_password_message_md5(self, bytes salt):
357+
cdef:
358+
WriteBuffer msg
359+
360+
msg = WriteBuffer.new_message(b'p')
361+
362+
# 'md5' + md5(md5(password + username) + salt))
363+
userpass = ((self.password or '') + (self.user or '')).encode('ascii')
364+
hash = hashlib_md5(hashlib_md5(userpass).hexdigest().\
365+
encode('ascii') + salt).hexdigest().encode('ascii')
366+
367+
msg.write_bytestring(b'md5' + hash)
368+
msg.end_message()
369+
370+
return msg
309371

310372
cdef _parse_msg_ready_for_query(self):
311373
cdef char status = self.buffer.read_byte()

0 commit comments

Comments
 (0)