# Copyright 2019 Regents of the University of Minnesota.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides functionality for measuring processor performance against gold
standards.
"""
import sys
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Sequence, NamedTuple, Callable, Tuple, \
TextIO
from mtap._document import Document
from mtap._label_indices import presorted_label_index
from mtap.descriptors import processor
from mtap.processing import DocumentProcessor
from mtap.types import LabelIndex, Label
from mtap.utilities import tokenization
__all__ = [
'Metric',
'Metrics',
'print_overlapping',
'fields_match_test',
'Accuracy',
'ConfusionMatrix',
'FirstTokenConfusion'
]
[docs]
class Metric(ABC):
"""Base class for metrics.
"""
name = None
@abstractmethod
def update(self,
document: Document,
tested_index: LabelIndex, target_index: LabelIndex) -> Any:
...
[docs]
@processor('mtap-metrics')
class Metrics(DocumentProcessor):
"""A document processor that computes a set of metrics.
Args:
tested: The name of the index to use as the hypothesis / predictions.
target: The name of the index to use as the ground truth / gold
standard.
tested_filter: A filter to apply to the tested index.
target_filter: A filter to apply to the target index.
"""
def __init__(self, *metrics: Metric,
tested: str,
target: str,
tested_filter: Callable[[Label], bool] = None,
target_filter: Callable[[Label], bool] = None):
self.tested = tested
self.target = target
self.metrics = metrics
self.tested_filter = tested_filter
self.target_filter = target_filter
def process_document(self, document: Document,
params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
tested = document.labels[self.tested]
if self.tested_filter is not None:
tested = tested.filter(self.tested_filter)
target = document.labels.get(self.target, presorted_label_index([]))
if self.target_filter is not None:
target = target.filter(self.target_filter)
local = {}
for metric in self.metrics:
local[metric.name] = metric.update(document, tested, target)
return local
def print_overlapping(document, target_label, tested_index):
print('Not found:', target_label)
print('"', target_label.text(document.text), '"')
overlapping = tested_index.overlapping(target_label)
print('Overlapping:', overlapping)
for overlap in overlapping:
print('"', overlap.text(document.text), '"')
[docs]
def fields_match_test(fields: Optional[Sequence[str]] = None):
"""Creates an equivalence test that tests whether the specified fields are
equal on both labels.
Args:
fields: The fields to test or `None` if all fields should be tested.
"""
def fields_match(tested_label: Label, target_label: Label) -> bool:
if fields is not ...:
return all(
getattr(tested_label, field) == getattr(target_label, field)
for field in fields)
else:
return tested_label.shallow_fields_equal(target_label)
return fields_match
[docs]
class Accuracy(Metric):
"""An accuracy metric with several options for equivalence.
Args:
name: An identifier for the metric.
mode: 'equals' - counts as a hit if there is one and only one
label at the same location in the tested index as the target
index, and it has the same values for its fields.
'location' - counts as a hit if there is one and only one
label at the same location in the tested index as the target
index.
'any' - counts as a hit if there is one or more labels at the
same location with the same values for its fields.
print_debug: If true will print debug information about the misses.
boundary_fuzz: How different the target label boundaries can be from
the tested boundaries before it doesn't count as a match.
equivalence_test: callable
A function which takes two argument labels, and returns true if
the labels are equivalent for the purpose of the test.
"""
def __init__(self,
name: str = 'accuracy',
mode: str = 'equals',
print_debug: bool = False,
boundary_fuzz: int = 0,
fields: Optional[Sequence[str]] = ...,
equivalence_test: Optional[
Callable[[Any, Any], bool]] = fields_match_test(...)):
self.correct = 0
self.total = 0
self.name = name
self.mode = mode
self.print_debug = print_debug
self.boundary_fuzz = boundary_fuzz
if fields is ...:
self.equivalence_test = equivalence_test
else:
self.equivalence_test = fields_match_test(fields)
@property
def value(self) -> float:
return self.correct / self.total if self.total > 0 else 1.
def find_candidates(self,
tested_index: LabelIndex,
target_label: Label) -> LabelIndex:
if self.boundary_fuzz == 0:
index = tested_index.at(target_label)
else:
index = tested_index.inside(
target_label.start_index - 1,
target_label.end_index + 1
).covering(
target_label.start_index + 1,
target_label.end_index - 1
)
return index
def has_match(self, candidates: LabelIndex, target_label: Label) -> bool:
if self.mode == 'equals':
return len(candidates) == 1 and self.equivalence_test(
candidates[0], target_label)
elif self.mode == 'location':
return len(candidates) > 0
elif self.mode == 'any':
return any(
self.equivalence_test(candidate, target_label) for candidate in
candidates)
def update(self, document: Document,
tested_index: LabelIndex,
target_index: LabelIndex) -> Any:
correct = 0
total = 0
for target_label in target_index:
total += 1
candidates = self.find_candidates(tested_index, target_label)
if self.has_match(candidates, target_label):
correct += 1
elif self.print_debug:
print_overlapping(document, target_label, tested_index)
self.correct += correct
self.total += total
return correct / total if total != 0 else 1.
def _collect_tokens(index, return_insides=True):
begins = set()
insides = set()
for label in index:
for i, token in enumerate(
tokenization.word_tokenize(label.text, label.start_index)):
if i == 0:
begins.add(token)
if not return_insides:
break
else:
insides.add(token)
return begins, insides
[docs]
class ConfusionMatrix(NamedTuple):
"""A representation of a confusion matrix.
"""
true_positives: float = 0
"""Count of true positive examples."""
false_positives: float = 0
"""Count of false positive examples."""
false_negatives: float = 0
"""Count of false negative examples."""
def __add__(self, other):
return ConfusionMatrix(self.true_positives + other.true_positives,
self.false_positives + other.false_positives,
self.false_negatives + other.false_negatives)
@property
def precision(self):
predicted_positives = self.true_positives + self.false_positives
if predicted_positives == 0:
return 1
return self.true_positives / predicted_positives
@property
def recall(self):
ground_truth_positives = self.true_positives + self.false_negatives
if ground_truth_positives == 0:
return 1
return self.true_positives / ground_truth_positives
@property
def f1(self):
divisor = (2 * self.true_positives + self.false_positives
+ self.false_negatives)
if divisor == 0:
return 1
return 2 * self.true_positives / divisor
[docs]
class FirstTokenConfusion(Metric):
"""A metric which treats the first word token in every label as an example
of the positive class and calculates the precision, recall, and f1
confusion matrix metrics for that positive class. Useful for evaluation of
segmentation tasks.
precision = true positives / (true positives + false positives)
recall = true positives / (true positives + false negatives)
f1 = 2 * true positives / (2 * true positives + false positives + false negatives)
Args:
name: An identifying name for the metric.
tested_filter: A filter to apply to the tested index.
target_filter: A filter to apply to the target index.
print_debug: An argument to print failing examples. 'fp' prints only
false positive errors, 'fn' prints only false negative errors,
'all' prints both false positive and false negative errors
debug_range: The range before and after the example to print.
debug_handle: A text io file handle to print the debug information to.
"""
def __init__(self, name: str = 'first_token_confusion',
tested_filter: Callable[[Label], bool] = None,
target_filter: Callable[[Label], bool] = None,
print_debug: str = None,
debug_range: int = 30,
debug_handle: TextIO = sys.stdout):
self.name = name
self._matrix = ConfusionMatrix()
self.tested_filter = tested_filter
self.target_filter = target_filter
self.print_debug = print_debug
self.debug_range = debug_range
self.debug_handle = debug_handle
@property
def precision(self) -> float:
"""Ratio of true positives to the total number of positive
predictions.
"""
return self._matrix.precision
@property
def recall(self) -> float:
"""Ratio of true positives to the total number of positive ground
truths.
"""
return self._matrix.recall
@property
def f1(self) -> float:
"""The harmonic mean of precision and recall."""
return self._matrix.f1
def update(self,
document: Document,
tested_index: LabelIndex,
target_index: LabelIndex) -> Any:
if self.tested_filter is not None:
tested_index = tested_index.filter(self.tested_filter)
if self.target_filter is not None:
target_index = target_index.filter(self.target_filter)
tested_tokens, _ = _collect_tokens(tested_index, return_insides=False)
target_tokens, _ = _collect_tokens(target_index, return_insides=False)
false_positives = tested_tokens.difference(target_tokens)
false_negatives = target_tokens.difference(tested_tokens)
if self.print_debug in ('fp', 'all'):
self.debug_handle.write('False Positives\n')
for false_positive in false_positives:
_print_example(document.text, false_positive, self.debug_range,
self.debug_handle)
self.debug_handle.write('\n')
if self.print_debug in ('fn', 'all'):
self.debug_handle.write('False Negatives\n')
for false_negative in false_negatives:
_print_example(document.text, false_negative, self.debug_range,
self.debug_handle)
self.debug_handle.write('\n')
local = ConfusionMatrix(
true_positives=len(tested_tokens.intersection(target_tokens)),
false_negatives=len(false_negatives),
false_positives=len(false_positives)
)
self._matrix += local
return {
'precision': local.precision,
'recall': local.recall,
'f1': local.f1
}
def _print_example(text: str,
token: Tuple[int, int],
debug_range: int,
debug_handle: TextIO):
start, end = token
print_start = max(0, start - debug_range)
print_end = min(end + debug_range, len(text))
text = text[print_start:start] + '{' + text[start:end] + '}' + text[
end:print_end]
text = text.replace('\n', ' ') + '\n'
debug_handle.write(text)