Skip to content

Commit c1798db

Browse files
authored
C++ implementation of crop transform (#967)
1 parent 62336a0 commit c1798db

19 files changed

+470
-201
lines changed

.github/workflows/linux_wheel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
run: python -m pip install --upgrade pip
8686
- name: Install PyTorch
8787
run: |
88-
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
88+
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
8989
- name: Install torchcodec from the wheel
9090
run: |
9191
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`

.github/workflows/macos_wheel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686

8787
- name: Install PyTorch
8888
run: |
89-
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
89+
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
9090
9191
- name: Install torchcodec from the wheel
9292
run: |

.github/workflows/reference_resources.yaml

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,51 @@ defaults:
1414
shell: bash -l -eo pipefail {0}
1515

1616
jobs:
17+
generate-matrix:
18+
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
19+
with:
20+
package-type: wheel
21+
os: linux
22+
test-infra-repository: pytorch/test-infra
23+
test-infra-ref: main
24+
with-xpu: disable
25+
with-rocm: disable
26+
with-cuda: disable
27+
build-python-only: "disable"
28+
29+
build:
30+
needs: generate-matrix
31+
strategy:
32+
fail-fast: false
33+
name: Build and Upload Linux wheel
34+
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
35+
with:
36+
repository: meta-pytorch/torchcodec
37+
ref: ""
38+
test-infra-repository: pytorch/test-infra
39+
test-infra-ref: main
40+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
41+
pre-script: packaging/pre_build_script.sh
42+
post-script: packaging/post_build_script.sh
43+
smoke-test-script: packaging/fake_smoke_test.py
44+
package-name: torchcodec
45+
trigger-event: ${{ github.event_name }}
46+
build-platform: "python-build-package"
47+
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"
48+
1749
test-reference-resource-generation:
50+
needs: build
1851
runs-on: ubuntu-latest
1952
strategy:
2053
fail-fast: false
2154
matrix:
2255
python-version: ['3.10']
2356
ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1']
2457
steps:
58+
- uses: actions/download-artifact@v4
59+
with:
60+
name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64
61+
path: pytorch/torchcodec/dist/
2562
- name: Setup conda env
2663
uses: conda-incubator/setup-miniconda@v2
2764
with:
@@ -43,11 +80,16 @@ jobs:
4380
# Note that we're installing stable - this is for running a script where we're a normal PyTorch
4481
# user, not for building TorhCodec.
4582
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
46-
python -m pip install numpy pillow
83+
python -m pip install numpy pillow pytest
4784
85+
- name: Install torchcodec from the wheel
86+
run: |
87+
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`
88+
echo Installing $wheel_path
89+
python -m pip install $wheel_path -vvv
4890
- name: Check out repo
4991
uses: actions/checkout@v3
5092

5193
- name: Run generation reference resources
5294
run: |
53-
python test/generate_reference_resources.py
95+
python -m test.generate_reference_resources

.github/workflows/windows_wheel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ jobs:
9393
run: python -m pip install --upgrade pip
9494
- name: Install PyTorch
9595
run: |
96-
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
96+
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
9797
- name: Install torchcodec from the wheel
9898
run: |
9999
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ FilterGraph::FilterGraph(
130130
TORCH_CHECK(
131131
status >= 0,
132132
"Failed to configure filter graph: ",
133-
getFFMPEGErrorStringFromErrorCode(status));
133+
getFFMPEGErrorStringFromErrorCode(status),
134+
", provided filters: " + filtersContext.filtergraphStr);
134135
}
135136

136137
UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {

src/torchcodec/_core/Frame.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
namespace facebook::torchcodec {
1010

11+
FrameDims::FrameDims(int height, int width) : height(height), width(width) {
12+
TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height);
13+
TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width);
14+
}
15+
1116
FrameBatchOutput::FrameBatchOutput(
1217
int64_t numFrames,
1318
const FrameDims& outputDims,

src/torchcodec/_core/Frame.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct FrameDims {
1919

2020
FrameDims() = default;
2121

22-
FrameDims(int h, int w) : height(h), width(w) {}
22+
FrameDims(int h, int w);
2323
};
2424

2525
// All public video decoding entry points return either a FrameOutput or a

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <sstream>
1313
#include <stdexcept>
1414
#include <string_view>
15+
#include "Metadata.h"
1516
#include "torch/types.h"
1617

1718
namespace facebook::torchcodec {
@@ -527,6 +528,7 @@ void SingleStreamDecoder::addVideoStream(
527528
if (transform->getOutputFrameDims().has_value()) {
528529
resizedOutputDims_ = transform->getOutputFrameDims().value();
529530
}
531+
transform->validate(streamMetadata);
530532

531533
// Note that we are claiming ownership of the transform objects passed in to
532534
// us.

src/torchcodec/_core/Transform.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,31 @@ int ResizeTransform::getSwsFlags() const {
5757
return toSwsInterpolation(interpolationMode_);
5858
}
5959

60+
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
61+
: outputDims_(dims), x_(x), y_(y) {
62+
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
63+
TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
64+
}
65+
66+
std::string CropTransform::getFilterGraphCpu() const {
67+
return "crop=" + std::to_string(outputDims_.width) + ":" +
68+
std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" +
69+
std::to_string(y_) + ":exact=1";
70+
}
71+
72+
std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
73+
return outputDims_;
74+
}
75+
76+
void CropTransform::validate(const StreamMetadata& streamMetadata) const {
77+
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
78+
TORCH_CHECK(
79+
x_ + outputDims_.width <= streamMetadata.width,
80+
"Crop x position out of bounds")
81+
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
82+
TORCH_CHECK(
83+
y_ + outputDims_.height <= streamMetadata.height,
84+
"Crop y position out of bounds");
85+
}
86+
6087
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <optional>
1010
#include <string>
1111
#include "src/torchcodec/_core/Frame.h"
12+
#include "src/torchcodec/_core/Metadata.h"
1213

1314
namespace facebook::torchcodec {
1415

@@ -33,6 +34,16 @@ class Transform {
3334
virtual bool isResize() const {
3435
return false;
3536
}
37+
38+
// The validity of some transforms depends on the characteristics of the
39+
// AVStream they're being applied to. For example, some transforms will
40+
// specify coordinates inside a frame, we need to validate that those are
41+
// within the frame's bounds.
42+
//
43+
// Note that the validation function does not return anything. We expect
44+
// invalid configurations to throw an exception.
45+
virtual void validate(
46+
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
3647
};
3748

3849
class ResizeTransform : public Transform {
@@ -56,4 +67,18 @@ class ResizeTransform : public Transform {
5667
InterpolationMode interpolationMode_;
5768
};
5869

70+
class CropTransform : public Transform {
71+
public:
72+
CropTransform(const FrameDims& dims, int x, int y);
73+
74+
std::string getFilterGraphCpu() const override;
75+
std::optional<FrameDims> getOutputFrameDims() const override;
76+
void validate(const StreamMetadata& streamMetadata) const override;
77+
78+
private:
79+
FrameDims outputDims_;
80+
int x_;
81+
int y_;
82+
};
83+
5984
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)