11using Microsoft . Extensions . Logging ;
22using Microsoft . ML . OnnxRuntime . Tensors ;
3- using Newtonsoft . Json . Linq ;
43using OnnxStack . Core ;
54using OnnxStack . Core . Config ;
65using OnnxStack . Core . Image ;
1110using System . Collections . Generic ;
1211using System . IO ;
1312using System . Linq ;
14- using System . Numerics . Tensors ;
1513using System . Runtime . CompilerServices ;
1614using System . Threading ;
1715using System . Threading . Tasks ;
@@ -73,9 +71,9 @@ public async Task UnloadAsync()
7371 /// <returns></returns>
7472 public async Task < DenseTensor < float > > RunAsync ( DenseTensor < float > inputImage , CancellationToken cancellationToken = default )
7573 {
76- var timestamp = _logger ? . LogBegin ( "Upscale image .." ) ;
74+ var timestamp = _logger ? . LogBegin ( "Upscale DenseTensor .." ) ;
7775 var result = await UpscaleTensorAsync ( inputImage , cancellationToken ) ;
78- _logger ? . LogEnd ( "Upscale image complete." , timestamp ) ;
76+ _logger ? . LogEnd ( "Upscale DenseTensor complete." , timestamp ) ;
7977 return result ;
8078 }
8179
@@ -88,9 +86,9 @@ public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, Ca
8886 /// <returns></returns>
8987 public async Task < OnnxImage > RunAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
9088 {
91- var timestamp = _logger ? . LogBegin ( "Upscale image .." ) ;
89+ var timestamp = _logger ? . LogBegin ( "Upscale OnnxImage .." ) ;
9290 var result = await UpscaleImageAsync ( inputImage , cancellationToken ) ;
93- _logger ? . LogEnd ( "Upscale image complete." , timestamp ) ;
91+ _logger ? . LogEnd ( "Upscale OnnxImage complete." , timestamp ) ;
9492 return result ;
9593 }
9694
@@ -103,7 +101,7 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
103101 /// <returns></returns>
104102 public async Task < OnnxVideo > RunAsync ( OnnxVideo inputVideo , CancellationToken cancellationToken = default )
105103 {
106- var timestamp = _logger ? . LogBegin ( "Upscale video .." ) ;
104+ var timestamp = _logger ? . LogBegin ( "Upscale OnnxVideo .." ) ;
107105 var upscaledFrames = new List < OnnxImage > ( ) ;
108106 foreach ( var videoFrame in inputVideo . Frames )
109107 {
@@ -117,7 +115,7 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
117115 Height = firstFrame . Height ,
118116 } ;
119117
120- _logger ? . LogEnd ( "Upscale video complete." , timestamp ) ;
118+ _logger ? . LogEnd ( "Upscale OnnxVideo complete." , timestamp ) ;
121119 return new OnnxVideo ( videoInfo , upscaledFrames ) ;
122120 }
123121
@@ -130,16 +128,15 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
130128 /// <returns></returns>
131129 public async IAsyncEnumerable < OnnxImage > RunAsync ( IAsyncEnumerable < OnnxImage > imageFrames , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
132130 {
133- var timestamp = _logger ? . LogBegin ( "Upscale video stream.." ) ;
131+ var timestamp = _logger ? . LogBegin ( "Upscale OnnxImage stream.." ) ;
134132 await foreach ( var imageFrame in imageFrames )
135133 {
136134 yield return await UpscaleImageAsync ( imageFrame , cancellationToken ) ;
137135 }
138- _logger ? . LogEnd ( "Upscale video stream complete." , timestamp ) ;
136+ _logger ? . LogEnd ( "Upscale OnnxImage stream complete." , timestamp ) ;
139137 }
140138
141139
142-
143140 /// <summary>
144141 /// Upscales the OnnxImage.
145142 /// </summary>
@@ -149,23 +146,25 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
149146 private async Task < OnnxImage > UpscaleImageAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
150147 {
151148 var inputTensor = inputImage . GetImageTensor ( _upscaleModel . NormalizeType , _upscaleModel . Channels ) ;
152- var outputTensor = await RunInternalAsync ( inputTensor , cancellationToken ) ;
149+ var outputTensor = await RunInternalAsync ( inputTensor , inputImage . Height , inputImage . Width , cancellationToken ) ;
153150 return new OnnxImage ( outputTensor , _upscaleModel . NormalizeType ) ;
154151 }
155152
156153
157154 /// <summary>
158155 /// Upscales the DenseTensor
159156 /// </summary>
160- /// <param name="inputImage ">The input image .</param>
157+ /// <param name="inputTensor ">The input Tensor .</param>
161158 /// <param name="cancellationToken">The cancellation token.</param>
162159 /// <returns></returns>
163- public async Task < DenseTensor < float > > UpscaleTensorAsync ( DenseTensor < float > inputImage , CancellationToken cancellationToken = default )
160+ public async Task < DenseTensor < float > > UpscaleTensorAsync ( DenseTensor < float > inputTensor , CancellationToken cancellationToken = default )
164161 {
165162 if ( _upscaleModel . NormalizeInput && _upscaleModel . NormalizeType == ImageNormalizeType . ZeroToOne )
166- inputImage . NormalizeOneOneToZeroOne ( ) ;
163+ inputTensor . NormalizeOneOneToZeroOne ( ) ;
167164
168- var result = await RunInternalAsync ( inputImage , cancellationToken ) ;
165+ var height = inputTensor . Dimensions [ 2 ] ;
166+ var width = inputTensor . Dimensions [ 3 ] ;
167+ var result = await RunInternalAsync ( inputTensor , height , width , cancellationToken ) ;
169168
170169 if ( _upscaleModel . NormalizeInput && _upscaleModel . NormalizeType == ImageNormalizeType . ZeroToOne )
171170 result . NormalizeZeroOneToOneOne ( ) ;
@@ -180,9 +179,9 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
180179 /// <param name="inputTensor">The input tensor.</param>
181180 /// <param name="cancellationToken">The cancellation token.</param>
182181 /// <returns></returns>
183- private async Task < DenseTensor < float > > RunInternalAsync ( DenseTensor < float > inputTensor , CancellationToken cancellationToken = default )
182+ private async Task < DenseTensor < float > > RunInternalAsync ( DenseTensor < float > inputTensor , int height , int width , CancellationToken cancellationToken = default )
184183 {
185- if ( inputTensor . Dimensions [ 2 ] <= _upscaleModel . TileSize && inputTensor . Dimensions [ 3 ] <= _upscaleModel . TileSize )
184+ if ( height <= _upscaleModel . TileSize && width <= _upscaleModel . TileSize )
186185 {
187186 return await RunInferenceAsync ( inputTensor , cancellationToken ) ;
188187 }
@@ -193,10 +192,10 @@ private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> input
193192 inputTiles . Width * _upscaleModel . ScaleFactor ,
194193 inputTiles . Height * _upscaleModel . ScaleFactor ,
195194 inputTiles . Overlap * _upscaleModel . ScaleFactor ,
196- await RunInternalAsync ( inputTiles . Tile1 , cancellationToken ) ,
197- await RunInternalAsync ( inputTiles . Tile2 , cancellationToken ) ,
198- await RunInternalAsync ( inputTiles . Tile3 , cancellationToken ) ,
199- await RunInternalAsync ( inputTiles . Tile4 , cancellationToken )
195+ await RunInternalAsync ( inputTiles . Tile1 , inputTiles . Height , inputTiles . Width , cancellationToken ) ,
196+ await RunInternalAsync ( inputTiles . Tile2 , inputTiles . Height , inputTiles . Width , cancellationToken ) ,
197+ await RunInternalAsync ( inputTiles . Tile3 , inputTiles . Height , inputTiles . Width , cancellationToken ) ,
198+ await RunInternalAsync ( inputTiles . Tile4 , inputTiles . Height , inputTiles . Width , cancellationToken )
200199 ) ;
201200 return outputTiles . JoinImageTiles ( ) ;
202201 }
0 commit comments