Skip to content

Commit 61c98fc

Browse files
committed
Basic LLM support (Encoder/Decoder and Decoder Only)
1 parent e6f8be8 commit 61c98fc

File tree

103 files changed

+7991
-3987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+7991
-3987
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,7 @@ site/
350350
docker-test-output/*
351351

352352
Examples/*
353-
TensorStack.Tokenizers/*
354353
TensorStack.UI.WPF/*
355354
TensorStackFull.sln
356-
TensorStack.Converter/*
357355
TensorStack.Diffusers/*
358-
TensorStack.Converter/*
356+
TensorStudio/*
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System;
4+
using System.Collections;
5+
using System.Collections.Generic;
6+
using System.Diagnostics.CodeAnalysis;
7+
8+
namespace TensorStack.Common
9+
{
10+
public class MapCollection<TKey, TValue> : IDictionary<TKey, TValue>
11+
{
12+
private readonly IDictionary<TKey, TValue> _collectionKeys;
13+
private readonly IDictionary<TValue, TKey> _collectionValues;
14+
15+
public MapCollection()
16+
{
17+
_collectionKeys = new Dictionary<TKey, TValue>();
18+
_collectionValues = new Dictionary<TValue, TKey>();
19+
}
20+
21+
public MapCollection(IDictionary<TKey, TValue> collection) : this()
22+
{
23+
foreach (var item in collection)
24+
Add(item.Key, item.Value);
25+
}
26+
27+
public MapCollection(IDictionary<TValue, TKey> collection) : this()
28+
{
29+
foreach (var item in collection)
30+
Add(item.Value, item.Key);
31+
}
32+
33+
34+
public ICollection<TKey> Keys => _collectionKeys.Keys;
35+
public ICollection<TValue> Values => _collectionValues.Keys;
36+
public int Count => _collectionKeys.Count;
37+
public bool IsReadOnly => false;
38+
39+
40+
public TValue this[TKey key]
41+
{
42+
get { return _collectionKeys[key]; }
43+
set
44+
{
45+
if (_collectionKeys.TryGetValue(key, out var oldValue))
46+
_collectionValues.Remove(oldValue);
47+
48+
_collectionKeys[key] = value;
49+
_collectionValues[value] = key;
50+
}
51+
}
52+
53+
54+
public TKey this[TValue key]
55+
{
56+
get { return _collectionValues[key]; }
57+
set
58+
{
59+
if (_collectionValues.TryGetValue(key, out var oldKey))
60+
_collectionKeys.Remove(oldKey);
61+
62+
_collectionValues[key] = value;
63+
_collectionKeys[value] = key;
64+
}
65+
}
66+
67+
68+
public void Add(TKey key, TValue value)
69+
{
70+
_collectionKeys.Add(key, value);
71+
_collectionValues.Add(value, key);
72+
}
73+
74+
public void Add(KeyValuePair<TKey, TValue> item)
75+
{
76+
_collectionKeys.Add(item.Key, item.Value);
77+
_collectionValues.Add(item.Value, item.Key);
78+
}
79+
80+
public bool TryAdd(TKey key, TValue value)
81+
{
82+
return _collectionKeys.TryAdd(key, value)
83+
&& _collectionValues.TryAdd(value, key);
84+
}
85+
86+
public void Clear()
87+
{
88+
_collectionKeys.Clear();
89+
_collectionValues.Clear();
90+
}
91+
92+
public bool Contains(KeyValuePair<TKey, TValue> item)
93+
{
94+
return _collectionKeys.Contains(item);
95+
}
96+
97+
public bool ContainsKey(TKey key)
98+
{
99+
return _collectionKeys.ContainsKey(key);
100+
}
101+
102+
public bool ContainsKey(TValue key)
103+
{
104+
return _collectionValues.ContainsKey(key);
105+
}
106+
107+
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
108+
{
109+
throw new NotImplementedException();
110+
}
111+
112+
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
113+
{
114+
return _collectionKeys.GetEnumerator();
115+
}
116+
117+
public bool Remove(TKey key)
118+
{
119+
if (_collectionKeys.TryGetValue(key, out TValue value))
120+
{
121+
_collectionKeys.Remove(key);
122+
_collectionValues.Remove(value);
123+
return true;
124+
}
125+
return false;
126+
}
127+
128+
public bool Remove(KeyValuePair<TKey, TValue> item)
129+
{
130+
return Remove(item.Key);
131+
}
132+
133+
public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value)
134+
{
135+
return _collectionKeys.TryGetValue(key, out value);
136+
}
137+
138+
139+
public bool TryGetValue(TValue key, [MaybeNullWhen(false)] out TKey value)
140+
{
141+
return _collectionValues.TryGetValue(key, out value);
142+
}
143+
144+
145+
IEnumerator IEnumerable.GetEnumerator()
146+
{
147+
return GetEnumerator();
148+
}
149+
}
150+
}

TensorStack.Common/Common/ParameterCollection.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
using Microsoft.ML.OnnxRuntime;
44
using System;
55
using System.Collections.Generic;
6+
using System.Linq;
67

78
namespace TensorStack.Common
89
{
910
public sealed class ParameterCollection : IDisposable
1011
{
1112
private readonly List<NamedMetadata> _metaData;
1213
private readonly Dictionary<string, OrtValue> _values;
14+
private readonly List<string> _disposables;
1315

1416
/// <summary>
1517
/// Initializes a new instance of the <see cref="ParameterCollection"/> class.
@@ -18,6 +20,7 @@ public ParameterCollection()
1820
{
1921
_metaData = new List<NamedMetadata>();
2022
_values = new Dictionary<string, OrtValue>();
23+
_disposables = new List<string>();
2124
}
2225

2326

@@ -26,21 +29,29 @@ public ParameterCollection()
2629
/// </summary>
2730
/// <param name="metaData">The meta data.</param>
2831
/// <param name="value">The value.</param>
29-
public void Add(NamedMetadata metaData, OrtValue value)
32+
public void Add(NamedMetadata metaData, OrtValue value, bool dispose = true)
3033
{
3134
_metaData.Add(metaData);
3235
_values.Add(metaData.Name, value);
36+
if (dispose)
37+
{
38+
_disposables.Add(metaData.Name);
39+
}
3340
}
3441

3542

3643
/// <summary>
3744
/// Adds the name only.
3845
/// </summary>
3946
/// <param name="metaData">The meta data.</param>
40-
public void AddName(NamedMetadata metaData)
47+
public void AddName(NamedMetadata metaData, bool dispose = true)
4148
{
4249
_metaData.Add(metaData);
4350
_values.Add(metaData.Name, default);
51+
if (dispose)
52+
{
53+
_disposables.Add(metaData.Name);
54+
}
4455
}
4556

4657
/// <summary>
@@ -66,8 +77,13 @@ public void AddName(NamedMetadata metaData)
6677
/// </summary>
6778
public void Dispose()
6879
{
69-
foreach (var ortValue in _values.Values)
70-
ortValue?.Dispose();
80+
foreach (var value in _values)
81+
{
82+
if (!_disposables.Contains(value.Key))
83+
continue;
84+
85+
value.Value?.Dispose();
86+
}
7187

7288
_values.Clear();
7389
_metaData.Clear();

TensorStack.Common/Extensions/OrtExtensions.cs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,24 @@ private static Tensor<T> CreateTensor<T>(OrtValue ortValue) where T : unmanaged,
181181
/// <param name="metadata">The input metadata.</param>
182182
/// <param name="tensor">The tensor input.</param>
183183
private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
184+
{
185+
return CreateOrtValue(metadata.Value.ElementDataType, tensor);
186+
}
187+
188+
189+
/// <summary>
190+
/// Creates a OrtValue from Tensor
191+
/// </summary>
192+
/// <typeparam name="T"></typeparam>
193+
/// <param name="ortType">Type of the ort.</param>
194+
/// <param name="tensor">The tensor.</param>
195+
/// <returns>OrtValue.</returns>
196+
public static OrtValue CreateOrtValue<T>(OrtType ortType, TensorSpan<T> tensor) where T : unmanaged, INumber<T>
184197
{
185198
var buffer = tensor.Span;
186199
var dimensions = tensor.Dimensions.ToLong();
187200
var memoryInstance = OrtMemoryInfo.DefaultInstance;
188-
return metadata.Value.ElementDataType switch
201+
return ortType switch
189202
{
190203
OrtType.Float => OrtValue.CreateTensorValueFromMemory<float>(memoryInstance, buffer.ConvertBuffer<T, float>(), dimensions),
191204
OrtType.UInt8 => OrtValue.CreateTensorValueFromMemory<byte>(memoryInstance, buffer.ConvertBuffer<T, byte>(), dimensions),
@@ -204,6 +217,40 @@ private static OrtValue CreateOrtValue<T>(NamedMetadata metadata, TensorSpan<T>
204217
}
205218

206219

220+
/// <summary>
221+
/// Clones the specified OrtValue.
222+
/// </summary>
223+
/// <param name="original">The original.</param>
224+
/// <returns>OrtValue.</returns>
225+
public static OrtValue Clone(this OrtValue original)
226+
{
227+
var info = original.GetTensorTypeAndShape();
228+
return info.ElementDataType switch
229+
{
230+
OrtType.Float => original.Clone<float>(info),
231+
OrtType.Float16 => original.Clone<Float16>(info),
232+
_ => throw new NotSupportedException($"Unsupported element type: {info.ElementDataType}")
233+
};
234+
}
235+
236+
237+
/// <summary>
238+
/// Clones the specified OrtValue.
239+
/// </summary>
240+
/// <typeparam name="T"></typeparam>
241+
/// <param name="original">The original.</param>
242+
/// <param name="info">The information.</param>
243+
/// <returns>OrtValue.</returns>
244+
public static OrtValue Clone<T>(this OrtValue original, OrtTensorTypeAndShapeInfo info) where T : unmanaged
245+
{
246+
var newValue = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, info.ElementDataType, info.Shape);
247+
var source = original.GetTensorDataAsSpan<T>();
248+
var destination = newValue.GetTensorMutableDataAsSpan<T>();
249+
source.CopyTo(destination);
250+
return newValue;
251+
}
252+
253+
207254
/// <summary>
208255
/// Creates an Array from OrtValue.
209256
/// </summary>

TensorStack.Common/Extensions/TensorExtensions.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,11 @@ public static Tensor<T> FirstBatch<T>(this Tensor<T> tensor)
338338
return Split(tensor).FirstOrDefault();
339339
}
340340

341+
public static Tensor<T> LastBatch<T>(this Tensor<T> tensor)
342+
{
343+
return Split(tensor).LastOrDefault();
344+
}
345+
341346

342347
/// <summary>
343348
/// Reshapes to new tensor.

0 commit comments

Comments
 (0)