diff --git a/DPF/filters/images/complexity_filter.py b/DPF/filters/images/complexity_filter.py new file mode 100644 index 0000000..d79c2c2 --- /dev/null +++ b/DPF/filters/images/complexity_filter.py @@ -0,0 +1,123 @@ +import os +from typing import Any +from urllib.request import urlretrieve + +import numpy as np +import torch +from segment_anything import ( # type: ignore + SamAutomaticMaskGenerator, + sam_model_registry, +) + +from DPF.utils import read_image_rgb_from_bytes + +from ...types import ModalityToDataMapping +from .img_filter import ImageFilter + +WEIGHTS_URL = {'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', + 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', + 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'} + + +class ComplexityFilter(ImageFilter): + """ + Image complexity filter based on SAM: https://github.com/facebookresearch/segment-anything + + Parameters + ---------- + weights_folder: str + Folder where the weights will be stored + model_name: str = 'vit_h' + Model version to use: vit_h - huge, vit_l - large, vit_b - big + points_per_side: int = 32 + Parameter that regulates granularity of automatic segmentation + batch_size: int = 1 + Batch size during mask calculation for one image + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating flow scores + pbar: bool = True + Whether to use a progress bar + """ + + def __init__( + self, + weights_folder: str, + model_name: str = 'vit_h', + points_per_side: int = 32, + workers: int = 16, + batch_size: int = 1, + device: str = "cuda:0", + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.batch_size = batch_size + self.device = device + + self.model_name = model_name + self.weights_folder = weights_folder + self.points_per_side = points_per_side + + # Download checkpoints + path_to_model = os.path.join(self.weights_folder, self.model_name + '.pth') + if not os.path.exists(path_to_model): + os.makedirs(self.weights_folder, exist_ok=True) + urlretrieve(WEIGHTS_URL[self.model_name], path_to_model) + + sam = sam_model_registry[self.model_name](checkpoint=path_to_model) + sam = sam.to(torch.device(self.device)) + self.mask_generator = SamAutomaticMaskGenerator( + sam, points_per_batch=batch_size, + points_per_side=points_per_side + ) + + @property + def result_columns(self) -> list[str]: + return ["complexity_num_segments", "complexity_max_segment_area", "complexity_mean_segment_area"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": 1, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + pil_img = read_image_rgb_from_bytes(modality2data['image']) + img = np.array(pil_img) + return key, img + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for data in batch: + key, img = data + h, w = img.shape[:2] + hw = h * w + with torch.no_grad(): + outputs = self.mask_generator.generate(img) + num_segments = len(outputs) + if num_segments > 0: + areas = [x['area'] for x in outputs] + bg_area = hw - np.sum(areas) + areas.append(bg_area) + max_area = np.max(areas) / hw + mean_area = np.mean(areas) / hw + else: + max_area = mean_area = 0 + + df_batch_labels["complexity_num_segments"].extend([num_segments]) + df_batch_labels["complexity_max_segment_area"].extend([max_area]) + df_batch_labels["complexity_mean_segment_area"].extend([mean_area]) + df_batch_labels[self.key_column].extend([key]) + + return df_batch_labels diff --git a/DPF/filters/videos/cogvlm2_filter.py b/DPF/filters/videos/cogvlm2_filter.py new file mode 100644 index 0000000..76d9016 --- /dev/null +++ b/DPF/filters/videos/cogvlm2_filter.py @@ -0,0 +1,222 @@ +import re +from io import BytesIO +from typing import Any + +import numpy as np +import torch +from decord import VideoReader, bridge +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +from DPF.types import ModalityToDataMapping + +from .video_filter import VideoFilter + +prompt_templates = { + 'detailed_video': 'Describe this video and its style in a very detailed manner', + 'short_video': 'Describe this video and its style briefly', + '1_sentance': "Describe this video very shortly in 1 sentence." + } +MODEL_PATH = "THUDM/cogvlm2-video-llama3-chat" +TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[ + 0] >= 8 else torch.float16 + +compiled_regexs = [ + (re.compile(r'the video (also )?is '), ''), + (re.compile(r'the video (also )?features '), ''), + (re.compile(r'the video (also )?shows '), ''), + (re.compile(r'the video (also )?depicts '), ''), + (re.compile(r'the video (also )?showcases '), ''), + (re.compile(r'the video (also )?captures '), ''), + (re.compile(r'the video (also )?provides '), ''), + (re.compile(r'the video (also )?showcases '), ''), + (re.compile(r'throughout the video, '), ''), +] + + +def clean_with_regex(caption: str) -> str: + lower_caption = str(caption).lower().strip() + for re_compiled, replacement in compiled_regexs: + iterator = reversed(list(re_compiled.finditer(lower_caption))) + for match in iterator: + pos = list(match.span()) + caption = caption[:pos[0]] + replacement + caption[pos[1]:] + lower_caption = str(caption).lower().strip() + + if caption.count('-') > 2: + split_captions = [] + for split_caption in caption.split(): + if split_caption.count('-') > 2: + split_caption = re.sub(r'-', ' ', split_caption) + split_captions.append(split_caption) + caption = ' '.join(split_captions) + + caption = caption.strip('—-:/+=|@#&*') + + return caption.strip() + + +class CogVLM2Filter(VideoFilter): + """ + CogVLM2 inference class to get captions for auto-labeling videos. + More info about the model here: https://github.com/THUDM/CogVLM2 + + Parameters + ---------- + prompt: str = '1_sentance' + Prompt for the model. + quant: int = 16 + Model quantization mode: 4, 8 or 16 + num_frames: int = 24 + Number of frames to sample from the video + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating flow scores + pbar: bool = True + Whether to use a progress bar + """ + def __init__( + self, + prompt: str = '1_sentance', + quant: int = 16, + num_frames: int = 24, + temperature: float = 0.05, + max_new_tokens: int = 1024, + device: str = "cuda:0", + workers: int = 16, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.strategy = 'chat' + self.prompt = prompt + self.tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + # padding_side="left" + ) + self.num_frames = num_frames + + if quant == 4: + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=TORCH_TYPE, + trust_remote_code=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=TORCH_TYPE, + ), + low_cpu_mem_usage=True, + revision='ca14f13b05f5ead425188aae3e5e725bf4905cd1' + ).eval() + elif quant == 8: + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=TORCH_TYPE, + trust_remote_code=True, + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + bnb_4bit_compute_dtype=TORCH_TYPE, + ), + low_cpu_mem_usage=True, + revision='ca14f13b05f5ead425188aae3e5e725bf4905cd1' + ).eval() + else: + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype=TORCH_TYPE, + trust_remote_code=True, + revision='ca14f13b05f5ead425188aae3e5e725bf4905cd1' + ).eval().to(device) + + self.query = prompt_templates[prompt] + + self.num_workers = workers + self.device = device + + self.temperature = temperature + self.max_new_tokens = max_new_tokens + + @property + def result_columns(self) -> list[str]: + return ["caption_cogvlm", "caption_cogvlm_clean"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": 1, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + video_file = BytesIO(modality2data['video']) + loaded_video_file = self.load_video(video_file, strategy=self.strategy) + return key, loaded_video_file + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + key, video = batch[0] + inputs = self.model.build_conversation_input_ids( + tokenizer=self.tokenizer, + query=self.query, + images=[video], + history=[], + template_version=self.strategy + ) + + inputs = { + 'input_ids': inputs['input_ids'].unsqueeze(0).to(self.device), + 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(self.device), + 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(self.device), + 'images': [[inputs['images'][0].to(self.device).to(TORCH_TYPE)]], + } + gen_kwargs = { + "max_new_tokens": self.max_new_tokens, + "pad_token_id": 128002, + "top_k": 1, + "do_sample": True, + "top_p": 0.1, + "temperature": self.temperature, + } + with torch.no_grad(): + outputs = self.model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs['input_ids'].shape[1]:] + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + response_clean = clean_with_regex(response) + df_batch_labels[self.schema[1]].extend([response]) + df_batch_labels[self.schema[2]].extend([response_clean]) + df_batch_labels[self.key_column].extend([key]) + return df_batch_labels + + + def load_video(self, video_path: BytesIO, strategy: str = 'chat') -> torch.Tensor: + bridge.set_bridge('torch') + num_frames = self.num_frames + + decord_vr = VideoReader(uri=video_path) + total_frames = len(decord_vr) + if strategy == 'base': + frame_id_list = np.linspace(0, total_frames - 1, num_frames, dtype=int) + elif strategy == 'chat': + timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames)) + timestamps = [i[0] for i in timestamps] + max_second = round(max(timestamps)) + 1 + frame_id_list = [] # type: ignore + for second in range(max_second): + closest_num = min(timestamps, key=lambda x: abs(x - second)) + index = timestamps.index(closest_num) + frame_id_list.append(index) # type: ignore + if len(frame_id_list) >= num_frames: + break + else: + frame_id_list = None + video_data: torch.Tensor = decord_vr.get_batch(frame_id_list) + video_data = video_data.permute(3, 0, 1, 2) + return video_data diff --git a/DPF/filters/videos/dover_filter.py b/DPF/filters/videos/dover_filter.py new file mode 100644 index 0000000..758353e --- /dev/null +++ b/DPF/filters/videos/dover_filter.py @@ -0,0 +1,193 @@ +import os +from io import BytesIO +from typing import Any +from urllib.request import urlretrieve + +import decord +import numpy as np +import torch +import yaml +from decord import VideoReader +from dover.datasets import UnifiedFrameSampler, get_single_view # type: ignore +from dover.models import DOVER # type: ignore + +from DPF.types import ModalityToDataMapping + +from .video_filter import VideoFilter + +WEIGHTS_URL = {'dover': 'https://github.com/QualityAssessment/DOVER/releases/download/v0.1.0/DOVER.pth', + 'dover_plus_plus': 'https://huggingface.co/teowu/DOVER/resolve/main/DOVER_plus_plus.pth', + 'dover-mobile': 'https://github.com/QualityAssessment/DOVER/releases/download/v0.5.0/DOVER-Mobile.pth'} + +CONFIGS_URL = {'dover': 'https://raw.githubusercontent.com/teowu/DOVER-Dev/master/dover.yml', + 'dover_plus_plus': 'https://raw.githubusercontent.com/teowu/DOVER-Dev/master/dover.yml', + 'dover-mobile': 'https://raw.githubusercontent.com/teowu/DOVER-Dev/master/dover-mobile.yml'} + + +def fuse_results(results: list[float]) -> dict[str, np.ndarray]: # type: ignore + t, a = (results[0] + 0.0758) / 0.0129, (results[1] - 0.1253) / 0.0318 + # t, a = (results[0] - 0.1107) / 0.07355, (results[1] + 0.08285) / 0.03774 + x = t * 0.6104 + a * 0.3896 + return { + "aesthetic": 1 / (1 + np.exp(-a)), + "technical": 1 / (1 + np.exp(-t)), + "overall": 1 / (1 + np.exp(-x)), + } + + +def spatial_temporal_view_decomposition( # type: ignore + video_path: str | BytesIO, sample_types: dict, samplers: dict, is_train: bool = False, augment: bool = False, # type: ignore +): + video = {} + decord.bridge.set_bridge("torch") + vreader = VideoReader(video_path) + ### Avoid duplicated video decoding!!! Important!!!! + all_frame_inds = [] + frame_inds = {} + for stype in samplers: + frame_inds[stype] = samplers[stype](len(vreader), is_train) + all_frame_inds.append(frame_inds[stype]) + + ### Each frame is only decoded one time!!! + all_frame_inds = np.concatenate(all_frame_inds, 0) + frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)} + + for stype in samplers: + imgs = [frame_dict[idx] for idx in frame_inds[stype]] + video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) + + sampled_video = {} + for stype, sopt in sample_types.items(): + sampled_video[stype] = get_single_view(video[stype], stype, **sopt) + return sampled_video, frame_inds + + +class DOVERFilter(VideoFilter): + """ + DOVER model inference class to get video quality scores. + More info about the model here: https://github.com/teowu/DOVER/ + + Parameters + ---------- + weights_folder: str + Path to the folder where the weights are located. + If there are no weights, they will be downloaded automatically + model_name: str = "dover" + "dover_plus_plus", "dover" or "dover-mobile" version of the model + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating flow scores + pbar: bool = True + Whether to use a progress bar + """ + + def __init__( + self, + weights_folder: str, + model_name: str = 'dover_plus_plus', + device: str = "cuda:0", + workers: int = 16, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.device = device + + self.model_name = model_name + self.weights_folder = weights_folder + + # Download checkpoints and configs + path_to_model = os.path.join(self.weights_folder, self.model_name + '.pth') + if not os.path.exists(path_to_model): + os.makedirs(self.weights_folder, exist_ok=True) + urlretrieve(WEIGHTS_URL[self.model_name], path_to_model) + path_to_config = os.path.join(self.weights_folder, self.model_name + '.yml') + if not os.path.exists(path_to_config): + os.makedirs(self.weights_folder, exist_ok=True) + urlretrieve(CONFIGS_URL[self.model_name], path_to_config) + + # Load model + with open(path_to_config) as f: + opt = yaml.safe_load(f) + self.model = DOVER(**opt["model"]["args"]).to(self.device) + state_dict = torch.load(path_to_model, map_location=self.device) + if self.model_name == 'dover_plus_plus': + state_dict = state_dict['state_dict'] + self.model.load_state_dict(state_dict) + + self.dopt = opt["data"]["val-l1080p"]["args"] + + @property + def result_columns(self) -> list[str]: + return ["dover_aesthetic", "dover_technical", "dover_overall"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": 1, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + video_file = BytesIO(modality2data['video']) + + mean, std = ( + torch.FloatTensor([123.675, 116.28, 103.53]), + torch.FloatTensor([58.395, 57.12, 57.375]) + ) + + temporal_samplers = {} + for stype, sopt in self.dopt["sample_types"].items(): + if "t_frag" not in sopt: + # resized temporal sampling for TQE in DOVER + temporal_samplers[stype] = UnifiedFrameSampler( + sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"] + ) + else: + # temporal sampling for AQE in DOVER + temporal_samplers[stype] = UnifiedFrameSampler( + sopt["clip_len"] // sopt["t_frag"], + sopt["t_frag"], + sopt["frame_interval"], + sopt["num_clips"], + ) + + ### View Decomposition + views, _ = spatial_temporal_view_decomposition( + video_file, self.dopt["sample_types"], temporal_samplers + ) + + for k, v in views.items(): + num_clips = self.dopt["sample_types"][k].get("num_clips", 1) + views[k] = ( + ((v.permute(1, 2, 3, 0) - mean) / std) + .permute(3, 0, 1, 2) + .reshape(v.shape[0], num_clips, -1, *v.shape[2:]) + .transpose(0, 1) + ) + + return key, views + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + key, views = batch[0] + for k, v in views.items(): + views[k] = v.to(self.device) + + with torch.no_grad(): + results = [r.mean().item() for r in self.model(views)] + rescaled_results = fuse_results(results) + df_batch_labels[self.key_column].append(key) + df_batch_labels[self.schema[1]].append(rescaled_results['aesthetic']) + df_batch_labels[self.schema[2]].append(rescaled_results['technical']) + df_batch_labels[self.schema[3]].append(rescaled_results['overall']) + return df_batch_labels diff --git a/DPF/filters/videos/farneback_filter.py b/DPF/filters/videos/farneback_filter.py index 760abc4..0e876d0 100644 --- a/DPF/filters/videos/farneback_filter.py +++ b/DPF/filters/videos/farneback_filter.py @@ -18,7 +18,7 @@ def transform_frame(frame: MatLike, target_size: tuple[int, int]) -> MatLike: def transform_keep_ar(frame: MatLike, min_side_size: int) -> MatLike: - h, w = frame.shape[:2] + h, w = frame.shape[:2] # type: ignore aspect_ratio = w / h if h <= w: new_height = min_side_size @@ -155,5 +155,5 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: for data in batch: key, mean_optical_flow = data df_batch_labels[self.key_column].append(key) - df_batch_labels[self.result_columns[0]].append(round(mean_optical_flow, 3)) + df_batch_labels[self.result_columns[0]].append(mean_optical_flow) return df_batch_labels diff --git a/DPF/filters/videos/raft_filter.py b/DPF/filters/videos/raft_filter.py index 6193e7b..0a44b8c 100644 --- a/DPF/filters/videos/raft_filter.py +++ b/DPF/filters/videos/raft_filter.py @@ -30,7 +30,7 @@ def transform_frame(frame: MatLike, target_size: tuple[int, int]) -> Tensor: def transform_keep_ar(frame: MatLike, min_side_size: int) -> Tensor: - h, w = frame.shape[:2] + h, w = frame.shape[:2] # type: ignore aspect_ratio = w / h if h <= w: new_height = min_side_size @@ -196,5 +196,5 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: mean_value = np.mean(mean_magnitudes) df_batch_labels[self.key_column].append(key) - df_batch_labels[self.schema[1]].append(round(mean_value, 3)) + df_batch_labels[self.schema[1]].append(mean_value) return df_batch_labels diff --git a/DPF/filters/videos/rpknet_filter.py b/DPF/filters/videos/rpknet_filter.py new file mode 100644 index 0000000..ea5dea8 --- /dev/null +++ b/DPF/filters/videos/rpknet_filter.py @@ -0,0 +1,172 @@ +import io +from typing import Any, Optional + +import cv2 +import imageio.v3 as iio +import numpy as np +import ptlflow +import torch +import torch.nn.functional as F +from cv2.typing import MatLike +from torch import Tensor + +from DPF.types import ModalityToDataMapping + +from .video_filter import VideoFilter + +WEIGHTS_URL = 'https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip' + + +def transform_keep_ar(frame: MatLike, min_side_size: int) -> Tensor: + h, w = frame.shape[:2] # type: ignore + aspect_ratio = w / h + if h <= w: + new_height = min_side_size + new_width = int(aspect_ratio * new_height) + else: + new_width = min_side_size + new_height = int(new_width / aspect_ratio) + + frame = cv2.resize(frame, dsize=(new_width, new_height), interpolation=cv2.INTER_LINEAR) + frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float()[None] + + padder = InputPadder(frame_tensor.shape) # type: ignore + frame_tensor = padder.pad(frame_tensor)[0] + return frame_tensor + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims: list[int], mode: str = 'sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, + pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, + 0, pad_ht] + + def pad(self, *inputs) -> list[Tensor]: # type: ignore + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x: Tensor) -> Tensor: + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +class RPKnetOpticalFlowFilter(VideoFilter): + """ + RPKnet model inference class to get mean optical flow each video. + The video's current and next frame are used for optical flow calculation between them. + After, the mean value of optical flow for the entire video is calculated on the array of optical flow between two frames. + More info about the model here: https://github.com/hmorimitsu/ptlflow + + Parameters + ---------- + pass_frames: int = 12 + Number of frames to pass. pass_frames = 1, if need to process all frames. + num_passes: Optional[int] = None + Number of flow scores calculations in one video. Set None to calculate flow scores on all video + min_frame_size: int = 512 + The size of the minimum side of the video frame after resizing + norm: bool = True + Normalize flow or not + frames_batch_size: int = 16 + Batch size during one video processing + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating flow scores + pbar: bool = True + Whether to use a progress bar + """ + + def __init__( + self, + pass_frames: int = 10, + num_passes: Optional[int] = None, + min_frame_size: int = 512, + norm: bool = True, + frames_batch_size: int = 16, + device: str = "cuda:0", + workers: int = 16, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.device = device + + assert pass_frames >= 1, "Number of pass_frames should be greater or equal to 1." + self.pass_frames = pass_frames + self.num_passes = num_passes + self.min_frame_size = min_frame_size + self.frames_batch_size = frames_batch_size + self.norm = norm + + self.model = ptlflow.get_model('rpknet', pretrained_ckpt='things') + self.model.to(self.device) + self.model.eval() + + @property + def result_columns(self) -> list[str]: + return ["optical_flow_rpk_mean"] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": 1, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + video_file = modality2data['video'] + + frames = iio.imread(io.BytesIO(video_file), plugin="pyav") + max_frame_to_process = self.num_passes*self.pass_frames if self.num_passes else len(frames) + frames_transformed = [] + frames_transformed = [ + transform_keep_ar(frames[i], self.min_frame_size) + for i in range(self.pass_frames, min(max_frame_to_process+1, len(frames)), self.pass_frames) + ] + return key, frames_transformed + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + for data in batch: + key, frames = data + magnitudes: list[float] = [] + with torch.no_grad(): + for i in range(0, len(frames)-1, self.frames_batch_size): + end = min(i+self.frames_batch_size, len(frames)-1) + current_frame = torch.cat(frames[i:end], dim=0) + next_frame = torch.cat(frames[i+1:i+self.frames_batch_size+1], dim=0) + + current_frame_cuda = current_frame.to(self.device) + next_frame_cuda = next_frame.to(self.device) + + inputs = torch.stack([current_frame_cuda, next_frame_cuda], dim=1) + + flow = self.model({'images': inputs})['flows'][:, 0] + if self.norm: + h, w = current_frame.shape[-2:] + flow[:, 0] = flow[:, 0] / w + flow[:, 1] = flow[:, 1] / h + magnitude = ((flow[:,0]**2+flow[:,1]**2)**0.5).detach().cpu().numpy() + magnitudes.extend(magnitude) + mean_value = np.mean(magnitudes) + + df_batch_labels[self.key_column].append(key) + df_batch_labels[self.schema[1]].append(mean_value) + return df_batch_labels diff --git a/DPF/filters/videos/structural_dynamics_filter.py b/DPF/filters/videos/structural_dynamics_filter.py new file mode 100644 index 0000000..4dc70d1 --- /dev/null +++ b/DPF/filters/videos/structural_dynamics_filter.py @@ -0,0 +1,158 @@ +import io +from typing import Any + +import cv2 +import imageio.v3 as iio +import numpy as np +import torch +import torch.nn.functional as F +from cv2.typing import MatLike +from pytorch_msssim import MS_SSIM +from torch import Tensor + +from DPF.types import ModalityToDataMapping + +from .video_filter import VideoFilter + + +def transform_keep_ar(frame: MatLike, min_side_size: int) -> Tensor: + h, w = frame.shape[:2] # type: ignore + aspect_ratio = w / h + if h <= w: + new_height = min_side_size + new_width = int(aspect_ratio * new_height) + else: + new_width = min_side_size + new_height = int(new_width / aspect_ratio) + + frame = cv2.resize(frame, dsize=(new_width, new_height), interpolation=cv2.INTER_LINEAR) + frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float()[None] + + padder = InputPadder(frame_tensor.shape) # type: ignore + frame_tensor = padder.pad(frame_tensor)[0] + return frame_tensor + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims: list[int], mode: str = 'sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, + pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, + 0, pad_ht] + + def pad(self, *inputs) -> list[Tensor]: # type: ignore + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x: Tensor) -> Tensor: + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +class StructuralDynamicsFilter(VideoFilter): + """ + + Structural dynamics score from https://arxiv.org/pdf/2407.01094 + The video's current and next frame are used for MS-SSIM calculation between them. + After, the mean value of scores for the entire video is calculated on the array of scores between two frames. + + Parameters + ---------- + pass_frames: int = 12 + Number of frames to pass. pass_frames = 1, if need to process all frames. + min_frame_size: int = 512 + The size of the minimum side of the video frame after resizing + frames_batch_size: int = 16 + Batch size during one video processing + device: str = "cuda:0" + Device to use + workers: int = 16 + Number of processes to use for reading data and calculating flow scores + pbar: bool = True + Whether to use a progress bar + """ + + def __init__( + self, + pass_frames: int = 10, + min_frame_size: int = 512, + frames_batch_size: int = 16, + device: str = "cuda:0", + workers: int = 16, + pbar: bool = True, + _pbar_position: int = 0 + ): + super().__init__(pbar, _pbar_position) + self.num_workers = workers + self.device = device + + assert pass_frames >= 1, "Number of pass_frames should be greater or equal to 1." + self.pass_frames = pass_frames + self.min_frame_size = min_frame_size + self.frames_batch_size = frames_batch_size + self.model = MS_SSIM(data_range=255, size_average=False, channel=3, win_size=11) + + @property + def result_columns(self) -> list[str]: + return ["structural_dynamics", 'structural_dynamics_max', 'structural_dynamics_min'] + + @property + def dataloader_kwargs(self) -> dict[str, Any]: + return { + "num_workers": self.num_workers, + "batch_size": 1, + "drop_last": False, + } + + def preprocess_data( + self, + modality2data: ModalityToDataMapping, + metadata: dict[str, Any] + ) -> Any: + key = metadata[self.key_column] + video_file = modality2data['video'] + + frames = iio.imread(io.BytesIO(video_file), plugin="pyav") + frames_transformed = [] + frames_transformed = [ + transform_keep_ar(frames[i], self.min_frame_size) + for i in range(self.pass_frames, len(frames), self.pass_frames) + ] + return key, frames_transformed + + def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]: + df_batch_labels = self._get_dict_from_schema() + + values: list[float] = [] + for data in batch: + key, frames = data + with torch.no_grad(): + for i in range(0, len(frames)-1, self.frames_batch_size): + end = min(i+self.frames_batch_size, len(frames)-1) + current_frame = torch.cat(frames[i:end], dim=0) + next_frame = torch.cat(frames[i+1:i+self.frames_batch_size+1], dim=0) + + current_frame_cuda = current_frame.to(self.device) + next_frame_cuda = next_frame.to(self.device) + + ssim = self.model( + current_frame_cuda, + next_frame_cuda + ) + values.extend(ssim.detach().cpu().numpy()) + mean_value = np.mean(values) + mn = np.min(values) + mx = np.max(values) + + df_batch_labels[self.key_column].append(key) + df_batch_labels[self.schema[1]].append(mean_value) + df_batch_labels[self.schema[2]].append(mx) + df_batch_labels[self.schema[3]].append(mn) + return df_batch_labels diff --git a/pyproject.toml b/pyproject.toml index b59d1e8..26f899c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,10 @@ filters = [ 'py3langid', 'deep_translator', 'huggingface_hub', + 'opencv-contrib-python', + 'protobuf==3.20.0', + 'pytorch_msssim', + 'ptlflow', 'videohash' ] nsfw_detector = ['tensorflow', 'autokeras'] @@ -70,6 +74,36 @@ grounding_gpt = [ 'torchvision==0.16.2', 'torchaudio==2.1.2' ] +dover = [ + 'DOVER @ git+https://github.com/teowu/DOVER-Dev' +] +complexity = [ + 'segment-anything @ git+https://github.com/facebookresearch/segment-anything.git' +] +cogvlm = [ + 'pydantic==1.10.14', + 'opencv-python==4.5.5.64', + 'decord>=0.6.0', + 'torch==2.1.0', + 'torchvision== 0.16.0', + 'pytorchvideo==0.1.5', + 'transformers==4.40', + 'pillow', + 'chainlit>=1.0', + 'pydantic>=2.7.1', + 'timm>=0.9.16', + 'openai>=1.30.1', + 'loguru>=0.7.2', + 'einops', + 'sse-starlette>=2.1.0', + 'bitsandbytes>=0.43.1', + 'flask', + 'gunicorn', + 'gevent', + 'requests', + 'xformers', + 'huggingface-hub>=0.23.0', +] [tool.hatch.version] path = "DPF/__init__.py"