@@ -602,25 +602,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
602602}
603603
604604FrameBatchOutput SingleStreamDecoder::getFramesAtIndices (
605- const std::vector< int64_t > & frameIndices) {
605+ const torch::Tensor & frameIndices) {
606606 validateActiveStream (AVMEDIA_TYPE_VIDEO);
607607
608- auto indicesAreSorted =
609- std::is_sorted (frameIndices.begin (), frameIndices.end ());
608+ auto frameIndicesAccessor = frameIndices.accessor <int64_t , 1 >();
609+
610+ bool indicesAreSorted = true ;
611+ for (int64_t i = 1 ; i < frameIndices.numel (); ++i) {
612+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1 ]) {
613+ indicesAreSorted = false ;
614+ break ;
615+ }
616+ }
610617
611618 std::vector<size_t > argsort;
612619 if (!indicesAreSorted) {
613620 // if frameIndices is [13, 10, 12, 11]
614621 // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
615622 // to use to decode the frames
616623 // and argsort is [ 1, 3, 2, 0]
617- argsort.resize (frameIndices.size ());
624+ argsort.resize (frameIndices.numel ());
618625 for (size_t i = 0 ; i < argsort.size (); ++i) {
619626 argsort[i] = i;
620627 }
621628 std::sort (
622- argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
623- return frameIndices[a] < frameIndices[b];
629+ argsort.begin (),
630+ argsort.end (),
631+ [&frameIndicesAccessor](size_t a, size_t b) {
632+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
624633 });
625634 }
626635
@@ -629,12 +638,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
629638 const auto & streamInfo = streamInfos_[activeStreamIndex_];
630639 const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
631640 FrameBatchOutput frameBatchOutput (
632- frameIndices.size (), videoStreamOptions, streamMetadata);
641+ frameIndices.numel (), videoStreamOptions, streamMetadata);
633642
634643 auto previousIndexInVideo = -1 ;
635- for (size_t f = 0 ; f < frameIndices.size (); ++f) {
644+ for (int64_t f = 0 ; f < frameIndices.numel (); ++f) {
636645 auto indexInOutput = indicesAreSorted ? f : argsort[f];
637- auto indexInVideo = frameIndices [indexInOutput];
646+ auto indexInVideo = frameIndicesAccessor [indexInOutput];
638647
639648 if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
640649 // Avoid decoding the same frame twice
@@ -776,7 +785,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
776785 frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
777786 }
778787
779- return getFramesAtIndices (frameIndices);
788+ // TODO: Support tensors natively instead of a vector to avoid a copy.
789+ return getFramesAtIndices (torch::tensor (frameIndices));
780790}
781791
782792FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange (
0 commit comments