Skip to content

scandeval.task_utils.text_to_text

docs module scandeval.task_utils.text_to_text

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Utility functions related to the text-to-text task group."""

import logging
import typing as t

import evaluate
import numpy as np
from evaluate import EvaluationModule

from ..constants import METRIC_ATTRIBUTES_TAKING_UP_MEMORY
from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
from ..exceptions import InvalidBenchmark
from ..utils import (
    HiddenPrints,
    clear_memory,
    raise_if_model_output_contains_nan_values,
)

if t.TYPE_CHECKING:
    from ..types import Labels, Predictions


logger = logging.getLogger("scandeval")


def compute_metrics(
    model_outputs_and_labels: tuple["Predictions", "Labels"],
    dataset_config: "DatasetConfig",
    benchmark_config: "BenchmarkConfig",
) -> dict[str, float]:
    """Compute the metrics needed for evaluation.

    Args:
        model_outputs_and_labels:
            The first sequence contains the model outputs and the second sequence
            contains the true labels.
        dataset_config:
            The configuration of the dataset.
        benchmark_config:
            The configuration of the benchmark.

    Returns:
        A dictionary with the names of the metrics as keys and the metric values as
        values.
    """
    model_outputs, labels = model_outputs_and_labels
    raise_if_model_output_contains_nan_values(model_output=model_outputs)

    metrics = {
        metric_cfg.name: (
            evaluate.load(
                path=metric_cfg.huggingface_id, cache_dir=benchmark_config.cache_dir
            )
            if metric_cfg.huggingface_id != ""
            else None
        )
        for metric_cfg in dataset_config.task.metrics
    }

    model_output_dtype = np.asarray(model_outputs).dtype
    output_is_prob = model_output_dtype in [np.float16, np.float32, np.float64]
    if output_is_prob:
        predictions = np.asarray(model_outputs).argmax(axis=-1)
    else:
        predictions = model_outputs

    results: dict[str, float] = dict()
    for cfg in dataset_config.task.metrics:
        metric = metrics[cfg.name]
        assert isinstance(metric, EvaluationModule)

        # Some metrics can be computed on hardware accelerators. In this case we
        # start by setting the device to the same device as the model
        if cfg.compute_kwargs.get("device", None) == "auto":
            cfg.compute_kwargs["device"] = benchmark_config.device.type

        while True:
            try:
                with HiddenPrints():
                    score_dict: dict[str, float] | None = metric.compute(
                        predictions=predictions, references=labels, **cfg.compute_kwargs
                    )

                # Clear the cache of the BERTScorer to avoid memory leaks
                for attribute in METRIC_ATTRIBUTES_TAKING_UP_MEMORY:
                    if hasattr(metric, attribute):
                        delattr(metric, attribute)

                clear_memory()
                break
            except Exception as e:
                # Clear the cache of the BERTScorer to avoid memory leaks
                if hasattr(metric, "cached_bertscorer"):
                    del metric.cached_bertscorer
                    clear_memory()

                oom_error = [
                    "CUDA out of memory",
                    "CUDA error",
                    "MPS backend out of memory",
                ]
                if not any(error in str(e) for error in oom_error):
                    raise InvalidBenchmark(str(e))

                if cfg.compute_kwargs.get("batch_size", 1) > 1:
                    batch_size = cfg.compute_kwargs["batch_size"]
                    cfg.compute_kwargs["batch_size"] = batch_size // 2
                    logger.debug(
                        "Out of memory error occurred during the computation of "
                        f"the metric {cfg.pretty_name}. Reducing the batch size to "
                        f"{cfg.compute_kwargs['batch_size']}."
                    )
                elif cfg.compute_kwargs.get("device", "cpu") != "cpu":
                    cfg.compute_kwargs["batch_size"] = 32
                    cfg.compute_kwargs["device"] = "cpu"
                    logger.debug(
                        "Out of memory error occurred during the computation of "
                        f"the metric {cfg.pretty_name}. Moving the computation to "
                        "the CPU."
                    )
                else:
                    raise InvalidBenchmark(str(e))

        # The metric returns None if we are running on multi-GPU and the current
        # process is not the main process
        if score_dict is not None:
            scores = score_dict[cfg.results_key]
            if isinstance(scores, list):
                scores = sum(scores) / len(scores)
            results[cfg.name] = scores

    return results


def extract_labels_from_generation(
    input_batch: dict[str, list], model_output: "GenerativeModelOutput"
) -> list[t.Any]:
    """Extract the predicted labels from the generated output.

    Args:
        input_batch:
            The input batch, where the keys are the feature names and the values
            are lists with the feature values.
        model_output:
            The raw generated output of the model.

    Returns:
        The predicted labels.
    """
    return model_output.sequences