Linux lhjmq-records 5.15.0-118-generic #128-Ubuntu SMP Fri Jul 5 09:28:59 UTC 2024 x86_64
Your IP : 18.191.195.180
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.cred}, now with 30% more starch.
"""
from binascii import hexlify, unhexlify
from zope.interface import Interface, implementer
from twisted.cred import checkers, credentials, error, portal
from twisted.internet import defer
from twisted.python import components
from twisted.python.versions import Version
from twisted.trial import unittest
try:
from crypt import crypt as _crypt
except ImportError:
crypt = None
else:
crypt = _crypt
# The Twisted version in which UsernameHashedPassword is first deprecated.
_uhpVersion = Version("Twisted", 21, 2, 0)
class ITestable(Interface):
"""
An interface for a theoretical protocol.
"""
pass
class TestAvatar:
"""
A test avatar.
"""
def __init__(self, name):
self.name = name
self.loggedIn = False
self.loggedOut = False
def login(self):
assert not self.loggedIn
self.loggedIn = True
def logout(self):
self.loggedOut = True
@implementer(ITestable)
class Testable(components.Adapter):
"""
A theoretical protocol for testing.
"""
pass
components.registerAdapter(Testable, TestAvatar, ITestable)
class IDerivedCredentials(credentials.IUsernamePassword):
pass
@implementer(IDerivedCredentials, ITestable)
class DerivedCredentials:
def __init__(self, username, password):
self.username = username
self.password = password
def checkPassword(self, password):
return password == self.password
@implementer(portal.IRealm)
class TestRealm:
"""
A basic test realm.
"""
def __init__(self):
self.avatars = {}
def requestAvatar(self, avatarId, mind, *interfaces):
if avatarId in self.avatars:
avatar = self.avatars[avatarId]
else:
avatar = TestAvatar(avatarId)
self.avatars[avatarId] = avatar
avatar.login()
return (interfaces[0], interfaces[0](avatar), avatar.logout)
class CredTests(unittest.TestCase):
"""
Tests for the meat of L{twisted.cred} -- realms, portals, avatars, and
checkers.
"""
def setUp(self):
self.realm = TestRealm()
self.portal = portal.Portal(self.realm)
self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
self.checker.addUser(b"bob", b"hello")
self.portal.registerChecker(self.checker)
def test_listCheckers(self):
"""
The checkers in a portal can check only certain types of credentials.
Since this portal has
L{checkers.InMemoryUsernamePasswordDatabaseDontUse} registered, it
"""
expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPassword]
got = self.portal.listCredentialsInterfaces()
self.assertEqual(sorted(got), sorted(expected))
def test_basicLogin(self):
"""
Calling C{login} on a portal with correct credentials and an interface
that the portal's realm supports works.
"""
login = self.successResultOf(
self.portal.login(
credentials.UsernamePassword(b"bob", b"hello"), self, ITestable
)
)
iface, impl, logout = login
# whitebox
self.assertEqual(iface, ITestable)
self.assertTrue(iface.providedBy(impl), f"{impl} does not implement {iface}")
# greybox
self.assertTrue(impl.original.loggedIn)
self.assertTrue(not impl.original.loggedOut)
logout()
self.assertTrue(impl.original.loggedOut)
def test_derivedInterface(self):
"""
Logging in with correct derived credentials and an interface
that the portal's realm supports works.
"""
login = self.successResultOf(
self.portal.login(DerivedCredentials(b"bob", b"hello"), self, ITestable)
)
iface, impl, logout = login
# whitebox
self.assertEqual(iface, ITestable)
self.assertTrue(iface.providedBy(impl), f"{impl} does not implement {iface}")
# greybox
self.assertTrue(impl.original.loggedIn)
self.assertTrue(not impl.original.loggedOut)
logout()
self.assertTrue(impl.original.loggedOut)
def test_failedLoginPassword(self):
"""
Calling C{login} with incorrect credentials (in this case a wrong
password) causes L{error.UnauthorizedLogin} to be raised.
"""
login = self.failureResultOf(
self.portal.login(
credentials.UsernamePassword(b"bob", b"h3llo"), self, ITestable
)
)
self.assertTrue(login)
self.assertEqual(error.UnauthorizedLogin, login.type)
def test_failedLoginName(self):
"""
Calling C{login} with incorrect credentials (in this case no known
user) causes L{error.UnauthorizedLogin} to be raised.
"""
login = self.failureResultOf(
self.portal.login(
credentials.UsernamePassword(b"jay", b"hello"), self, ITestable
)
)
self.assertTrue(login)
self.assertEqual(error.UnauthorizedLogin, login.type)
class OnDiskDatabaseTests(unittest.TestCase):
users = [
(b"user1", b"pass1"),
(b"user2", b"pass2"),
(b"user3", b"pass3"),
]
def setUp(self):
self.dbfile = self.mktemp()
with open(self.dbfile, "wb") as f:
for (u, p) in self.users:
f.write(u + b":" + p + b"\n")
def test_getUserNonexistentDatabase(self):
"""
A missing db file will cause a permanent rejection of authorization
attempts.
"""
self.db = checkers.FilePasswordDB("test_thisbetternoteverexist.db")
self.assertRaises(error.UnauthorizedLogin, self.db.getUser, "user")
def testUserLookup(self):
self.db = checkers.FilePasswordDB(self.dbfile)
for (u, p) in self.users:
self.assertRaises(KeyError, self.db.getUser, u.upper())
self.assertEqual(self.db.getUser(u), (u, p))
def testCaseInSensitivity(self):
self.db = checkers.FilePasswordDB(self.dbfile, caseSensitive=False)
for (u, p) in self.users:
self.assertEqual(self.db.getUser(u.upper()), (u, p))
def testRequestAvatarId(self):
self.db = checkers.FilePasswordDB(self.dbfile)
creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults(
[defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds]
)
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testRequestAvatarId_hashed(self):
self.db = checkers.FilePasswordDB(self.dbfile)
UsernameHashedPassword = self.getDeprecatedModuleAttribute(
"twisted.cred.credentials", "UsernameHashedPassword", _uhpVersion
)
creds = [UsernameHashedPassword(u, p) for u, p in self.users]
d = defer.gatherResults(
[defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds]
)
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
class HashedPasswordOnDiskDatabaseTests(unittest.TestCase):
users = [
(b"user1", b"pass1"),
(b"user2", b"pass2"),
(b"user3", b"pass3"),
]
def setUp(self):
dbfile = self.mktemp()
self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
with open(dbfile, "wb") as f:
for (u, p) in self.users:
f.write(u + b":" + self.hash(u, p, u[:2]) + b"\n")
r = TestRealm()
self.port = portal.Portal(r)
self.port.registerChecker(self.db)
def hash(self, u: bytes, p: bytes, s: bytes) -> bytes:
hashed_password = crypt(p.decode("ascii"), s.decode("ascii")) # type: ignore[misc]
# workaround for pypy3 3.6.9 and above which returns bytes from crypt.crypt()
# This is fixed in pypy3 7.3.5.
# See L{https://foss.heptapod.net/pypy/pypy/-/issues/3395}
if isinstance(hashed_password, bytes):
return hashed_password
return hashed_password.encode("ascii")
def testGoodCredentials(self):
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testGoodCredentials_login(self):
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
d = defer.gatherResults(
[self.port.login(c, None, ITestable) for c in goodCreds]
)
d.addCallback(lambda x: [a.original.name for i, a, l in x])
d.addCallback(self.assertEqual, [u for u, p in self.users])
return d
def testBadCredentials(self):
badCreds = [
credentials.UsernamePassword(u, b"wrong password") for u, p in self.users
]
d = defer.DeferredList(
[self.port.login(c, None, ITestable) for c in badCreds], consumeErrors=True
)
d.addCallback(self._assertFailures, error.UnauthorizedLogin)
return d
def testHashedCredentials(self):
UsernameHashedPassword = self.getDeprecatedModuleAttribute(
"twisted.cred.credentials", "UsernameHashedPassword", _uhpVersion
)
hashedCreds = [
UsernameHashedPassword(u, self.hash(None, p, u[:2])) for u, p in self.users
]
d = defer.DeferredList(
[self.port.login(c, None, ITestable) for c in hashedCreds],
consumeErrors=True,
)
d.addCallback(self._assertFailures, error.UnhandledCredentials)
return d
def _assertFailures(self, failures, *expectedFailures):
for flag, failure in failures:
self.assertEqual(flag, defer.FAILURE)
failure.trap(*expectedFailures)
return None
if crypt is None:
skip = "crypt module not available"
class CheckersMixin:
"""
L{unittest.TestCase} mixin for testing that some checkers accept
and deny specified credentials.
Subclasses must provide
- C{getCheckers} which returns a sequence of
L{checkers.ICredentialChecker}
- C{getGoodCredentials} which returns a list of 2-tuples of
credential to check and avaterId to expect.
- C{getBadCredentials} which returns a list of credentials
which are expected to be unauthorized.
"""
@defer.inlineCallbacks
def test_positive(self):
"""
The given credentials are accepted by all the checkers, and give
the expected C{avatarID}s
"""
for chk in self.getCheckers():
for (cred, avatarId) in self.getGoodCredentials():
r = yield chk.requestAvatarId(cred)
self.assertEqual(r, avatarId)
@defer.inlineCallbacks
def test_negative(self):
"""
The given credentials are rejected by all the checkers.
"""
for chk in self.getCheckers():
for cred in self.getBadCredentials():
d = chk.requestAvatarId(cred)
yield self.assertFailure(d, error.UnauthorizedLogin)
class HashlessFilePasswordDBMixin:
credClass = credentials.UsernamePassword
diskHash = None
@staticmethod
def networkHash(x: bytes) -> bytes:
return x
_validCredentials = [
(b"user1", b"password1"),
(b"user2", b"password2"),
(b"user3", b"password3"),
]
def getGoodCredentials(self):
for u, p in self._validCredentials:
yield self.credClass(u, self.networkHash(p)), u
def getBadCredentials(self):
for u, p in [
(b"user1", b"password3"),
(b"user2", b"password1"),
(b"bloof", b"blarf"),
]:
yield self.credClass(u, self.networkHash(p))
def getCheckers(self):
diskHash = self.diskHash or (lambda x: x)
hashCheck = self.diskHash and (
lambda username, password, stored: self.diskHash(password)
)
for cache in True, False:
fn = self.mktemp()
with open(fn, "wb") as fObj:
for u, p in self._validCredentials:
fObj.write(u + b":" + diskHash(p) + b"\n")
yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
fn = self.mktemp()
with open(fn, "wb") as fObj:
for u, p in self._validCredentials:
fObj.write(diskHash(p) + b" dingle dongle " + u + b"\n")
yield checkers.FilePasswordDB(fn, b" ", 3, 0, cache=cache, hash=hashCheck)
fn = self.mktemp()
with open(fn, "wb") as fObj:
for u, p in self._validCredentials:
fObj.write(
b"zip,zap," + u.title() + b",zup," + diskHash(p) + b"\n",
)
yield checkers.FilePasswordDB(
fn, b",", 2, 4, False, cache=cache, hash=hashCheck
)
class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
@staticmethod
def diskHash(x):
return hexlify(x)
class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
@staticmethod
def networkHash(x: bytes) -> bytes:
return hexlify(x)
class credClass(credentials.UsernamePassword):
def checkPassword(self, password):
return unhexlify(self.password) == password
class HashlessFilePasswordDBCheckerTests(
HashlessFilePasswordDBMixin, CheckersMixin, unittest.TestCase
):
pass
class LocallyHashedFilePasswordDBCheckerTests(
LocallyHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase
):
pass
class NetworkHashedFilePasswordDBCheckerTests(
NetworkHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase
):
pass
class UsernameHashedPasswordTests(unittest.TestCase):
"""
UsernameHashedPassword is a deprecated class that is functionally
equivalent to UsernamePassword.
"""
def test_deprecation(self):
"""
Tests that UsernameHashedPassword is deprecated.
"""
self.getDeprecatedModuleAttribute(
"twisted.cred.credentials",
"UsernameHashedPassword",
_uhpVersion,
"Use twisted.cred.credentials.UsernamePassword instead.",
)
|