33using System . Text ;
44using LLama . Abstractions ;
55using LLama . Native ;
6+ using System . Collections . Generic ;
67
78namespace LLama . Extensions ;
89
@@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
4546 result . tensor_split = ( float * ) disposer . Add ( @params . TensorSplits . Pin ( ) ) . Pointer ;
4647 }
4748
48- // Add tensor buffer overrides, if any
49- if ( @params . TensorBufferOverrides . Count > 0 )
49+ // Add tensor buffer overrides
50+ unsafe
5051 {
51- var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper ( ) ;
52- disposer . Add ( bufferOverrideHelper ) ;
53-
54- foreach ( var tensorOverride in @params . TensorBufferOverrides )
55- {
56- bufferOverrideHelper . AddOverride ( tensorOverride . Pattern , tensorOverride . BufferType ) ;
57- }
58-
59- bufferOverrideHelper . ApplyToModelParams ( ref result ) ;
52+ result . tensor_buft_overrides = ConvertOverrides ( @params . TensorBufferOverrides , disposer ) ;
6053 }
6154
55+ // Add metadata overrides
6256 if ( @params . MetadataOverrides . Count == 0 )
6357 {
6458 unsafe
@@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
106100
107101 return disposer ;
108102 }
103+
104+ /// <summary>
105+ /// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`)
106+ /// </summary>
107+ /// <returns>Dictionary mapping buffer type names to their handles</returns>
108+ private static IReadOnlyDictionary < string , IntPtr > GetAvailableBufferTypes ( )
109+ {
110+ var result = new Dictionary < string , IntPtr > ( ) ;
111+
112+ var count = NativeApi . ggml_backend_dev_count ( ) ;
113+ for ( nuint i = 0 ; i < count ; i ++ )
114+ {
115+ var dev = NativeApi . ggml_backend_dev_get ( i ) ;
116+ var buft = NativeApi . ggml_backend_dev_buffer_type ( dev ) ;
117+
118+ var name = Marshal . PtrToStringAnsi ( NativeApi . ggml_backend_buft_name ( buft ) ) ;
119+ if ( string . IsNullOrEmpty ( name ) )
120+ continue ;
121+
122+ result [ name ] = buft ;
123+ }
124+
125+ return result ;
126+ }
127+
128+ private static unsafe LLamaModelTensorBufferOverride * ConvertOverrides ( List < TensorBufferOverride > overrides , GroupDisposable disposer )
129+ {
130+ // Early out if there are no overrides
131+ if ( overrides . Count == 0 )
132+ return null ;
133+
134+ var bufferTypes = GetAvailableBufferTypes ( ) ;
135+
136+ var overridesCount = 0 ;
137+ var overridesArray = new LLamaModelTensorBufferOverride [ overrides . Count + 1 ] ;
138+
139+ foreach ( var @override in overrides )
140+ {
141+ // Check if we have this buffer type
142+ if ( ! bufferTypes . TryGetValue ( @override . BufferType , out var bufferType ) )
143+ continue ;
144+
145+ // Create null terminated string and pin this memory so it can be passed to native code
146+ var patternBytes = Encoding . UTF8 . GetBytes ( @override . Pattern + "\0 " ) ;
147+ var patternPin = patternBytes . AsMemory ( ) . Pin ( ) ;
148+ disposer . Add ( patternPin ) ;
149+
150+ // Add the item to the overridesArray
151+ overridesArray [ overridesCount ++ ] = new ( )
152+ {
153+ Pattern = ( byte * ) patternPin . Pointer ,
154+ BufferType = bufferType
155+ } ;
156+ }
157+
158+ // Early out if there were no valid overrides
159+ if ( overridesCount == 0 )
160+ return null ;
161+
162+ // Pin it so it can be safely passed across to native code
163+ var overrideArrayPin = overridesArray . AsMemory ( ) . Pin ( ) ;
164+ disposer . Add ( overrideArrayPin ) ;
165+
166+ return ( LLamaModelTensorBufferOverride * ) overrideArrayPin . Pointer ;
167+ }
109168}
0 commit comments