Skip to content

scandeval.data_loading

docs module scandeval.data_loading

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Functions related to the loading of the data."""

import logging
import sys
import time

from datasets import Dataset, DatasetDict, load_dataset
from datasets.exceptions import DatasetsError
from huggingface_hub.utils import HfHubHTTPError
from numpy.random import Generator

from .data_models import BenchmarkConfig, DatasetConfig
from .exceptions import InvalidBenchmark
from .utils import unscramble

logger = logging.getLogger("scandeval")


def load_data(
    rng: Generator, dataset_config: "DatasetConfig", benchmark_config: "BenchmarkConfig"
) -> list[DatasetDict]:
    """Load the raw bootstrapped datasets.

    Args:
        rng:
            The random number generator to use.
        dataset_config:
            The configuration for the dataset.
        benchmark_config:
            The configuration for the benchmark.

    Returns:
        A list of bootstrapped datasets, one for each iteration.
    """
    num_attempts = 5
    for _ in range(num_attempts):
        try:
            dataset = load_dataset(
                path=dataset_config.huggingface_id,
                cache_dir=benchmark_config.cache_dir,
                token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_"),
            )
            break
        except (FileNotFoundError, DatasetsError):
            logger.warning(
                f"Failed to load dataset {dataset_config.huggingface_id!r}. Retrying..."
            )
            time.sleep(1)
            continue
        except HfHubHTTPError:
            raise InvalidBenchmark("The Hugging Face Hub seems to be down.")
    else:
        raise InvalidBenchmark(
            f"Failed to load dataset {dataset_config.huggingface_id!r} after "
            f"{num_attempts} attempts."
        )

    assert isinstance(dataset, DatasetDict)  # type: ignore[used-before-def]

    dataset = DatasetDict({key: dataset[key] for key in ["train", "val", "test"]})

    if not benchmark_config.evaluate_test_split:
        dataset["test"] = dataset["val"]

    # Remove empty examples from the datasets
    for text_feature in ["tokens", "text"]:
        if text_feature in dataset["train"].features:
            dataset = dataset.filter(lambda x: len(x[text_feature]) > 0)

    # If we are testing then truncate the test set
    if hasattr(sys, "_called_from_test"):
        dataset["test"] = dataset["test"].select(range(1))

    # Bootstrap the splits
    bootstrapped_splits: dict[str, list[Dataset]] = dict()
    for split in ["train", "val", "test"]:
        bootstrap_indices = rng.integers(
            0,
            len(dataset[split]),
            size=(benchmark_config.num_iterations, len(dataset[split])),
        )
        bootstrapped_splits[split] = [
            dataset[split].select(bootstrap_indices[idx])
            for idx in range(benchmark_config.num_iterations)
        ]

    datasets = [
        DatasetDict(
            {
                split: bootstrapped_splits[split][idx]
                for split in ["train", "val", "test"]
            }
        )
        for idx in range(benchmark_config.num_iterations)
    ]
    return datasets