Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 01d6290

Browse files
committed
Handle both normalization types
1 parent c1025c2 commit 01d6290

File tree

4 files changed

+65
-11
lines changed

4 files changed

+65
-11
lines changed

OnnxStack.Core/Extensions/TensorExtension.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,25 @@ private static DenseTensor<float> ConcatenateAxis2(DenseTensor<float> tensor1, D
413413

414414
return concatenatedTensor;
415415
}
416+
417+
418+
/// <summary>
419+
/// Normalizes the tensor values from range -1 to 1 to 0 to 1.
420+
/// </summary>
421+
/// <param name="imageTensor">The image tensor.</param>
422+
public static void NormalizeOneOneToZeroOne(this DenseTensor<float> imageTensor)
423+
{
424+
Parallel.For(0, (int)imageTensor.Length, (i) => imageTensor.SetValue(i, imageTensor.GetValue(i) / 2f + 0.5f));
425+
}
426+
427+
428+
/// <summary>
429+
/// Normalizes the tensor values from range 0 to 1 to -1 to 1.
430+
/// </summary>
431+
/// <param name="imageTensor">The image tensor.</param>
432+
public static void NormalizeZeroOneToOneOne(this DenseTensor<float> imageTensor)
433+
{
434+
Parallel.For(0, (int)imageTensor.Length, (i) => imageTensor.SetValue(i, 2f * imageTensor.GetValue(i) - 1f));
435+
}
416436
}
417437
}

OnnxStack.ImageUpscaler/Common/UpscaleModel.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.ML.OnnxRuntime;
22
using OnnxStack.Core.Config;
3+
using OnnxStack.Core.Image;
34
using OnnxStack.Core.Model;
45
using System;
56

@@ -19,13 +20,15 @@ public UpscaleModel(UpscaleModelConfig configuration) : base(configuration)
1920
public int ScaleFactor => _configuration.ScaleFactor;
2021
public int TileSize => _configuration.TileSize;
2122
public int TileOverlap => _configuration.TileOverlap;
23+
public ImageNormalizeType NormalizeType => _configuration.NormalizeType;
24+
public bool NormalizeInput => _configuration.NormalizeInput;
2225

2326
public static UpscaleModel Create(UpscaleModelConfig configuration)
2427
{
2528
return new UpscaleModel(configuration);
2629
}
2730

28-
public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
31+
public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleSize, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
2932
{
3033
var configuration = new UpscaleModelConfig
3134
{
@@ -34,6 +37,8 @@ public static UpscaleModel Create(string modelFile, int scaleFactor, int sampleS
3437
ScaleFactor = scaleFactor,
3538
TileOverlap = tileOverlap,
3639
TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize),
40+
NormalizeType = normalizeType,
41+
NormalizeInput = normalizeInput,
3742
DeviceId = deviceId,
3843
ExecutionProvider = executionProvider,
3944
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,

OnnxStack.ImageUpscaler/Common/UpscaleModelConfig.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using OnnxStack.Core.Config;
2+
using OnnxStack.Core.Image;
23

34
namespace OnnxStack.ImageUpscaler.Common
45
{
@@ -10,5 +11,7 @@ public record UpscaleModelConfig : OnnxModelConfig
1011

1112
public int TileSize { get; set; }
1213
public int TileOverlap { get; set; }
14+
public ImageNormalizeType NormalizeType { get; set; }
15+
public bool NormalizeInput { get; set; }
1316
}
1417
}

OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using Newtonsoft.Json.Linq;
34
using OnnxStack.Core;
45
using OnnxStack.Core.Config;
56
using OnnxStack.Core.Image;
@@ -10,6 +11,7 @@
1011
using System.Collections.Generic;
1112
using System.IO;
1213
using System.Linq;
14+
using System.Numerics.Tensors;
1315
using System.Runtime.CompilerServices;
1416
using System.Threading;
1517
using System.Threading.Tasks;
@@ -72,7 +74,7 @@ public async Task UnloadAsync()
7274
public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, CancellationToken cancellationToken = default)
7375
{
7476
var timestamp = _logger?.LogBegin("Upscale image..");
75-
var result = await RunInternalAsync(inputImage, cancellationToken);
77+
var result = await UpscaleTensorAsync(inputImage, cancellationToken);
7678
_logger?.LogEnd("Upscale image complete.", timestamp);
7779
return result;
7880
}
@@ -87,7 +89,7 @@ public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, Ca
8789
public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
8890
{
8991
var timestamp = _logger?.LogBegin("Upscale image..");
90-
var result = await RunInternalAsync(inputImage, cancellationToken);
92+
var result = await UpscaleImageAsync(inputImage, cancellationToken);
9193
_logger?.LogEnd("Upscale image complete.", timestamp);
9294
return result;
9395
}
@@ -105,7 +107,7 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
105107
var upscaledFrames = new List<OnnxImage>();
106108
foreach (var videoFrame in inputVideo.Frames)
107109
{
108-
upscaledFrames.Add(await RunInternalAsync(videoFrame, cancellationToken));
110+
upscaledFrames.Add(await UpscaleImageAsync(videoFrame, cancellationToken));
109111
}
110112

111113
var firstFrame = upscaledFrames.First();
@@ -131,23 +133,44 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
131133
var timestamp = _logger?.LogBegin("Upscale video stream..");
132134
await foreach (var imageFrame in imageFrames)
133135
{
134-
yield return await RunInternalAsync(imageFrame, cancellationToken);
136+
yield return await UpscaleImageAsync(imageFrame, cancellationToken);
135137
}
136138
_logger?.LogEnd("Upscale video stream complete.", timestamp);
137139
}
138140

139141

142+
140143
/// <summary>
141-
/// Runs the upscale pipeline
144+
/// Upscales the OnnxImage.
142145
/// </summary>
143146
/// <param name="inputImage">The input image.</param>
144147
/// <param name="cancellationToken">The cancellation token.</param>
145148
/// <returns></returns>
146-
private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
149+
private async Task<OnnxImage> UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
147150
{
148-
var inputTensor = inputImage.GetImageTensor(ImageNormalizeType.ZeroToOne, _upscaleModel.Channels);
151+
var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels);
149152
var outputTensor = await RunInternalAsync(inputTensor, cancellationToken);
150-
return new OnnxImage(outputTensor, ImageNormalizeType.ZeroToOne);
153+
return new OnnxImage(outputTensor, _upscaleModel.NormalizeType);
154+
}
155+
156+
157+
/// <summary>
158+
/// Upscales the DenseTensor
159+
/// </summary>
160+
/// <param name="inputImage">The input image.</param>
161+
/// <param name="cancellationToken">The cancellation token.</param>
162+
/// <returns></returns>
163+
public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inputImage, CancellationToken cancellationToken = default)
164+
{
165+
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
166+
inputImage.NormalizeOneOneToZeroOne();
167+
168+
var result = await RunInternalAsync(inputImage, cancellationToken);
169+
170+
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
171+
result.NormalizeZeroOneToOneOne();
172+
173+
return result;
151174
}
152175

153176

@@ -233,7 +256,7 @@ public static ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet, ILog
233256
/// <param name="executionProvider">The execution provider.</param>
234257
/// <param name="logger">The logger.</param>
235258
/// <returns></returns>
236-
public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
259+
public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFactor, int sampleSize, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, int tileSize = 0, int tileOverlap = 20, int channels = 3, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
237260
{
238261
var name = Path.GetFileNameWithoutExtension(modelFile);
239262
var configuration = new UpscaleModelSet
@@ -249,10 +272,13 @@ public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFac
249272
ScaleFactor = scaleFactor,
250273
TileOverlap = tileOverlap,
251274
TileSize = Math.Min(sampleSize, tileSize > 0 ? tileSize : sampleSize),
252-
OnnxModelPath = modelFile
275+
NormalizeType = normalizeType,
276+
NormalizeInput = normalizeInput,
277+
OnnxModelPath = modelFile,
253278
}
254279
};
255280
return CreatePipeline(configuration, logger);
256281
}
257282
}
283+
258284
}

0 commit comments

Comments
 (0)