# Copyright 2020 Regents of the University of Minnesota.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Labels functionality."""
import threading
from abc import abstractmethod, ABCMeta
from queue import Queue
from typing import (
TYPE_CHECKING,
TypeVar,
NamedTuple,
Any,
Mapping,
Sequence,
Union,
Optional
)
if TYPE_CHECKING:
from mtap import (
Document
)
LocationLike = Union['Location', 'Label']
HasStart = Union['Location', 'Label', float]
HasEnd = Union['Location', 'Label', float]
[docs]
class Location(NamedTuple):
"""A location in text, a tuple of (`start_index`, `end_index`).
Used to perform comparison of labels based on their locations.
"""
start_index: float
"""The start index inclusive of the location in text."""
end_index: float
"""The end index exclusive of the location in text."""
[docs]
def covers(self, other: LocationLike):
"""Whether the span of text covered by this label completely overlaps
the span of text covered by the ``other`` label or location.
Args:
other: A location or label to compare against.
Returns:
``True`` if `other` is completely overlapped/covered
``False`` otherwise.
"""
return self.start_index <= other.start_index \
and self.end_index >= other.end_index
[docs]
def relative_to(self, location: HasStart) -> 'Location':
"""Creates a location relative to the same origin as ``location``
and makes it relative to ``location``.
Args:
location: A location to relativize this
location to.
Returns:
A copy with updated indices.
Examples:
>>> sentence = Location(10, 20)
>>> token = Location(10, 15)
>>> token.relative_to(sentence)
Location(start_index=0, end_index=5)
"""
try:
start_index = location.start_index
except AttributeError:
start_index = location
if not isinstance(start_index, int):
raise ValueError(
'location must be Label, Location, or an int value')
return Location(self.start_index - start_index,
self.end_index - start_index)
[docs]
def offset_by(self, location: HasStart) -> 'Location':
"""Creates a location by offsetting this location by an integer or the
``start_index`` of a location / label. De-relativizes this location.
Args:
location: A location to offset this location by.
Returns:
A copy with updated indices.
Examples:
>>> sentence = Location(10, 20)
>>> token_in_sentence = Location(0, 5)
>>> token_in_sentence.offset_by(sentence)
Location(start_index=10, end_index=15)
"""
try:
start_index = location.start_index
except AttributeError:
start_index = location
if not isinstance(start_index, int):
raise ValueError(
'location must be Label, Location, or an int value'
)
return Location(self.start_index + start_index,
self.end_index + start_index)
[docs]
class Label(metaclass=ABCMeta):
"""An abstract base class for a label of attributes on text.
"""
@property
@abstractmethod
def document(self) -> 'Document':
"""The parent document this label appears on."""
...
@document.setter
@abstractmethod
def document(self, value: 'Document'):
"""Sets the label's document, this will automatically be done when the
label is created via a Document (i.e. get_label_index) or added to a
document (i.e. via labeler or add_labels).
"""
...
@property
@abstractmethod
def label_index_name(self) -> str:
"""The label index this label appears on."""
...
@label_index_name.setter
@abstractmethod
def label_index_name(self, value: str):
"""Sets the name for the label index this label appears on. Will
automatically be called when a label is added to a document via
labeler or add_labels."""
...
@property
@abstractmethod
def identifier(self) -> int:
"""The index of the label within its label index."""
...
@identifier.setter
@abstractmethod
def identifier(self, value: int):
"""The index of the label within its label index. Labels will
automatically be assigned this when added to a document via labeler or
add_labels."""
...
@property
@abstractmethod
def start_index(self) -> int:
"""The index of the first character of the text covered by this label.
"""
...
@start_index.setter
@abstractmethod
def start_index(self, value: int):
...
@property
@abstractmethod
def end_index(self) -> int:
"""The index after the last character of the text covered by this
label.
"""
...
@end_index.setter
@abstractmethod
def end_index(self, value: int):
...
@property
def location(self) -> Location:
"""A tuple of (start_index, end_index) used to perform sorting and
comparison first based on start_index, then based on end_index.
"""
return Location(self.start_index, self.end_index)
@property
def text(self):
"""The slice of document text covered by this label. Will
retrieve from events server if it is not cached locally.
"""
return self.document.text[self.start_index:self.end_index]
[docs]
@abstractmethod
def shallow_fields_equal(self, other) -> bool:
"""Tests if the fields on this label and locations of references are
the same as another label.
Args:
other: The other label to test.
Returns:
True if all the fields are equal and the references are at the
same locations.
"""
pass
@abstractmethod
def collect_floating_references(self, s):
pass
L = TypeVar('L', bound=Label)
_repr_local = threading.local()
def _collect_refs_from_item(o, queue, s):
if isinstance(o, Label):
if o.identifier is None:
s.add(id(o))
elif isinstance(o, Mapping):
for _, v in o.items():
if v is not None:
queue.put(v)
elif isinstance(o, Sequence):
for v in o:
if v is not None:
queue.put(v)
[docs]
class GenericLabel(Label):
"""Default implementation of the Label class which uses a dictionary to
store attributes.
Will be suitable for the majority of use cases for labels.
Args:
start_index: The index of the first character in text to be included
in the label.
end_index: The index after the last character in text to be included
in the label.
document: The parent document of the label. This will be automatically
set if the label is created via labeler.
**kwargs : Arbitrary, any other fields that should be added to the
label, values must be json-serializable.
Examples:
>>> pos_tag = pos_tag_labeler(0, 5)
>>> pos_tag.tag = 'NNS'
>>> pos_tag.tag
'NNS'
>>> pos_tag2 = pos_tag_labeler(6, 10, tag='VB')
>>> pos_tag2.tag
'VB'
"""
def __init__(self, start_index: int, end_index: int, *,
identifier: Optional[int] = None,
document: 'Optional[Document]' = None,
label_index_name: Optional[str] = None,
fields: Optional[dict] = None,
reference_field_ids: Optional[dict] = None,
**kwargs):
self._document = document
self._label_index_name = label_index_name
self._identifier = identifier
self._start_index = int(start_index)
self._end_index = int(end_index)
if fields is None:
self.fields = {}
else:
self.fields = fields
if reference_field_ids is None:
self.reference_field_ids = {}
else:
self.reference_field_ids = reference_field_ids
self.reference_cache = {}
for key, value in kwargs.items():
setattr(self, key, value)
@property
def document(self) -> 'Document':
return self._document
@document.setter
def document(self, document: 'Document'):
self._document = document
@property
def label_index_name(self) -> str:
return self._label_index_name
@label_index_name.setter
def label_index_name(self, value: str):
self._label_index_name = value
@property
def identifier(self) -> int:
return self._identifier
@identifier.setter
def identifier(self, value: int):
self._identifier = value
@property
def start_index(self) -> int:
return self._start_index
@start_index.setter
def start_index(self, start_index: int):
self._start_index = start_index
@property
def end_index(self) -> int:
return self._end_index
@end_index.setter
def end_index(self, end_index: int):
self._end_index = end_index
def _is_reserved(self, key):
return key in self.__dict__.keys() or key in vars(
GenericLabel) or key in vars(Label)
def __getattr__(self, item):
try:
return self.fields[item]
except KeyError:
pass
try:
return self.reference_cache[item]
except KeyError:
pass
try:
ref_value = self.reference_field_ids[item]
self.reference_cache[item] = _dereference(ref_value, self.document)
return self.reference_cache[item]
except KeyError:
raise AttributeError(
'Key "{}" not in fields, reference cache, or reference ids.'
.format(item))
def __setattr__(self, key, value):
if key in ('document', 'label_index_name', 'identifier', 'start_index',
'end_index',
'_document', '_label_index_name', '_identifier',
'_start_index', '_end_index',
'fields', 'reference_field_ids', 'reference_cache'):
object.__setattr__(self, key, value)
return
if self._is_reserved(key):
raise ValueError('The key "{}" is a reserved key.'.format(key))
is_ref = _is_referential(value, [id(self)])
if is_ref:
self.reference_cache[key] = value
else:
self.fields[key] = value
def __eq__(self, other):
if not isinstance(other, GenericLabel):
return False
if other is self:
return True
if not self.location == other.location:
return False
return self.shallow_fields_equal(other)
def shallow_fields_equal(self, other):
if not self.fields == other.fields:
return False
refs = set(self.reference_field_ids.keys()).union(
self.reference_cache.keys())
other_refs = set(other.reference_field_ids.keys()).union(
other.reference_cache.keys())
if not refs == other_refs:
return False
for k in refs:
try:
if self.reference_field_ids[k] == other.reference_field_ids[k]:
continue
except KeyError:
pass
self_k = getattr(self, k)
other_k = getattr(other, k)
if not _collect_locations(self_k) == _collect_locations(other_k):
return False
return True
def __repr__(self):
try:
stack = _repr_local.stack
except AttributeError:
stack = set()
_repr_local.stack = stack
if id(self) in stack:
return 'GenericLabel(...)'
stack.add(id(self))
attributes = [repr(self.start_index), repr(self.end_index)]
for k, v in self.fields.items():
attributes.append("{}={}".format(k, repr(v)))
for k, v in self.reference_cache.items():
attributes.append("{}={}".format(k, repr(v)))
for k, v in self.reference_field_ids.items():
if k not in self.reference_cache:
attributes.append("{}=ref:{}".format(k, repr(v)))
stack.remove(id(self))
return "GenericLabel(".format() + ", ".join(attributes) + ")"
def collect_floating_references(self, s):
queue = Queue()
self._populate_queue(queue)
while not queue.empty():
o = queue.get_nowait()
_collect_refs_from_item(o, queue, s)
def _populate_queue(self, queue):
for k, v in self.reference_cache.items():
if v is not None:
queue.put(v)
[docs]
def label(start_index: int,
end_index: int,
*, document: 'Optional[Document]' = None,
**kwargs) -> GenericLabel:
"""An alias for :class:`GenericLabel`.
Args:
start_index: The index of the first character in text to be included
in the label.
end_index: The index after the last character in text to be included
in the label.
document: The parent document of the label. This will be automatically
set if the label is created via labeler.
**kwargs : Arbitrary, any other fields that should be added to the
label, values must be json-serializable.
"""
return GenericLabel(start_index, end_index, document=document, **kwargs)
def _is_referential(o: Any, parents=None) -> bool:
if parents is None:
parents = [id(o)]
if isinstance(o, (str, float, bool, int)) or o is None:
return False
elif isinstance(o, Label):
return True
elif isinstance(o, Mapping):
return _mapping_is_ref(o, parents)
elif isinstance(o, Sequence):
return _sequence_is_ref(o, parents)
else:
raise TypeError('Unrecognized type')
def _mapping_is_ref(o: Mapping, parents) -> bool:
map_is_ref = None
for v in o.values():
if id(v) in parents:
raise ValueError('Recursive loop')
x = _is_referential(v, parents + [id(v)])
if map_is_ref is None:
map_is_ref = x
elif x != map_is_ref:
raise TypeError(
'Label dictionaries cannot have mixes of references to '
'labels and primitive types.'
)
return map_is_ref
def _sequence_is_ref(o: Sequence, parents) -> bool:
seq_is_ref = None
for v in o:
if id(v) in parents:
raise ValueError('Recursive loop')
x = _is_referential(v, parents + [id(v)])
if seq_is_ref is None:
seq_is_ref = x
elif x != seq_is_ref:
raise TypeError(
'Label lists cannot have mixes of references to labels'
'and primitive types.')
return seq_is_ref
def _dereference(o: Any, document: 'Document') -> Any:
if o is None:
return o
if isinstance(o, str):
label_index_name, label_id = o.split(':')
label_index = document.labels[label_index_name]
label_ = label_index[int(label_id)]
return label_
if isinstance(o, Mapping):
replacement = {}
for k, v in o.items():
replacement[k] = _dereference(v, document)
return replacement
if isinstance(o, Sequence):
replacement = [_dereference(v, document) for v in o]
return replacement
def _collect_locations(o):
if o is None:
return None
if isinstance(o, Label):
return o.location
if isinstance(o, Mapping):
return {k: _collect_locations(v) for k, v in o.items()}
if isinstance(o, Sequence):
return [_collect_locations(v) for v in o]