Skip to content

API Reference

datacollective.datasets

get_dataset_details(dataset_id)

Return dataset details from the MDC API as a dictionary.

Parameters:

Name Type Description Default
dataset_id str

The dataset ID (as shown in MDC platform).

required

Returns:

Type Description
dict[str, Any]

A dict with dataset details as returned by the API.

Raises:

Type Description
ValueError

If dataset_id is empty.

FileNotFoundError

If the dataset does not exist (404).

PermissionError

If access is denied (403).

RuntimeError

If rate limit is exceeded (429).

HTTPError

For other non-2xx responses.

Source code in src/datacollective/datasets.py
def get_dataset_details(dataset_id: str) -> dict[str, Any]:
    """
    Return dataset details from the MDC API as a dictionary.

    Args:
        dataset_id: The dataset ID (as shown in MDC platform).

    Returns:
        A dict with dataset details as returned by the API.

    Raises:
        ValueError: If dataset_id is empty.
        FileNotFoundError: If the dataset does not exist (404).
        PermissionError: If access is denied (403).
        RuntimeError: If rate limit is exceeded (429).
        requests.HTTPError: For other non-2xx responses.
    """
    if not dataset_id or not dataset_id.strip():
        raise ValueError("`dataset_id` must be a non-empty string")

    url = f"{_get_api_url()}/datasets/{dataset_id}"
    resp = send_api_request(method="GET", url=url)
    return dict(resp.json())

load_dataset(dataset_id, download_directory=None, show_progress=True, overwrite_existing=False, overwrite_extracted=False)

Download (if needed), extract (if not already extracted), and load the dataset into a pandas DataFrame.

If the dataset archive already exists in the download directory, it will not be re-downloaded unless overwrite_existing=True.

If there is a directory with the same name as the archive file without the suffix extension, we assume it has already been extracted, and it will not be re-extracted unless overwrite_extracted=True.

Uses the dataset schema to determine task-specific loading logic.

Automatically resumes interrupted downloads if a .checksum file exists from a previous attempt.

Parameters:

Name Type Description Default
dataset_id str

The dataset ID (as shown in MDC platform).

required
download_directory str | None

Directory where to save the downloaded archive file. If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

None
show_progress bool

Whether to show a progress bar during download.

True
overwrite_existing bool

Whether to overwrite existing archive.

False
overwrite_extracted bool

Whether to overwrite existing extracted files by re-extracting the archive file. Only makes sense when overwrite_existing is False. Will check in the download directory for existing extracted files with the default naming of the folder.

False

Returns: A pandas DataFrame with the loaded dataset.

Raises:

Type Description
ValueError

If dataset_id is empty or schema is unsupported.

FileNotFoundError

If the dataset does not exist (404).

PermissionError

If access is denied (403) or download directory is not writable.

RuntimeError

If rate limit is exceeded (429) or unexpected response format.

HTTPError

For other non-2xx responses.

Source code in src/datacollective/datasets.py
def load_dataset(
    dataset_id: str,
    download_directory: str | None = None,
    show_progress: bool = True,
    overwrite_existing: bool = False,
    overwrite_extracted: bool = False,
) -> pd.DataFrame:
    """
    Download (if needed), extract (if not already extracted), and load the dataset into a pandas DataFrame.

    If the dataset archive already exists in the download directory, it will not be re-downloaded
    unless `overwrite_existing=True`.

    If there is a directory with the same name as the archive file without the suffix extension, we assume
    it has already been extracted, and it will not be re-extracted unless `overwrite_extracted=True`.

    Uses the dataset schema to determine task-specific loading logic.

    Automatically resumes interrupted downloads if a .checksum file exists from a
    previous attempt.

    Args:
        dataset_id: The dataset ID (as shown in MDC platform).
        download_directory: Directory where to save the downloaded archive file.
            If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.
        show_progress: Whether to show a progress bar during download.
        overwrite_existing: Whether to overwrite existing archive.
        overwrite_extracted: Whether to overwrite existing extracted files by re-extracting the archive file.
            Only makes sense when overwrite_existing is False.
            Will check in the download directory for existing extracted files with the default naming of the folder.
    Returns:
        A pandas DataFrame with the loaded dataset.

    Raises:
        ValueError: If dataset_id is empty or schema is unsupported.
        FileNotFoundError: If the dataset does not exist (404).
        PermissionError: If access is denied (403) or download directory is not writable.
        RuntimeError: If rate limit is exceeded (429) or unexpected response format.
        requests.HTTPError: For other non-2xx responses.
    """
    schema = get_dataset_schema(dataset_id)
    if schema is None:
        try:
            get_dataset_details(dataset_id)
        except FileNotFoundError:
            raise RuntimeError(
                f"Dataset '{dataset_id}' does not exist in MDC or the ID is mistyped. "
            )
        raise RuntimeError(
            f"Dataset '{dataset_id}' exists but is not supported by load_dataset yet. "
            f"You can download the raw archive with: save_dataset_to_disk('{dataset_id}'). "
            f"If you are the data owner consider submitting a schema for your dataset via the registry: https://mozilla-data-collective.github.io/dataset-schema-registry/"
        )

    download_plan = get_download_plan(dataset_id, download_directory)
    archive_checksum = download_plan.checksum

    archive_path = save_dataset_to_disk(
        dataset_id=dataset_id,
        download_directory=download_directory,
        show_progress=show_progress,
        overwrite_existing=overwrite_existing,
    )
    base_dir = resolve_download_dir(download_directory)
    extract_dir = _extract_archive(
        archive_path=archive_path,
        dest_dir=base_dir,
        overwrite_extracted=overwrite_extracted,
    )

    schema = _resolve_schema(dataset_id, extract_dir, archive_checksum)
    return load_dataset_from_schema(schema, extract_dir)

save_dataset_to_disk(dataset_id, download_directory=None, show_progress=True, overwrite_existing=False)

Download the dataset archive to a local directory and return the archive path. Skips download if the target file already exists (unless overwrite_existing=True).

Automatically resumes interrupted downloads if a matching .checksum file exists from a previous attempt.

Parameters:

Name Type Description Default
dataset_id str

The dataset ID (as shown in MDC platform).

required
download_directory str | None

Directory where to save the downloaded archive file. If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

None
show_progress bool

Whether to show a progress bar during download.

True
overwrite_existing bool

Whether to overwrite the existing archive file.

False

Returns:

Type Description
Path

Path to the downloaded dataset archive.

Raises:

Type Description
ValueError

If dataset_id is empty.

FileNotFoundError

If the dataset does not exist (404).

PermissionError

If access is denied (403) or download directory is not writable.

RuntimeError

If rate limit is exceeded (429) or unexpected response format.

HTTPError

For other non-2xx responses.

Source code in src/datacollective/datasets.py
def save_dataset_to_disk(
    dataset_id: str,
    download_directory: str | None = None,
    show_progress: bool = True,
    overwrite_existing: bool = False,
) -> Path:
    """
    Download the dataset archive to a local directory and return the archive path.
    Skips download if the target file already exists (unless `overwrite_existing=True`).

    Automatically resumes interrupted downloads if a matching .checksum file exists from a
    previous attempt.

    Args:
        dataset_id: The dataset ID (as shown in MDC platform).
        download_directory: Directory where to save the downloaded archive file.
            If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.
        show_progress: Whether to show a progress bar during download.
        overwrite_existing: Whether to overwrite the existing archive file.

    Returns:
        Path to the downloaded dataset archive.

    Raises:
        ValueError: If dataset_id is empty.
        FileNotFoundError: If the dataset does not exist (404).
        PermissionError: If access is denied (403) or download directory is not writable.
        RuntimeError: If rate limit is exceeded (429) or unexpected response format.
        requests.HTTPError: For other non-2xx responses.
    """
    download_plan = get_download_plan(dataset_id, download_directory)

    # Case 1: Skip download if complete dataset archive already exists
    if download_plan.target_filepath.exists() and not overwrite_existing:
        logger.info(
            f"File already exists. "
            f"Skipping download: `{str(download_plan.target_filepath)}`"
        )
        return Path(download_plan.target_filepath)

    # If overwriting, clean up any existing complete or partial download files
    if overwrite_existing:
        cleanup_partial_download(
            download_plan.tmp_filepath, download_plan.checksum_filepath
        )
        if download_plan.target_filepath.exists():
            download_plan.target_filepath.unlink()

    # Determine whether to resume download based on existing .checksum and .part files
    resume_checksum = determine_resume_state(download_plan)

    # Write checksum file before starting download (for potential resume later)
    if download_plan.checksum and not resume_checksum:
        write_checksum_file(download_plan.checksum_filepath, download_plan.checksum)

    execute_download_plan(download_plan, resume_checksum, show_progress)

    # Download complete. Rename temp file to target and remove checksum file
    download_plan.tmp_filepath.replace(download_plan.target_filepath)
    if download_plan.checksum_filepath.exists():
        download_plan.checksum_filepath.unlink()

    logger.info(f"Saved dataset to `{str(download_plan.target_filepath)}`")
    return Path(download_plan.target_filepath)

datacollective.download

cleanup_partial_download(tmp_filepath, checksum_filepath)

Remove partial download files (.part and .checksum).

Source code in src/datacollective/download.py
def cleanup_partial_download(tmp_filepath: Path, checksum_filepath: Path) -> None:
    """Remove partial download files (.part and .checksum)."""
    if tmp_filepath.exists():
        tmp_filepath.unlink()
    if checksum_filepath.exists():
        checksum_filepath.unlink()

determine_resume_state(download_plan)

Determine whether to resume a download based on existing files.

Cases handled

Case 1: .checksum and .part exist, checksum matches -> resume download. Case 2: .checksum and .part exist, checksum does NOT match -> start fresh. Case 3: .part exists but no .checksum -> start fresh (cannot safely resume). Case 4: .checksum exists but no .part -> start fresh (orphaned checksum). Case 5: Neither .checksum nor .part exist -> start fresh.

Parameters:

Name Type Description Default
download_plan DownloadPlan

The DownloadPlan object with download details.

required

Returns:

Name Type Description
resume_checksum str | None

The checksum to use for resumption, or None if starting fresh.

Source code in src/datacollective/download.py
def determine_resume_state(download_plan: DownloadPlan) -> str | None:
    """
    Determine whether to resume a download based on existing files.

    Cases handled:
        Case 1: .checksum and .part exist, checksum matches -> resume download.
        Case 2: .checksum and .part exist, checksum does NOT match -> start fresh.
        Case 3: .part exists but no .checksum -> start fresh (cannot safely resume).
        Case 4: .checksum exists but no .part -> start fresh (orphaned checksum).
        Case 5: Neither .checksum nor .part exist -> start fresh.

    Args:
        download_plan: The DownloadPlan object with download details.

    Returns:
        resume_checksum: The checksum to use for resumption, or None if starting fresh.
    """
    tmp_filepath = download_plan.tmp_filepath

    # Check existence of .part and .checksum files
    part_exists = tmp_filepath.exists()
    checksum_file_exists = download_plan.checksum_filepath.exists()
    stored_checksum = (
        _read_checksum_file(download_plan.checksum_filepath)
        if checksum_file_exists
        else None
    )

    # Case 1: Both .part and .checksum exist
    if part_exists and checksum_file_exists and stored_checksum:
        if stored_checksum == download_plan.checksum:
            # Checksum matches -> resume download
            logger.info("Resuming previously interrupted download...")
            return stored_checksum
        else:
            # Case 2: Checksum does not match, i.e. dataset was updated -> start fresh
            logger.info(
                "Dataset has been updated since the previous download attempt. "
                "Starting fresh download..."
            )
            cleanup_partial_download(tmp_filepath, download_plan.checksum_filepath)
            return None

    # Case 3: .part exists but no .checksum: cannot safely resume -> start fresh
    if part_exists and not checksum_file_exists:
        logger.warning(
            "Partial download found without checksum file. Starting fresh download..."
        )
        cleanup_partial_download(tmp_filepath, download_plan.checksum_filepath)
        return None

    # Case 4: .checksum exists but no .part -> start fresh
    if checksum_file_exists and not part_exists:
        cleanup_partial_download(tmp_filepath, download_plan.checksum_filepath)
        return None

    # Case 5: Neither .checksum nor .part exist -> start fresh
    return None

execute_download_plan(download_plan, resume_download_checksum, show_progress)

Execute the download plan, downloading the dataset to a temporary path.

Parameters:

Name Type Description Default
download_plan DownloadPlan

The DownloadPlan object with download details.

required
resume_download_checksum str | None

Provide the checksum to resume a previously interrupted download.

required
show_progress bool

Whether to show a progress bar during download.

required

Raises:

Type Description
DownloadError

If the download fails or is interrupted.

Source code in src/datacollective/download.py
def execute_download_plan(
    download_plan: DownloadPlan,
    resume_download_checksum: str | None,
    show_progress: bool,
) -> None:
    """
    Execute the download plan, downloading the dataset to a temporary path.

    Args:
        download_plan: The DownloadPlan object with download details.
        resume_download_checksum: Provide the checksum to resume a previously interrupted download.
        show_progress: Whether to show a progress bar during download.

    Raises:
        DownloadError: If the download fails or is interrupted.
    """

    headers, downloaded_bytes_so_far = _prepare_download_headers(
        download_plan.tmp_filepath, resume_download_checksum
    )

    progress_bar = None
    session_downloaded_bytes = 0
    total_downloaded_bytes = downloaded_bytes_so_far
    logger.info(f"Downloading dataset: {download_plan.filename}")
    if show_progress:
        progress_bar = ProgressBar(download_plan.size_bytes)
        progress_bar.update(downloaded_bytes_so_far)
        progress_bar._display()
    try:
        with send_api_request(
            method="GET",
            url=download_plan.download_url,
            stream=True,
            timeout=HTTP_TIMEOUT,
            extra_headers=headers,
            include_auth_headers=False,  # Download URL is pre-signed, no auth needed
        ) as response:
            with open(download_plan.tmp_filepath, "ab") as f:
                # Iterate over response in 64KB chunks to avoid using too much memory
                for chunk in response.iter_content(chunk_size=1 << 16):
                    if not chunk:
                        continue
                    f.write(chunk)
                    downloaded_bytes_so_far = len(chunk)
                    session_downloaded_bytes += downloaded_bytes_so_far
                    total_downloaded_bytes += downloaded_bytes_so_far
                    if progress_bar:
                        progress_bar.update(downloaded_bytes_so_far)

            if progress_bar:
                progress_bar.finish()
    except (Exception, KeyboardInterrupt) as e:
        raise DownloadError(
            session_bytes=session_downloaded_bytes,
            total_downloaded_bytes=total_downloaded_bytes,
            total_archive_bytes=download_plan.size_bytes,
            checksum=download_plan.checksum,
        ) from e

get_download_plan(dataset_id, download_directory)

Send a POST request to the API to receive the download session details for a dataset.

Parameters:

Name Type Description Default
dataset_id str

The dataset ID (as shown in MDC platform).

required
download_directory str | None

Directory where to save the downloaded dataset. If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

required

Returns:

Type Description
DownloadPlan

a DownloadPlan containing:

DownloadPlan
  • a download session URL created by the API
DownloadPlan
  • the filename for the dataset archive defined by the API
DownloadPlan
  • the final target filepath on disk where the archive will be saved
DownloadPlan
  • a temporary path for atomic download
DownloadPlan
  • the size of the dataset archive in bytes
DownloadPlan
  • the checksum of the dataset

Raises:

Type Description
ValueError

If dataset_id is empty.

FileNotFoundError

If the dataset does not exist (404).

PermissionError

If access is denied (403) or download directory is not writable.

RuntimeError

If rate limit is exceeded (429) or unexpected response format.

HTTPError

For other non-2xx responses.

Source code in src/datacollective/download.py
def get_download_plan(dataset_id: str, download_directory: str | None) -> DownloadPlan:
    """
    Send a POST request to the API to receive the download session details for a dataset.

    Args:
        dataset_id: The dataset ID (as shown in MDC platform).
        download_directory: Directory where to save the downloaded dataset.
            If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

    Returns:
        a DownloadPlan containing:
        - a download session URL created by the API
        - the filename for the dataset archive defined by the API
        - the final target filepath on disk where the archive will be saved
        - a temporary path for atomic download
        - the size of the dataset archive in bytes
        - the checksum of the dataset

    Raises:
        ValueError: If dataset_id is empty.
        FileNotFoundError: If the dataset does not exist (404).
        PermissionError: If access is denied (403) or download directory is not writable.
        RuntimeError: If rate limit is exceeded (429) or unexpected response format.
        requests.HTTPError: For other non-2xx responses.
    """
    if not dataset_id or not dataset_id.strip():
        raise ValueError("`dataset_id` must be a non-empty string")

    base_dir = resolve_download_dir(download_directory)

    # Create a download session to get `downloadUrl` and `filename`
    session_url = f"{_get_api_url()}/datasets/{dataset_id}/download"
    resp = send_api_request(method="POST", url=session_url)

    payload: dict[str, Any] = dict(resp.json())
    download_url = payload.get("downloadUrl")
    filename = payload.get("filename")
    size_bytes = payload.get("sizeBytes")
    checksum = payload.get("checksum")

    if not download_url or not filename or not size_bytes:
        raise RuntimeError(f"Unexpected response format: {payload}")

    target_filepath = base_dir / filename

    # Stream download to a temporary file for atomicity
    tmp_filepath = target_filepath.with_name(target_filepath.name + ".part")

    checksum_filepath = _get_checksum_filepath(target_filepath)

    download_plan = DownloadPlan(
        download_url=download_url,
        filename=filename,
        target_filepath=target_filepath,
        tmp_filepath=tmp_filepath,
        size_bytes=int(size_bytes),
        checksum=checksum,
        checksum_filepath=checksum_filepath,
    )
    logger.debug(
        f"Download plan: filename={filename}, size={int(size_bytes)} bytes, target={target_filepath}",
    )
    return download_plan

resolve_download_dir(download_directory)

Resolve and ensure the download directory exists and is writable.

Parameters:

Name Type Description Default
download_directory str | None

User-specified download directory. If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

required

Returns:

Type Description
Path

The resolved Path object for the download directory.

Source code in src/datacollective/download.py
def resolve_download_dir(download_directory: str | None) -> Path:
    """
    Resolve and ensure the download directory exists and is writable.

    Args:
        download_directory (str | None): User-specified download directory.
            If None or empty, falls back to env MDC_DOWNLOAD_PATH or default.

    Returns:
        The resolved Path object for the download directory.
    """
    if download_directory and download_directory.strip():
        base = download_directory
    else:
        base = os.getenv(ENV_DOWNLOAD_PATH, "~/.mozdata/datasets")
    p = Path(os.path.expanduser(base))
    p.mkdir(parents=True, exist_ok=True)
    if not os.access(p, os.W_OK):
        raise PermissionError(f"Directory `{str(p)}` is not writable")
    logger.debug(f"Resolved download directory: {p}")
    return p

write_checksum_file(checksum_filepath, checksum)

Write the checksum to the .checksum file.

Source code in src/datacollective/download.py
def write_checksum_file(checksum_filepath: Path, checksum: str) -> None:
    """Write the checksum to the .checksum file."""
    checksum_filepath.write_text(checksum)

datacollective.api_utils

send_api_request(method, url, stream=False, extra_headers=None, timeout=HTTP_TIMEOUT, include_auth_headers=True)

Send an HTTP request to the MDC API with appropriate headers and error handling.

Parameters:

Name Type Description Default
method str

HTTP method (e.g., 'GET', 'POST').

required
url str

Full URL for the API endpoint.

required
stream bool

Whether to stream the response (default: False).

False
extra_headers dict[str, str] | None

Additional headers to include in the request (default: None). E.g. for resuming

None
timeout tuple[int, int] | None

A tuple specifying (connect timeout, read timeout) in seconds (default: None).

HTTP_TIMEOUT
include_auth_headers bool

Whether to include authentication (API KEY) headers (default: True).

True

Returns:

Type Description
Response

The HTTP response object.

Raises:

Type Description
FileNotFoundError

If the resource is not found (404).

PermissionError

If access is denied (403).

RuntimeError

If rate limit is exceeded (429).

ValueError

If API key is missing when authentication is required.

HTTPError

For other non-2xx responses.

Source code in src/datacollective/api_utils.py
def send_api_request(
    method: str,
    url: str,
    stream: bool = False,
    extra_headers: dict[str, str] | None = None,
    timeout: tuple[int, int] | None = HTTP_TIMEOUT,
    include_auth_headers: bool = True,
) -> requests.Response:
    """
    Send an HTTP request to the MDC API with appropriate headers and error handling.

    Args:
        method: HTTP method (e.g., 'GET', 'POST').
        url: Full URL for the API endpoint.
        stream: Whether to stream the response (default: False).
        extra_headers: Additional headers to include in the request (default: None). E.g. for resuming
        timeout: A tuple specifying (connect timeout, read timeout) in seconds (default: None).
        include_auth_headers: Whether to include authentication (API KEY) headers (default: True).

    Returns:
        The HTTP response object.

    Raises:
        FileNotFoundError: If the resource is not found (404).
        PermissionError: If access is denied (403).
        RuntimeError: If rate limit is exceeded (429).
        ValueError: If API key is missing when authentication is required.
        requests.HTTPError: For other non-2xx responses.
    """
    headers = {"User-Agent": _get_user_agent()}
    if include_auth_headers:
        headers.update(_auth_headers())
    if extra_headers:
        headers.update(extra_headers)

    logger.debug(f"API request: {method.upper()} {url} (stream={stream})")

    resp = requests.request(
        method=method.upper(),
        url=url,
        stream=stream,
        headers=headers,
        timeout=timeout,
    )

    if resp.status_code == 404:
        raise FileNotFoundError("Dataset not found")
    if resp.status_code == 403:
        raise PermissionError(
            "Access denied. If the dataset is public, make sure you have read thoroughly and agreed"
            " to the dataset's Terms & Conditions in its respective page on the MDC platform before downloading. "
        )
    if resp.status_code == 429:
        raise RuntimeError(RATE_LIMIT_ERROR)
    resp.raise_for_status()

    return resp

datacollective.schema

ColumnMapping

Bases: BaseModel

A single column mapping entry inside a schema.

Used by index-based tasks to describe how columns in the index file map to logical fields and their data types.

Source code in src/datacollective/schema.py
class ColumnMapping(BaseModel):
    """
    A single column mapping entry inside a schema.

    Used by index-based tasks to describe how columns in the
    index file map to logical fields and their data types.
    """

    model_config = ConfigDict(frozen=True)

    source_column: str | int = Field(
        description="column name (str) or positional index (int) for headerless files"
    )
    dtype: str = "string"
    optional: bool = False

ContentMapping

Bases: BaseModel

Describes how file contents / metadata map to DataFrame columns.

Used by glob-based tasks (e.g. LM) to specify how to extract text and metadata from files found via glob patterns. For example, the text content might come from the file contents, while metadata (e.g. language code) might come from the file name or parent directory.

Source code in src/datacollective/schema.py
class ContentMapping(BaseModel):
    """
    Describes how file contents / metadata map to DataFrame columns.

    Used by glob-based tasks (e.g. LM) to specify how to extract text and metadata
    from files found via glob patterns.  For example, the text content might come
    from the file contents, while metadata (e.g. language code) might come from
    the file name or parent directory.
    """

    model_config = ConfigDict(frozen=True)

    text: str | None = Field(default=None, description='e.g. "file_content"')
    meta_source: str | None = Field(default=None, description='e.g. "file_name"')

DatasetSchema

Bases: BaseModel

Task-agnostic representation of a dataset schema, as defined by a schema.yaml file.

Every schema must have dataset_id and task. The remaining fields depend on the task type and the root_strategy ("index" vs "glob").

New task types only need to populate the fields they care about; the loader registered for that task will decide which fields are required at load time.

Source code in src/datacollective/schema.py
class DatasetSchema(BaseModel):
    """
    Task-agnostic representation of a dataset schema, as defined by a ``schema.yaml`` file.

    Every schema **must** have ``dataset_id`` and ``task``.  The remaining
    fields depend on the task type and the ``root_strategy``
    (``"index"`` vs ``"glob"``).

    New task types only need to populate the fields they care about;
    the loader registered for that task will decide which fields are
    required at load time.
    """

    model_config = ConfigDict(frozen=False)

    dataset_id: str = Field(
        description="Unique identifier for the dataset in the registry"
    )
    task: str = Field(
        description="A task as defined in the MDC Platform e.g. ASR, TTS etc"
    )

    # --- Index-based strategy (ASR / TTS) ---
    format: str | None = Field(default=None, description='e.g. "csv", "tsv", "pipe"')
    index_file: str | None = Field(default=None, description='e.g. "train.csv"')
    base_audio_path: str | None = Field(default=None, description='e.g. "clips/"')
    columns: dict[str, ColumnMapping] = Field(
        default_factory=dict, description="Mapping of index columns to logical fields"
    )
    separator: str | None = Field(
        default=None, description='explicit separator override (e.g. "|")'
    )
    has_header: bool = Field(
        default=True, description="whether the index file has a header row"
    )
    encoding: str = Field(
        default="utf-8", description='file encoding (e.g. "utf-8-sig" for BOM)'
    )

    # --- Glob-based strategy (LM, paired-file TTS) ---
    root_strategy: str | None = Field(
        default=None, description='"glob" | "paired_glob" | "multi_split"'
    )
    file_pattern: str | None = Field(default=None, description='e.g. "**/*.txt"')
    audio_extension: str | None = Field(
        default=None, description='for paired-file TTS: e.g. ".webm"'
    )
    content_mapping: ContentMapping | None = Field(
        default=None, description="Mapping for glob-based content extraction"
    )

    # --- Multi-split strategy (e.g. Common Voice) ---
    splits: list[str] | None = Field(
        default=None, description='split names to load, e.g. ["train", "dev", "test"]'
    )
    splits_file_pattern: str | None = Field(
        default=None, description='glob pattern for split files, e.g. "**/*.tsv"'
    )

    # --- Schema versioning ---
    checksum: str | None = Field(
        default=None, description="archive checksum for cache validation"
    )

    # --- Catch-all for future / unknown keys ---
    extra: dict[str, Any] = Field(
        default_factory=dict, description="Catch-all for future / unknown keys"
    )

    def to_yaml_dict(self) -> dict[str, Any]:
        """
        Serialise the schema to a plain dict suitable for YAML output.

        Excludes fields that are at their default values so that the
        generated ``schema.yaml`` stays compact and readable.  The
        ``extra`` dict is merged into the top level.
        """
        data = self.model_dump(exclude_defaults=True, exclude={"extra"})
        # Merge extra keys into the top level
        if self.extra:
            data.update(self.extra)
        return data

to_yaml_dict()

Serialise the schema to a plain dict suitable for YAML output.

Excludes fields that are at their default values so that the generated schema.yaml stays compact and readable. The extra dict is merged into the top level.

Source code in src/datacollective/schema.py
def to_yaml_dict(self) -> dict[str, Any]:
    """
    Serialise the schema to a plain dict suitable for YAML output.

    Excludes fields that are at their default values so that the
    generated ``schema.yaml`` stays compact and readable.  The
    ``extra`` dict is merged into the top level.
    """
    data = self.model_dump(exclude_defaults=True, exclude={"extra"})
    # Merge extra keys into the top level
    if self.extra:
        data.update(self.extra)
    return data

get_dataset_schema(dataset_id)

Download and return the schema.yaml content for dataset_id.

Parameters:

Name Type Description Default
dataset_id str

The registry dataset ID (the folder name under /registry/).

required

Returns:

Type Description
DatasetSchema | None

A fully-populated DatasetSchema for the given dataset, or None if

DatasetSchema | None

the dataset is not found in the registry (HTTP 404).

Raises: RuntimeError For any other network / HTTP error.

Source code in src/datacollective/schema.py
def get_dataset_schema(dataset_id: str) -> DatasetSchema | None:
    """
    Download and return the schema.yaml content for *dataset_id*.

    Args:
        dataset_id: The registry dataset ID (the folder name under /registry/).

    Returns:
        A fully-populated `DatasetSchema` for the given dataset, or ``None`` if
        the dataset is not found in the registry (HTTP 404).
    Raises:
        RuntimeError
            For any other network / HTTP error.
    """

    url = f"{SCHEMA_REGISTRY_RAW_BASE_URL}/main/registry/{dataset_id}/schema.yaml"

    try:
        with urllib.request.urlopen(url) as response:
            raw = response.read().decode("utf-8")
        return parse_schema(raw)
    except urllib.error.HTTPError as exc:
        if exc.code == 404:
            return None
        raise RuntimeError(f"HTTP {exc.code} while fetching {url}") from exc
    except urllib.error.URLError as exc:
        raise RuntimeError(f"Network error while fetching {url}: {exc.reason}") from exc

parse_schema(raw)

Parse a schema from a YAML string, a dict, or a file path.

Parameters:

Name Type Description Default
raw str | dict[str, Any] | Path

YAML string, already-parsed dict, or Path to a YAML file.

required

Returns:

Type Description
DatasetSchema

A fully-populated DatasetSchema.

Raises:

Type Description
ValueError

If required fields are missing or the input cannot be parsed.

Source code in src/datacollective/schema.py
def parse_schema(raw: str | dict[str, Any] | Path) -> DatasetSchema:
    """
    Parse a schema from a YAML string, a dict, or a file path.

    Args:
        raw: YAML string, already-parsed dict, or ``Path`` to a YAML file.

    Returns:
        A fully-populated `DatasetSchema`.

    Raises:
        ValueError: If required fields are missing or the input cannot be parsed.
    """
    if isinstance(raw, Path):
        raw = raw.read_text(encoding="utf-8")
    if isinstance(raw, str):
        raw = yaml.safe_load(raw)
    if not isinstance(raw, dict):
        raise ValueError(f"Expected a dict after YAML parsing, got {type(raw)}")

    data: dict[str, Any] = raw

    dataset_id = data.get("dataset_id")
    task = data.get("task")
    if not dataset_id or not task:
        raise ValueError("schema.yaml must contain 'dataset_id' and 'task'")

    # Columns (index-based)
    columns: dict[str, ColumnMapping] = {}
    raw_columns = data.get("columns", {})
    if isinstance(raw_columns, dict):
        for col_name, col_def in raw_columns.items():
            if not isinstance(col_def, dict):
                continue
            columns[col_name] = ColumnMapping(
                source_column=col_def["source_column"],  # str or int
                dtype=col_def.get("dtype", "string"),
                optional=col_def.get("optional", False),
            )

    # Content mapping (glob-based)
    content_mapping: ContentMapping | None = None
    raw_cm = data.get("content_mapping")
    if isinstance(raw_cm, dict):
        content_mapping = ContentMapping(
            text=raw_cm.get("text"),
            meta_source=raw_cm.get("meta_source"),
        )

    # Recognised top-level keys
    known_keys = {
        "dataset_id",
        "task",
        "format",
        "index_file",
        "base_audio_path",
        "columns",
        "separator",
        "has_header",
        "encoding",
        "root_strategy",
        "file_pattern",
        "audio_extension",
        "content_mapping",
        "splits",
        "splits_file_pattern",
        "checksum",
    }
    extra = {k: v for k, v in data.items() if k not in known_keys}

    return DatasetSchema(
        dataset_id=str(dataset_id),
        task=str(task).upper(),
        format=data.get("format"),
        index_file=data.get("index_file"),
        base_audio_path=data.get("base_audio_path"),
        columns=columns,
        separator=data.get("separator"),
        has_header=data.get("has_header", True),
        encoding=data.get("encoding", "utf-8"),
        root_strategy=data.get("root_strategy"),
        file_pattern=data.get("file_pattern"),
        audio_extension=data.get("audio_extension"),
        content_mapping=content_mapping,
        splits=data.get("splits"),
        splits_file_pattern=data.get("splits_file_pattern"),
        checksum=data.get("checksum"),
        extra=extra,
    )

datacollective.schema_loaders.base

BaseSchemaLoader

Bases: ABC

Interface that every task-specific loader must implement.

Parameters:

Name Type Description Default
schema DatasetSchema

The parsed schema for the dataset.

required
extract_dir Path

The directory where the dataset files have been extracted.

required
Source code in src/datacollective/schema_loaders/base.py
class BaseSchemaLoader(abc.ABC):
    """
    Interface that every task-specific loader must implement.

    Args:
        schema (DatasetSchema): The parsed schema for the dataset.
        extract_dir (Path): The directory where the dataset files have been extracted.
    """

    def __init__(self, schema: DatasetSchema, extract_dir: Path) -> None:
        self.schema = schema
        self.extract_dir = extract_dir

    @abc.abstractmethod
    def load(self) -> pd.DataFrame:
        """Load the dataset into a pandas DataFrame according to ``self.schema``."""
        ...

    def _load_index_file(self) -> pd.DataFrame:
        """Locate the index file and read it into a raw `~pandas.DataFrame`.

        Resolves the separator from ``schema.separator`` (explicit override) or
        ``schema.format`` via `FORMAT_SEP`, then delegates the file
        lookup to `_resolve_index_file`.

        Used by all index-based loaders (ASR, TTS, ...) so that each loader
        only needs to call `_apply_column_mappings` on the result.

        Returns:
            A raw (unmapped) DataFrame exactly as read from the index file.
        """
        index_path = self._resolve_index_file()
        sep = self.schema.separator or FORMAT_SEP.get(self.schema.format or "", ",")
        header = "infer" if self.schema.has_header else None

        logger.debug(f"Reading index file: {index_path} (sep={sep!r})")
        return pd.read_csv(
            index_path, sep=sep, header=header, encoding=self.schema.encoding
        )

    def _resolve_index_file(self) -> Path:
        """Find the index file inside the extracted directory.

        The method searches recursively and returns the shallowest match.

        Used by index-based loaders.

        Raises:
            FileNotFoundError: If no matching file is found.
        """
        assert self.schema.index_file is not None
        candidates = list(self.extract_dir.rglob(self.schema.index_file))
        if not candidates:
            raise FileNotFoundError(
                f"Index file '{self.schema.index_file}' not found "
                f"under '{self.extract_dir}'"
            )
        # Prefer the shallowest match
        candidates.sort(key=lambda p: len(p.parts))
        return candidates[0]

    def _apply_column_mappings(self, raw_df: pd.DataFrame) -> pd.DataFrame:
        """Select and rename columns according to the schema, applying dtype conversions.

        Used by index-based loaders.

        Raises:
            KeyError: If a required column is not found in *raw_df*.
        """
        result_cols: dict[str, pd.Series] = {}

        for logical_name, col_map in self.schema.columns.items():
            source = col_map.source_column

            if source not in raw_df.columns:
                if col_map.optional:
                    logger.debug(f"Optional column '{source}' not found — skipping.")
                    continue
                raise KeyError(
                    f"Required column '{source}' not found in index file. "
                    f"Available columns: {list(raw_df.columns)}"
                )

            series = raw_df[source]

            if col_map.dtype == "file_path":
                base = self.schema.base_audio_path or ""
                series = series.apply(
                    lambda v, _b=base: str(self.extract_dir / _b / str(v))
                )
            elif col_map.dtype == "category":
                series = series.astype("category")
            elif col_map.dtype == "int":
                series = pd.to_numeric(series, errors="coerce").astype("Int64")
            elif col_map.dtype == "float":
                series = pd.to_numeric(series, errors="coerce")
            else:
                # default: treat as string
                series = series.astype(str)

            result_cols[logical_name] = series

        return pd.DataFrame(result_cols)

load() abstractmethod

Load the dataset into a pandas DataFrame according to self.schema.

Source code in src/datacollective/schema_loaders/base.py
@abc.abstractmethod
def load(self) -> pd.DataFrame:
    """Load the dataset into a pandas DataFrame according to ``self.schema``."""
    ...

Strategy

Bases: StrEnum

Loading strategies recognised by schema loaders.

Source code in src/datacollective/schema_loaders/base.py
class Strategy(StrEnum):
    """Loading strategies recognised by schema loaders."""

    MULTI_SPLIT = "multi_split"
    PAIRED_GLOB = "paired_glob"
    GLOB = "glob"

datacollective.schema_loaders.registry

get_task_loader(task)

Return the loader class for task.

Raises:

Type Description
ValueError

If no loader is registered for the given task.

Source code in src/datacollective/schema_loaders/registry.py
def get_task_loader(task: str) -> Type[BaseSchemaLoader]:
    """
    Return the loader class for *task*.

    Raises:
        ValueError: If no loader is registered for the given task.
    """
    key = task.upper()
    if key not in _TASK_REGISTRY:
        supported = ", ".join(sorted(_TASK_REGISTRY))
        raise ValueError(
            f"No schema loader registered for task '{key}'. "
            f"Supported tasks: {supported}"
        )
    return _TASK_REGISTRY[key]

load_dataset_from_schema(schema, extract_dir)

Instantiate the appropriate loader for schema.task and return the loaded ~pandas.DataFrame.

Parameters:

Name Type Description Default
schema DatasetSchema

Parsed dataset schema.

required
extract_dir Path

Root directory where the dataset archive was extracted.

required

Returns:

Type Description
DataFrame

A pandas DataFrame with the loaded dataset.

Source code in src/datacollective/schema_loaders/registry.py
def load_dataset_from_schema(schema: DatasetSchema, extract_dir: Path) -> pd.DataFrame:
    """
    Instantiate the appropriate loader for *schema.task* and return the
    loaded `~pandas.DataFrame`.

    Args:
        schema: Parsed dataset schema.
        extract_dir: Root directory where the dataset archive was extracted.

    Returns:
        A pandas DataFrame with the loaded dataset.
    """
    loader_cls = get_task_loader(schema.task)
    loader = loader_cls(schema=schema, extract_dir=extract_dir)
    logger.info(f"Loading dataset '{schema.dataset_id}' with {loader_cls.__name__}")
    return loader.load()

datacollective.schema_loaders.cache_schema

datacollective.schema_loaders.tasks.asr

ASRLoader

Bases: BaseSchemaLoader

Load an ASR dataset described by a DatasetSchema.

Source code in src/datacollective/schema_loaders/tasks/asr.py
class ASRLoader(BaseSchemaLoader):
    """Load an ASR dataset described by a `DatasetSchema`."""

    def __init__(self, schema: DatasetSchema, extract_dir: Path) -> None:
        super().__init__(schema, extract_dir)
        if schema.root_strategy == Strategy.MULTI_SPLIT:
            if not schema.splits:
                raise ValueError(
                    "ASR multi_split schema must specify 'splits' (list of split names)"
                )
        else:
            if not schema.index_file:
                raise ValueError("ASR schema must specify 'index_file'")
            if not schema.format:
                raise ValueError("ASR schema must specify 'format' (csv or tsv)")
            if not schema.columns:
                raise ValueError(
                    "ASR schema must specify at least two column mappings for audio and transcription"
                )

    def load(self) -> pd.DataFrame:
        if self.schema.root_strategy == Strategy.MULTI_SPLIT:
            return self._load_multi_split()
        raw_df = self._load_index_file()
        return self._apply_column_mappings(raw_df)

    def _load_multi_split(self) -> pd.DataFrame:
        """
        Load all split TSV/CSV files whose stems match the ``splits`` list,
        add a ``split`` column to each, apply column mappings, and concatenate.
        """
        assert self.schema.splits is not None

        pattern = self.schema.splits_file_pattern or "**/*.tsv"
        allowed_splits = set(self.schema.splits)

        split_files: dict[str, Path] = {}
        for path in self.extract_dir.rglob(pattern):
            if path.stem in allowed_splits:
                # Prefer the shallowest match per split name
                if path.stem not in split_files or len(path.parts) < len(
                    split_files[path.stem].parts
                ):
                    split_files[path.stem] = path

        if not split_files:
            raise RuntimeError(
                f"No split files matching pattern '{pattern}' with stems in "
                f"{sorted(allowed_splits)} found under '{self.extract_dir}'"
            )

        sep = self.schema.separator or FORMAT_SEP.get(self.schema.format or "tsv", "\t")
        frames: list[pd.DataFrame] = []

        for split_name, file_path in sorted(split_files.items()):
            logger.debug(f"Reading split '{split_name}' from {file_path}")
            raw_df = pd.read_csv(
                file_path, sep=sep, header="infer", encoding=self.schema.encoding
            )  # If the split files have headers, they will be inferred here and preserved in the raw_df columns. If not, raw_df will have default integer columns.
            raw_df["split"] = split_name

            if self.schema.columns:
                mapped = self._apply_column_mappings(raw_df)
                mapped["split"] = split_name
                frames.append(mapped)
            else:
                frames.append(raw_df)

        return pd.concat(frames, ignore_index=True)

datacollective.schema_loaders.tasks.tts

TTSLoader

Bases: BaseSchemaLoader

Load a TTS dataset described by a DatasetSchema.

See docs/loaders/tts.md for details on supported loading strategies and schema fields.

Source code in src/datacollective/schema_loaders/tasks/tts.py
class TTSLoader(BaseSchemaLoader):
    """Load a TTS dataset described by a `DatasetSchema`.

    See docs/loaders/tts.md for details on supported loading strategies and schema fields.
    """

    def __init__(self, schema: DatasetSchema, extract_dir: Path) -> None:
        super().__init__(schema, extract_dir)

    def load(self) -> pd.DataFrame:
        if self.schema.root_strategy == Strategy.PAIRED_GLOB:
            return self._load_paired_glob()
        return self._load_based_on_index()

    def _load_based_on_index(self) -> pd.DataFrame:
        """
        Load a TTS dataset using the "index" strategy, where an index file (e.g. CSV) maps audio paths to transcriptions.
        """
        if not self.schema.index_file:
            raise ValueError("TTS index-based schema must specify 'index_file'")
        if not self.schema.format and not self.schema.separator:
            raise ValueError(
                "TTS index-based schema must specify 'format' or 'separator'"
            )

        raw_df = self._load_index_file()

        if not self.schema.columns:
            # No column mapping -> return the raw dataframe as-is
            return raw_df

        return self._apply_column_mappings(raw_df)

    def _load_paired_glob(self) -> pd.DataFrame:
        """
        Load a TTS dataset using the "paired_glob" strategy, where each audio file has a
        matching `.txt` file containing the transcription. The loader searches
        recursively for all text files matching the specified `file_pattern`,
        reads their contents, and pairs them with the corresponding audio files based
        on the same filename stem. The parent directory name of each text/audio pair
        is captured as a `split` column in the resulting DataFrame.
        """
        if not self.schema.file_pattern:
            raise ValueError("TTS paired_glob schema must specify 'file_pattern'")
        if not self.schema.audio_extension:
            raise ValueError("TTS paired_glob schema must specify 'audio_extension'")

        text_files = sorted(self.extract_dir.rglob(self.schema.file_pattern))
        if not text_files:
            raise FileNotFoundError(
                f"No files matching '{self.schema.file_pattern}' "
                f"found under '{self.extract_dir}'"
            )

        logger.debug(
            f"Found {len(text_files)} text files matching '{self.schema.file_pattern}'"
        )

        audio_ext = self.schema.audio_extension
        rows: list[dict[str, str]] = []

        for txt_path in text_files:
            audio_path = txt_path.with_suffix(audio_ext)
            if not audio_path.exists():
                logger.debug(
                    f"No matching audio file for '{txt_path.name}' — skipping."
                )
                continue

            transcription = txt_path.read_text(encoding=self.schema.encoding).strip()
            row: dict[str, str] = {
                "audio_path": str(audio_path),
                "transcription": transcription,
            }

            # Derive domain / split from parent directory name if present
            parent_name = txt_path.parent.name
            if parent_name:
                row["split"] = parent_name

            rows.append(row)

        if not rows:
            raise FileNotFoundError(
                f"No paired (text + {audio_ext}) files found under '{self.extract_dir}'"
            )

        return pd.DataFrame(rows)