Skip to content

Commit f1678b7

Browse files
committed
clean up deviceVariant, clean up tests
1 parent 7a7d4d3 commit f1678b7

File tree

4 files changed

+7
-63
lines changed

4 files changed

+7
-63
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -627,11 +627,7 @@ void VideoEncoder::initializeEncoder(
627627

628628
// Try to find a hardware-accelerated encoder if not using CPU
629629
if (videoStreamOptions.device.type() != torch::kCPU) {
630-
auto hardwareCodec =
631-
deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec);
632-
if (hardwareCodec.has_value()) {
633-
avCodec = hardwareCodec.value();
634-
}
630+
avCodec = deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec).value_or(avCodec);
635631
}
636632

637633
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);

src/torchcodec/_core/custom_ops.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,6 @@ void _add_video_stream(
379379
validateDeviceInterface(std::string(device), std::string(device_variant));
380380

381381
videoStreamOptions.device = torch::Device(std::string(device));
382-
videoStreamOptions.deviceVariant = device_variant;
383-
384382
std::vector<Transform*> transforms =
385383
makeTransforms(std::string(transform_specs));
386384

@@ -609,7 +607,6 @@ void encode_video_to_file(
609607
videoStreamOptions.crf = crf;
610608

611609
videoStreamOptions.device = torch::Device(std::string(device));
612-
videoStreamOptions.deviceVariant = "ffmpeg";
613610
VideoEncoder(
614611
frames,
615612
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -629,7 +626,6 @@ at::Tensor encode_video_to_tensor(
629626
videoStreamOptions.crf = crf;
630627

631628
videoStreamOptions.device = torch::Device(std::string(device));
632-
videoStreamOptions.deviceVariant = "ffmpeg";
633629
return VideoEncoder(
634630
frames,
635631
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -654,9 +650,7 @@ void _encode_video_to_file_like(
654650

655651
VideoStreamOptions videoStreamOptions;
656652
videoStreamOptions.crf = crf;
657-
658653
videoStreamOptions.device = torch::Device(std::string(device));
659-
videoStreamOptions.deviceVariant = "ffmpeg";
660654

661655
VideoEncoder encoder(
662656
frames,

test/test_encoders.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -687,35 +687,3 @@ def encode_to_tensor(frames):
687687
torch.testing.assert_close(
688688
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
689689
)
690-
691-
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
692-
@pytest.mark.parametrize(
693-
"device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
694-
)
695-
def test_device_video_encoder(self, method, device, tmp_path):
696-
# Test that encoding works on CPU and CUDA devices
697-
# num_frames, channels, height, width = 5, 3, 1024, 1024
698-
# frames = (torch.rand(num_frames, channels, height, width) * 255).to(torch.uint8)
699-
700-
asset = TEST_SRC_2_720P
701-
frames = self.decode(asset.path).data
702-
703-
encoder = VideoEncoder(frames, frame_rate=30, device=device)
704-
705-
if method == "to_file":
706-
dest = str(tmp_path / "output.mp4")
707-
encoder.to_file(dest=dest)
708-
# Verify file was created
709-
assert Path(dest).exists()
710-
elif method == "to_tensor":
711-
encoded = encoder.to_tensor(format="mp4")
712-
assert encoded.dtype == torch.uint8
713-
assert encoded.ndim == 1
714-
assert encoded.numel() > 0
715-
elif method == "to_file_like":
716-
file_like = io.BytesIO()
717-
encoder.to_file_like(file_like, format="mp4")
718-
encoded_bytes = file_like.getvalue()
719-
assert len(encoded_bytes) > 0
720-
else:
721-
raise ValueError(f"Unknown method: {method}")

test/test_ops.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,20 +1409,13 @@ def get_encoded_data(self):
14091409
decoded_samples = self.decode(file_like.get_encoded_data())
14101410

14111411
if device == "cuda":
1412-
atol = 15
1413-
percentage = 98
1414-
assert_close = partial(
1415-
assert_tensor_close_on_at_least, percentage=percentage
1416-
)
1412+
assert_close = assert_frames_equal
14171413
else:
1418-
assert_close = torch.testing.assert_close
1419-
atol = 2
1414+
assert_close = partial(torch.testing.assert_close, atol=2)
14201415

14211416
assert_close(
14221417
decoded_samples.data,
1423-
source_frames,
1424-
atol=atol,
1425-
rtol=0,
1418+
source_frames
14261419
)
14271420

14281421
@pytest.mark.parametrize(
@@ -1445,20 +1438,13 @@ def test_to_file_like_real_file(self, tmp_path, device):
14451438
decoded_samples = self.decode(str(file_path))
14461439

14471440
if device == "cuda":
1448-
atol = 15
1449-
percentage = 98
1450-
assert_close = partial(
1451-
assert_tensor_close_on_at_least, percentage=percentage
1452-
)
1441+
assert_close = assert_frames_equal
14531442
else:
1454-
assert_close = torch.testing.assert_close
1455-
atol = 2
1443+
assert_close = partial(torch.testing.assert_close, atol=2)
14561444

14571445
assert_close(
14581446
decoded_samples.data,
1459-
source_frames,
1460-
atol=atol,
1461-
rtol=0,
1447+
source_frames
14621448
)
14631449

14641450
def test_to_file_like_bad_methods(self):

0 commit comments

Comments
 (0)