Skip to content

Commit e1a1a90

Browse files
authored
Merge pull request #71 from inception-project/feature/70-Provide-demo-for-recommender-that-suggests-multiple-features-at-the-same-time
#70 - Provide demo for recommender that suggests multiple features at the same time
2 parents 3e15155 + 6ec4f2e commit e1a1a90

File tree

7 files changed

+252
-11
lines changed

7 files changed

+252
-11
lines changed

ariadne/contrib/sbert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def fit(self, documents: List[TrainingDocument], layer: str, feature: str, proje
7878
else:
7979
continue
8080

81-
assert (
82-
sentence.begin == annotation.begin and sentence.end == annotation.end
83-
), "Annotation should cover sentence fully!"
81+
assert sentence.begin == annotation.begin and sentence.end == annotation.end, (
82+
"Annotation should cover sentence fully!"
83+
)
8484

8585
label = getattr(annotation, feature)
8686

ariadne/contrib/sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def fit(self, documents: List[TrainingDocument], layer: str, feature: str, proje
5151
else:
5252
continue
5353

54-
assert (
55-
sentence.begin == annotation.begin and sentence.end == annotation.end
56-
), "Annotation should cover sentence fully!"
54+
assert sentence.begin == annotation.begin and sentence.end == annotation.end, (
55+
"Annotation should cover sentence fully!"
56+
)
5757

5858
label = getattr(annotation, feature)
5959

ariadne/demo/demo_link_feature.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,30 @@
2828

2929

3030
class DemoLinkFeatureRecommender(Classifier):
31+
"""Simple demo recommender that learns link roles between span annotations.
32+
33+
Training
34+
--------
35+
For each document, we iterate over all annotations of the given
36+
``layer`` and reads the ``feature`` field which is expected to contain
37+
link objects (UIMA link relations). It counts how often a source span
38+
text (lowercased) was linked to a particular target span text with a
39+
given role. The model stored per-user is a nested mapping:
40+
``{source_text: {target_text: best_role}}``, where ``best_role`` is the
41+
role with the highest frequency for that (source, target) pair.
42+
43+
Prediction
44+
----------
45+
The ``predict`` method loads the per-user model and iterates over source
46+
annotations in the CAS. For each source whose lowercased covered text
47+
appears in the model, it looks for target annotations of the same
48+
``layer`` inside the covering sentence. If a target's lowercased text
49+
matches a target recorded for the source, the recommender creates a span
50+
prediction suggestion that contains a link to the found target using the
51+
learned role. The suggestion is added to the CAS as a span prediction
52+
feature structure.
53+
"""
54+
3155
def fit(self, documents: List[TrainingDocument], layer: str, feature: str, project_id, user_id: str):
3256
logger.info(
3357
f"Training triggered for [{feature}] on [{layer}] in [{len(documents)}] documents from project [{project_id}] for user [{user_id}]"
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to the Technische Universität Darmstadt under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The Technische Universität Darmstadt
5+
# licenses this file to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License.
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
from typing import List, Dict
17+
18+
from cassis import Cas
19+
20+
from ariadne.classifier import Classifier
21+
from ariadne.protocol import TrainingDocument
22+
from collections import defaultdict
23+
from ariadne.contrib.inception_util import create_span_prediction, TOKEN_TYPE
24+
from cassis.typesystem import TYPE_NAME_STRING
25+
26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class DemoMultipleFeaturesRecommender(Classifier):
32+
"""
33+
Demo recommender that behaves like DemoStringFeatureRecommender but ignores the
34+
provided `feature` parameter and instead trains/predicts on all string-typed
35+
features of the specified layer. A separate dictionary is learned for each
36+
feature (mapping mention -> best label).
37+
38+
This recommender requires INCEpTION 39.0 or higher. Older versions of INCEpTION
39+
will only extract the configured feature, even though multiple features are trained
40+
and predicted.
41+
"""
42+
43+
def _get_string_features(self, cas: Cas, layer: str) -> List[str]:
44+
"""Return the names of all features of `layer` whose range is a string."""
45+
try:
46+
AnnotationType = cas.typesystem.get_type(layer)
47+
except Exception:
48+
return []
49+
50+
features = []
51+
for feat in AnnotationType.features:
52+
try:
53+
if feat.rangeType.name == TYPE_NAME_STRING:
54+
features.append(feat.name)
55+
except Exception:
56+
# best-effort: skip features we cannot introspect
57+
continue
58+
59+
return features
60+
61+
def fit(self, documents: List[TrainingDocument], layer: str, feature: str, project_id, user_id: str):
62+
logger.info(
63+
"Training triggered for all string features on [%s] in [%d] documents from project [%s] for user [%s]",
64+
layer,
65+
len(documents),
66+
project_id,
67+
user_id,
68+
)
69+
70+
# counts: feature -> mention -> label -> count
71+
counts: Dict[str, Dict[str, Dict[str, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
72+
73+
features_discovered = None
74+
75+
for document in documents:
76+
cas = document.cas
77+
78+
if features_discovered is None:
79+
features_discovered = self._get_string_features(cas, layer)
80+
81+
for annotation in cas.select(layer):
82+
mention = annotation.get_covered_text().lower()
83+
84+
if not mention:
85+
continue
86+
87+
for feat in features_discovered or []:
88+
label = annotation.get(feat)
89+
if not label:
90+
continue
91+
counts[feat][mention][label] += 1
92+
93+
# For each feature, compute best_labels mapping mention -> top label
94+
model: Dict[str, Dict[str, str]] = {}
95+
for feat, mention_map in counts.items():
96+
best_labels = {
97+
mention: max(candidate_counts, key=candidate_counts.get) if candidate_counts else ""
98+
for mention, candidate_counts in mention_map.items()
99+
}
100+
model[feat] = best_labels
101+
102+
logger.info("Trained multiple-feature model for features: %s", list(model.keys()))
103+
self._save_model(user_id, model)
104+
105+
logger.info("Training finished for user [%s]", user_id)
106+
107+
def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):
108+
logger.info(
109+
"Prediction triggered on document [%s] for all string features on [%s] in project [%s] for user [%s]",
110+
document_id,
111+
layer,
112+
project_id,
113+
user_id,
114+
)
115+
116+
model = self._load_model(user_id)
117+
if model is None:
118+
return
119+
120+
# Determine which string features to predict (use typesystem from cas)
121+
features = self._get_string_features(cas, layer)
122+
123+
# For each token, try to predict for each discovered string feature if the token text
124+
# exists in the per-feature dictionary
125+
suggestion_count = 0
126+
for token in cas.select(TOKEN_TYPE):
127+
mention = token.get_covered_text().lower()
128+
for feat in features:
129+
feature_model = model.get(feat)
130+
if not feature_model:
131+
continue
132+
if mention in feature_model:
133+
label = feature_model.get(mention)
134+
suggestion = create_span_prediction(
135+
cas, layer, feat, token.begin, token.begin + len(mention), label
136+
)
137+
logger.info("Creating suggestion for feature [%s]: %s -> %s", feat, mention, label)
138+
cas.add(suggestion)
139+
suggestion_count += 1
140+
141+
logger.info("Prediction finished for user [%s]; suggestions created: %d", user_id, suggestion_count)

scripts/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525

2626
def download_file(url: str, target_path: Path):
27-
2827
if target_path.exists():
2928
logging.info("File already exists: [%s]", str(target_path.resolve()))
3029
return
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Licensed to the Technische Universität Darmstadt under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The Technische Universität Darmstadt
5+
# licenses this file to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License.
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from ariadne.demo.demo_multiple_features import DemoMultipleFeaturesRecommender
18+
from ariadne.protocol import TrainingDocument
19+
from cassis import Cas
20+
from ariadne.contrib.inception_util import create_span_prediction, TOKEN_TYPE
21+
from tests.util import create_cas
22+
from ariadne.contrib.inception_util import IS_PREDICTION
23+
24+
25+
def test_demo_multiple_features_fit_and_predict():
26+
# Prepare a training CAS
27+
cas_train = create_cas()
28+
cas_train.sofa_string = "Hello world"
29+
# Create a custom predicted type that has two string features so the recommender
30+
# actually learns separate dictionaries per feature
31+
ts = cas_train.typesystem
32+
CustomPred = ts.create_type("ariadne.testtype_multi")
33+
ts.create_feature(CustomPred, "value1", "uima.cas.String")
34+
ts.create_feature(CustomPred, "value2", "uima.cas.String")
35+
ts.create_feature(CustomPred, IS_PREDICTION, "uima.cas.Boolean")
36+
37+
# Add two training annotations on the same covered text but with different feature values
38+
span1 = create_span_prediction(cas_train, "ariadne.testtype_multi", "value1", 0, 5, "GREETING")
39+
span2 = create_span_prediction(cas_train, "ariadne.testtype_multi", "value2", 0, 5, "HELLO")
40+
cas_train.add(span1)
41+
cas_train.add(span2)
42+
43+
docs = [TrainingDocument(cas_train, "doc1", "user1")]
44+
45+
recommender = DemoMultipleFeaturesRecommender()
46+
# feature argument is ignored by the recommender
47+
recommender.fit(docs, "ariadne.testtype_multi", "value1", project_id=1, user_id="user1")
48+
49+
# Create a new CAS to predict into using the same typesystem as the training CAS
50+
predict_cas = Cas(cas_train.typesystem)
51+
predict_cas.sofa_string = cas_train.sofa_string
52+
53+
Token = predict_cas.typesystem.get_type(TOKEN_TYPE)
54+
predict_cas.add(Token(begin=0, end=5))
55+
predict_cas.add(Token(begin=6, end=11))
56+
57+
recommender.predict(
58+
predict_cas,
59+
"ariadne.testtype_multi",
60+
"value1",
61+
project_id=1,
62+
document_id="doc1",
63+
user_id="user1",
64+
)
65+
66+
# After prediction there should be predictions for both features
67+
preds_v1 = [a for a in predict_cas.select("ariadne.testtype_multi") if a.get("value1") == "GREETING"]
68+
preds_v2 = [a for a in predict_cas.select("ariadne.testtype_multi") if a.get("value2") == "HELLO"]
69+
70+
assert len(preds_v1) >= 1
71+
assert len(preds_v2) >= 1

wsgi.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from ariadne.demo.demo_link_feature import DemoLinkFeatureRecommender
17+
from ariadne.demo.demo_multiple_features import DemoMultipleFeaturesRecommender
18+
from ariadne.demo.demo_relation import DemoRelationLayerRecommender
19+
from ariadne.demo.demo_string_array_feature import DemoStringArrayFeatureRecommender
20+
from ariadne.demo.demo_string_feature import DemoStringFeatureRecommender
1621
from ariadne.server import Server
1722
from ariadne.util import setup_logging
1823
from ariadne.contrib.spacy import SpacyNerClassifier
@@ -21,10 +26,11 @@
2126

2227
server = Server()
2328

24-
# server.add_classifier("demo_string_feature", DemoStringFeatureRecommender())
25-
# server.add_classifier("demo_string_array_feature", DemoStringArrayFeatureRecommender())
26-
# server.add_classifier("demo_link_feature", DemoLinkFeatureRecommender())
27-
# server.add_classifier("demo_relation_layer", DemoRelationLayerRecommender())
29+
server.add_classifier("demo_string_feature", DemoStringFeatureRecommender())
30+
server.add_classifier("demo_string_array_feature", DemoStringArrayFeatureRecommender())
31+
server.add_classifier("demo_link_feature", DemoLinkFeatureRecommender())
32+
server.add_classifier("demo_relation_layer", DemoRelationLayerRecommender())
33+
server.add_classifier("demo_multiple_features", DemoMultipleFeaturesRecommender())
2834

2935
server.add_classifier("spacy_ner", SpacyNerClassifier("en_core_web_sm"))
3036
# server.add_classifier("spacy_pos", SpacyPosClassifier("en_core_web_sm"))

0 commit comments

Comments
 (0)