22import os
33import urllib .request
44from typing import Callable , Dict , Generator , Optional , Tuple
5+ from typing_extensions import Literal
56from uuid import uuid4
67
78import cv2
89import numpy as np
910from pydantic import root_validator
1011
1112from .base_data import BaseData
13+ from ..types import TypedArray
1214
1315logger = logging .getLogger (__name__ )
1416
@@ -19,7 +21,7 @@ class VideoData(BaseData):
1921 """
2022 file_path : Optional [str ] = None
2123 url : Optional [str ] = None
22- frames : Optional [Dict [int , np . ndarray ]] = None
24+ frames : Optional [Dict [int , TypedArray [ Literal [ 'uint8' ]] ]] = None
2325
2426 def load_frames (self , overwrite : bool = False ) -> None :
2527 """
@@ -33,8 +35,14 @@ def load_frames(self, overwrite: bool = False) -> None:
3335 return
3436
3537 for count , frame in self .frame_generator ():
38+ if self .frames is None :
39+ self .frames = {}
3640 self .frames [count ] = frame
3741
42+ @property
43+ def data (self ):
44+ return self .frame_generator ()
45+
3846 def frame_generator (
3947 self ,
4048 cache_frames = False ,
@@ -48,26 +56,27 @@ def frame_generator(
4856 download_dir (str): Directory to save the video to. Defaults to `/tmp` dir
4957 """
5058 if self .frames is not None :
51- for idx , img in self .frames .items ():
52- yield idx , img
59+ for idx , frame in self .frames .items ():
60+ yield idx , frame
5361 return
5462 elif self .url and not self .file_path :
5563 file_path = os .path .join (download_dir , f"{ uuid4 ()} .mp4" )
5664 logger .info ("Downloading the video locally to %s" , file_path )
57- urllib . request . urlretrieve ( self .url , file_path )
65+ self .fetch_remote ( file_path )
5866 self .file_path = file_path
5967
6068 vidcap = cv2 .VideoCapture (self .file_path )
6169
62- success , img = vidcap .read ()
70+ success , frame = vidcap .read ()
6371 count = 0
64- self .frames = {}
72+ if cache_frames :
73+ self .frames = {}
6574 while success :
66- img = img [:, :, ::- 1 ]
67- yield count , img
75+ frame = frame [:, :, ::- 1 ]
76+ yield count , frame
6877 if cache_frames :
69- self .frames [count ] = img
70- success , img = vidcap .read ()
78+ self .frames [count ] = frame
79+ success , frame = vidcap .read ()
7180 count += 1
7281
7382 def __getitem__ (self , idx : int ) -> np .ndarray :
@@ -77,6 +86,18 @@ def __getitem__(self, idx: int) -> np.ndarray:
7786 )
7887 return self .frames [idx ]
7988
89+ def fetch_remote (self , local_path ) -> None :
90+ """
91+ Method for downloading data from self.url
92+
93+ If url is not publicly accessible or requires another access pattern
94+ simply override this function
95+
96+ Args:
97+ local_path: Where to save the thing too.
98+ """
99+ urllib .request .urlretrieve (self .url , local_path )
100+
80101 def create_url (self , signer : Callable [[bytes ], str ]) -> None :
81102 """
82103 Utility for creating a url from any of the other video references.
@@ -134,7 +155,5 @@ def validate_data(cls, values):
134155 return values
135156
136157 class Config :
137- # Required for numpy arrays
138- arbitrary_types_allowed = True
139158 # Required for discriminating between data types
140159 extra = 'forbid'
0 commit comments