Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 6ae4eee

Browse files
committed
Fix incorrect upscale tiling
1 parent de788da commit 6ae4eee

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
148148
private async Task<OnnxImage> UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
149149
{
150150
var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels);
151-
var outputTensor = await RunInternalAsync(inputTensor, inputImage.Height, inputImage.Width, cancellationToken);
151+
var outputTensor = await RunInternalAsync(inputTensor, cancellationToken);
152152
return new OnnxImage(outputTensor, _upscaleModel.NormalizeType);
153153
}
154154

@@ -164,10 +164,7 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
164164
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
165165
inputTensor.NormalizeOneOneToZeroOne();
166166

167-
var height = inputTensor.Dimensions[2];
168-
var width = inputTensor.Dimensions[3];
169-
var result = await RunInternalAsync(inputTensor, height, width, cancellationToken);
170-
167+
var result = await RunInternalAsync(inputTensor, cancellationToken);
171168
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
172169
result.NormalizeZeroOneToOneOne();
173170

@@ -181,9 +178,9 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
181178
/// <param name="inputTensor">The input tensor.</param>
182179
/// <param name="cancellationToken">The cancellation token.</param>
183180
/// <returns></returns>
184-
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, int height, int width, CancellationToken cancellationToken = default)
181+
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
185182
{
186-
if (height <= _upscaleModel.TileSize && width <= _upscaleModel.TileSize)
183+
if (inputTensor.Dimensions[2] <= _upscaleModel.SampleSize && inputTensor.Dimensions[3] <= _upscaleModel.SampleSize)
187184
{
188185
return await RunInferenceAsync(inputTensor, cancellationToken);
189186
}
@@ -194,10 +191,10 @@ private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> input
194191
inputTiles.Width * _upscaleModel.ScaleFactor,
195192
inputTiles.Height * _upscaleModel.ScaleFactor,
196193
inputTiles.Overlap * _upscaleModel.ScaleFactor,
197-
await RunInternalAsync(inputTiles.Tile1, inputTiles.Height, inputTiles.Width, cancellationToken),
198-
await RunInternalAsync(inputTiles.Tile2, inputTiles.Height, inputTiles.Width, cancellationToken),
199-
await RunInternalAsync(inputTiles.Tile3, inputTiles.Height, inputTiles.Width, cancellationToken),
200-
await RunInternalAsync(inputTiles.Tile4, inputTiles.Height, inputTiles.Width, cancellationToken)
194+
await RunInternalAsync(inputTiles.Tile1, cancellationToken),
195+
await RunInternalAsync(inputTiles.Tile2, cancellationToken),
196+
await RunInternalAsync(inputTiles.Tile3, cancellationToken),
197+
await RunInternalAsync(inputTiles.Tile4, cancellationToken)
201198
);
202199
return outputTiles.JoinImageTiles();
203200
}

0 commit comments

Comments
 (0)