trajdl.datasets.open_source.conf module#

class trajdl.datasets.open_source.conf.GowallaDataset(dataset_name: str = 'gowalla', size: int = 105470044, url: str = 'https://snap.stanford.edu/data/loc-gowalla_totalCheckins.txt.gz', sha256_original: str = 'c1c3e19effba649b6c89aeab3c1f9459fad88cfdc2b460fc70fd54e295d83ea0', mmh3_cache: str = '8a2eb882146b2ab51774b4bf8b1432dc')[source]#

Bases: OpenSourceDataset

Gowalla Dataset class extending OpenSourceDataset with specific parameters.

cache() None[source]#

Cache the Gowalla dataset by decompressing and storing it as Parquet.

dataset_name: str = 'gowalla'#
mmh3_cache: str = '8a2eb882146b2ab51774b4bf8b1432dc'#
sha256_original: str = 'c1c3e19effba649b6c89aeab3c1f9459fad88cfdc2b460fc70fd54e295d83ea0'#
size: int = 105470044#
url: str = 'https://snap.stanford.edu/data/loc-gowalla_totalCheckins.txt.gz'#
class trajdl.datasets.open_source.conf.OpenSourceDataset(dataset_name: str, size: int, url: str | None, sha256_original: str, mmh3_cache: str)[source]#

Bases: ABC

Base class for open-source datasets.

Provides functionality for downloading, validating, and loading datasets.

dataset_name#

The name of the dataset.

Type:

str

size#

The expected size of the dataset.

Type:

int

url#

The URL to download the dataset from.

Type:

Optional[str]

sha256_original#

The expected SHA-256 hash of the original dataset.

Type:

str

mmh3_cache#

The expected MMH3 hash of the cached dataset.

Type:

str

check_original_dataset(path: str) bool[source]#

Validates the downloaded dataset.

download(chunk_size: int = 8192) str[source]#

Downloads the dataset if it does not exist or is invalid.

load_cache(table: pa.Table, return_as: str = 'pl') pl.DataFrame | pd.DataFrame | pa.Table[source]#

Loads the cached dataset in the specified format.

load(return_as: str = 'pl', chunk_size: int = 8192) pl.DataFrame | pd.DataFrame | pa.Table[source]#

Loads the dataset, attempting to use the cache first.

cache()[source]#

Abstract method to be implemented by subclasses for caching behavior.

set_path(path: str) None[source]#

Set the original dataset path.

abstract cache() None[source]#

Cache the dataset.

Implement this method to define how the dataset is cached, typically by extracting the downloaded file and saving it as a Parquet file.

property cache_path: str#

Return the path to the cached Parquet file.

check_original_dataset(path: str) bool[source]#

Check if the original dataset is valid.

Parameters:

path (str) – The path to the dataset to validate.

Returns:

True if the dataset is valid, False otherwise.

Return type:

bool

dataset_name: str#
download(chunk_size: int = 8192) str[source]#

Download the dataset.

The dataset is downloaded to the CACHE_DATASET_DIR. The path is determined by the dataset name.

Parameters:

chunk_size (int, optional) – The size of each download chunk (default is 8192).

Returns:

The path to the downloaded dataset.

Return type:

str

load(return_as: str = 'pl', chunk_size: int = 8192, original_dataset_path: str | None = None, unsafe: bool = False) DataFrame | DataFrame | Table[source]#

Load the dataset, checking the cache first.

If no valid cached data exists, it attempts to load the original dataset from original_dataset_path or download the original dataset, cache it, and then load the data.

Parameters:
  • return_as (str, optional) – The format to return (default is ‘pl’).

  • chunk_size (int, optional) – The size of each loading chunk (default is 8192).

  • original_dataset_path (Optional[str], optional) – The path of the original dataset downloaded by user (default is None).

  • unsafe (bool, optional) – Do not check mmh3 of cache dataset (default is False).

Returns:

The loaded dataset in the specified format.

Return type:

Union[pl.DataFrame, pd.DataFrame, pa.Table]

load_cache(table: Table, return_as: str = 'pl') DataFrame | DataFrame | Table[source]#

Load the cache in the specified format.

Parameters:
  • table (pa.Table) – The Arrow table to load.

  • return_as (str, optional) – The format to return the table in (‘pl’ for Polars, ‘pd’ for Pandas, or ‘pa’ for PyArrow).

Returns:

The loaded table in the specified format.

Return type:

Union[pl.DataFrame, pd.DataFrame, pa.Table]

Raises:

ValueError – If the specified return format is unsupported.

mmh3_cache: str#
property path: str#

Return the path to the downloaded dataset file.

set_path(path: str) None[source]#

Set the path of the original dataset

If user specifies the path of the original dataset, this path will be saved into this dataset object.

Parameters:

path (str) – The path of the original dataset, like ~/Downloads/loc-gowalla_totalCheckins.txt.gz

set_url(url: str) None[source]#

Set the download url of the original dataset

If user specifies a url for a dataset, the dataset will be downloaded from that link.

Parameters:

url (str) – The download link of the original dataset.

sha256_original: str#
size: int#
url: str | None#
class trajdl.datasets.open_source.conf.PortoDataset(dataset_name: str = 'porto', size: int = 534065916, url: str = 'http://localhost:7077/taxi%2Bservice%2Btrajectory%2Bprediction%2Bchallenge%2Becml%2Bpkdd%2B2015.zip', sha256_original: str = 'a33e2a5e145607ae2bad0db5d21b7548c88b7e0f9db1ce15839f24c4c61f8c76', mmh3_cache: str = '5ef83abb3cf649583f28b80f16e1f4a7')[source]#

Bases: OpenSourceDataset

Porto Dataset class extending OpenSourceDataset with specific parameters.

cache() None[source]#

Cache the Porto dataset by decompressing and storing it as Parquet.

dataset_name: str = 'porto'#
mmh3_cache: str = '5ef83abb3cf649583f28b80f16e1f4a7'#
sha256_original: str = 'a33e2a5e145607ae2bad0db5d21b7548c88b7e0f9db1ce15839f24c4c61f8c76'#
size: int = 534065916#
url: str = 'http://localhost:7077/taxi%2Bservice%2Btrajectory%2Bprediction%2Bchallenge%2Becml%2Bpkdd%2B2015.zip'#