Skip to content

scandeval.task_utils.sequence_classification

source module scandeval.task_utils.sequence_classification

Utility functions related to the sequence-classification task group.

Functions

source compute_metrics(model_outputs_and_labels: tuple[Predictions, Labels], dataset_config: DatasetConfig, benchmark_config: BenchmarkConfig)dict[str, float]

Compute the metrics needed for evaluation.

Parameters

  • model_outputs_and_labels : tuple[Predictions, Labels]

    The first sequence contains the model outputs and the second sequence contains the true labels.

  • dataset_config : DatasetConfig

    The configuration of the dataset.

  • benchmark_config : BenchmarkConfig

    The configuration of the benchmark.

Returns

  • dict[str, float] A dictionary with the names of the metrics as keys and the metric values as values.

source extract_labels_from_generation(input_batch: dict[str, list], model_output: GenerativeModelOutput, dataset_config: DatasetConfig)list[str]

Extract the predicted labels from the generated output.

Parameters

  • input_batch : dict[str, list]

    The input batch, where the keys are the feature names and the values are lists with the feature values.

  • model_output : GenerativeModelOutput

    The raw generated output of the model.

  • dataset_config : DatasetConfig

    The configuration of the dataset.

Returns

  • list[str] The predicted labels.

source get_closest_logprobs_labels(generation_logprobs: list[list[list[tuple[str, float]]]], dataset_config: DatasetConfig)list[str]

Get the labels with the highest predicted logprob value.

In case a candidate label is split into multiple tokens, we only use the first token to compute the logprob value. E.g., if the candidate label "positive" is tokenised as ["pos", "itive"], we only use the logprob value of "pos" to represent the logprob value of the entire label.

Parameters

  • generation_logprobs : list[list[list[tuple[str, float]]]]

    The logprobs of the generated tokens, for all samples in the batch. Of shape (batch_size, num_tokens, num_logprobs).

  • dataset_config : DatasetConfig

    The configuration of the dataset.

Returns

  • list[str] The predicted labels.

Raises

  • InvalidBenchmark

    If no candidate label can be found for any of the generated labels.

source get_closest_word_edit_labels(generated_sequences: list[str], dataset_config: DatasetConfig)list[str]

Get the labels with the smallest edit distance to the predicted labels.

Parameters

  • generated_sequences : list[str]

    The generated sequences from the model.

  • dataset_config : DatasetConfig

    The configuration of the dataset.

Returns

  • list[str] The candidate labels with the smallest edit distance to the predicted labels.