@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737 m.def (
3838 " _encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
3939 m.def (
40- " encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()" );
40+ " encode_video_to_file(Tensor frames, int frame_rate, str filename, str device= \" cpu \" , str device_variant= \" ffmpeg \" , int? crf=None) -> ()" );
4141 m.def (
42- " encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor" );
42+ " encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device= \" cpu \" , str device_variant= \" ffmpeg \" , int? crf=None) -> Tensor" );
4343 m.def (
44- " _encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()" );
44+ " _encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device= \" cpu \" , str device_variant= \" ffmpeg \" , int? crf=None) -> ()" );
4545 m.def (
4646 " create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor" );
4747 m.def (
@@ -603,9 +603,15 @@ void encode_video_to_file(
603603 const at::Tensor& frames,
604604 int64_t frame_rate,
605605 std::string_view file_name,
606+ std::string_view device = " cpu" ,
607+ std::string_view device_variant = " ffmpeg" ,
606608 std::optional<int64_t > crf = std::nullopt ) {
607609 VideoStreamOptions videoStreamOptions;
608610 videoStreamOptions.crf = crf;
611+
612+ validateDeviceInterface (std::string (device), std::string (device_variant));
613+ videoStreamOptions.device = torch::Device (std::string (device));
614+ videoStreamOptions.deviceVariant = device_variant;
609615 VideoEncoder (
610616 frames,
611617 validateInt64ToInt (frame_rate, " frame_rate" ),
@@ -618,10 +624,16 @@ at::Tensor encode_video_to_tensor(
618624 const at::Tensor& frames,
619625 int64_t frame_rate,
620626 std::string_view format,
627+ std::string_view device = " cpu" ,
628+ std::string_view device_variant = " ffmpeg" ,
621629 std::optional<int64_t > crf = std::nullopt ) {
622630 auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
623631 VideoStreamOptions videoStreamOptions;
624632 videoStreamOptions.crf = crf;
633+
634+ validateDeviceInterface (std::string (device), std::string (device_variant));
635+ videoStreamOptions.device = torch::Device (std::string (device));
636+ videoStreamOptions.deviceVariant = device_variant;
625637 return VideoEncoder (
626638 frames,
627639 validateInt64ToInt (frame_rate, " frame_rate" ),
@@ -636,6 +648,8 @@ void _encode_video_to_file_like(
636648 int64_t frame_rate,
637649 std::string_view format,
638650 int64_t file_like_context,
651+ std::string_view device = " cpu" ,
652+ std::string_view device_variant = " ffmpeg" ,
639653 std::optional<int64_t > crf = std::nullopt ) {
640654 auto fileLikeContext =
641655 reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
@@ -646,6 +660,10 @@ void _encode_video_to_file_like(
646660 VideoStreamOptions videoStreamOptions;
647661 videoStreamOptions.crf = crf;
648662
663+ validateDeviceInterface (std::string (device), std::string (device_variant));
664+ videoStreamOptions.device = torch::Device (std::string (device));
665+ videoStreamOptions.deviceVariant = device_variant;
666+
649667 VideoEncoder encoder (
650668 frames,
651669 validateInt64ToInt (frame_rate, " frame_rate" ),
0 commit comments