diff --git a/pom.xml b/pom.xml index 69046e0f..f6d00358 100644 --- a/pom.xml +++ b/pom.xml @@ -317,6 +317,8 @@ limitations under the License. com.microsoft.azure.cosmosdb.internal.query.metrics.* com.microsoft.azure.cosmosdb.internal.query.orderbyquery.* com.microsoft.azure.cosmosdb.internal.routing.* + com.microsoft.azure.cosmosdb.TokenResolver + com.microsoft.azure.cosmosdb.CosmosResourceType diff --git a/src/main/scala/com/microsoft/azure/cosmosdb/spark/AsyncCosmosDBConnection.scala b/src/main/scala/com/microsoft/azure/cosmosdb/spark/AsyncCosmosDBConnection.scala index 2955f830..e70585f6 100644 --- a/src/main/scala/com/microsoft/azure/cosmosdb/spark/AsyncCosmosDBConnection.scala +++ b/src/main/scala/com/microsoft/azure/cosmosdb/spark/AsyncCosmosDBConnection.scala @@ -31,19 +31,20 @@ import com.microsoft.azure.cosmosdb._ import com.microsoft.azure.cosmosdb.internal._ import com.microsoft.azure.cosmosdb.rx.AsyncDocumentClient import com.microsoft.azure.cosmosdb.spark.schema.CosmosDBRowConverter -import com.microsoft.azure.cosmosdb.spark.streaming.CosmosDBWriteStreamRetryPolicy import org.apache.spark.sql.Row import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag - import java.util.concurrent.ConcurrentHashMap +import com.microsoft.azure.cosmosdb.spark.util.CosmosUtils + case class AsyncClientConfiguration(host: String, key: String, connectionPolicy: ConnectionPolicy, - consistencyLevel: ConsistencyLevel) + consistencyLevel: ConsistencyLevel, + tokenResolver: CosmosDBTokenResolver) object AsyncCosmosDBConnection { private lazy val clients: ConcurrentHashMap[Config, AsyncDocumentClient] = { @@ -103,13 +104,26 @@ object AsyncCosmosDBConnection { val consistencyLevel = ConsistencyLevel.valueOf(config.get[String](CosmosDBConfig.ConsistencyLevel) .getOrElse(CosmosDBConfig.DefaultConsistencyLevel)) - val resourceToken = config.getOrElse[String](CosmosDBConfig.ResourceToken, "") + var resourceKey: String = null + + // Check Resource Token and Token Resolver + var tokenResolver: CosmosDBTokenResolver = null + val tokenResolverClassName = config.getOrElse[String](CosmosDBConfig.TokenResolver, "") + + if (!tokenResolverClassName.isEmpty) { + tokenResolver = CosmosUtils.getTokenResolverFromClassName(tokenResolverClassName) + tokenResolver.initialize(config) + } else { + val resourceToken = config.getOrElse[String](CosmosDBConfig.ResourceToken, "") + resourceKey = config.getOrElse[String](CosmosDBConfig.Masterkey, resourceToken) + } AsyncClientConfiguration( config.get[String](CosmosDBConfig.Endpoint).get, - config.getOrElse[String](CosmosDBConfig.Masterkey, resourceToken), + resourceKey, connectionPolicy, - consistencyLevel + consistencyLevel, + tokenResolver ) } @@ -126,6 +140,7 @@ object AsyncCosmosDBConnection { .withMasterKeyOrResourceToken(clientConfig.key) .withConnectionPolicy(clientConfig.connectionPolicy) .withConsistencyLevel(clientConfig.consistencyLevel) + .withTokenResolver(clientConfig.tokenResolver) .build() } } diff --git a/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBConnection.scala b/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBConnection.scala index 9d3b7cf5..17e9adaf 100644 --- a/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBConnection.scala +++ b/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBConnection.scala @@ -25,8 +25,9 @@ package com.microsoft.azure.cosmosdb.spark import java.lang.management.ManagementFactory import java.util.{Timer, TimerTask} +import com.microsoft.azure.cosmosdb.{CosmosResourceType} import com.microsoft.azure.cosmosdb.spark.config._ -import com.microsoft.azure.documentdb +import com.microsoft.azure.cosmosdb.spark.util.CosmosUtils import com.microsoft.azure.documentdb._ import com.microsoft.azure.documentdb.bulkexecutor.DocumentBulkExecutor import com.microsoft.azure.documentdb.internal._ @@ -43,7 +44,8 @@ case class ClientConfiguration(host: String, key: String, connectionPolicy: ConnectionPolicy, consistencyLevel: ConsistencyLevel, - resourceLink: String) + resourceLink: String, + tokenResolver: CosmosDBTokenResolver) object CosmosDBConnection extends CosmosDBLoggingTrait { // For verification purpose @@ -518,11 +520,21 @@ private[spark] case class CosmosDBConnection(config: Config) extends CosmosDBLog val consistencyLevel = ConsistencyLevel.valueOf(config.get[String](CosmosDBConfig.ConsistencyLevel) .getOrElse(CosmosDBConfig.DefaultConsistencyLevel)) - //Check if resource token exists - val resourceToken = config.getOrElse[String](CosmosDBConfig.ResourceToken, "") - var resourceLink: String = "" - if(!resourceToken.isEmpty) { - resourceLink = s"dbs/${config.get[String](CosmosDBConfig.Database).get}/colls/${config.get[String](CosmosDBConfig.Collection).get}" + // check Token Resolver before checking resource token + var resourceLink = s"dbs/${config.get[String](CosmosDBConfig.Database).get}/colls/${config.get[String](CosmosDBConfig.Collection).get}" + var resourceToken = config.getOrElse(CosmosDBConfig.ResourceToken, "") + + var tokenResolver: CosmosDBTokenResolver = null + val tokenResolverClassName = config.getOrElse[String](CosmosDBConfig.TokenResolver, "") + + if (!tokenResolverClassName.isEmpty) { + tokenResolver = CosmosUtils.getTokenResolverFromClassName(tokenResolverClassName) + tokenResolver.initialize(config) + resourceToken = tokenResolver.getAuthorizationToken("GET", resourceLink, CosmosResourceType.DocumentCollection, config.asOptions) + } + + if(resourceToken.isEmpty) { + resourceLink = "" } ClientConfiguration( @@ -530,7 +542,9 @@ private[spark] case class CosmosDBConnection(config: Config) extends CosmosDBLog config.getOrElse[String](CosmosDBConfig.Masterkey, resourceToken), connectionPolicy, consistencyLevel, - resourceLink) + resourceLink, + tokenResolver + ) } } diff --git a/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBTokenResolver.scala b/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBTokenResolver.scala new file mode 100644 index 00000000..e81f1e25 --- /dev/null +++ b/src/main/scala/com/microsoft/azure/cosmosdb/spark/CosmosDBTokenResolver.scala @@ -0,0 +1,8 @@ +package com.microsoft.azure.cosmosdb.spark + +import com.microsoft.azure.cosmosdb.spark.config.Config +import com.microsoft.azure.cosmosdb.TokenResolver + +trait CosmosDBTokenResolver extends TokenResolver { + def initialize(config: Config): Unit +} diff --git a/src/main/scala/com/microsoft/azure/cosmosdb/spark/config/CosmosDBConfig.scala b/src/main/scala/com/microsoft/azure/cosmosdb/spark/config/CosmosDBConfig.scala index 10bf6fcb..b818e05e 100755 --- a/src/main/scala/com/microsoft/azure/cosmosdb/spark/config/CosmosDBConfig.scala +++ b/src/main/scala/com/microsoft/azure/cosmosdb/spark/config/CosmosDBConfig.scala @@ -39,6 +39,7 @@ object CosmosDBConfig { val Collection = "collection" val Masterkey = "masterkey" val ResourceToken = "resourcetoken" + val TokenResolver = "tokenresolver" val PreferredRegionsList = "preferredregions" val ConsistencyLevel = "consistencylevel" diff --git a/src/main/scala/com/microsoft/azure/cosmosdb/spark/util/CosmosUtils.scala b/src/main/scala/com/microsoft/azure/cosmosdb/spark/util/CosmosUtils.scala new file mode 100644 index 00000000..b2983775 --- /dev/null +++ b/src/main/scala/com/microsoft/azure/cosmosdb/spark/util/CosmosUtils.scala @@ -0,0 +1,12 @@ +package com.microsoft.azure.cosmosdb.spark.util + +import com.microsoft.azure.cosmosdb.spark.CosmosDBTokenResolver + +object CosmosUtils extends Serializable { + + def getTokenResolverFromClassName(className: String, constructorArgs: AnyRef*): CosmosDBTokenResolver = { + val argsClassSeq = constructorArgs.map(e => e.getClass) + Class.forName(className).getDeclaredConstructor(argsClassSeq:_*).newInstance(constructorArgs:_*).asInstanceOf[CosmosDBTokenResolver] + } + +}