Skip to content

Commit 4bd6ea6

Browse files
committed
Add CPU, DML, CUDA providers
1 parent 6c8f47d commit 4bd6ea6

File tree

10 files changed

+365
-0
lines changed

10 files changed

+365
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using TensorStack.Common;
3+
4+
namespace TensorStack.Providers
5+
{
6+
public static class Provider
7+
{
8+
public const string CPUProviderName = "CPU Provider";
9+
public const string CUDAProviderName = "CUDA Provider";
10+
11+
/// <summary>
12+
/// Gets the CPU provider.
13+
/// </summary>
14+
/// <param name="optimizationLevel">The optimization level.</param>
15+
/// <returns>ExecutionProvider.</returns>
16+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
17+
{
18+
return new ExecutionProvider(CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
19+
{
20+
var sessionOptions = new SessionOptions
21+
{
22+
EnableCpuMemArena = true,
23+
EnableMemoryPattern = true,
24+
GraphOptimizationLevel = optimizationLevel
25+
};
26+
sessionOptions.AppendExecutionProvider_CPU();
27+
return sessionOptions;
28+
});
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the CUDA provider.
34+
/// </summary>
35+
/// <param name="deviceId">The device identifier.</param>
36+
/// <param name="optimizationLevel">The optimization level.</param>
37+
/// <returns>ExecutionProvider.</returns>
38+
public static ExecutionProvider GetProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
39+
{
40+
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCUDA_PINNED, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
41+
return new ExecutionProvider(CUDAProviderName, memoryInfo, configuration =>
42+
{
43+
var sessionOptions = new SessionOptions
44+
{
45+
GraphOptimizationLevel = optimizationLevel
46+
};
47+
48+
sessionOptions.AppendExecutionProvider_CUDA(deviceId);
49+
sessionOptions.AppendExecutionProvider_CPU();
50+
return sessionOptions;
51+
});
52+
}
53+
}
54+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>net9.0-windows</TargetFramework>
5+
<PlatformTarget>x64</PlatformTarget>
6+
<Description></Description>
7+
</PropertyGroup>
8+
9+
<!--Projects-->
10+
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
11+
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
12+
</ItemGroup>
13+
14+
<!--Packages-->
15+
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
16+
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
17+
</ItemGroup>
18+
19+
<!--Other Packages-->
20+
<ItemGroup>
21+
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu.Windows" Version="1.23.0" />
22+
</ItemGroup>
23+
24+
25+
<!--Nuget Settings-->
26+
<PropertyGroup>
27+
<Title>$(AssemblyName)</Title>
28+
<PackageId>$(AssemblyName)</PackageId>
29+
<Product>$(AssemblyName)</Product>
30+
<PackageIcon>Icon.png</PackageIcon>
31+
</PropertyGroup>
32+
<ItemGroup Condition="'$(Configuration)' == 'Debug'">
33+
<None Remove="README.md" />
34+
<None Remove="Icon.png" />
35+
</ItemGroup>
36+
<ItemGroup Condition="'$(Configuration)' == 'Release'">
37+
<None Remove="Add.png" />
38+
<None Update="README.md">
39+
<Pack>True</Pack>
40+
<PackagePath>\</PackagePath>
41+
</None>
42+
<None Include="..\Assets\Icon.png">
43+
<Pack>True</Pack>
44+
<PackagePath>\</PackagePath>
45+
</None>
46+
</ItemGroup>
47+
48+
</Project>
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using TensorStack.Common;
3+
4+
namespace TensorStack.Providers
5+
{
6+
public static class Provider
7+
{
8+
public const string CPUProviderName = "CPU Provider";
9+
public const string DMLProviderName = "DirectML Provider";
10+
11+
/// <summary>
12+
/// Gets the CPU provider.
13+
/// </summary>
14+
/// <param name="optimizationLevel">The optimization level.</param>
15+
/// <returns>ExecutionProvider.</returns>
16+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
17+
{
18+
return new ExecutionProvider(CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
19+
{
20+
var sessionOptions = new SessionOptions
21+
{
22+
EnableCpuMemArena = true,
23+
EnableMemoryPattern = true,
24+
GraphOptimizationLevel = optimizationLevel
25+
};
26+
sessionOptions.AppendExecutionProvider_CPU();
27+
return sessionOptions;
28+
});
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the DirectML provider.
34+
/// </summary>
35+
/// <param name="deviceId">The device identifier.</param>
36+
/// <param name="optimizationLevel">The optimization level.</param>
37+
/// <returns>ExecutionProvider.</returns>
38+
public static ExecutionProvider GetProvider(int deviceId, GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
39+
{
40+
var memoryInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.DeviceAllocator, deviceId, OrtMemType.Default);
41+
return new ExecutionProvider(DMLProviderName, memoryInfo, configuration =>
42+
{
43+
var sessionOptions = new SessionOptions
44+
{
45+
GraphOptimizationLevel = optimizationLevel
46+
};
47+
48+
sessionOptions.AppendExecutionProvider_DML(deviceId);
49+
sessionOptions.AppendExecutionProvider_CPU();
50+
return sessionOptions;
51+
});
52+
}
53+
}
54+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>net9.0-windows</TargetFramework>
5+
<PlatformTarget>x64</PlatformTarget>
6+
<Description></Description>
7+
</PropertyGroup>
8+
9+
<!--Projects-->
10+
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
11+
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
12+
</ItemGroup>
13+
14+
<!--Packages-->
15+
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
16+
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
17+
</ItemGroup>
18+
19+
<!--Other Packages-->
20+
<ItemGroup>
21+
<PackageReference Include="Microsoft.ML.OnnxRuntime.DirectML" Version="1.23.0" />
22+
</ItemGroup>
23+
24+
25+
<!--Nuget Settings-->
26+
<PropertyGroup>
27+
<Title>$(AssemblyName)</Title>
28+
<PackageId>$(AssemblyName)</PackageId>
29+
<Product>$(AssemblyName)</Product>
30+
<PackageIcon>Icon.png</PackageIcon>
31+
</PropertyGroup>
32+
<ItemGroup Condition="'$(Configuration)' == 'Debug'">
33+
<None Remove="README.md" />
34+
<None Remove="Icon.png" />
35+
</ItemGroup>
36+
<ItemGroup Condition="'$(Configuration)' == 'Release'">
37+
<None Remove="Add.png" />
38+
<None Update="README.md">
39+
<Pack>True</Pack>
40+
<PackagePath>\</PackagePath>
41+
</None>
42+
<None Include="..\Assets\Icon.png">
43+
<Pack>True</Pack>
44+
<PackagePath>\</PackagePath>
45+
</None>
46+
</ItemGroup>
47+
48+
</Project>

TensorStack.Provider/Provider.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using TensorStack.Common;
3+
4+
namespace TensorStack.Providers
5+
{
6+
public static class Provider
7+
{
8+
public const string CPUProviderName = "CPU Provider";
9+
10+
/// <summary>
11+
/// Gets the CPU provider.
12+
/// </summary>
13+
/// <param name="optimizationLevel">The optimization level.</param>
14+
/// <returns>ExecutionProvider.</returns>
15+
public static ExecutionProvider GetProvider(GraphOptimizationLevel optimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL)
16+
{
17+
return new ExecutionProvider(CPUProviderName, OrtMemoryInfo.DefaultInstance, configuration =>
18+
{
19+
var sessionOptions = new SessionOptions
20+
{
21+
EnableCpuMemArena = true,
22+
EnableMemoryPattern = true,
23+
GraphOptimizationLevel = optimizationLevel
24+
};
25+
sessionOptions.AppendExecutionProvider_CPU();
26+
return sessionOptions;
27+
});
28+
}
29+
}
30+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>net9.0</TargetFramework>
5+
<PlatformTarget>x64</PlatformTarget>
6+
<Description></Description>
7+
</PropertyGroup>
8+
9+
<!--Projects-->
10+
<ItemGroup Condition=" '$(Configuration)' == 'Debug'">
11+
<ProjectReference Include="..\TensorStack.Common\TensorStack.Common.csproj" />
12+
</ItemGroup>
13+
14+
<!--Packages-->
15+
<ItemGroup Condition=" '$(Configuration)' == 'Release'">
16+
<PackageReference Include="TensorStack.Common" Version="$(Version)" />
17+
</ItemGroup>
18+
19+
20+
<!--Nuget Settings-->
21+
<PropertyGroup>
22+
<Title>$(AssemblyName)</Title>
23+
<PackageId>$(AssemblyName)</PackageId>
24+
<Product>$(AssemblyName)</Product>
25+
<PackageIcon>Icon.png</PackageIcon>
26+
</PropertyGroup>
27+
<ItemGroup Condition="'$(Configuration)' == 'Debug'">
28+
<None Remove="README.md" />
29+
<None Remove="Icon.png" />
30+
</ItemGroup>
31+
<ItemGroup Condition="'$(Configuration)' == 'Release'">
32+
<None Remove="Add.png" />
33+
<None Update="README.md">
34+
<Pack>True</Pack>
35+
<PackagePath>\</PackagePath>
36+
</None>
37+
<None Include="..\Assets\Icon.png">
38+
<Pack>True</Pack>
39+
<PackagePath>\</PackagePath>
40+
</None>
41+
</ItemGroup>
42+
43+
</Project>
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Phi3 Pipeline
2+
3+
### Greedy
4+
```csharp
5+
var provider = Provider.GetProvider(0);
6+
var modelPath = "M:\\Models\\Phi-3-medium-128k-instruct-onnx-directml";
7+
var pipeline = Phi3Pipeline.Create(provider, modelPath, PhiType.Mini);
8+
var options = new GenerateOptions
9+
{
10+
Prompt = "<|user|>What is an apple?<|end|><|assistant|>"
11+
};
12+
13+
var generateResult = await pipeline.RunAsync(options);
14+
System.Console.WriteLine(generateResult.Result);
15+
```
16+
17+
### Beam Search
18+
```csharp
19+
var provider = Provider.GetProvider(0);
20+
var modelPath = "M:\\Models\\Phi-3-medium-128k-instruct-onnx-directml";
21+
var pipeline = Phi3Pipeline.Create(provider, modelPath, PhiType.Mini);
22+
var options = new SearchOptions
23+
{
24+
Seed = 0,
25+
TopK = 50,
26+
Beams = 3,
27+
TopP = 0.9f,
28+
Temperature = 1f,
29+
LengthPenalty = -1f,
30+
DiversityLength = 20,
31+
NoRepeatNgramSize = 3,
32+
EarlyStopping = EarlyStopping.None,
33+
Prompt = "<|user|>What is an apple?<|end|><|assistant|>"
34+
};
35+
36+
foreach (var beamResult in await pipeline.RunAsync(options))
37+
{
38+
System.Console.WriteLine(beamResult.Result);
39+
}
40+
```
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Whisper
2+
3+
### Greedy
4+
```csharp
5+
var provider = Provider.GetProvider(0);
6+
var modelPath = "M:\\Models\\Whisper-Base";
7+
var pipeline = WhisperPipeline.Create(provider, modelPath, WhisperType.Base);
8+
var options = new GenerateOptions
9+
{
10+
Task = TaskType.Transcribe,
11+
Language = LanguageType.EN,
12+
AudioInput = await AudioInput.CreateAsync("kennedy.mp3")
13+
};
14+
15+
var generateResult = await pipeline.RunAsync(options);
16+
System.Console.WriteLine(generateResult.Result);
17+
```
18+
19+
### Beam Search
20+
```csharp
21+
var provider = Provider.GetProvider(0);
22+
var modelPath = "M:\\Models\\Whisper-Large";
23+
var pipeline = WhisperPipeline.Create(provider, modelPath, WhisperType.Large);
24+
var options = new SearchOptions
25+
{
26+
Seed = 0,
27+
TopK = 50,
28+
Beams = 3,
29+
TopP = 0.9f,
30+
Temperature = 1f,
31+
LengthPenalty = -1f,
32+
DiversityLength = 20,
33+
NoRepeatNgramSize = 3,
34+
EarlyStopping = EarlyStopping.BestBeam,
35+
Task = TaskType.Transcribe,
36+
Language = LanguageType.EN,
37+
AudioInput = await AudioInput.CreateAsync("kennedy.mp3")
38+
};
39+
40+
foreach (var beamResult in await pipeline.RunAsync(options))
41+
{
42+
System.Console.WriteLine(beamResult.Result);
43+
}
44+
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# TensorStack.TextGeneration

TensorStack.TextGeneration/TensorStack.TextGeneration.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@
3333
</PropertyGroup>
3434
<ItemGroup Condition="'$(Configuration)' == 'Debug'">
3535
<None Remove="README.md" />
36+
<None Remove="*/*/README.md" />
3637
<None Remove="Icon.png" />
3738
</ItemGroup>
3839
<ItemGroup Condition="'$(Configuration)' == 'Release'">
3940
<None Update="README.md">
4041
<Pack>True</Pack>
4142
<PackagePath>\</PackagePath>
4243
</None>
44+
<None Update="*/*/README.md">
45+
</None>
4346
<None Include="..\Assets\Icon.png">
4447
<Pack>True</Pack>
4548
<PackagePath>\</PackagePath>

0 commit comments

Comments
 (0)