diff --git a/scrapely/extraction/__init__.py b/scrapely/extraction/__init__.py index a4c10c4..99c6d67 100644 --- a/scrapely/extraction/__init__.py +++ b/scrapely/extraction/__init__.py @@ -35,6 +35,7 @@ class InstanceBasedLearningExtractor(object): RepeatedDataExtractor, RecordExtractor, ] + _ext_items_max_number = 2 def __init__(self, td_pairs, trace=False, apply_extrarequired=True): """Initialise this extractor @@ -77,11 +78,11 @@ def __init__(self, td_pairs, trace=False, apply_extrarequired=True): modified_parsed_tdpairs.append((parsed, (t, descriptor))) # templates with more attributes are considered first sorted_tdpairs = sorted(modified_parsed_tdpairs, - key=lambda x: _annotation_count(x[0]), reverse=True) + key=lambda x: _annotation_count(x[0]), reverse=True) self.extraction_trees = [ self.build_extraction_tree(p, td[1], trace) for p, td in sorted_tdpairs - ] + ] self.validated = dict( (td[0].page_id, td[1].validated if td[1] else self._filter_not_none) for _, td in sorted_tdpairs @@ -108,23 +109,27 @@ def extract(self, html, pref_template_id=None): If pref_template_url is specified, the template with that url will be used first. """ + max_extracted_value = {} + correctly_extracted_template = '' extraction_page = parse_extraction_page(self.token_dict, html) if pref_template_id is not None: extraction_trees = sorted(self.extraction_trees, - key=lambda x: x.template.id != pref_template_id) + key=lambda x: x.template.id != pref_template_id) else: extraction_trees = self.extraction_trees for extraction_tree in extraction_trees: extracted = extraction_tree.extract(extraction_page) correctly_extracted = self.validated[extraction_tree.template.id](extracted) - if len(correctly_extracted) > 0: - return correctly_extracted, extraction_tree.template - return None, None + if len(correctly_extracted[0]) > len(max_extracted_value): + max_extracted_value = correctly_extracted[0] + correctly_extracted_template = extraction_tree.template + + return [max_extracted_value], correctly_extracted_template def __str__(self): return "InstanceBasedLearningExtractor[\n%s\n]" % \ - (',\n'.join(map(str, self.extraction_trees))) + (',\n'.join(map(str, self.extraction_trees))) @staticmethod def _filter_not_none(items):