Skip to content

pelinker.onto

ChunkMapper dataclass

Bases: BaseDataclass

Maps encoder chunks back to documents and optional pooled span rows.

When :func:pelinker.util.texts_to_vrep runs multiple word_modes, it calls :func:pelinker.util.render_elementary_tensor_table repeatedly on the same instance. Fields text_word_spans_list, token_word_spans_list, tt_expressions, and mapping_table therefore reflect only the last grouping pass; read per-mode results from :class:ReportBatch instead.

Source code in pelinker/onto.py
@dataclasses.dataclass
class ChunkMapper(BaseDataclass):
    """Maps encoder chunks back to documents and optional pooled span rows.

    When :func:`pelinker.util.texts_to_vrep` runs multiple ``word_modes``, it calls
    :func:`pelinker.util.render_elementary_tensor_table` repeatedly on the **same**
    instance. Fields ``text_word_spans_list``, ``token_word_spans_list``,
    ``tt_expressions``, and ``mapping_table`` therefore reflect **only the last**
    grouping pass; read per-mode results from :class:`ReportBatch` instead.
    """

    tensor: torch.Tensor  # n_layers x n_batch x n_len x n_emb - tensor where n_batch dim goes over all chunks
    chunks: list[str]  # flat list of chunks
    token_spans_list: list[
        list[tuple[int, int]]
    ]  # for each chunk contains a list of token spans
    it_ic: list[tuple[int, int]]
    cumulative_lens: list[list[int]]
    text_word_spans_list: list[list[tuple[int, int]]] | None = None
    token_word_spans_list: list[list[tuple[int, int]]] | None = None
    mapping_table: list[tuple[int, int, tuple[int, int], tuple[int, int]]] | None = None
    text_chunk_map: defaultdict[int, list] = field(
        default_factory=lambda: defaultdict(list)
    )
    tt_expressions: list[torch.Tensor] = field(
        default_factory=list
    )  # n_expressions [n_len x n_emb]

    def set_token_word_spans(self, word_int_bounds):
        from pelinker.util import map_words_to_tokens_list

        self.token_word_spans_list, self.text_word_spans_list = (
            map_words_to_tokens_list(self.token_spans_list, word_int_bounds)
        )

    def set_mapping_table(self):
        it_ic = sorted(self.it_ic)
        self.mapping_table = []
        self.text_chunk_map = defaultdict(list)

        if self.text_word_spans_list is None:
            pass
        for (ichunk, (itext, ichunk_local)), chsp in zip(
            enumerate(it_ic), self.text_word_spans_list
        ):
            self.text_chunk_map[itext].append(ichunk)
            chunk_offset = self.cumulative_lens[itext][ichunk_local]
            for a, b in chsp:
                self.mapping_table += [
                    (itext, ichunk, (a, b), (a + chunk_offset, b + chunk_offset))
                ]

    def map_chunk_to_text(self, itext, ichunk_local, a=0):
        return a + self.cumulative_lens[itext][ichunk_local]

    def ichunk_to_itext_ichunk_local(self, ichunk: int):
        return self.it_ic[ichunk]

    def ichunk_char_offset(self, ichunk: int):
        itext, ichunk_local = self.ichunk_to_itext_ichunk_local(ichunk)
        return self.cumulative_lens[itext][ichunk_local]

MentionCandidate dataclass

Bases: BaseDataclass

Typed mention payload used by :class:pelinker.model.Linker predictions.

Source code in pelinker/onto.py
@dataclasses.dataclass
class MentionCandidate(BaseDataclass):
    """Typed mention payload used by :class:`pelinker.model.Linker` predictions."""

    mention: str
    a: int | None
    b: int | None
    a_abs: int | None = None
    b_abs: int | None = None
    itext: int | None = None
    ichunk: int | None = None
    word_grouping: WordGrouping | None = None
    lemma: str = ""

ReportBatch dataclass

Bases: BaseDataclass

Batch of texts with shared encoder state and one holder list per WordGrouping.

chunk_mapper is shared across groupings; its span/pooling fields may match only the last mode processed—use _data / __getitem__ for mode-specific embeddings and expressions.

Source code in pelinker/onto.py
@dataclasses.dataclass
class ReportBatch(BaseDataclass):
    """Batch of texts with shared encoder state and one holder list per ``WordGrouping``.

    ``chunk_mapper`` is shared across groupings; its span/pooling fields may match
    only the **last** mode processed—use ``_data`` / ``__getitem__`` for mode-specific
    embeddings and expressions.
    """

    _data: list[ExpressionHolderBatch]
    texts: list[str]
    chunk_mapper: ChunkMapper

    def __post_init__(self):
        for item in self._data:
            if len(self.texts) != len(item.expression_data):
                raise ValueError(
                    "The number of ExpressionHolders does not match the number of texts"
                )

    def __getitem__(self, wg: WordGrouping) -> ExpressionHolderBatch:
        filtered = [item for item in self._data if item.word_grouping == wg]
        if filtered:
            return filtered[0]
        else:
            raise ValueError(f"{wg} not available")

    def get_data_for_grouping(self, wg: WordGrouping):
        return [item for item in self._data if item.word_grouping == wg]

    def available_groupings(self):
        return [item.word_grouping for item in self._data]

    def get_text_embeddings(self, layers_spec: str | list[int]):
        """
        Extract sentence-level embeddings for each text in the batch.

        Args:
            layers_spec: Layer specification (string digits or negative indices); see
                :func:`pelinker.util.normalize_layers_spec`. Not ``\"sent\"``.

        Returns:
            list[torch.Tensor]: List of embeddings, one per text in self.texts
        """
        from pelinker.util import normalize_layers_spec, tt_aggregate_normalize

        layers = normalize_layers_spec(
            layers_spec,
            n_hidden_states=self.chunk_mapper.tensor.shape[0],
        )
        # Extract chunk-level embeddings from chunk_mapper.tensor
        # chunk_mapper.tensor has shape: n_layers x n_chunks x n_tokens x n_emb
        chunk_embeddings = tt_aggregate_normalize(self.chunk_mapper.tensor, layers)

        # Map chunks back to texts
        # Group chunks by text index
        text_chunks = {}
        for ichunk, (itext, ichunk_local) in enumerate(self.chunk_mapper.it_ic):
            if itext not in text_chunks:
                text_chunks[itext] = []
            text_chunks[itext].append(ichunk)

        # Aggregate embeddings for each text (average if multiple chunks)
        text_embeddings = []
        for itext in range(len(self.texts)):
            chunk_indices = text_chunks.get(itext, [])
            if chunk_indices:
                # Get embeddings for all chunks of this text
                chunk_embs = chunk_embeddings[chunk_indices]
                # Average over chunks if multiple chunks
                if len(chunk_indices) > 1:
                    text_emb = chunk_embs.mean(dim=0)
                else:
                    text_emb = chunk_embs[0]
                text_embeddings.append(text_emb)
            else:
                # Fallback: create zero embedding if no chunks found
                # This shouldn't happen, but handle it gracefully
                text_embeddings.append(torch.zeros_like(chunk_embeddings[0]))

        return text_embeddings

get_text_embeddings(layers_spec)

Extract sentence-level embeddings for each text in the batch.

Parameters:

Name Type Description Default
layers_spec str | list[int]

Layer specification (string digits or negative indices); see :func:pelinker.util.normalize_layers_spec. Not "sent".

required

Returns:

Type Description

list[torch.Tensor]: List of embeddings, one per text in self.texts

Source code in pelinker/onto.py
def get_text_embeddings(self, layers_spec: str | list[int]):
    """
    Extract sentence-level embeddings for each text in the batch.

    Args:
        layers_spec: Layer specification (string digits or negative indices); see
            :func:`pelinker.util.normalize_layers_spec`. Not ``\"sent\"``.

    Returns:
        list[torch.Tensor]: List of embeddings, one per text in self.texts
    """
    from pelinker.util import normalize_layers_spec, tt_aggregate_normalize

    layers = normalize_layers_spec(
        layers_spec,
        n_hidden_states=self.chunk_mapper.tensor.shape[0],
    )
    # Extract chunk-level embeddings from chunk_mapper.tensor
    # chunk_mapper.tensor has shape: n_layers x n_chunks x n_tokens x n_emb
    chunk_embeddings = tt_aggregate_normalize(self.chunk_mapper.tensor, layers)

    # Map chunks back to texts
    # Group chunks by text index
    text_chunks = {}
    for ichunk, (itext, ichunk_local) in enumerate(self.chunk_mapper.it_ic):
        if itext not in text_chunks:
            text_chunks[itext] = []
        text_chunks[itext].append(ichunk)

    # Aggregate embeddings for each text (average if multiple chunks)
    text_embeddings = []
    for itext in range(len(self.texts)):
        chunk_indices = text_chunks.get(itext, [])
        if chunk_indices:
            # Get embeddings for all chunks of this text
            chunk_embs = chunk_embeddings[chunk_indices]
            # Average over chunks if multiple chunks
            if len(chunk_indices) > 1:
                text_emb = chunk_embs.mean(dim=0)
            else:
                text_emb = chunk_embs[0]
            text_embeddings.append(text_emb)
        else:
            # Fallback: create zero embedding if no chunks found
            # This shouldn't happen, but handle it gracefully
            text_embeddings.append(torch.zeros_like(chunk_embeddings[0]))

    return text_embeddings