|
| 1 | +# Licensed to the Technische Universität Darmstadt under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The Technische Universität Darmstadt |
| 5 | +# licenses this file to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +from typing import List, Dict |
| 17 | + |
| 18 | +from cassis import Cas |
| 19 | + |
| 20 | +from ariadne.classifier import Classifier |
| 21 | +from ariadne.protocol import TrainingDocument |
| 22 | +from collections import defaultdict |
| 23 | +from ariadne.contrib.inception_util import create_span_prediction, TOKEN_TYPE |
| 24 | +from cassis.typesystem import TYPE_NAME_STRING |
| 25 | + |
| 26 | +import logging |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class DemoMultipleFeaturesRecommender(Classifier): |
| 32 | + """ |
| 33 | + Demo recommender that behaves like DemoStringFeatureRecommender but ignores the |
| 34 | + provided `feature` parameter and instead trains/predicts on all string-typed |
| 35 | + features of the specified layer. A separate dictionary is learned for each |
| 36 | + feature (mapping mention -> best label). |
| 37 | +
|
| 38 | + This recommender requires INCEpTION 39.0 or higher. Older versions of INCEpTION |
| 39 | + will only extract the configured feature, even though multiple features are trained |
| 40 | + and predicted. |
| 41 | + """ |
| 42 | + |
| 43 | + def _get_string_features(self, cas: Cas, layer: str) -> List[str]: |
| 44 | + """Return the names of all features of `layer` whose range is a string.""" |
| 45 | + try: |
| 46 | + AnnotationType = cas.typesystem.get_type(layer) |
| 47 | + except Exception: |
| 48 | + return [] |
| 49 | + |
| 50 | + features = [] |
| 51 | + for feat in AnnotationType.features: |
| 52 | + try: |
| 53 | + if feat.rangeType.name == TYPE_NAME_STRING: |
| 54 | + features.append(feat.name) |
| 55 | + except Exception: |
| 56 | + # best-effort: skip features we cannot introspect |
| 57 | + continue |
| 58 | + |
| 59 | + return features |
| 60 | + |
| 61 | + def fit(self, documents: List[TrainingDocument], layer: str, feature: str, project_id, user_id: str): |
| 62 | + logger.info( |
| 63 | + "Training triggered for all string features on [%s] in [%d] documents from project [%s] for user [%s]", |
| 64 | + layer, |
| 65 | + len(documents), |
| 66 | + project_id, |
| 67 | + user_id, |
| 68 | + ) |
| 69 | + |
| 70 | + # counts: feature -> mention -> label -> count |
| 71 | + counts: Dict[str, Dict[str, Dict[str, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) |
| 72 | + |
| 73 | + features_discovered = None |
| 74 | + |
| 75 | + for document in documents: |
| 76 | + cas = document.cas |
| 77 | + |
| 78 | + if features_discovered is None: |
| 79 | + features_discovered = self._get_string_features(cas, layer) |
| 80 | + |
| 81 | + for annotation in cas.select(layer): |
| 82 | + mention = annotation.get_covered_text().lower() |
| 83 | + |
| 84 | + if not mention: |
| 85 | + continue |
| 86 | + |
| 87 | + for feat in features_discovered or []: |
| 88 | + label = annotation.get(feat) |
| 89 | + if not label: |
| 90 | + continue |
| 91 | + counts[feat][mention][label] += 1 |
| 92 | + |
| 93 | + # For each feature, compute best_labels mapping mention -> top label |
| 94 | + model: Dict[str, Dict[str, str]] = {} |
| 95 | + for feat, mention_map in counts.items(): |
| 96 | + best_labels = { |
| 97 | + mention: max(candidate_counts, key=candidate_counts.get) if candidate_counts else "" |
| 98 | + for mention, candidate_counts in mention_map.items() |
| 99 | + } |
| 100 | + model[feat] = best_labels |
| 101 | + |
| 102 | + logger.info("Trained multiple-feature model for features: %s", list(model.keys())) |
| 103 | + self._save_model(user_id, model) |
| 104 | + |
| 105 | + logger.info("Training finished for user [%s]", user_id) |
| 106 | + |
| 107 | + def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str): |
| 108 | + logger.info( |
| 109 | + "Prediction triggered on document [%s] for all string features on [%s] in project [%s] for user [%s]", |
| 110 | + document_id, |
| 111 | + layer, |
| 112 | + project_id, |
| 113 | + user_id, |
| 114 | + ) |
| 115 | + |
| 116 | + model = self._load_model(user_id) |
| 117 | + if model is None: |
| 118 | + return |
| 119 | + |
| 120 | + # Determine which string features to predict (use typesystem from cas) |
| 121 | + features = self._get_string_features(cas, layer) |
| 122 | + |
| 123 | + # For each token, try to predict for each discovered string feature if the token text |
| 124 | + # exists in the per-feature dictionary |
| 125 | + suggestion_count = 0 |
| 126 | + for token in cas.select(TOKEN_TYPE): |
| 127 | + mention = token.get_covered_text().lower() |
| 128 | + for feat in features: |
| 129 | + feature_model = model.get(feat) |
| 130 | + if not feature_model: |
| 131 | + continue |
| 132 | + if mention in feature_model: |
| 133 | + label = feature_model.get(mention) |
| 134 | + suggestion = create_span_prediction( |
| 135 | + cas, layer, feat, token.begin, token.begin + len(mention), label |
| 136 | + ) |
| 137 | + logger.info("Creating suggestion for feature [%s]: %s -> %s", feat, mention, label) |
| 138 | + cas.add(suggestion) |
| 139 | + suggestion_count += 1 |
| 140 | + |
| 141 | + logger.info("Prediction finished for user [%s]; suggestions created: %d", user_id, suggestion_count) |
0 commit comments