From 442ea7c752467dd0eab61a8f0f8cf89abddb6dea Mon Sep 17 00:00:00 2001 From: David Bachowski Date: Thu, 19 Dec 2019 12:58:03 -0500 Subject: [PATCH] Adding support for custom domains --- aiosfstream/auth.py | 53 +++++++++++++++++++++++++++++-------------- aiosfstream/client.py | 9 ++++++-- tests/test_auth.py | 25 +++++++++++++++++--- tests/test_client.py | 3 +++ 4 files changed, 68 insertions(+), 22 deletions(-) diff --git a/aiosfstream/auth.py b/aiosfstream/auth.py index 726d0bf..f940abd 100644 --- a/aiosfstream/auth.py +++ b/aiosfstream/auth.py @@ -14,8 +14,10 @@ from aiosfstream.exceptions import AuthenticationError -TOKEN_URL = "https://login.salesforce.com/services/oauth2/token" -SANDBOX_TOKEN_URL = "https://test.salesforce.com/services/oauth2/token" +# formatted later with a domain +BASE_URL = "https://{}.salesforce.com/services/oauth2/token" +LOGIN_DOMAIN = "login" +SANDBOX_DOMAIN = "test" # pylint: disable=too-many-instance-attributes @@ -23,20 +25,29 @@ class AuthenticatorBase(AuthExtension): """Abstract base class to serve as a base for implementing concrete authenticators""" - def __init__(self, sandbox: bool = False, + def __init__(self, sandbox: bool = None, domain: str = None, json_dumps: JsonDumper = json.dumps, json_loads: JsonLoader = json.loads) -> None: """ - :param sandbox: Marks whether the authentication has to be done \ - for a sandbox org or for a production org + :param sandbox: Marks whether the connection has to be made with \ + a sandbox org or with a production org. Cannot be used concurrently with \ + a value for domain. + :param domain: A custom salesforce domain instead of 'login' or 'test'. \ + Cannot be used concurrently with a value for sandbox :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` :param json_loads: Function for JSON deserialization, the default is \ :func:`json.loads` """ - #: Marks whether the authentication has to be done for a sandbox org \ - #: or for a production org - self._sandbox = sandbox + if sandbox is not None and domain is not None: + raise ValueError('You cannot specify a value for sandbox AND domain. Please use just one.') + elif domain is not None: + self._domain = domain + elif sandbox is True: + self._domain = SANDBOX_DOMAIN + else: + self._domain = LOGIN_DOMAIN + #: Salesforce session ID that can be used with the web services API self.access_token: Optional[str] = None #: Value is Bearer for all responses that include an access token @@ -60,9 +71,7 @@ def __init__(self, sandbox: bool = False, @property def _token_url(self) -> str: """The URL that should be used for token requests""" - if self._sandbox: - return SANDBOX_TOKEN_URL - return TOKEN_URL + return BASE_URL.format(self._domain) async def outgoing(self, payload: Payload, headers: Headers) -> None: """Process outgoing *payload* and *headers* @@ -124,7 +133,8 @@ async def _authenticate(self) -> Tuple[int, JsonObject]: class PasswordAuthenticator(AuthenticatorBase): """Authenticator for using the OAuth 2.0 Username-Password Flow""" def __init__(self, consumer_key: str, consumer_secret: str, - username: str, password: str, sandbox: bool = False, + username: str, password: str, + sandbox: bool = None, domain: str = None, json_dumps: JsonDumper = json.dumps, json_loads: JsonLoader = json.loads) -> None: """ @@ -134,14 +144,18 @@ def __init__(self, consumer_key: str, consumer_secret: str, connected app definition :param username: Salesforce username :param password: Salesforce password - :param sandbox: Marks whether the authentication has to be done \ - for a sandbox org or for a production org + :param sandbox: Marks whether the connection has to be made with \ + a sandbox org or with a production org. Cannot be used concurrently with \ + a value for domain. + :param domain: A custom salesforce domain instead of 'login' or 'test'. \ + Cannot be used concurrently with a value for sandbox :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` :param json_loads: Function for JSON deserialization, the default is \ :func:`json.loads` """ super().__init__(sandbox=sandbox, + domain=domain, json_dumps=json_dumps, json_loads=json_loads) #: OAuth2 client id @@ -178,7 +192,8 @@ async def _authenticate(self) -> Tuple[int, JsonObject]: class RefreshTokenAuthenticator(AuthenticatorBase): """Authenticator for using the OAuth 2.0 Refresh Token Flow""" def __init__(self, consumer_key: str, consumer_secret: str, - refresh_token: str, sandbox: bool = False, + refresh_token: str, + sandbox: bool = None, domain: str = None, json_dumps: JsonDumper = json.dumps, json_loads: JsonLoader = json.loads) -> None: """ @@ -189,14 +204,18 @@ def __init__(self, consumer_key: str, consumer_secret: str, :param refresh_token: A refresh token obtained from Salesforce \ by using one of its authentication methods (for example with the \ OAuth 2.0 Web Server Authentication Flow) - :param sandbox: Marks whether the authentication has to be done \ - for a sandbox org or for a production org + :param sandbox: Marks whether the connection has to be made with \ + a sandbox org or with a production org. Cannot be used concurrently with \ + a value for domain. + :param domain: A custom salesforce domain instead of 'login' or 'test'. \ + Cannot be used concurrently with a value for sandbox :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` :param json_loads: Function for JSON deserialization, the default is \ :func:`json.loads` """ super().__init__(sandbox=sandbox, + domain=domain, json_dumps=json_dumps, json_loads=json_loads) #: OAuth2 client id diff --git a/aiosfstream/client.py b/aiosfstream/client.py index 0c52b7d..261d9b3 100644 --- a/aiosfstream/client.py +++ b/aiosfstream/client.py @@ -295,7 +295,8 @@ def __init__(self, *, # pylint: disable=too-many-locals replay_storage_policy: ReplayMarkerStoragePolicy = ReplayMarkerStoragePolicy.AUTOMATIC, connection_timeout: Union[int, float] = 10.0, - max_pending_count: int = 100, sandbox: bool = False, + max_pending_count: int = 100, + sandbox: bool = None, domain: str = None, json_dumps: JsonDumper = json.dumps, json_loads: JsonLoader = json.loads, loop: Optional[asyncio.AbstractEventLoop] = None): @@ -326,7 +327,10 @@ def __init__(self, *, # pylint: disable=too-many-locals consumed. \ If it is less than or equal to zero, the count is infinite. :param sandbox: Marks whether the connection has to be made with \ - a sandbox org or with a production org + a sandbox org or with a production org. Cannot be used concurrently with \ + a value for domain. + :param domain: A custom salesforce domain instead of 'login' or 'test'. \ + Cannot be used concurrently with a value for sandbox :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` :param json_loads: Function for JSON deserialization, the default is \ @@ -342,6 +346,7 @@ def __init__(self, *, # pylint: disable=too-many-locals username=username, password=password, sandbox=sandbox, + domain=domain, json_dumps=json_dumps, json_loads=json_loads, ) diff --git a/tests/test_auth.py b/tests/test_auth.py index a34237a..449ed0c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -5,7 +5,7 @@ from aiohttp.client_exceptions import ClientError from aiosfstream.auth import AuthenticatorBase, PasswordAuthenticator, \ - TOKEN_URL, SANDBOX_TOKEN_URL, RefreshTokenAuthenticator + LOGIN_DOMAIN, SANDBOX_DOMAIN, BASE_URL, RefreshTokenAuthenticator from aiosfstream.exceptions import AuthenticationError @@ -131,12 +131,31 @@ async def test_incoming(self): def test_token_url_non_sandbox(self): auth = Authenticator() - self.assertEqual(auth._token_url, TOKEN_URL) + self.assertEqual(auth._token_url, BASE_URL.format(LOGIN_DOMAIN)) def test_token_url_sandbox(self): auth = Authenticator(sandbox=True) - self.assertEqual(auth._token_url, SANDBOX_TOKEN_URL) + self.assertEqual(auth._token_url, BASE_URL.format(SANDBOX_DOMAIN)) + + def test_custom_url_sandbox(self): + domain = 'sparkles' + auth = Authenticator(domain=domain) + self.assertEqual(auth._token_url, BASE_URL.format(domain)) + + def test_sandbox_true_with_custom_domain(self): + domain = 'sparkles' + sandbox = True + with self.assertRaisesRegex(ValueError, + "You cannot specify a value for sandbox AND domain"): + auth = Authenticator(sandbox=sandbox, domain=domain) + + def test_sandbox_false_with_custom_domain(self): + domain = 'sparkles' + sandbox = False + with self.assertRaisesRegex(ValueError, + "You cannot specify a value for sandbox AND domain"): + auth = Authenticator(sandbox=sandbox, domain=domain) class TestPasswordAuthenticator(TestCase): diff --git a/tests/test_client.py b/tests/test_client.py index 44872eb..cce793b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -418,6 +418,7 @@ def test_init(self, super_init, authenticator_cls): json_loads = object() loop = object() sandbox_enabled = True + domain = None SalesforceStreamingClient( consumer_key=consumer_key, @@ -430,6 +431,7 @@ def test_init(self, super_init, authenticator_cls): connection_timeout=connection_timeout, max_pending_count=max_pending_count, sandbox=sandbox_enabled, + domain=domain, json_dumps=json_dumps, json_loads=json_loads, loop=loop @@ -441,6 +443,7 @@ def test_init(self, super_init, authenticator_cls): username=username, password=password, sandbox=sandbox_enabled, + domain=domain, json_dumps=json_dumps, json_loads=json_loads, )