Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5627a18

Browse files
committed
Background Removal Pipeline
1 parent d08aaa1 commit 5627a18

File tree

3 files changed

+221
-3
lines changed

3 files changed

+221
-3
lines changed

OnnxStack.Core/Extensions/TensorExtension.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ private static DenseTensor<float> ConcatenateAxis1(DenseTensor<float> tensor1, D
397397

398398
// Copy data from the second tensor
399399
for (int i = 0; i < dimensions[0]; i++)
400-
for (int j = 0; j < tensor1.Dimensions[1]; j++)
400+
for (int j = 0; j < tensor2.Dimensions[1]; j++)
401401
concatenatedTensor[i, j + tensor1.Dimensions[1]] = tensor2[i, j];
402402

403403
return concatenatedTensor;

OnnxStack.Core/Image/OnnxImage.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public OnnxImage(DenseTensor<float> imageTensor, ImageNormalizeType normalizeTyp
6464
{
6565
var height = imageTensor.Dimensions[2];
6666
var width = imageTensor.Dimensions[3];
67+
var hasTransparency = imageTensor.Dimensions[1] == 4;
6768
_imageData = new Image<Rgba32>(width, height);
6869
for (var y = 0; y < height; y++)
6970
{
@@ -74,14 +75,16 @@ public OnnxImage(DenseTensor<float> imageTensor, ImageNormalizeType normalizeTyp
7475
_imageData[x, y] = new Rgba32(
7576
DenormalizeZeroToOneToByte(imageTensor, 0, y, x),
7677
DenormalizeZeroToOneToByte(imageTensor, 1, y, x),
77-
DenormalizeZeroToOneToByte(imageTensor, 2, y, x));
78+
DenormalizeZeroToOneToByte(imageTensor, 2, y, x),
79+
hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
7880
}
7981
else
8082
{
8183
_imageData[x, y] = new Rgba32(
8284
DenormalizeOneToOneToByte(imageTensor, 0, y, x),
8385
DenormalizeOneToOneToByte(imageTensor, 1, y, x),
84-
DenormalizeOneToOneToByte(imageTensor, 2, y, x));
86+
DenormalizeOneToOneToByte(imageTensor, 2, y, x),
87+
hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
8588
}
8689
}
8790
}
@@ -337,6 +340,7 @@ private DenseTensor<float> NormalizeToZeroToOne(ReadOnlySpan<int> dimensions)
337340
var width = dimensions[3];
338341
var height = dimensions[2];
339342
var channels = dimensions[1];
343+
var hasTransparency = channels == 4;
340344
var imageArray = new DenseTensor<float>(new[] { 1, channels, height, width });
341345
_imageData.ProcessPixelRows(img =>
342346
{
@@ -348,6 +352,8 @@ private DenseTensor<float> NormalizeToZeroToOne(ReadOnlySpan<int> dimensions)
348352
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f);
349353
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f);
350354
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f);
355+
if (hasTransparency)
356+
imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f);
351357
}
352358
}
353359
});
@@ -366,6 +372,7 @@ private DenseTensor<float> NormalizeToOneToOne(ReadOnlySpan<int> dimensions)
366372
var width = dimensions[3];
367373
var height = dimensions[2];
368374
var channels = dimensions[1];
375+
var hasTransparency = channels == 4;
369376
var imageArray = new DenseTensor<float>(new[] { 1, channels, height, width });
370377
_imageData.ProcessPixelRows(img =>
371378
{
@@ -377,6 +384,8 @@ private DenseTensor<float> NormalizeToOneToOne(ReadOnlySpan<int> dimensions)
377384
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f;
378385
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f;
379386
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f;
387+
if (hasTransparency)
388+
imageArray[0, 3, y, x] = (pixelSpan[x].A / 255.0f) * 2.0f - 1.0f;
380389
}
381390
}
382391
});
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Image;
6+
using OnnxStack.Core.Model;
7+
using OnnxStack.Core.Video;
8+
using OnnxStack.FeatureExtractor.Common;
9+
using System;
10+
using System.Collections.Generic;
11+
using System.IO;
12+
using System.Linq;
13+
using System.Runtime.CompilerServices;
14+
using System.Threading;
15+
using System.Threading.Tasks;
16+
17+
namespace OnnxStack.FeatureExtractor.Pipelines
18+
{
19+
public class BackgroundRemovalPipeline
20+
{
21+
private readonly string _name;
22+
private readonly ILogger _logger;
23+
private readonly FeatureExtractorModel _model;
24+
25+
/// <summary>
26+
/// Initializes a new instance of the <see cref="BackgroundRemovalPipeline"/> class.
27+
/// </summary>
28+
/// <param name="name">The name.</param>
29+
/// <param name="model">The model.</param>
30+
/// <param name="logger">The logger.</param>
31+
public BackgroundRemovalPipeline(string name, FeatureExtractorModel model, ILogger logger = default)
32+
{
33+
_name = name;
34+
_logger = logger;
35+
_model = model;
36+
}
37+
38+
39+
/// <summary>
40+
/// Gets the name.
41+
/// </summary>
42+
/// <value>
43+
public string Name => _name;
44+
45+
46+
/// <summary>
47+
/// Loads the model.
48+
/// </summary>
49+
/// <returns></returns>
50+
public Task LoadAsync()
51+
{
52+
return _model.LoadAsync();
53+
}
54+
55+
56+
/// <summary>
57+
/// Unloads the models.
58+
/// </summary>
59+
public async Task UnloadAsync()
60+
{
61+
await Task.Yield();
62+
_model?.Dispose();
63+
}
64+
65+
66+
/// <summary>
67+
/// Generates the background removal image result
68+
/// </summary>
69+
/// <param name="inputImage">The input image.</param>
70+
/// <returns></returns>
71+
public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
72+
{
73+
var timestamp = _logger?.LogBegin("Removing video background...");
74+
var result = await RunInternalAsync(inputImage, cancellationToken);
75+
_logger?.LogEnd("Removing video background complete.", timestamp);
76+
return result;
77+
}
78+
79+
80+
/// <summary>
81+
/// Generates the background removal video result
82+
/// </summary>
83+
/// <param name="videoFrames">The input video.</param>
84+
/// <returns></returns>
85+
public async Task<OnnxVideo> RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
86+
{
87+
var timestamp = _logger?.LogBegin("Removing video background...");
88+
var videoFrames = new List<OnnxImage>();
89+
foreach (var videoFrame in video.Frames)
90+
{
91+
videoFrames.Add(await RunAsync(videoFrame, cancellationToken));
92+
}
93+
_logger?.LogEnd("Removing video background complete.", timestamp);
94+
return new OnnxVideo(video.Info with
95+
{
96+
Height = videoFrames[0].Height,
97+
Width = videoFrames[0].Width,
98+
}, videoFrames);
99+
}
100+
101+
102+
/// <summary>
103+
/// Generates the background removal video stream
104+
/// </summary>
105+
/// <param name="imageFrames">The image frames.</param>
106+
/// <param name="cancellationToken">The cancellation token.</param>
107+
/// <returns></returns>
108+
public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
109+
{
110+
var timestamp = _logger?.LogBegin("Extracting video stream features...");
111+
await foreach (var imageFrame in imageFrames)
112+
{
113+
yield return await RunInternalAsync(imageFrame, cancellationToken);
114+
}
115+
_logger?.LogEnd("Extracting video stream features complete.", timestamp);
116+
}
117+
118+
119+
/// <summary>
120+
/// Runs the pipeline
121+
/// </summary>
122+
/// <param name="inputImage">The input image.</param>
123+
/// <param name="cancellationToken">The cancellation token.</param>
124+
/// <returns></returns>
125+
private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
126+
{
127+
var souceImageTenssor = await inputImage.GetImageTensorAsync(_model.SampleSize, _model.SampleSize, ImageNormalizeType.ZeroToOne);
128+
var metadata = await _model.GetMetadataAsync();
129+
cancellationToken.ThrowIfCancellationRequested();
130+
var outputShape = new[] { 1, _model.Channels, _model.SampleSize, _model.SampleSize };
131+
var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..];
132+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
133+
{
134+
inferenceParameters.AddInputTensor(souceImageTenssor);
135+
inferenceParameters.AddOutputBuffer(outputBuffer);
136+
137+
var results = await _model.RunInferenceAsync(inferenceParameters);
138+
using (var result = results.First())
139+
{
140+
cancellationToken.ThrowIfCancellationRequested();
141+
142+
var resultTensor = result.ToDenseTensor(outputShape);
143+
if (_model.Normalize)
144+
resultTensor.NormalizeMinMax();
145+
146+
var imageTensor = AddAlphaChannel(souceImageTenssor, result.GetTensorDataAsSpan<float>());
147+
return new OnnxImage(imageTensor, ImageNormalizeType.ZeroToOne);
148+
}
149+
}
150+
}
151+
152+
153+
/// <summary>
154+
/// Adds an alpha channel to the RGB tensor.
155+
/// </summary>
156+
/// <param name="sourceImage">The source image.</param>
157+
/// <param name="alphaChannel">The alpha channel.</param>
158+
/// <returns></returns>
159+
private static DenseTensor<float> AddAlphaChannel(DenseTensor<float> sourceImage, ReadOnlySpan<float> alphaChannel)
160+
{
161+
var resultTensor = new DenseTensor<float>(new int[] { 1, 4, sourceImage.Dimensions[2], sourceImage.Dimensions[3] });
162+
sourceImage.Buffer.Span.CopyTo(resultTensor.Buffer[..(int)sourceImage.Length].Span);
163+
alphaChannel.CopyTo(resultTensor.Buffer[(int)sourceImage.Length..].Span);
164+
return resultTensor;
165+
}
166+
167+
168+
/// <summary>
169+
/// Creates the pipeline from a FeatureExtractorModelSet.
170+
/// </summary>
171+
/// <param name="modelSet">The model set.</param>
172+
/// <param name="logger">The logger.</param>
173+
/// <returns></returns>
174+
public static BackgroundRemovalPipeline CreatePipeline(FeatureExtractorModelSet modelSet, ILogger logger = default)
175+
{
176+
var model = new FeatureExtractorModel(modelSet.FeatureExtractorConfig.ApplyDefaults(modelSet));
177+
return new BackgroundRemovalPipeline(modelSet.Name, model, logger);
178+
}
179+
180+
181+
/// <summary>
182+
/// Creates the pipeline from the specified file.
183+
/// </summary>
184+
/// <param name="modelFile">The model file.</param>
185+
/// <param name="deviceId">The device identifier.</param>
186+
/// <param name="executionProvider">The execution provider.</param>
187+
/// <param name="logger">The logger.</param>
188+
/// <returns></returns>
189+
public static BackgroundRemovalPipeline CreatePipeline(string modelFile, int sampleSize = 512, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
190+
{
191+
var name = Path.GetFileNameWithoutExtension(modelFile);
192+
var configuration = new FeatureExtractorModelSet
193+
{
194+
Name = name,
195+
IsEnabled = true,
196+
DeviceId = deviceId,
197+
ExecutionProvider = executionProvider,
198+
FeatureExtractorConfig = new FeatureExtractorModelConfig
199+
{
200+
OnnxModelPath = modelFile,
201+
SampleSize = sampleSize,
202+
Normalize = false,
203+
Channels = 1
204+
}
205+
};
206+
return CreatePipeline(configuration, logger);
207+
}
208+
}
209+
}

0 commit comments

Comments
 (0)