Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit fa234ee

Browse files
committed
Feature Extractor: Allow output shapes without batch
1 parent d76a113 commit fa234ee

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

OnnxStack.Core/Extensions/OrtValueExtensions.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,19 @@ public static OrtValue CreateOutputBuffer(this OnnxNamedMetadata metadata, ReadO
105105
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
106106
{
107107
var typeInfo = ortValue.GetTensorTypeAndShape();
108-
var dimensions = typeInfo.Shape.ToInt();
108+
return ortValue.ToDenseTensor(typeInfo.Shape.ToInt());
109+
}
110+
111+
112+
/// <summary>
113+
/// Converts to DenseTensor<float>.
114+
/// TODO: Optimization
115+
/// </summary>
116+
/// <param name="ortValue">The ort value.</param>
117+
/// <returns></returns>
118+
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue, ReadOnlySpan<int> dimensions)
119+
{
120+
var typeInfo = ortValue.GetTensorTypeAndShape();
109121
return typeInfo.ElementDataType switch
110122
{
111123
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),

OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,19 @@ private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, Cancellatio
121121
var controlImage = await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, ImageNormalizeType.ZeroToOne);
122122
var metadata = await _featureExtractorModel.GetMetadataAsync();
123123
cancellationToken.ThrowIfCancellationRequested();
124+
var outputShape = new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize };
125+
var outputBuffer = metadata.Outputs[0].Value.Dimensions.Length == 4 ? outputShape : outputShape[1..];
124126
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
125127
{
126128
inferenceParameters.AddInputTensor(controlImage);
127-
inferenceParameters.AddOutputBuffer(new[] { 1, _featureExtractorModel.Channels, _featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize });
129+
inferenceParameters.AddOutputBuffer(outputBuffer);
128130

129131
var results = await _featureExtractorModel.RunInferenceAsync(inferenceParameters);
130132
using (var result = results.First())
131133
{
132134
cancellationToken.ThrowIfCancellationRequested();
133135

134-
var resultTensor = result.ToDenseTensor();
136+
var resultTensor = result.ToDenseTensor(outputShape);
135137
if (_featureExtractorModel.Normalize)
136138
resultTensor.NormalizeMinMax();
137139

0 commit comments

Comments
 (0)