From d7692569972a31db763b0021e45b9a6798728044 Mon Sep 17 00:00:00 2001 From: Robert Schütz Date: Tue, 3 Jul 2018 00:21:16 +0200 Subject: python.pkgs.asyncssh: fix tests (#42848) --- .../python-modules/asyncssh/default.nix | 3 + .../python-modules/asyncssh/mock_getnameinfo.patch | 159 +++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 pkgs/development/python-modules/asyncssh/mock_getnameinfo.patch (limited to 'pkgs') diff --git a/pkgs/development/python-modules/asyncssh/default.nix b/pkgs/development/python-modules/asyncssh/default.nix index ef21cb001b49..76b863e24fb3 100644 --- a/pkgs/development/python-modules/asyncssh/default.nix +++ b/pkgs/development/python-modules/asyncssh/default.nix @@ -12,6 +12,9 @@ buildPythonPackage rec { sha256 = "a44736830741e2bb9c4e3992819288b77ac4af217a46d12f415bb57c18ed9c22"; }; + # See https://github.com/ronf/asyncssh/commit/6a92930e00f3bbc67d9cdf66816917e64f2450e9#r29559647 + patches = ./mock_getnameinfo.patch; + propagatedBuildInputs = [ bcrypt cryptography diff --git a/pkgs/development/python-modules/asyncssh/mock_getnameinfo.patch b/pkgs/development/python-modules/asyncssh/mock_getnameinfo.patch new file mode 100644 index 000000000000..5e8f475445e7 --- /dev/null +++ b/pkgs/development/python-modules/asyncssh/mock_getnameinfo.patch @@ -0,0 +1,159 @@ +diff --git a/tests/server.py b/tests/server.py +index fbd7e37..e7542dc 100644 +--- a/tests/server.py ++++ b/tests/server.py +@@ -208,12 +208,10 @@ class ServerTestCase(AsyncTestCase): + cls._server = yield from cls.start_server() + + sock = cls._server.sockets[0] +- cls._client_host, _ = yield from cls.loop.getnameinfo(('127.0.0.1', 0)) + cls._server_addr = '127.0.0.1' + cls._server_port = sock.getsockname()[1] + +- host = '[%s]:%d,%s ' % (cls._server_addr, cls._server_port, +- cls._client_host) ++ host = '[%s]:%d,localhost ' % (cls._server_addr, cls._server_port) + + with open('known_hosts', 'w') as known_hosts: + known_hosts.write(host) +diff --git a/tests/test_auth_keys.py b/tests/test_auth_keys.py +index 1d625ef..72a49f7 100644 +--- a/tests/test_auth_keys.py ++++ b/tests/test_auth_keys.py +@@ -13,13 +13,13 @@ + """Unit tests for matching against authorized_keys file""" + + import unittest +-from unittest.mock import patch + + import asyncssh + +-from .util import TempDirTestCase, x509_available ++from .util import TempDirTestCase, patch_getnameinfo, x509_available + + ++@patch_getnameinfo + class _TestAuthorizedKeys(TempDirTestCase): + """Unit tests for auth_keys module""" + +@@ -69,36 +69,22 @@ class _TestAuthorizedKeys(TempDirTestCase): + def match_keys(self, tests, x509=False): + """Match against authorized keys""" + +- def getnameinfo(sockaddr, flags): +- """Mock reverse DNS lookup of client address""" +- +- # pylint: disable=unused-argument +- +- host, port = sockaddr +- +- if host == '127.0.0.1': +- return ('localhost', port) +- else: +- return sockaddr +- +- with patch('socket.getnameinfo', getnameinfo): +- for keys, matches in tests: +- auth_keys = self.build_keys(keys, x509) +- for (msg, keynum, client_addr, +- cert_principals, match) in matches: +- with self.subTest(msg, x509=x509): +- if x509: +- result, trusted_cert = auth_keys.validate_x509( +- self.imported_certlist[keynum], client_addr) +- if (trusted_cert and trusted_cert.subject != +- self.imported_certlist[keynum].subject): +- result = None +- else: +- result = auth_keys.validate( +- self.imported_keylist[keynum], client_addr, +- cert_principals, keynum == 1) +- +- self.assertEqual(result is not None, match) ++ for keys, matches in tests: ++ auth_keys = self.build_keys(keys, x509) ++ for (msg, keynum, client_addr, cert_principals, match) in matches: ++ with self.subTest(msg, x509=x509): ++ if x509: ++ result, trusted_cert = auth_keys.validate_x509( ++ self.imported_certlist[keynum], client_addr) ++ if (trusted_cert and trusted_cert.subject != ++ self.imported_certlist[keynum].subject): ++ result = None ++ else: ++ result = auth_keys.validate( ++ self.imported_keylist[keynum], client_addr, ++ cert_principals, keynum == 1) ++ ++ self.assertEqual(result is not None, match) + + def test_matches(self): + """Test authorized keys matching""" +diff --git a/tests/test_connection_auth.py b/tests/test_connection_auth.py +index 3da8a5b..ff3e3cc 100644 +--- a/tests/test_connection_auth.py ++++ b/tests/test_connection_auth.py +@@ -23,8 +23,8 @@ from asyncssh.packet import String + from asyncssh.public_key import CERT_TYPE_USER, CERT_TYPE_HOST + + from .server import Server, ServerTestCase +-from .util import asynctest, gss_available, patch_gss, make_certificate +-from .util import x509_available ++from .util import asynctest, gss_available, patch_getnameinfo, patch_gss ++from .util import make_certificate, x509_available + + + class _FailValidateHostSSHServerConnection(asyncssh.SSHServerConnection): +@@ -455,6 +455,7 @@ class _TestGSSFQDN(ServerTestCase): + yield from conn.wait_closed() + + ++@patch_getnameinfo + class _TestHostBasedAuth(ServerTestCase): + """Unit tests for host-based authentication""" + +@@ -579,7 +580,7 @@ class _TestHostBasedAuth(ServerTestCase): + """Test stripping of trailing dot from client host""" + + with (yield from self.connect(username='user', client_host_keys='skey', +- client_host=self._client_host + '.', ++ client_host='localhost.', + client_username='user')) as conn: + pass + +@@ -667,6 +668,7 @@ class _TestHostBasedAsyncServerAuth(_TestHostBasedAuth): + client_username='user') + + ++@patch_getnameinfo + class _TestLimitedHostBasedSignatureAlgs(ServerTestCase): + """Unit tests for limited host key signature algorithms""" + +diff --git a/tests/util.py b/tests/util.py +index 42bb596..4d92ec3 100644 +--- a/tests/util.py ++++ b/tests/util.py +@@ -84,6 +84,24 @@ def asynctest35(func): + return async_wrapper + + ++def patch_getnameinfo(cls): ++ """Decorator for patching socket.getnameinfo""" ++ ++ def getnameinfo(sockaddr, flags): ++ """Mock reverse DNS lookup of client address""" ++ ++ # pylint: disable=unused-argument ++ ++ host, port = sockaddr ++ ++ if host == '127.0.0.1': ++ return ('localhost', port) ++ else: ++ return sockaddr ++ ++ return patch('socket.getnameinfo', getnameinfo)(cls) ++ ++ + def patch_gss(cls): + """Decorator for patching GSSAPI classes""" -- cgit 1.4.1