2222
2323import collections
2424import json
25+ import os
2526import sys
2627import threading
2728
29+ import google .auth
2830import requests
2931import six
3032from six .moves import urllib
4143_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython' .format (
4244 firebase_admin .__version__ , sys .version_info .major , sys .version_info .minor )
4345_TRANSACTION_MAX_RETRIES = 25
46+ _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST'
4447
4548
4649def reference (path = '/' , app = None , url = None ):
@@ -768,46 +771,108 @@ class _DatabaseService(object):
768771 _DEFAULT_AUTH_OVERRIDE = '_admin_'
769772
770773 def __init__ (self , app ):
771- self ._credential = app .credential . get_credential ()
774+ self ._credential = app .credential
772775 db_url = app .options .get ('databaseURL' )
773776 if db_url :
774- self ._db_url = _DatabaseService ._validate_url (db_url )
777+ _DatabaseService ._parse_db_url (db_url ) # Just for validation.
778+ self ._db_url = db_url
775779 else :
776780 self ._db_url = None
777781 auth_override = _DatabaseService ._get_auth_override (app )
778782 if auth_override != self ._DEFAULT_AUTH_OVERRIDE and auth_override != {}:
779- encoded = json .dumps (auth_override , separators = (',' , ':' ))
780- self ._auth_override = 'auth_variable_override={0}' .format (encoded )
783+ self ._auth_override = json .dumps (auth_override , separators = (',' , ':' ))
781784 else :
782785 self ._auth_override = None
783786 self ._timeout = app .options .get ('httpTimeout' )
784787 self ._clients = {}
785788
786- def get_client (self , base_url = None ):
787- if base_url is None :
788- base_url = self ._db_url
789- base_url = _DatabaseService ._validate_url (base_url )
790- if base_url not in self ._clients :
791- client = _Client (self ._credential , base_url , self ._auth_override , self ._timeout )
792- self ._clients [base_url ] = client
793- return self ._clients [base_url ]
789+ emulator_host = os .environ .get (_EMULATOR_HOST_ENV_VAR )
790+ if emulator_host :
791+ if '//' in emulator_host :
792+ raise ValueError (
793+ 'Invalid {0}: "{1}". It must follow format "host:port".' .format (
794+ _EMULATOR_HOST_ENV_VAR , emulator_host ))
795+ self ._emulator_host = emulator_host
796+ else :
797+ self ._emulator_host = None
798+
799+ def get_client (self , db_url = None ):
800+ """Creates a client based on the db_url. Clients may be cached."""
801+ if db_url is None :
802+ db_url = self ._db_url
803+
804+ base_url , namespace = _DatabaseService ._parse_db_url (db_url , self ._emulator_host )
805+ if base_url == 'https://{0}.firebaseio.com' .format (namespace ):
806+ # Production base_url. No need to specify namespace in query params.
807+ params = {}
808+ credential = self ._credential .get_credential ()
809+ else :
810+ # Emulator base_url. Use fake credentials and specify ?ns=foo in query params.
811+ credential = _EmulatorAdminCredentials ()
812+ params = {'ns' : namespace }
813+ if self ._auth_override :
814+ params ['auth_variable_override' ] = self ._auth_override
815+
816+ client_cache_key = (base_url , json .dumps (params , sort_keys = True ))
817+ if client_cache_key not in self ._clients :
818+ client = _Client (credential , base_url , self ._timeout , params )
819+ self ._clients [client_cache_key ] = client
820+ return self ._clients [client_cache_key ]
794821
795822 @classmethod
796- def _validate_url (cls , url ):
797- """Parses and validates a given database URL."""
823+ def _parse_db_url (cls , url , emulator_host = None ):
824+ """Parses (base_url, namespace) from a database URL.
825+
826+ The input can be either a production URL (https://foo-bar.firebaseio.com/)
827+ or an Emulator URL (http://localhost:8080/?ns=foo-bar). In case of Emulator
828+ URL, the namespace is extracted from the query param ns. The resulting
829+ base_url never includes query params.
830+
831+ If url is a production URL and emulator_host is specified, the result
832+ base URL will use emulator_host instead. emulator_host is ignored
833+ if url is already an emulator URL.
834+ """
798835 if not url or not isinstance (url , six .string_types ):
799836 raise ValueError (
800837 'Invalid database URL: "{0}". Database URL must be a non-empty '
801838 'URL string.' .format (url ))
802- parsed = urllib .parse .urlparse (url )
803- if parsed .scheme != 'https' :
839+ parsed_url = urllib .parse .urlparse (url )
840+ if parsed_url .netloc .endswith ('.firebaseio.com' ):
841+ return cls ._parse_production_url (parsed_url , emulator_host )
842+ else :
843+ return cls ._parse_emulator_url (parsed_url )
844+
845+ @classmethod
846+ def _parse_production_url (cls , parsed_url , emulator_host ):
847+ """Parses production URL like https://foo-bar.firebaseio.com/"""
848+ if parsed_url .scheme != 'https' :
804849 raise ValueError (
805- 'Invalid database URL: "{0}". Database URL must be an HTTPS URL.' .format (url ))
806- elif not parsed .netloc .endswith ('.firebaseio.com' ):
850+ 'Invalid database URL scheme: "{0}". Database URL must be an HTTPS URL.' .format (
851+ parsed_url .scheme ))
852+ namespace = parsed_url .netloc .split ('.' )[0 ]
853+ if not namespace :
807854 raise ValueError (
808855 'Invalid database URL: "{0}". Database URL must be a valid URL to a '
809- 'Firebase Realtime Database instance.' .format (url ))
810- return 'https://{0}' .format (parsed .netloc )
856+ 'Firebase Realtime Database instance.' .format (parsed_url .geturl ()))
857+
858+ if emulator_host :
859+ base_url = 'http://{0}' .format (emulator_host )
860+ else :
861+ base_url = 'https://{0}' .format (parsed_url .netloc )
862+ return base_url , namespace
863+
864+ @classmethod
865+ def _parse_emulator_url (cls , parsed_url ):
866+ """Parses emulator URL like http://localhost:8080/?ns=foo-bar"""
867+ query_ns = urllib .parse .parse_qs (parsed_url .query ).get ('ns' )
868+ if parsed_url .scheme != 'http' or (not query_ns or len (query_ns ) != 1 or not query_ns [0 ]):
869+ raise ValueError (
870+ 'Invalid database URL: "{0}". Database URL must be a valid URL to a '
871+ 'Firebase Realtime Database instance.' .format (parsed_url .geturl ()))
872+
873+ namespace = query_ns [0 ]
874+ base_url = '{0}://{1}' .format (parsed_url .scheme , parsed_url .netloc )
875+ return base_url , namespace
811876
812877 @classmethod
813878 def _get_auth_override (cls , app ):
@@ -833,7 +898,7 @@ class _Client(_http_client.JsonHttpClient):
833898 marshalling and unmarshalling of JSON data.
834899 """
835900
836- def __init__ (self , credential , base_url , auth_override , timeout ):
901+ def __init__ (self , credential , base_url , timeout , params = None ):
837902 """Creates a new _Client from the given parameters.
838903
839904 This exists primarily to enable testing. For regular use, obtain _Client instances by
@@ -843,22 +908,21 @@ def __init__(self, credential, base_url, auth_override, timeout):
843908 credential: A Google credential that can be used to authenticate requests.
844909 base_url: A URL prefix to be added to all outgoing requests. This is typically the
845910 Firebase Realtime Database URL.
846- auth_override: The encoded auth_variable_override query parameter to be included in
847- outgoing requests.
848911 timeout: HTTP request timeout in seconds. If not set connections will never
849912 timeout, which is the default behavior of the underlying requests library.
913+ params: Dict of query parameters to add to all outgoing requests.
850914 """
851915 _http_client .JsonHttpClient .__init__ (
852916 self , credential = credential , base_url = base_url , headers = {'User-Agent' : _USER_AGENT })
853917 self .credential = credential
854- self .auth_override = auth_override
855918 self .timeout = timeout
919+ self .params = params if params else {}
856920
857921 def request (self , method , url , ** kwargs ):
858922 """Makes an HTTP call using the Python requests library.
859923
860- Extends the request() method of the parent JsonHttpClient class. Handles auth overrides,
861- and low-level exceptions.
924+ Extends the request() method of the parent JsonHttpClient class. Handles default
925+ params like auth overrides, and low-level exceptions.
862926
863927 Args:
864928 method: HTTP method name as a string (e.g. get, post).
@@ -872,13 +936,15 @@ def request(self, method, url, **kwargs):
872936 Raises:
873937 ApiCallError: If an error occurs while making the HTTP call.
874938 """
875- if self .auth_override :
876- params = kwargs .get ('params' )
877- if params :
878- params += '&{0}' .format (self .auth_override )
939+ query = '&' .join ('{0}={1}' .format (key , self .params [key ]) for key in self .params )
940+ extra_params = kwargs .get ('params' )
941+ if extra_params :
942+ if query :
943+ query = extra_params + '&' + query
879944 else :
880- params = self .auth_override
881- kwargs ['params' ] = params
945+ query = extra_params
946+ kwargs ['params' ] = query
947+
882948 if self .timeout :
883949 kwargs ['timeout' ] = self .timeout
884950 try :
@@ -911,3 +977,12 @@ def extract_error_message(cls, error):
911977 except ValueError :
912978 pass
913979 return '{0}\n Reason: {1}' .format (error , error .response .content .decode ())
980+
981+
982+ class _EmulatorAdminCredentials (google .auth .credentials .Credentials ):
983+ def __init__ (self ):
984+ google .auth .credentials .Credentials .__init__ (self )
985+ self .token = 'owner'
986+
987+ def refresh (self , request ):
988+ pass
0 commit comments