scandeval.utils
source module scandeval.utils
Utility functions to be used in other scripts.
Classes
-
HiddenPrints — Context manager which removes all terminal output.
Functions
-
create_model_cache_dir — Create cache directory for a model.
-
clear_memory — Clears the memory of unused items.
-
enforce_reproducibility — Ensures reproducibility of experiments.
-
is_module_installed — Check if a module is installed.
-
block_terminal_output — Blocks libraries from writing output to the terminal.
-
get_class_by_name — Get a class by its name.
-
kebab_to_pascal — Converts a kebab-case string to PascalCase.
-
internet_connection_available — Checks if internet connection is available by pinging google.com.
-
get_special_token_metadata — Get the special token metadata for a tokenizer.
-
raise_if_model_output_contains_nan_values — Raise an exception if the model output contains NaN values.
-
should_prompts_be_stripped — Determine if we should strip the prompts for few-shot evaluation.
-
should_prefix_space_be_added_to_labels — Determine if we should add a prefix space to the labels.
-
get_end_of_chat_token_ids — Get the end token ID for chat models.
-
scramble — Scramble a string in a bijective manner.
-
unscramble — Unscramble a string in a bijective manner.
-
log_once — Log a message once.
source create_model_cache_dir(cache_dir: str, model_id: str) → str
Create cache directory for a model.
Parameters
-
cache_dir : str —
The cache directory.
-
model_id : str —
The model ID.
Returns
-
str — The path to the cache directory.
Clears the memory of unused items.
source enforce_reproducibility(seed: int = 4242)
Ensures reproducibility of experiments.
Parameters
-
seed : int —
Seed for the random number generator.
source is_module_installed(module: str) → bool
Check if a module is installed.
This is used when dealing with spaCy models, as these are installed as separate Python packages.
Parameters
-
module : str —
The name of the module.
Returns
-
bool — Whether the module is installed or not.
source block_terminal_output()
Blocks libraries from writing output to the terminal.
This filters warnings from some libraries, sets the logging level to ERROR for some
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
disables most of the logging from the transformers
library.
source get_class_by_name(class_name: str | list[str], module_name: str) → t.Type | None
Get a class by its name.
Parameters
-
class_name : str | list[str] —
The name of the class, written in kebab-case. The corresponding class name must be the same, but written in PascalCase, and lying in a module with the same name, but written in snake_case. If a list of strings is passed, the first class that is found is returned.
-
module_name : str —
The name of the module where the class is located.
Returns
-
t.Type | None — The class. If the class is not found, None is returned.
source kebab_to_pascal(kebab_string: str) → str
Converts a kebab-case string to PascalCase.
Parameters
-
kebab_string : str —
The kebab-case string.
Returns
-
str — The PascalCase string.
source internet_connection_available() → bool
Checks if internet connection is available by pinging google.com.
Returns
-
bool — Whether or not internet connection is available.
source get_special_token_metadata(tokenizer: PreTrainedTokenizer) → dict
Get the special token metadata for a tokenizer.
Parameters
-
tokenizer : PreTrainedTokenizer —
The tokenizer.
Returns
-
dict — The special token metadata.
source class HiddenPrints()
Context manager which removes all terminal output.
source raise_if_model_output_contains_nan_values(model_output: Predictions) → None
Raise an exception if the model output contains NaN values.
Parameters
-
model_output : Predictions —
The model output to check.
Raises
-
If the model output contains NaN values.
source should_prompts_be_stripped(labels_to_be_generated: list[str], tokenizer: PreTrainedTokenizer) → bool
Determine if we should strip the prompts for few-shot evaluation.
This is the case if the tokenizer needs to include the space as part of the label token. The strategy is thus to tokenize a label with a preceeding colon (as in the prompts), i.e., ": positive", and check if the tokenization starts with the tokens of ": ". If this is the case, then we should not strip the prompts, since the tokenizer produces the whitespace token separately.
Parameters
-
labels_to_be_generated : list[str] —
The labels that are to be generated.
-
tokenizer : PreTrainedTokenizer —
The tokenizer used to tokenize the labels.
Returns
-
bool — Whether we should strip the prompts.
source should_prefix_space_be_added_to_labels(labels_to_be_generated: list[str], tokenizer: PreTrainedTokenizer) → bool
Determine if we should add a prefix space to the labels.
This is the case if the prompts are stripped and the tokenizer doesn't automatically add prefix whitespaces to the labels.
Parameters
-
labels_to_be_generated : list[str] —
The labels that are to be generated.
-
tokenizer : PreTrainedTokenizer —
The tokenizer used to tokenize the labels.
Returns
-
bool — Whether we should add a prefix space to the labels.
source get_end_of_chat_token_ids(tokenizer: PreTrainedTokenizer) → list[int] | None
Get the end token ID for chat models.
This is only relevant for tokenizers with a chat template.
Parameters
-
tokenizer : PreTrainedTokenizer —
The tokenizer.
Returns
-
list[int] | None — The token IDs used to end chats, or None if the tokenizer does not have a chat template.
Raises
-
ValueError —
If the end-of-chat token could not be located.
source scramble(text: str) → str
Scramble a string in a bijective manner.
Parameters
-
text : str —
The string to scramble.
Returns
-
str — The scrambled string.
source unscramble(scrambled_text: str) → str
Unscramble a string in a bijective manner.
Parameters
-
scrambled_text : str —
The scrambled string to unscramble.
Returns
-
str — The unscrambled string.
source log_once(message: str, level: int = logging.INFO) → None
Log a message once.
This is ensured by caching the input/output pairs of this function, using the
functools.cache
decorator.
Parameters
-
message : str —
The message to log.
-
level : int —
The logging level. Defaults to logging.INFO.
Raises
-
ValueError