|
3 | 3 | """ |
4 | 4 | from collections import defaultdict |
5 | 5 | from typing import Dict, List, Tuple, Union |
| 6 | + |
| 7 | +from labelbox.data.annotation_types.annotation import ClassificationAnnotation, Checklist, Radio |
6 | 8 | try: |
7 | 9 | from typing import Literal |
8 | 10 | except ImportError: |
9 | 11 | from typing_extensions import Literal |
10 | 12 |
|
11 | 13 | from ..annotation_types.feature import FeatureSchema |
12 | | -from ..annotation_types import ObjectAnnotation, Label, LabelList |
| 14 | +from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label, LabelList |
13 | 15 |
|
14 | 16 |
|
15 | 17 | def get_identifying_key( |
@@ -56,6 +58,14 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: |
56 | 58 | all_names = True |
57 | 59 | all_schemas = True |
58 | 60 | for feature in features: |
| 61 | + if isinstance(feature, ClassificationAnnotation): |
| 62 | + if isinstance(feature.value, Checklist): |
| 63 | + all_names, all_schemas = all_have_key(feature.value.answer) |
| 64 | + else: |
| 65 | + if feature.value.answer.name is None: |
| 66 | + all_names = False |
| 67 | + if feature.value.answer.feature_schema_id is None: |
| 68 | + all_schemas = False |
59 | 69 | if feature.name is None: |
60 | 70 | all_names = False |
61 | 71 | if feature.feature_schema_id is None: |
@@ -155,7 +165,17 @@ def _create_feature_lookup(features: List[FeatureSchema], |
155 | 165 | """ |
156 | 166 | grouped_features = defaultdict(list) |
157 | 167 | for feature in features: |
158 | | - grouped_features[getattr(feature, key)].append(feature) |
| 168 | + if isinstance(feature, ClassificationAnnotation): |
| 169 | + #checklists |
| 170 | + if isinstance(feature.value, Checklist): |
| 171 | + for answer in feature.value.answer: |
| 172 | + new_feature = Radio(answer=answer) |
| 173 | + grouped_features[getattr(answer, key)] = new_feature |
| 174 | + else: |
| 175 | + grouped_features[getattr(feature.value.answer, |
| 176 | + key)].append(feature) |
| 177 | + else: |
| 178 | + grouped_features[getattr(feature, key)].append(feature) |
159 | 179 | return grouped_features |
160 | 180 |
|
161 | 181 |
|
|
0 commit comments