Skip to content

Commit 982c1d4

Browse files
rafaelodonRafael Odon
andauthored
INTPYTHON-785 Only list authorized collections when listing collections (langchain-ai#226)
Fixes langchain-ai#225 --------- Co-authored-by: Rafael Odon <rafael.odon@serpro.gov.br>
1 parent 3913e40 commit 982c1d4

File tree

7 files changed

+20
-7
lines changed

7 files changed

+20
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ __pycache__
99
docs/langchain_mongodb
1010
docs/langgraph_checkpoint_mongodb
1111
docs/index.md
12+
.vscode

libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def __init__(
5757
)
5858
self._include_colls = set(include_collections or [])
5959
self._ignore_colls = set(ignore_collections or [])
60-
self._all_colls = set(self._db.list_collection_names())
60+
self._all_colls = set(
61+
self._db.list_collection_names(authorizedCollections=True)
62+
)
6163

6264
self._sample_docs_in_coll_info = sample_docs_in_collection_info
6365
self._indexes_in_coll_info = indexes_in_collection_info

libs/langchain-mongodb/langchain_mongodb/cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(
5252
self.__database_name = database_name
5353
self.__collection_name = collection_name
5454

55-
if self.__collection_name not in self.database.list_collection_names():
55+
if self.__collection_name not in self.database.list_collection_names(
56+
authorizedCollections=True
57+
):
5658
self.database.create_collection(self.__collection_name)
5759
# Create an index on key and llm_string
5860
self.collection.create_index([self.PROMPT, self.LLM])

libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def __init__(
142142
driver=DRIVER_METADATA,
143143
)
144144
db = client[database_name]
145-
if collection_name not in db.list_collection_names():
145+
if collection_name not in db.list_collection_names(
146+
authorizedCollections=True
147+
):
146148
validator = {"$jsonSchema": self._schema} if validate else None
147149
collection = client[database_name].create_collection(
148150
collection_name,

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def create_vector_search_index(
6060
"""
6161
logger.info("Creating Search Index %s on %s", index_name, collection.name)
6262

63-
if collection.name not in collection.database.list_collection_names():
63+
if collection.name not in collection.database.list_collection_names(
64+
authorizedCollections=True
65+
):
6466
collection.database.create_collection(collection.name)
6567

6668
result = collection.create_search_index(
@@ -219,7 +221,9 @@ def create_fulltext_search_index(
219221
"""
220222
logger.info("Creating Search Index %s on %s", index_name, collection.name)
221223

222-
if collection.name not in collection.database.list_collection_names():
224+
if collection.name not in collection.database.list_collection_names(
225+
authorizedCollections=True
226+
):
223227
collection.database.create_collection(collection.name)
224228

225229
if isinstance(field, str):

libs/langchain-mongodb/tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class MockDatabase:
238238
def __init__(self, client=None):
239239
self.client = client or MockClient()
240240

241-
def list_collection_names(self) -> list[str]:
241+
def list_collection_names(self, authorizedCollections: bool = True) -> list[str]:
242242
return ["test"]
243243

244244
def __getitem__(self, key: str) -> Any:

libs/langgraph-store-mongodb/langgraph/store/mongodb/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def from_conn_string(
251251
driver=DRIVER_METADATA,
252252
)
253253
db = client[db_name]
254-
if collection_name not in db.list_collection_names():
254+
if collection_name not in db.list_collection_names(
255+
authorizedCollections=True
256+
):
255257
db.create_collection(collection_name)
256258
collection = client[db_name][collection_name]
257259

0 commit comments

Comments
 (0)