@@ -2010,3 +2010,176 @@ def test_ivf_flat_taskgraph_query(tmp_path):
20102010 queries , k = k , nprobe = nprobe , nthreads = 8 , mode = Mode .LOCAL , num_partitions = 10
20112011 )
20122012 assert accuracy (result , gt_i ) > MINIMUM_ACCURACY
2013+
2014+
2015+ def test_dimensions_parameter_override (tmp_path ):
2016+ """
2017+ Test the dimensions parameter functionality with TileDB array input.
2018+
2019+ This test verifies that the dimensions parameter can override
2020+ the dimensions detected from the source array, which is useful
2021+ for handling cases where the source array has an artificially
2022+ large domain (e.g., due to TileDBSOMA; https://github.com/TileDB-Inc/TileDB-Vector-Search/issues/564).
2023+ """
2024+ # Create test data
2025+ actual_dimensions = 64
2026+ nb = 1000
2027+ nq = 10
2028+ k = 5
2029+
2030+ # Create random test vectors with actual dimensions
2031+ test_vectors = np .random .rand (nb , actual_dimensions ).astype (np .float32 )
2032+ queries = np .random .rand (nq , actual_dimensions ).astype (np .float32 )
2033+
2034+ # Create a TileDB array with artificially large domain (simulating the problem)
2035+ source_uri = os .path .join (tmp_path , "source_array" )
2036+ large_domain_value = 100000
2037+
2038+ # Create schema with large dimension domain
2039+ schema = tiledb .ArraySchema (
2040+ domain = tiledb .Domain (
2041+ tiledb .Dim (
2042+ name = "__dim_0" ,
2043+ domain = (0 , large_domain_value ),
2044+ tile = 1000 ,
2045+ dtype = "int32" ,
2046+ ),
2047+ tiledb .Dim (
2048+ name = "__dim_1" ,
2049+ domain = (0 , large_domain_value ),
2050+ tile = actual_dimensions ,
2051+ dtype = "int32" ,
2052+ ),
2053+ ),
2054+ attrs = [
2055+ tiledb .Attr (name = "values" , dtype = "float32" , var = False , nullable = False ),
2056+ ],
2057+ cell_order = "col-major" ,
2058+ tile_order = "col-major" ,
2059+ capacity = 10000 ,
2060+ sparse = False ,
2061+ )
2062+
2063+ # Create the array and write test data
2064+ tiledb .Array .create (source_uri , schema )
2065+ with tiledb .open (source_uri , "w" ) as A :
2066+ A [0 :nb , 0 :actual_dimensions ] = test_vectors
2067+
2068+ # Test ingestion with dimensions parameter override
2069+ # Without the override, the large domain would be detected as 100001 dimensions
2070+ # With the override, we explicitly set it to the actual dimensions (64)
2071+ index_uri = os .path .join (tmp_path , "test_index" )
2072+
2073+ index = ingest (
2074+ index_type = "FLAT" ,
2075+ index_uri = index_uri ,
2076+ source_uri = source_uri ,
2077+ source_type = "TILEDB_ARRAY" ,
2078+ dimensions = actual_dimensions , # Override the detected large dimensions
2079+ size = nb ,
2080+ )
2081+
2082+ # Verify the index was created successfully
2083+ assert index is not None
2084+ index .vacuum ()
2085+
2086+ # Verify the index works correctly with queries
2087+ distances , indices = index .query (queries , k = k )
2088+
2089+ # Basic sanity checks
2090+ assert distances .shape == (nq , k )
2091+ assert indices .shape == (nq , k )
2092+ assert np .all (indices >= 0 )
2093+ assert np .all (indices < nb )
2094+
2095+ # Verify that dimensions=-1 (or not passing at all) detects large dimensions but creates unusable index
2096+ # This demonstrates the problem that the dimensions parameter is meant to solve
2097+ index_uri_2 = os .path .join (tmp_path , "test_index_2" )
2098+
2099+ # Create with explicit dimensions=-1 - this will use the large detected dimensions
2100+ # The index creation will succeed, but queries will fail due to dimension mismatch
2101+ index_2 = ingest (
2102+ index_type = "FLAT" ,
2103+ index_uri = index_uri_2 ,
2104+ source_uri = source_uri ,
2105+ source_type = "TILEDB_ARRAY" ,
2106+ dimensions = - 1 , # Uses detected large dimensions (100001)
2107+ size = nb ,
2108+ )
2109+
2110+ assert index_2 is not None
2111+ index_2 .vacuum ()
2112+
2113+ # Verify that the index was created with the large detected dimensions
2114+ assert index_2 .dimensions == large_domain_value + 1 # 100001 dimensions
2115+
2116+ # Verify that queries fail due to dimension mismatch
2117+ # This demonstrates why the dimensions parameter override is needed
2118+ with pytest .raises (Exception ) as exc_info :
2119+ index_2 .query (queries , k = k )
2120+ assert (
2121+ "A query in queries has 64 dimensions, but the indexed data had 100001 dimensions"
2122+ in str (exc_info .value )
2123+ ) # Should contain dimension mismatch error
2124+
2125+
2126+ def test_dimensions_parameter_with_numpy_input (tmp_path ):
2127+ """
2128+ Test the dimensions parameter with numpy input vectors.
2129+
2130+ This is to ensure that when input_vectors is provided as a numpy array,
2131+ the dimensions parameter is either ignored or validated correctly.
2132+ """
2133+ # Create test data
2134+ nb = 100
2135+ actual_dimensions = 32
2136+ nq = 5
2137+ k = 3
2138+
2139+ # Create random test vectors
2140+ input_vectors = np .random .rand (nb , actual_dimensions ).astype (np .float32 )
2141+ queries = np .random .rand (nq , actual_dimensions ).astype (np .float32 )
2142+
2143+ # Ingest with numpy input and dimensions parameter (should be ignored since input_vectors is provided)
2144+ index_uri = os .path .join (tmp_path , "test_numpy_index" )
2145+
2146+ # When input_vectors is provided, the dimensions parameter should not affect the detected dimensions
2147+ # but the function should still accept it without error
2148+ index = ingest (
2149+ index_type = "FLAT" ,
2150+ index_uri = index_uri ,
2151+ input_vectors = input_vectors ,
2152+ dimensions = 999 , # This should be ignored since input_vectors is provided
2153+ )
2154+
2155+ # Verify the index was created successfully
2156+ assert index is not None
2157+ index .vacuum ()
2158+
2159+ # Test that queries work correctly
2160+ distances , indices = index .query (queries , k = k )
2161+
2162+ # Basic sanity checks
2163+ assert distances .shape == (nq , k )
2164+ assert indices .shape == (nq , k )
2165+ assert np .all (indices >= 0 )
2166+ assert np .all (indices < nb )
2167+
2168+ # Verify that dimensions parameter doesn't cause issues with default behavior
2169+ index_uri_2 = os .path .join (tmp_path , "test_numpy_index_2" )
2170+
2171+ # Test without dimensions parameter (default behavior)
2172+ index_2 = ingest (
2173+ index_type = "FLAT" ,
2174+ index_uri = index_uri_2 ,
2175+ input_vectors = input_vectors ,
2176+ # No dimensions parameter - should work as before
2177+ )
2178+
2179+ assert index_2 is not None
2180+ index_2 .vacuum ()
2181+
2182+ # Verify both indexes produce similar results
2183+ distances_2 , indices_2 = index_2 .query (queries , k = k )
2184+ assert distances_2 .shape == (nq , k )
2185+ assert indices_2 .shape == (nq , k )
0 commit comments