1717
1818@pytest .mark .test_async
1919class TestCollection (Base ):
20-
2120 @pytest .mark .run (order = 21 )
2221 @pytest .mark .asyncio
2322 async def test_a_create_collection (self ):
@@ -101,10 +100,11 @@ async def test_a_delete_collection(self):
101100
102101@pytest .mark .test_async
103102class TestRecord (Base ):
104-
105103 text_splitter_list = [
106- {"type" : "token" , "chunk_size" : 100 , "chunk_overlap" : 10 },
107- TokenTextSplitter (chunk_size = 200 , chunk_overlap = 20 ),
104+ # {"type": "token", "chunk_size": 100, "chunk_overlap": 10},
105+ # TokenTextSplitter(chunk_size=200, chunk_overlap=20),
106+ {"type" : "separator" , "chunk_size" : 100 , "chunk_overlap" : 10 , "separators" : ["." , "!" , "?" ]},
107+ TextSplitter (type = "separator" , chunk_size = 200 , chunk_overlap = 20 , separators = ["." , "!" , "?" ]),
108108 ]
109109
110110 upload_file_data_list = []
@@ -120,8 +120,8 @@ class TestRecord(Base):
120120
121121 @pytest .mark .run (order = 31 )
122122 @pytest .mark .asyncio
123- async def test_a_create_record_by_text ( self ):
124- text_splitter = TokenTextSplitter ( chunk_size = 200 , chunk_overlap = 100 )
123+ @ pytest . mark . parametrize ( "text_splitter" , text_splitter_list )
124+ async def test_a_create_record_by_text ( self , text_splitter ):
125125 text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."
126126 create_record_data = {
127127 "type" : "text" ,
@@ -131,16 +131,10 @@ async def test_a_create_record_by_text(self):
131131 "text_splitter" : text_splitter ,
132132 "metadata" : {"key1" : "value1" , "key2" : "value2" },
133133 }
134-
135- for x in range (2 ):
136- # Create a record.
137- if x == 0 :
138- create_record_data .update ({"text_splitter" : {"type" : "token" , "chunk_size" : 100 , "chunk_overlap" : 10 }})
139-
140- res = await a_create_record (** create_record_data )
141- res_dict = vars (res )
142- assume_record_result (create_record_data , res_dict )
143- Base .record_id = res_dict ["record_id" ]
134+ res = await a_create_record (** create_record_data )
135+ res_dict = vars (res )
136+ assume_record_result (create_record_data , res_dict )
137+ Base .record_id = res_dict ["record_id" ]
144138
145139 @pytest .mark .run (order = 31 )
146140 @pytest .mark .asyncio
@@ -332,13 +326,14 @@ async def test_a_query_chunks(self):
332326 query_text = "Machine learning"
333327 top_k = 1
334328 res = await a_query_chunks (
335- collection_id = self .collection_id , query_text = query_text , top_k = top_k , max_tokens = 20000
329+ collection_id = self .collection_id , query_text = query_text , top_k = top_k , max_tokens = 20000 , score_threshold = 0.04
336330 )
337331 pytest .assume (len (res ) == top_k )
338332 for chunk in res :
339333 chunk_dict = vars (chunk )
340334 assume_query_chunk_result (query_text , chunk_dict )
341335 pytest .assume (chunk_dict .keys () == self .chunk_keys )
336+ pytest .assume (chunk_dict ["score" ] >= 0.04 )
342337
343338 @pytest .mark .run (order = 42 )
344339 @pytest .mark .asyncio
0 commit comments