Skip to content

Commit af63f87

Browse files
authored
restore i3d + itegration test (#562)
* restore i3d + itegration test * pipeline tiemout update * fix * action rec integratino * black
1 parent cb51b0f commit af63f87

25 files changed

+1503
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The following is a summary of commonly used Computer Vision scenarios that are c
5454
| [Detection](scenarios/detection) | Base | Object Detection is a technique that allows you to detect the bounding box of an object within an image. |
5555
| [Keypoints](scenarios/keypoints) | Base | Keypoint detection can be used to detect specific points on an object. A pre-trained model is provided to detect body joints for human pose estimation. |
5656
| [Segmentation](scenarios/segmentation) | Base | Image Segmentation assigns a category to each pixel in an image. |
57-
| [Action recognition](scenarios/action_recognition) | Base | Action recognition to identify in video/webcam footage what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times.|
57+
| [Action recognition](scenarios/action_recognition) | Base | Action recognition to identify in video/webcam footage what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times. We also implemented the i3d implementation of action recognition that can be found under (contrib)[contrib]. |
5858
| [Crowd counting](contrib/crowd_counting) | Contrib | Counting the number of people in low-crowd-density (e.g. less than 10 people) and high-crowd-density (e.g. thousands of people) scenarios.|
5959

6060
We separate the supported CV scenarios into two locations: (i) **base**: code and notebooks within the "utils_cv" and "scenarios" folders which follow strict coding guidelines, are well tested and maintained; (ii) **contrib**: code and other assets within the "contrib" folder, mainly covering less common CV scenarios using bleeding edge state-of-the-art approaches. Code in "contrib" is not regularly tested or maintained.

contrib/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Each project should live in its own subdirectory ```/contrib/<project>``` and co
99
| Directory | Project description | Build status (optional) |
1010
|---|---|---|
1111
| [Crowd counting](crowd_counting) | Counting the number of people in low-crowd-density (e.g. less than 10 people) and high-crowd-density (e.g. thousands of people) scenarios. | [![Build Status](https://dev.azure.com/team-sharat/crowd-counting/_apis/build/status/lixzhang.cnt?branchName=lixzhang%2Fsubmodule-rev3)](https://dev.azure.com/team-sharat/crowd-counting/_build/latest?definitionId=49&branchName=lixzhang%2Fsubmodule-rev3)|
12+
| [Action Recognition with I3D](action_recognition) | Action recognition to identify video/webcam footage from what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times. Please note, that we also have a R(2+1)D implementation of action recognition that you can find under [scenarios](../sceanrios).| |
1213

1314
## Tools
1415
| Directory | Project description | Build status (optional) |
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Action Recognition
2+
3+
This directory contains resources for building video-based action recognition systems.
4+
5+
Action recognition (also known as activity recognition) consists of classifying various actions from a sequence of frames:
6+
7+
![](./media/action_recognition2.gif "Example of action recognition")
8+
9+
We implemented two state-of-the-art approaches: (i) [I3D](https://arxiv.org/pdf/1705.07750.pdf) and (ii) [R(2+1)D](https://arxiv.org/abs/1711.11248). This includes example notebooks for e.g. scoring of webcam footage or fine-tuning on the [HMDB-51](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) dataset. The latter can be accessed under [scenarios](../scenarios) at the root level.
10+
11+
We recommend to use the **R(2+1)D** model for its competitive accuracy, fast inference speed, and less dependencies on other packages. For both approaches, using our implementations, we were able to reproduce reported accuracies:
12+
13+
| Model | Reported in the paper | Our results |
14+
| ------- | -------| ------- |
15+
| R(2+1)D-34 RGB | 79.6% | 79.8% |
16+
| I3D RGB | 74.8% | 73.7% |
17+
| I3D Optical flow | 77.1% | 77.5% |
18+
| I3D Two-Stream | 80.7% | 81.2% |
19+
20+
21+
## Projects
22+
23+
| Directory | Description |
24+
| -------- | ----------- |
25+
| [i3d](i3d) | Scripts for fine-tuning a pre-trained I3D model on HMDB-51
26+
dataset. |
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
__pycache__/
2+
models/__pycache__/
3+
log/
4+
.vscode/
5+
checkpoints/
6+
pretrained_models/
7+
inference/.ipynb_checkpoints/
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
## Fine-tuning I3D model on HMDB-51
2+
3+
In this section we provide code for training a Two-Stream Inflated 3D ConvNet (I3D), introduced in \[[1](https://arxiv.org/pdf/1705.07750.pdf)\]. Our implementation uses the Pytorch models (and code) provided in [https://github.com/piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d) - which have been pre-trained on the Kinetics Human Action Video dataset - and fine-tunes the models on the HMDB-51 action recognition dataset. The I3D model consists of two "streams" which are independently trained models. One stream takes the RGB image frames from videos as input and the other stream takes pre-computed optical flow as input. At test time, the outputs of each stream model are averaged to make the final prediction. The model results are as follows:
4+
5+
| Model | Paper top 1 accuracy (average over 3 splits) | Our models top 1 accuracy (split 1 only) |
6+
| ------- | -------| ------- |
7+
| RGB | 74.8 | 73.7 |
8+
| Optical flow | 77.1 | 77.5 |
9+
| Two-Stream | 80.7 | 81.2 |
10+
11+
## Download and pre-process HMDB-51 data
12+
13+
Download the HMDB-51 video database from [here](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). Extract the videos with
14+
```
15+
mkdir rars && mkdir videos
16+
unrar x hmdb51-org.rar rars/
17+
for a in $(ls rars); do unrar x "rars/${a}" videos/; done;
18+
```
19+
20+
Use code provided in [https://github.com/yjxiong/temporal-segment-networks](https://github.com/yjxiong/temporal-segment-networks) to preprocess the raw videos into split videos into RGB frames and compute optical flow frames:
21+
```
22+
git clone https://github.com/yjxiong/temporal-segment-networks
23+
cd temporal-segment-networks
24+
bash scripts/extract_optical_flow.sh /path/to/hmdb51/videos /path/to/rawframes/output
25+
```
26+
Edit the _C.DATASET.DIR option in [default.py](default.py) to point towards the rawframes input data directory.
27+
28+
## Setup environment
29+
Setup environment
30+
31+
```
32+
conda env create -f environment.yaml
33+
conda activate i3d
34+
```
35+
36+
## Download pretrained models
37+
Download pretrained models
38+
39+
```
40+
bash download_models.sh
41+
```
42+
43+
## Fine-tune pretrained models on HMDB-51
44+
45+
Train RGB model
46+
```
47+
python train.py --cfg config/train_rgb.yaml
48+
```
49+
50+
Train flow model
51+
```
52+
python train.py --cfg config/train_flow.yaml
53+
```
54+
55+
Evaluate combined model
56+
```
57+
python test.py
58+
```
59+
60+
\[1\] J. Carreira and A. Zisserman. Quo vadis, action recognition?
61+
a new model and the kinetics dataset. In CVPR, 2017.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MODEL:
2+
NAME: "i3d_flow"
3+
TRAIN:
4+
MODALITY: "flow"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MODEL:
2+
NAME: "i3d_rgb"
3+
TRAIN:
4+
MODALITY: "RGB"
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
# Adapted from https://github.com/feiyunzhang/i3d-non-local-pytorch/blob/master/dataset.py
5+
6+
import torch.utils.data as data
7+
import torch
8+
9+
from PIL import Image
10+
import os
11+
import os.path
12+
import numpy as np
13+
from numpy.random import randint
14+
from pathlib import Path
15+
16+
import torchvision
17+
from torchvision import datasets, transforms
18+
from videotransforms import (
19+
GroupRandomCrop, GroupRandomHorizontalFlip,
20+
GroupScale, GroupCenterCrop, GroupNormalize, Stack
21+
)
22+
23+
from itertools import cycle
24+
25+
26+
class VideoRecord(object):
27+
def __init__(self, row):
28+
self._data = row
29+
30+
@property
31+
def path(self):
32+
return self._data[0]
33+
34+
@property
35+
def num_frames(self):
36+
return int(
37+
len([x for x in Path(
38+
self._data[0]).glob('img_*')])-1)
39+
40+
@property
41+
def label(self):
42+
return int(self._data[1])
43+
44+
45+
class I3DDataSet(data.Dataset):
46+
def __init__(self, data_root, split=1, sample_frames=64,
47+
modality='RGB', transform=lambda x:x,
48+
train_mode=True, sample_frames_at_test=False):
49+
50+
self.data_root = data_root
51+
self.split = split
52+
self.sample_frames = sample_frames
53+
self.modality = modality
54+
self.transform = transform
55+
self.train_mode = train_mode
56+
self.sample_frames_at_test = sample_frames_at_test
57+
58+
self._parse_split_files()
59+
60+
61+
def _parse_split_files(self):
62+
# class labels assigned by sorting the file names in /data/hmdb51_splits directory
63+
file_list = sorted(Path('./data/hmdb51_splits').glob('*'+str(self.split)+'.txt'))
64+
video_list = []
65+
for class_idx, f in enumerate(file_list):
66+
class_name = str(f).strip().split('/')[2][:-16]
67+
for line in open(f):
68+
tokens = line.strip().split(' ')
69+
video_path = self.data_root+class_name+'/'+tokens[0][:-4]
70+
record = (video_path, class_idx)
71+
# 1 indicates video should be in training set
72+
if self.train_mode & (tokens[-1] == '1'):
73+
video_list.append(VideoRecord(record))
74+
# 2 indicates video should be in test set
75+
elif (self.train_mode == False) & (tokens[-1] == '2'):
76+
video_list.append(VideoRecord(record))
77+
78+
self.video_list = video_list
79+
80+
81+
def _load_image(self, directory, idx):
82+
if self.modality == 'RGB':
83+
img_path = os.path.join(directory, 'img_{:05}.jpg'.format(idx))
84+
try:
85+
img = Image.open(img_path).convert('RGB')
86+
except:
87+
print("Couldn't load image:{}".format(img_path))
88+
return None
89+
return img
90+
else:
91+
try:
92+
img_path = os.path.join(directory, 'flow_x_{:05}.jpg'.format(idx))
93+
x_img = Image.open(img_path).convert('L')
94+
except:
95+
print("Couldn't load image:{}".format(img_path))
96+
return None
97+
try:
98+
img_path = os.path.join(directory, 'flow_y_{:05}.jpg'.format(idx))
99+
y_img = Image.open(img_path).convert('L')
100+
except:
101+
print("Couldn't load image:{}".format(img_path))
102+
return None
103+
# Combine flow images into single PIL image
104+
x_img = np.array(x_img, dtype=np.float32)
105+
y_img = np.array(y_img, dtype=np.float32)
106+
img = np.asarray([x_img, y_img]).transpose([1, 2, 0])
107+
img = Image.fromarray(img.astype('uint8'))
108+
return img
109+
110+
111+
def _sample_indices(self, record):
112+
if record.num_frames > self.sample_frames:
113+
start_pos = randint(record.num_frames - self.sample_frames + 1)
114+
indices = range(start_pos, start_pos + self.sample_frames, 1)
115+
else:
116+
indices = [x for x in range(record.num_frames)]
117+
if len(indices) < self.sample_frames:
118+
self._loop_indices(indices)
119+
return indices
120+
121+
122+
def _loop_indices(self, indices):
123+
indices_cycle = cycle(indices)
124+
while len(indices) < self.sample_frames:
125+
indices.append(next(indices_cycle))
126+
127+
128+
def __getitem__(self, index):
129+
record = self.video_list[index]
130+
# Sample frames from the the video for training, or if sampling
131+
# turned on at test time
132+
if self.train_mode or self.sample_frames_at_test:
133+
segment_indices = self._sample_indices(record)
134+
else:
135+
segment_indices = [i for i in range(record.num_frames)]
136+
# Image files are 1-indexed
137+
segment_indices = [i+1 for i in segment_indices]
138+
# Get video frame images
139+
images = []
140+
for i in segment_indices:
141+
seg_img = self._load_image(record.path, i)
142+
if seg_img is None:
143+
raise ValueError("Couldn't load", record.path, i)
144+
images.append(seg_img)
145+
# Apply transformations
146+
transformed_images = self.transform(images)
147+
148+
return transformed_images, record.label
149+
150+
151+
def __len__(self):
152+
return len(self.video_list)
153+
154+
155+
if __name__ == '__main__':
156+
157+
input_size = 224
158+
resize_small_edge = 256
159+
160+
train_rgb = I3DDataSet(
161+
data_root='/datadir/rawframes/',
162+
split=1,
163+
sample_frames = 64,
164+
modality='RGB',
165+
train_mode=True,
166+
sample_frames_at_test=False,
167+
transform=torchvision.transforms.Compose([
168+
GroupScale(resize_small_edge),
169+
GroupRandomCrop(input_size),
170+
GroupRandomHorizontalFlip(),
171+
GroupNormalize(modality="RGB"),
172+
Stack(),
173+
])
174+
)
175+
item = train_rgb.__getitem__(10)
176+
print("train_rgb:")
177+
print(item[0].size())
178+
print("max=", item[0].max())
179+
print("min=", item[0].min())
180+
print("label=",item[1])
181+
182+
val_rgb = I3DDataSet(
183+
data_root='/datadir/rawframes/',
184+
split=1,
185+
sample_frames = 64,
186+
modality='RGB',
187+
train_mode=False,
188+
sample_frames_at_test=False,
189+
transform=torchvision.transforms.Compose([
190+
GroupScale(resize_small_edge),
191+
GroupCenterCrop(input_size),
192+
GroupNormalize(modality="RGB"),
193+
Stack(),
194+
])
195+
)
196+
item = val_rgb.__getitem__(10)
197+
print("val_rgb:")
198+
print(item[0].size())
199+
print("max=", item[0].max())
200+
print("min=", item[0].min())
201+
print("label=",item[1])
202+
203+
train_flow = I3DDataSet(
204+
data_root='/datadir/rawframes/',
205+
split=1,
206+
sample_frames = 64,
207+
modality='flow',
208+
train_mode=True,
209+
sample_frames_at_test=False,
210+
transform=torchvision.transforms.Compose([
211+
GroupScale(resize_small_edge),
212+
GroupRandomCrop(input_size),
213+
GroupRandomHorizontalFlip(),
214+
GroupNormalize(modality="flow"),
215+
Stack(),
216+
])
217+
)
218+
item = train_flow.__getitem__(100)
219+
print("train_flow:")
220+
print(item[0].size())
221+
print("max=", item[0].max())
222+
print("min=", item[0].min())
223+
print("label=",item[1])
224+
225+
val_flow = I3DDataSet(
226+
data_root='/datadir/rawframes/',
227+
split=1,
228+
sample_frames = 64,
229+
modality='flow',
230+
train_mode=False,
231+
sample_frames_at_test=False,
232+
transform=torchvision.transforms.Compose([
233+
GroupScale(resize_small_edge),
234+
GroupCenterCrop(input_size),
235+
GroupNormalize(modality="flow"),
236+
Stack(),
237+
])
238+
)
239+
item = val_flow.__getitem__(100)
240+
print("val_flow:")
241+
print(item[0].size())
242+
print("max=", item[0].max())
243+
print("min=", item[0].min())
244+
print("label=",item[1])

0 commit comments

Comments
 (0)