@@ -62,16 +62,30 @@ public async Task UnloadAsync()
6262 }
6363
6464
65+ /// <summary>
66+ /// Generates the feature extractor image
67+ /// </summary>
68+ /// <param name="inputImage">The input image.</param>
69+ /// <returns></returns>
70+ public async Task < DenseTensor < float > > RunAsync ( DenseTensor < float > inputTensor , CancellationToken cancellationToken = default )
71+ {
72+ var timestamp = _logger ? . LogBegin ( "Extracting DenseTensor feature..." ) ;
73+ var result = await ExtractTensorAsync ( inputTensor , cancellationToken ) ;
74+ _logger ? . LogEnd ( "Extracting DenseTensor feature complete." , timestamp ) ;
75+ return result ;
76+ }
77+
78+
6579 /// <summary>
6680 /// Generates the feature extractor image
6781 /// </summary>
6882 /// <param name="inputImage">The input image.</param>
6983 /// <returns></returns>
7084 public async Task < OnnxImage > RunAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
7185 {
72- var timestamp = _logger ? . LogBegin ( "Extracting image feature..." ) ;
73- var result = await RunInternalAsync ( inputImage , cancellationToken ) ;
74- _logger ? . LogEnd ( "Extracting image feature complete." , timestamp ) ;
86+ var timestamp = _logger ? . LogBegin ( "Extracting OnnxImage feature..." ) ;
87+ var result = await ExtractImageAsync ( inputImage , cancellationToken ) ;
88+ _logger ? . LogEnd ( "Extracting OnnxImage feature complete." , timestamp ) ;
7589 return result ;
7690 }
7791
@@ -83,13 +97,13 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
8397 /// <returns></returns>
8498 public async Task < OnnxVideo > RunAsync ( OnnxVideo video , CancellationToken cancellationToken = default )
8599 {
86- var timestamp = _logger ? . LogBegin ( "Extracting video features..." ) ;
100+ var timestamp = _logger ? . LogBegin ( "Extracting OnnxVideo features..." ) ;
87101 var featureFrames = new List < OnnxImage > ( ) ;
88102 foreach ( var videoFrame in video . Frames )
89103 {
90104 featureFrames . Add ( await RunAsync ( videoFrame , cancellationToken ) ) ;
91105 }
92- _logger ? . LogEnd ( "Extracting video features complete." , timestamp ) ;
106+ _logger ? . LogEnd ( "Extracting OnnxVideo features complete." , timestamp ) ;
93107 return new OnnxVideo ( video . Info , featureFrames ) ;
94108 }
95109
@@ -102,28 +116,62 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo video, CancellationToken cancell
102116 /// <returns></returns>
103117 public async IAsyncEnumerable < OnnxImage > RunAsync ( IAsyncEnumerable < OnnxImage > imageFrames , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
104118 {
105- var timestamp = _logger ? . LogBegin ( "Extracting video stream features..." ) ;
119+ var timestamp = _logger ? . LogBegin ( "Extracting OnnxImage stream features..." ) ;
106120 await foreach ( var imageFrame in imageFrames )
107121 {
108- yield return await RunInternalAsync ( imageFrame , cancellationToken ) ;
122+ yield return await ExtractImageAsync ( imageFrame , cancellationToken ) ;
109123 }
110- _logger ? . LogEnd ( "Extracting video stream features complete." , timestamp ) ;
124+ _logger ? . LogEnd ( "Extracting OnnxImage stream features complete." , timestamp ) ;
111125 }
112126
113127
114128 /// <summary>
115- /// Runs the pipeline
129+ /// Extracts the feature to OnnxImage.
116130 /// </summary>
117131 /// <param name="inputImage">The input image.</param>
118132 /// <param name="cancellationToken">The cancellation token.</param>
119133 /// <returns></returns>
120- private async Task < OnnxImage > RunInternalAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
134+ private async Task < OnnxImage > ExtractImageAsync ( OnnxImage inputImage , CancellationToken cancellationToken = default )
121135 {
122136 var originalWidth = inputImage . Width ;
123137 var originalHeight = inputImage . Height ;
124138 var inputTensor = _featureExtractorModel . SampleSize <= 0
125- ? await inputImage . GetImageTensorAsync ( _featureExtractorModel . InputNormalization )
126- : await inputImage . GetImageTensorAsync ( _featureExtractorModel . SampleSize , _featureExtractorModel . SampleSize , _featureExtractorModel . InputNormalization , resizeMode : _featureExtractorModel . InputResizeMode ) ;
139+ ? await inputImage . GetImageTensorAsync ( _featureExtractorModel . NormalizeType )
140+ : await inputImage . GetImageTensorAsync ( _featureExtractorModel . SampleSize , _featureExtractorModel . SampleSize , _featureExtractorModel . NormalizeType , resizeMode : _featureExtractorModel . InputResizeMode ) ;
141+
142+ var outputTensor = await RunInternalAsync ( inputTensor , cancellationToken ) ;
143+ var imageResult = new OnnxImage ( outputTensor , _featureExtractorModel . NormalizeType ) ;
144+
145+ if ( _featureExtractorModel . InputResizeMode == ImageResizeMode . Stretch && ( imageResult . Width != originalWidth || imageResult . Height != originalHeight ) )
146+ imageResult . Resize ( originalHeight , originalWidth , _featureExtractorModel . InputResizeMode ) ;
147+
148+ return imageResult ;
149+ }
150+
151+
152+ /// <summary>
153+ /// Extracts the feature to DenseTensor.
154+ /// </summary>
155+ /// <param name="inputTensor">The input tensor.</param>
156+ /// <param name="cancellationToken">The cancellation token.</param>
157+ /// <returns></returns>
158+ public async Task < DenseTensor < float > > ExtractTensorAsync ( DenseTensor < float > inputTensor , CancellationToken cancellationToken = default )
159+ {
160+ if ( _featureExtractorModel . NormalizeInput && _featureExtractorModel . NormalizeType == ImageNormalizeType . ZeroToOne )
161+ inputTensor . NormalizeOneOneToZeroOne ( ) ;
162+
163+ return await RunInternalAsync ( inputTensor , cancellationToken ) ;
164+ }
165+
166+
167+ /// <summary>
168+ /// Runs the pipeline
169+ /// </summary>
170+ /// <param name="inputTensor">The input tensor.</param>
171+ /// <param name="cancellationToken">The cancellation token.</param>
172+ /// <returns></returns>
173+ private async Task < DenseTensor < float > > RunInternalAsync ( DenseTensor < float > inputTensor , CancellationToken cancellationToken = default )
174+ {
127175 var metadata = await _featureExtractorModel . GetMetadataAsync ( ) ;
128176 cancellationToken . ThrowIfCancellationRequested ( ) ;
129177 var outputShape = new [ ] { 1 , _featureExtractorModel . OutputChannels , inputTensor . Dimensions [ 2 ] , inputTensor . Dimensions [ 3 ] } ;
@@ -139,21 +187,13 @@ private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, Cancellatio
139187 cancellationToken . ThrowIfCancellationRequested ( ) ;
140188
141189 var outputTensor = inferenceResult . ToDenseTensor ( outputShape ) ;
142- if ( _featureExtractorModel . NormalizeOutputTensor )
190+ if ( _featureExtractorModel . NormalizeOutput )
143191 outputTensor . NormalizeMinMax ( ) ;
144192
145- var imageResult = default ( OnnxImage ) ;
146193 if ( _featureExtractorModel . SetOutputToInputAlpha )
147- imageResult = new OnnxImage ( AddAlphaChannel ( inputTensor , outputTensor ) , _featureExtractorModel . InputNormalization ) ;
148- else if ( _featureExtractorModel . OutputChannels >= 3 )
149- imageResult = new OnnxImage ( outputTensor , _featureExtractorModel . InputNormalization ) ;
150- else
151- imageResult = outputTensor . ToImageMask ( ) ;
152-
153- if ( _featureExtractorModel . InputResizeMode == ImageResizeMode . Stretch && ( imageResult . Width != originalWidth || imageResult . Height != originalHeight ) )
154- imageResult . Resize ( originalHeight , originalWidth , _featureExtractorModel . InputResizeMode ) ;
194+ return AddAlphaChannel ( inputTensor , outputTensor ) ;
155195
156- return imageResult ;
196+ return outputTensor ;
157197 }
158198 }
159199 }
@@ -200,7 +240,7 @@ public static FeatureExtractorPipeline CreatePipeline(FeatureExtractorModelSet m
200240 /// <param name="executionProvider">The execution provider.</param>
201241 /// <param name="logger">The logger.</param>
202242 /// <returns></returns>
203- public static FeatureExtractorPipeline CreatePipeline ( string modelFile , int sampleSize = 0 , int outputChannels = 1 , bool normalizeOutputTensor = false , ImageNormalizeType normalizeInputTensor = ImageNormalizeType . ZeroToOne , ImageResizeMode inputResizeMode = ImageResizeMode . Crop , bool setOutputToInputAlpha = false , int deviceId = 0 , ExecutionProvider executionProvider = ExecutionProvider . DirectML , ILogger logger = default )
243+ public static FeatureExtractorPipeline CreatePipeline ( string modelFile , int sampleSize = 0 , int outputChannels = 1 , ImageNormalizeType normalizeType = ImageNormalizeType . ZeroToOne , bool normalizeInput = true , bool normalizeOutput = false , ImageResizeMode inputResizeMode = ImageResizeMode . Crop , bool setOutputToInputAlpha = false , int deviceId = 0 , ExecutionProvider executionProvider = ExecutionProvider . DirectML , ILogger logger = default )
204244 {
205245 var name = Path . GetFileNameWithoutExtension ( modelFile ) ;
206246 var configuration = new FeatureExtractorModelSet
@@ -214,9 +254,10 @@ public static FeatureExtractorPipeline CreatePipeline(string modelFile, int samp
214254 OnnxModelPath = modelFile ,
215255 SampleSize = sampleSize ,
216256 OutputChannels = outputChannels ,
217- NormalizeOutputTensor = normalizeOutputTensor ,
257+ NormalizeOutput = normalizeOutput ,
258+ NormalizeInput = normalizeInput ,
259+ NormalizeType = normalizeType ,
218260 SetOutputToInputAlpha = setOutputToInputAlpha ,
219- NormalizeInputTensor = normalizeInputTensor ,
220261 InputResizeMode = inputResizeMode
221262 }
222263 } ;
0 commit comments