diff --git a/matchzoo/dataloader/callbacks/__init__.py b/matchzoo/dataloader/callbacks/__init__.py index eaeb31f..f10e9ef 100755 --- a/matchzoo/dataloader/callbacks/__init__.py +++ b/matchzoo/dataloader/callbacks/__init__.py @@ -4,3 +4,4 @@ from .padding import BasicPadding from .padding import DRMMPadding from .padding import BertPadding +from .window import Window diff --git a/matchzoo/dataloader/callbacks/window.py b/matchzoo/dataloader/callbacks/window.py new file mode 100644 index 0000000..db984f1 --- /dev/null +++ b/matchzoo/dataloader/callbacks/window.py @@ -0,0 +1,127 @@ +from typing import List, Dict, Tuple +from itertools import product, chain, zip_longest + +import numpy as np + +import matchzoo as mz +from matchzoo.engine.base_callback import BaseCallback + + +class Window(BaseCallback): + """ + Generate document match window for each query term. + + :param half_window_size: half of the matching-window size, not including the + center word, so the full window size is 2 * half_window_size + 1 + :param max_match: a term should have fewer than max_match matching-windows, + the excess will be discarded + + Example: + >>> import matchzoo as mz + >>> from matchzoo.dataloader.callbacks import Ngram + >>> data = mz.datasets.toy.load_data() + >>> preprocessor = mz.preprocessors.BasicPreprocessor(ngram_size=3) + >>> data = preprocessor.fit_transform(data) + >>> callback = Ngram(preprocessor=preprocessor, mode='index') + >>> dataset = mz.dataloader.Dataset( + ... data, callbacks=[callback]) + >>> _ = dataset[0] + + """ + + def __init__( + self, + half_window_size: int = 5, + max_match: int = 20, + ): + """Init.""" + self._half_window_size = half_window_size + self._max_match = max_match + + def on_batch_unpacked(self, x, y): + """Extract `window_right`, `window_position_right`, `term_window_num` for `x`.""" + batch_size = len(x['text_left']) + x['window_right'] = [... for _ in range(batch_size)] + x['window_position_right'] = [... for _ in range(batch_size)] + x['term_window_num'] = [... for _ in range(batch_size)] + for idx, (query, query_len, doc, doc_len) in enumerate(zip( + x['text_left'], x['length_left'], x['text_right'], x['length_right'])): + window_right, window_position_right, term_window_num = \ + self._build_window(query, query_len, doc, doc_len) + x['window_right'][idx] = window_right + x['window_position_right'][idx] = window_position_right + x['term_window_num'][idx] = term_window_num + + array_query_window_num = np.array([array.shape[0] for array in x['window_right']]) + array_window_right = _pad_sequence(x['window_right'], pad_value=-1) + array_window_position_right = \ + _pad_sequence(x['window_position_right'], pad_value=-1) + array_term_window_num = _pad_sequence(x['term_window_num'], pad_value=-1) + + x['query_window_num'] = array_query_window_num + x['window_right'] = array_window_right + x['window_position_right'] = array_window_position_right + x['term_window_num'] = array_term_window_num + + def _build_window(self, query: list, query_len: int, doc: list, doc_len: int) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + window_of_term = [[] for _ in range(query_len)] + window_position_of_term = [[] for _ in range(query_len)] + window_num_of_term = [0 for _ in range(query_len)] + + # get doc window for each query term + for doc_window_position in range(doc_len): + padding_doc_window_position = doc_window_position + self._half_window_size + doc_term = doc[padding_doc_window_position] + for query_term_position in range(query_len): + if window_num_of_term[query_term_position] > self._max_match: + continue + query_term = query[query_term_position] + if query_term == doc_term: + window = self._get_window(doc=doc, center=padding_doc_window_position) + # window: list, len=full_window_size, element: int, token_id + window_of_term[query_term_position].append(window) + window_position_of_term[query_term_position].append( + doc_window_position) + window_num_of_term[query_term_position] += 1 + + # window_of_term: list[list[list[int]]]: len=query_len, + # window_of_term[i]: list[list]: len: window_num of term_i + # window_of_term[i][j]: list: len: len=full_window_size, + # window_of_term[i][j][k]: int, token_id + # + # window_position_of_term: list[list[int]]: len=query_len + # window_position_of_term[i]: list[int]: len: window_num + # window_position_of_term[i][j]: int, position index of window center + # + # window_num_of_term: list[int], len=query_len + # window_num_of_term[i]: int, window_num of term_i, sum() + + # flatten + window_of_term = list(chain.from_iterable(window_of_term)) + window_position_of_term = list(chain.from_iterable(window_position_of_term)) + + # to array + window_of_term = np.stack(window_of_term, axis=0) if len(window_of_term) > 0 \ + else np.zeros((0, 2 * self._half_window_size + 1), dtype=np.long) + window_position_of_term = np.array(window_position_of_term) + window_num_of_term = np.array(window_num_of_term) + + return window_of_term, window_position_of_term, window_num_of_term + + def _get_window(self, doc: list, center: int) -> list: + return doc[center - self._half_window_size: center + self._half_window_size + 1] + + +def _pad_sequence(list_of_array: List[np.ndarray], pad_value): + """Padding list of array to an array, like pytorch pad_sequence.""" + batch_size = len(list_of_array) + max_shape = \ + np.array([array.shape for array in list_of_array]).max(axis=0).tolist() + batch_array = \ + np.ones([batch_size, *max_shape], dtype=list_of_array[0].dtype) * pad_value + for i in range(batch_size): + array = list_of_array[i] + array_slice = [slice(None, end, None) for end in array.shape] + batch_array[(i, *array_slice)] = array + return batch_array diff --git a/matchzoo/models/__init__.py b/matchzoo/models/__init__.py index 161ab47..3cff0ae 100644 --- a/matchzoo/models/__init__.py +++ b/matchzoo/models/__init__.py @@ -22,6 +22,7 @@ from .re2 import RE2 from .hcrn import HCRN from .dynamic_clip import DynamicClip +from .deep_rank import DeepRank def list_available() -> list: from matchzoo.engine.base_model import BaseModel diff --git a/matchzoo/models/deep_rank.py b/matchzoo/models/deep_rank.py new file mode 100644 index 0000000..887b6da --- /dev/null +++ b/matchzoo/models/deep_rank.py @@ -0,0 +1,410 @@ +"""An implementation of DSSM, Deep Structured Semantic Model.""" +import typing + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +import numpy as np + +from matchzoo import preprocessors +from matchzoo.engine.param_table import ParamTable +from matchzoo.engine.param import Param +from matchzoo.engine.base_model import BaseModel +from matchzoo.engine.base_preprocessor import BasePreprocessor +from matchzoo.dataloader import callbacks +from matchzoo.engine.base_callback import BaseCallback + + +class DeepRank(BaseModel): + """ + Deep structured semantic model. + + Examples: + >>> model = DeepRank() + >>> embedding_matrix = np.ones((3000, 50), dtype=float) + >>> term_weight_embedding_matrix = np.ones((3000, 1), dtype=float) + >>> model.params['embedding'] = embedding_matrix + >>> model.params['embedding_input_dim'] = embedding_matrix.shape[0] + >>> model.params['embedding_output_dim'] = embedding_matrix.shape[1] + >>> model.params['term_weight_embedding'] = term_weight_embedding_matrix + >>> model.params['embedding_freeze'] = False + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.build() + + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """:return: model default parameters.""" + params = super().get_default_params( + with_multi_layer_perceptron=False, with_embedding=True) + params.add(Param(name='term_weight_embedding', + desc="Query term weight embedding matrix")) + params.add(Param(name='reduce_out_dim', value=1, + desc="Output dimension of word embedding reduction")) + params.add(Param(name='reduce_conv_kernel_size', value=1, + desc="Kernel size of convolution word embedding reduction")) + params.add(Param(name='reduce_conv_stride', value=1, + desc="Stride of convolution word embedding reduction")) + params.add(Param(name='reduce_conv_padding', value=0, + desc="Zero-padding added to both side of convolution \ + word embedding reduction")) + params.add(Param(name='half_window_size', value=5, + desc="Half of matching-window size, not including center term")) + params.add(Param(name='encode_out_dim', value=4, + desc="Output dimension of encode")) + params.add(Param(name='encode_conv_kernel_size', value=3, + desc="Kernel size of convolution encode")) + params.add(Param(name='encode_conv_stride', value=1, + desc="Stride of convolution encode")) + params.add(Param(name='encode_conv_padding', value=1, + desc="Zero-padding added to both side of convolution encode")) + params.add(Param(name='encode_pool_out', value=1, + desc="Pooling size of global max-pooling for encoding matrix")) + params.add(Param(name='encode_leaky', value=0.2, + desc="Relu leaky of encoder")) + params.add(Param(name='gru_hidden_dim', value=3, + desc="Aggregation Network gru hidden dimension")) + + return params + + @classmethod + def get_default_preprocessor( + cls, + truncated_mode: str = 'pre', + truncated_length_left: int = None, + truncated_length_right: int = None, + filter_low_freq: float = 5, + half_window_size: int = 5, + padding_token_index: int = 0 + ) -> BasePreprocessor: + """ + Model default preprocessor. + + The preprocessor's transform should produce a correctly shaped data + pack that can be used for training. + + :return: Default preprocessor. + """ + return preprocessors.DeepRankPreprocessor( + truncated_mode=truncated_mode, + truncated_length_left=truncated_length_left, + truncated_length_right=truncated_length_right, + filter_low_freq=filter_low_freq, + half_window_size=half_window_size, + padding_token_index=padding_token_index, + ) + + @classmethod + def get_default_padding_callback( + cls, + pad_word_value: typing.Union[int, str] = 0, + pad_word_mode: str = 'post', + ) -> BaseCallback: + """ + Only padding query. + + :return: Default padding callback. + """ + return callbacks.BasicPadding( + pad_word_value=pad_word_value, + pad_word_mode=pad_word_mode + ) + + def build(self): + """Build model structure.""" + # self.embedding = self._make_embedding_layer( + # freeze=self._params["embedding_freeze"], + # embedding=self._params["embedding"] + # ) + self.embedding = self._make_default_embedding_layer() + + + if self._params["term_weight_embedding"] is not None: + term_weight_embedding = self._params["term_weight_embedding"] + else: + # self._params['embedding_input_dim'] = ( + # self._params['embedding'].shape[0] + # ) + # self._params['embedding_output_dim'] = 1 + term_weight_embedding = np.ones( + (self._params["embedding_input_dim"], 1), dtype=float) + self.term_weight_embedding = self._make_embedding_layer( + freeze=self._params["embedding_freeze"], + embedding=term_weight_embedding + ) + + self.query_reduce = nn.Conv1d( + in_channels=self._params["embedding_output_dim"], + out_channels=self._params["reduce_out_dim"], + kernel_size=self._params["reduce_conv_kernel_size"], + stride=self._params["reduce_conv_stride"], + padding=self._params["reduce_conv_padding"] + ) + + self.doc_reduce = nn.Conv1d( + in_channels=self._params["embedding_output_dim"], + out_channels=self._params["reduce_out_dim"], + kernel_size=self._params["reduce_conv_kernel_size"], + stride=self._params["reduce_conv_stride"], + padding=self._params["reduce_conv_padding"] + ) + + interact_tensor_channel = 2 * self._params["reduce_out_dim"] + 1 + self.encoding = nn.Sequential( + nn.Conv2d( + in_channels=interact_tensor_channel, + out_channels=self._params["encode_out_dim"], + kernel_size=self._params["encode_conv_kernel_size"], + stride=self._params["encode_conv_stride"], + padding=self._params["encode_conv_padding"], + ), + nn.AdaptiveAvgPool2d(output_size=self._params["encode_pool_out"]), + nn.LeakyReLU(self._params["encode_leaky"]) + ) + + gru_in_dim = \ + self._params["encode_out_dim"] * self._params["encode_pool_out"]**2 + 1 + self.gru = nn.GRU( + input_size=gru_in_dim, + hidden_size=self._params["gru_hidden_dim"], + bidirectional=True + ) + + self.pool = nn.AdaptiveAvgPool1d(1) + + self.out = self._make_output_layer( + in_features=2 * self._params["gru_hidden_dim"] + ) + + def forward(self, inputs): + """Forward.""" + # Process left & right input. + all_query: torch.LongTensor = inputs["text_left"] + all_query_term_window_num: torch.LongTensor = inputs['term_window_num'] + all_query_len: torch.LongTensor = inputs["length_left"] + + all_window: torch.LongTensor = inputs['window_right'] + all_window_position: torch.LongTensor = inputs['window_position_right'] + all_query_window_num: torch.LongTensor = inputs['query_window_num'] + + # all_query: [batch_size, max_q_seq_len] + # all_query_term_window_num: [batch_size, max_q_seq_len] + # all_query_len: [batch_size] + + # all_window: [batch_size, max_window_num, full_window_size] + # all_window_position: [batch_size, max_window_num] + # all_query_window_num: [batch_size] + + batch_size: int = all_query.shape[0] + full_window_size: int = all_window.shape[2] + device: torch.device = all_query.device + + ############################## + # query embedding and reduce + ############################## + + # all_query embedding + # all_query: [batch_size, max_q_seq_len] + # -embedding-> [batch_size, max_q_seq_len, embed_dim] + # -permute(0, 2, 1)-> [batch_size, embed_dim, max_q_seq_len] + all_query_embed = self.embedding(all_query).permute(0, 2, 1) + + # all_query reduce + # all_query_embed: [batch_size, embed_dim, max_q_seq_len] + # -all_query_reduce-> [batch_size, reduce_out_dim, max_q_seq_len] + all_query_reduce = self.query_reduce(all_query_embed) + + ############################## + # query term weight + ############################## + + # all_query: [batch_size, max_q_seq_len] + # -term_weight_embedding-> [batch_size, max_q_seq_len, 1] + # -squeeze-> [batch_size, max_q_seq_len] + all_query_term_weight = self.term_weight_embedding(all_query).squeeze(2) + + out = [] + for i in range(batch_size): + one_query_seq_len = all_query_len[i].item() + one_query_window_num = all_query_window_num[i].item() + if one_query_window_num == 0: + out.append(torch.zeros(self._params["gru_hidden_dim"] * 2, device=device)) + continue + + ############################## + # window embedding and reduce + ############################## + + # all_window: [batch_size, max_window_num, full_window_size] + # -[i]-> [one_query_window_num, full_window_size] + one_query_windows = all_window[i][:one_query_window_num] + + # one_query_windows: [one_query_window_num, full_window_size] + # -embedding-> [one_query_window_num, full_window_size, embed_dim] + # -permute(0, 2, 1)-> [one_query_window_num, embed_dim, full_window_size] + one_query_windows_embed = self.embedding(one_query_windows).permute(0, 2, 1) + + # one_query_windows_embed: [one_query_window_num, embed_dim, full_window_size] + # -doc_reduce-> [one_query_window_num, reduce_out_dim, full_window_size] + one_query_windows_reduce = self.doc_reduce(one_query_windows_embed) + + ############################## + # interaction signal + ############################## + + # all_query_embed: [batch_size, embed_dim, max_q_seq_len] + # -[]-> [embed_dim, one_query_seq_len] + one_query_embed = all_query_embed[i, :, :one_query_seq_len] + + # one_query_embed: [embed_dim, one_query_seq_len] + # one_query_windows_embed: [one_query_window_num, embed_dim, full_window_size] + # -einsum("el,new->nlw")-> + # [one_query_window_num, one_query_seq_len, full_window_size] + # -[]-> [one_query_window_num, 1, one_query_seq_len, full_window_size] + interacition_signal = torch.einsum( + "el,new->nlw", one_query_embed, one_query_windows_embed)[:, None, :, :] + + ############################## + # encoding + ############################## + + # all_query_reduce: [batch_size, reduce_out_dim, max_q_seq_len] + # -[]-> [reduce_out_dim, one_query_seq_len] + one_query_reduce = all_query_reduce[i, :, :one_query_seq_len] + + # one_query_reduce: [reduce_dim, one_query_seq_len] + # -[]-> [1, reduce_dim, one_query_seq_len, 1] + # -expand-> + # [one_query_window_num, reduce_dim, one_query_seq_len, full_window_size] + one_query_reduce_expand = one_query_reduce[None, :, :, None] \ + .expand(one_query_window_num, -1, -1, full_window_size) + + # one_query_windows_reduce: + # [one_query_window_num, reduce_dim, full_window_size] + # -[]-> [one_query_window_num, reduce_dim, 1, full_window_size] + # -expand-> + # [one_query_window_num, reduce_dim, one_query_seq_len, full_window_size] + one_query_windows_reduce_expand = one_query_windows_reduce[:, :, None, :] \ + .expand(-1, -1, one_query_seq_len, -1) + + # one_query_reduce_expand: + # [one_query_window_num, reduce_dim, one_query_seq_len, full_window_size] + # one_query_windows_reduce_expand: + # [one_query_window_num, reduce_dim, one_query_seq_len, full_window_size] + # interacition_signal: + # [one_query_window_num, 1, one_query_seq_len, full_window_size] + # -stack-> + # [one_query_window_num, 2*reduce_dim + 1, + # one_query_seq_len, full_window_size] + encoder_input_tensor = torch.cat( + [one_query_reduce_expand, + one_query_windows_reduce_expand, + interacition_signal], + dim=1) + + # encoder_input_tensor: + # [one_query_window_num, 2*reduce_dim + 1, + # one_query_seq_len, full_window_size] + # -encoding-> + # -Conv2d-> + # [one_query_window_num, encode_out_dim, + # one_query_seq_len, full_window_size] + # -AdaptiveAvgPool2d+ReLU-> + # [one_query_window_num, encode_out_dim, + # encode_pool_out, encode_pool_out] + # -flatten-> + # [one_query_window_num, encode_out_dim * encode_pool_out^2] + encoder_output_tensor = self.encoding(encoder_input_tensor).flatten(1) + + ############################## + # position encoding and gru + ############################## + + # all_window_position: [batch_size, max_window_num] + # -[]-> [one_query_window_num, 1] + one_query_window_position = \ + all_window_position[i, :one_query_window_num, None] + + # one_query_window_position: [one_query_window_num, 1] + # --> [one_query_window_num, 1] + window_position_encoding = (1. / (one_query_window_position + 1.)).float() + + # encoder_output_tensor: + # [one_query_window_num, encode_out_dim * encode_pool_out^2] + # window_position_encoding: + # [one_query_window_num, 1] gru_in_dim + # -cat-> + # [one_query_window_num, gru_in_dim], + # gru_in_dim = encode_out_dim * encode_pool_out^2 + 1 + enc_and_pos = \ + torch.cat([encoder_output_tensor, window_position_encoding], dim=1) + + # all_query_term_window_num: [batch_size, max_q_seq_len] + # -[]-> [one_query_seq_len] + # -tolist-> list, len=one_query_seq_len + one_query_term_window_num = \ + all_query_term_window_num[i, :one_query_seq_len].tolist() + + # enc_and_pos: [one_query_window_num, gru_in_dim] + # -split-> + # tuple of tensor, len=one_query_seq_len, + # element i: tensor, [one_query_term_window_num[i], gru_in_dim] + # -pad_sequence-> + # [max(one_query_term_window_num[i]), one_query_seq_len, gru_in_dim] + one_query_pad_split_windows = \ + pad_sequence(enc_and_pos.split(one_query_term_window_num, dim=0)) + + # one_query_pad_split_windows: + # [max(one_query_term_window_num[i]), one_query_seq_len, gru_in_dim] + # -gru-> + # gru_out: [max(one_query_term_window_num[i]), + # one_query_seq_len, gru_hidden_dim * 2] + # gru_hidden: [num_layers, one_query_seq_len, gru_hidden_dim * 2] + gru_out, gru_hidden = self.gru(one_query_pad_split_windows) + # type: torch.Tensor, torch.Tensor + + ############################## + # aggregate + ############################## + + # gru_out: + # [max(one_query_term_window_num[i]), + # one_query_seq_len, gru_hidden_dim * 2] + # -permute(1,2,0)-> + # [one_query_seq_len, gru_hidden_dim * 2, + # max(one_query_term_window_num[i])] + gru_out = gru_out.permute(1, 2, 0) + + # gru_out: + # [one_query_seq_len, gru_hidden_dim * 2, + # max(one_query_term_window_num[i])] + # -pool-> [one_query_seq_len, gru_hidden_dim * 2, 1] + # -squeeze-> [one_query_seq_len, gru_hidden_dim * 2] + pool_gru_out = self.pool(gru_out).squeeze(2) + + # all_query_term_weight: [batch_size, max_q_seq_len] + # -[]-> [one_query_seq_len] + one_query_term_weight = all_query_term_weight[i, :one_query_seq_len] + + # pool_gru_out: [one_query_seq_len, gru_hidden_dim * 2] + # one_query_term_weight: [one_query_seq_len] + # -einsum-> [gru_hidden_dim * 2] + final_embed = torch.einsum("lh,l->h", pool_gru_out, one_query_term_weight) + + out.append(final_embed) + pass + + ############################## + # output score + ############################## + + # out: list, len=batch_size, element i: tensor, [gru_hidden_dim * 2] + # -stack-> [batch_size, gru_hidden_dim * 2] + out = torch.stack(out, dim=0) + + # out: [batch_size, gru_hidden_dim * 2] + # -out-> [batch_size, out_dim] + out = self.out(out) + return out diff --git a/matchzoo/preprocessors/__init__.py b/matchzoo/preprocessors/__init__.py index 90ba40c..7239be9 100755 --- a/matchzoo/preprocessors/__init__.py +++ b/matchzoo/preprocessors/__init__.py @@ -2,6 +2,7 @@ from .naive_preprocessor import NaivePreprocessor from .basic_preprocessor import BasicPreprocessor from .bert_preprocessor import BertPreprocessor +from .deeprank_preprocessor import DeepRankPreprocessor def list_available() -> list: diff --git a/matchzoo/preprocessors/deeprank_preprocessor.py b/matchzoo/preprocessors/deeprank_preprocessor.py new file mode 100644 index 0000000..ff650cd --- /dev/null +++ b/matchzoo/preprocessors/deeprank_preprocessor.py @@ -0,0 +1,169 @@ +"""DeepRank Preprocessor.""" + +from tqdm import tqdm +import typing + +from . import units +from matchzoo import DataPack +from matchzoo.engine.base_preprocessor import BasePreprocessor +from .build_vocab_unit import build_vocab_unit +from .build_unit_from_data_pack import build_unit_from_data_pack +from .chain_transform import chain_transform + +tqdm.pandas() + + +class DeepRankPreprocessor(BasePreprocessor): + """ + DeepRank model preprocessor helper. + + For pre-processing, all the words in documents and queries are white-space + tokenized, lower-cased, and stemmed using the Krovetz stemmer. + Stopword removal is performed on query and document words using the + INQUERY stop list. + Words occurred less than 5 times in the collection are removed from all + the document. + Querys are truncated below a max_length. + + :param filter_low_freq: Float, lower bound value used by + :class:`FrequenceFilterUnit`. + :param half_window_size: int, half of the match window size (not including + the center word), so the real window size is 2 * half_window_size + 1 + :padding_token_index: int: vocabulary index of pad token, default 0 + + Example: + >>> import matchzoo as mz + >>> train_data = mz.datasets.toy.load_data('train') + >>> test_data = mz.datasets.toy.load_data('test') + >>> preprocessor = mz.preprocessors.DeepRankPreprocessor( + ... filter_low_freq=5, + ... half_window_size=10, + ... padding_token_index=0, + ... ) + >>> preprocessor = preprocessor.fit(train_data, verbose=0) + >>> preprocessor.context['vocab_size'] + 105 + >>> processed_train_data = preprocessor.transform(train_data, + ... verbose=0) + >>> type(processed_train_data) + + >>> test_data_transformed = preprocessor.transform(test_data, + ... verbose=0) + >>> type(test_data_transformed) + + + """ + + def __init__(self, + truncated_mode: str = 'pre', + truncated_length_left: int = None, + truncated_length_right: int = None, + filter_low_freq: float = 5, + half_window_size: int = 5, + padding_token_index: int = 0): + """Initialization.""" + super().__init__() + + self._truncated_mode = truncated_mode + self._truncated_length_left = truncated_length_left + self._truncated_length_right = truncated_length_right + if self._truncated_length_left: + self._left_truncatedlength_unit = units.TruncatedLength( + self._truncated_length_left, self._truncated_mode + ) + if self._truncated_length_right: + self._right_truncatedlength_unit = units.TruncatedLength( + self._truncated_length_right, self._truncated_mode + ) + + self._counter_unit = units.FrequencyCounter() + self._filter_unit = units.FrequencyFilter( + low=filter_low_freq, + mode="tf" + ) + + self._units = [ + units.Tokenize(), + units.Lowercase(), + units.Stemming(stemmer="krovetz"), + units.stop_removal.StopRemoval(), # todo: INQUERY stop list ? + ] + + self._context["filter_low_freq"] = filter_low_freq + self._context["half_window_size"] = half_window_size + self._context['padding_unit'] = units.PaddingLeftAndRight( + left_padding_num=self._context["half_window_size"], + right_padding_num=self._context["half_window_size"], + pad_value=padding_token_index, + ) + + def fit(self, data_pack: DataPack, verbose: int = 1): + """ + Fit pre-processing context for transformation. + + :param data_pack: data_pack to be preprocessed. + :param verbose: Verbosity. + :return: class:`BasicPreprocessor` instance. + """ + data_pack = data_pack.apply_on_text(chain_transform(self._units), + verbose=verbose) + fitted_counter_unit = build_unit_from_data_pack(self._counter_unit, + data_pack, + flatten=False, + mode='right', + verbose=verbose) + self._context['counter_unit'] = fitted_counter_unit + self._context['term_idf'] = fitted_counter_unit.context["idf"] + + fitted_filter_unit = build_unit_from_data_pack(self._filter_unit, + data_pack, + flatten=False, + mode='right', + verbose=verbose) + data_pack = data_pack.apply_on_text(fitted_filter_unit.transform, + mode='right', verbose=verbose) + self._context['filter_unit'] = fitted_filter_unit + + vocab_unit = build_vocab_unit(data_pack, verbose=verbose) + self._context['vocab_unit'] = vocab_unit + + vocab_size = len(vocab_unit.state['term_index']) + self._context['vocab_size'] = vocab_size + self._context['embedding_input_dim'] = vocab_size + + return self + + def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: + """ + Apply transformation on data. + + :param data_pack: Inputs to be preprocessed. + :param verbose: Verbosity. + + :return: Transformed data as :class:`DataPack` object. + """ + data_pack = data_pack.copy() + # simple preprocessing + data_pack.apply_on_text(chain_transform(self._units), inplace=True, + verbose=verbose) + # filter + data_pack.apply_on_text(self._context['filter_unit'].transform, + mode='right', inplace=True, verbose=verbose) + # token to id + data_pack.apply_on_text(self._context['vocab_unit'].transform, + mode='both', inplace=True, verbose=verbose) + # truncate + if self._truncated_length_left: + data_pack.apply_on_text(self._left_truncatedlength_unit.transform, + mode='left', inplace=True, verbose=verbose) + if self._truncated_length_right: + data_pack.apply_on_text(self._right_truncatedlength_unit.transform, + mode='right', inplace=True, + verbose=verbose) + # add length + data_pack.append_text_length(inplace=True, verbose=verbose) + data_pack.drop_empty(inplace=True) + # padding on left and right for matching window + data_pack.apply_on_text(self._context['padding_unit'].transform, + mode='right', inplace=True, verbose=verbose) + return data_pack diff --git a/matchzoo/preprocessors/units/__init__.py b/matchzoo/preprocessors/units/__init__.py index 950b794..9531f6b 100755 --- a/matchzoo/preprocessors/units/__init__.py +++ b/matchzoo/preprocessors/units/__init__.py @@ -15,6 +15,8 @@ from .character_index import CharacterIndex from .word_exact_match import WordExactMatch from .truncated_length import TruncatedLength +from .frequency_counter import FrequencyCounter +from .padding_left_and_right import PaddingLeftAndRight def list_available() -> list: diff --git a/matchzoo/preprocessors/units/frequency_counter.py b/matchzoo/preprocessors/units/frequency_counter.py new file mode 100644 index 0000000..d51e39d --- /dev/null +++ b/matchzoo/preprocessors/units/frequency_counter.py @@ -0,0 +1,64 @@ +import collections +import typing + +import numpy as np + +from .stateful_unit import StatefulUnit + + +class FrequencyCounter(StatefulUnit): + """ + Frequency counter unit. + + Examples:: + >>> from collections import Counter + >>> import matchzoo as mz + + To filter based on document frequency (df): + >>> unit = mz.preprocessors.units.FrequencyCounter() + >>> unit.fit([['A', 'B'], ['B', 'C']]) + >>> unit.context + {'tf': Counter({'B': 2, 'A': 1, 'C': 1}), + 'df': Counter({'B': 2, 'A': 1, 'C': 1}), + 'idf': Counter({'A': 1.4054651081081644, 'C': 1.4054651081081644, 'B': 1.0})} + + """ + + def __init__(self): + """Frequency counter unit.""" + super().__init__() + + def fit(self, list_of_tokens: typing.List[typing.List[str]]): + """Fit `list_of_tokens` by calculating tf/df/idf states.""" + + self._context["tf"] = self._tf(list_of_tokens) + self._context["df"] = self._df(list_of_tokens) + self._context["idf"] = self._idf(list_of_tokens) + + def transform(self, input_: list) -> list: + """Transform do nothing.""" + return input_ + + @classmethod + def _tf(cls, list_of_tokens: list) -> dict: + stats = collections.Counter() + for tokens in list_of_tokens: + stats.update(tokens) + return stats + + @classmethod + def _df(cls, list_of_tokens: list) -> dict: + stats = collections.Counter() + for tokens in list_of_tokens: + stats.update(set(tokens)) + return stats + + @classmethod + def _idf(cls, list_of_tokens: list) -> dict: + num_docs = len(list_of_tokens) + stats = cls._df(list_of_tokens) + for key, val in stats.most_common(): + stats[key] = np.log((1 + num_docs) / (1 + val)) + 1 + return stats + + diff --git a/matchzoo/preprocessors/units/padding_left_and_right.py b/matchzoo/preprocessors/units/padding_left_and_right.py new file mode 100644 index 0000000..10894dc --- /dev/null +++ b/matchzoo/preprocessors/units/padding_left_and_right.py @@ -0,0 +1,39 @@ +from .stateful_unit import StatefulUnit + + +class PaddingLeftAndRight(StatefulUnit): + """ + Vocabulary class. + + :param pad_value: The string value for the padding position. + :param oov_value: The string value for the out-of-vocabulary terms. + + Examples: + >>> unit = PaddingLeftAndRight( + ... left_padding_num=5, right_padding_num=3, pad_value=0) + >>> unit.transform([3, 1, 4, 1, 5]) + [0, 0, 0, 0, 0, 3, 1, 4, 1, 5, 0, 0, 0] + + """ + + def __init__( + self, + left_padding_num: int, + right_padding_num: int, + pad_value: str or int = 0): + """Vocabulary unit initializer.""" + super().__init__() + self._context["left_padding_num"] = left_padding_num + self._context["right_padding_num"] = right_padding_num + self._context["pad_value"] = pad_value + + def fit(self, tokens: list): + """Do nothing.""" + pass + + def transform(self, input_: list) -> list: + """Padding on left and right.""" + return \ + [self._context["pad_value"]] * self._context["left_padding_num"] \ + + input_ \ + + [self._context["pad_value"]] * self._context["right_padding_num"]