11using Microsoft . Extensions . Logging ;
22using Microsoft . ML . OnnxRuntime . Tensors ;
3+ using Newtonsoft . Json . Linq ;
34using OnnxStack . Core ;
45using OnnxStack . Core . Config ;
56using OnnxStack . Core . Image ;
1011using System . Collections . Generic ;
1112using System . IO ;
1213using System . Linq ;
14+ using System . Numerics . Tensors ;
1315using System . Runtime . CompilerServices ;
1416using System . Threading ;
1517using System . Threading . Tasks ;
@@ -72,7 +74,7 @@ public async Task UnloadAsync()
7274 public async Task < DenseTensor < float > > RunAsync ( DenseTensor < float > inputImage , CancellationToken cancellationToken = default )
7375 {
7476 var timestamp = _logger ? . LogBegin ( "Upscale image.." ) ;
75- var result = await RunInternalAsync ( inputImage , cancellationToken ) ;
77+ var result = await UpscaleTensorAsync ( inputImage , cancellationToken ) ;
7678 _logger ? . LogEnd ( "Upscale image complete." , timestamp ) ;
7779 return result ;
7880 }
@@ -87,7 +89,7 @@ public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, Ca
8789 public async Task < OnnxImage > RunAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
8890 {
8991 var timestamp = _logger ? . LogBegin ( "Upscale image.." ) ;
90- var result = await RunInternalAsync ( inputImage , cancellationToken ) ;
92+ var result = await UpscaleImageAsync ( inputImage , cancellationToken ) ;
9193 _logger ? . LogEnd ( "Upscale image complete." , timestamp ) ;
9294 return result ;
9395 }
@@ -105,7 +107,7 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
105107 var upscaledFrames = new List < OnnxImage > ( ) ;
106108 foreach ( var videoFrame in inputVideo . Frames )
107109 {
108- upscaledFrames . Add ( await RunInternalAsync ( videoFrame , cancellationToken ) ) ;
110+ upscaledFrames . Add ( await UpscaleImageAsync ( videoFrame , cancellationToken ) ) ;
109111 }
110112
111113 var firstFrame = upscaledFrames . First ( ) ;
@@ -131,23 +133,44 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
131133 var timestamp = _logger ? . LogBegin ( "Upscale video stream.." ) ;
132134 await foreach ( var imageFrame in imageFrames )
133135 {
134- yield return await RunInternalAsync ( imageFrame , cancellationToken ) ;
136+ yield return await UpscaleImageAsync ( imageFrame , cancellationToken ) ;
135137 }
136138 _logger ? . LogEnd ( "Upscale video stream complete." , timestamp ) ;
137139 }
138140
139141
142+
140143 /// <summary>
141- /// Runs the upscale pipeline
144+ /// Upscales the OnnxImage.
142145 /// </summary>
143146 /// <param name="inputImage">The input image.</param>
144147 /// <param name="cancellationToken">The cancellation token.</param>
145148 /// <returns></returns>
146- private async Task < OnnxImage > RunInternalAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
149+ private async Task < OnnxImage > UpscaleImageAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
147150 {
148- var inputTensor = inputImage . GetImageTensor ( ImageNormalizeType . ZeroToOne , _upscaleModel . Channels ) ;
151+ var inputTensor = inputImage . GetImageTensor ( _upscaleModel . NormalizeType , _upscaleModel . Channels ) ;
149152 var outputTensor = await RunInternalAsync ( inputTensor , cancellationToken ) ;
150- return new OnnxImage ( outputTensor , ImageNormalizeType . ZeroToOne ) ;
153+ return new OnnxImage ( outputTensor , _upscaleModel . NormalizeType ) ;
154+ }
155+
156+
157+ /// <summary>
158+ /// Upscales the DenseTensor
159+ /// </summary>
160+ /// <param name="inputImage">The input image.</param>
161+ /// <param name="cancellationToken">The cancellation token.</param>
162+ /// <returns></returns>
163+ public async Task < DenseTensor < float > > UpscaleTensorAsync ( DenseTensor < float > inputImage , CancellationToken cancellationToken = default )
164+ {
165+ if ( _upscaleModel . NormalizeInput && _upscaleModel . NormalizeType == ImageNormalizeType . ZeroToOne )
166+ inputImage . NormalizeOneOneToZeroOne ( ) ;
167+
168+ var result = await RunInternalAsync ( inputImage , cancellationToken ) ;
169+
170+ if ( _upscaleModel . NormalizeInput && _upscaleModel . NormalizeType == ImageNormalizeType . ZeroToOne )
171+ result . NormalizeZeroOneToOneOne ( ) ;
172+
173+ return result ;
151174 }
152175
153176
@@ -233,7 +256,7 @@ public static ImageUpscalePipeline CreatePipeline(UpscaleModelSet modelSet, ILog
233256 /// <param name="executionProvider">The execution provider.</param>
234257 /// <param name="logger">The logger.</param>
235258 /// <returns></returns>
236- public static ImageUpscalePipeline CreatePipeline ( string modelFile , int scaleFactor , int sampleSize , int tileSize = 0 , int tileOverlap = 20 , int channels = 3 , int deviceId = 0 , ExecutionProvider executionProvider = ExecutionProvider . DirectML , ILogger logger = default )
259+ public static ImageUpscalePipeline CreatePipeline ( string modelFile , int scaleFactor , int sampleSize , ImageNormalizeType normalizeType = ImageNormalizeType . ZeroToOne , bool normalizeInput = true , int tileSize = 0 , int tileOverlap = 20 , int channels = 3 , int deviceId = 0 , ExecutionProvider executionProvider = ExecutionProvider . DirectML , ILogger logger = default )
237260 {
238261 var name = Path . GetFileNameWithoutExtension ( modelFile ) ;
239262 var configuration = new UpscaleModelSet
@@ -249,10 +272,13 @@ public static ImageUpscalePipeline CreatePipeline(string modelFile, int scaleFac
249272 ScaleFactor = scaleFactor ,
250273 TileOverlap = tileOverlap ,
251274 TileSize = Math . Min ( sampleSize , tileSize > 0 ? tileSize : sampleSize ) ,
252- OnnxModelPath = modelFile
275+ NormalizeType = normalizeType ,
276+ NormalizeInput = normalizeInput ,
277+ OnnxModelPath = modelFile ,
253278 }
254279 } ;
255280 return CreatePipeline ( configuration , logger ) ;
256281 }
257282 }
283+
258284}
0 commit comments