Skip to content

Commit c760153

Browse files
interval prune
1 parent 706d1c0 commit c760153

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

tsml/transformations/_interval_extraction.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class RandomIntervalTransformer(TransformerMixin, BaseTimeSeriesEstimator):
8181
The number of dimensions per case.
8282
series_length_ : int
8383
The length of each series.
84+
n_intervals_ : int
85+
8486
intervals_ : list of tuples
8587
Contains information for each feature extracted in fit. Each tuple contains the
8688
interval start, interval end, interval dimension and the feature(s) extracted.
@@ -190,12 +192,27 @@ def fit_transform(self, X, y=None):
190192
transformed_intervals,
191193
) = zip(*fit)
192194

193-
for i in intervals:
194-
self.intervals_.extend(i)
195+
current = []
196+
removed_idx = []
197+
self.n_intervals_ = 0
198+
for i, interval in enumerate(intervals):
199+
new_interval = (
200+
interval[0][0],
201+
interval[0][1],
202+
interval[0][2],
203+
interval[0][4],
204+
)
205+
if new_interval not in current:
206+
current.append(new_interval)
207+
self.intervals_.extend(interval)
208+
self.n_intervals_ += 1
209+
else:
210+
removed_idx.append(i)
195211

196212
Xt = transformed_intervals[0]
197213
for i in range(1, self.n_intervals):
198-
Xt = np.hstack((Xt, transformed_intervals[i]))
214+
if i not in removed_idx:
215+
Xt = np.hstack((Xt, transformed_intervals[i]))
199216

200217
return Xt
201218

@@ -219,8 +236,14 @@ def fit(self, X, y=None):
219236
_,
220237
) = zip(*fit)
221238

239+
current = []
240+
self.n_intervals_ = 0
222241
for i in intervals:
223-
self.intervals_.extend(i)
242+
interval = (i[0][0], i[0][1], i[0][2], i[0][4])
243+
if interval not in current:
244+
current.append(interval)
245+
self.intervals_.extend(i)
246+
self.n_intervals_ += 1
224247

225248
return self
226249

@@ -234,7 +257,7 @@ def transform(self, X, y=None):
234257
else:
235258
count = 0
236259
transform_features = []
237-
for _ in range(self.n_intervals):
260+
for _ in range(self.n_intervals_):
238261
for feature in self._features:
239262
if is_transformer(feature):
240263
nf = feature.n_transformed_features

tsml/transformations/tests/test_interval_extraction.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
# -*- coding: utf-8 -*-
2-
from tsml.transformations import Catch22Transformer, SupervisedIntervalTransformer
3-
from tsml.utils.numba_functions.stats import row_mean
2+
from tsml.transformations import (
3+
Catch22Transformer,
4+
RandomIntervalTransformer,
5+
SupervisedIntervalTransformer,
6+
)
7+
from tsml.utils.numba_functions.stats import row_mean, row_median
48
from tsml.utils.testing import generate_3d_test_data
59

610

11+
def test_interval_prune():
12+
X, y = generate_3d_test_data(random_state=0, n_channels=2, series_length=10)
13+
14+
rit = RandomIntervalTransformer(
15+
features=[row_mean, row_median],
16+
n_intervals=10,
17+
random_state=0,
18+
)
19+
X_t = rit.fit_transform(X, y)
20+
21+
assert X_t.shape == (10, 16)
22+
assert rit.transform(X).shape == (10, 16)
23+
24+
725
def test_supervised_transformers():
826
X, y = generate_3d_test_data(random_state=0)
927

0 commit comments

Comments
 (0)