Source code for trajdl.datasets.open_source.utils
# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gzip
import os
import shutil
import zipfile
import requests
from tqdm import tqdm
[docs]
def decompress_gz(gz_path: str, output_path: str) -> None:
"""Decompress a .gz file.
Parameters
----------
gz_path : str
The path to the .gz file.
output_path : str
The path where the decompressed file will be saved.
"""
with gzip.open(gz_path, "rb") as f_in:
with open(output_path, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
[docs]
def unzip_file(zip_file_path: str, output_folder: str) -> None:
"""Unzip a zip file to a specified folder.
Parameters
----------
zip_file_path : str
The path to the zip file.
output_folder : str
The directory where the zip file will be extracted.
"""
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(output_folder)
[docs]
def download_file(url: str, path: str, chunk_size: int = 8192) -> None:
"""Download a file from a URL.
Parameters
----------
url : str
The URL to download the file from.
path : str
The path where the downloaded file will be saved.
chunk_size : int, optional
The size of each download chunk (default is 8192).
Raises
------
RuntimeError
If the download fails for any reason.
"""
with requests.get(url, stream=True) as response:
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
with open(path, "wb") as f:
for chunk in tqdm(
response.iter_content(chunk_size=chunk_size),
total=total_size // chunk_size,
unit="chunks",
desc="Downloading dataset...",
):
f.write(chunk)
else:
raise RuntimeError("Downloading dataset failed! Check your network.")
[docs]
def remove_path(path):
"""Remove the specified file or directory, making it non-existent.
Parameters
----------
path : str
The path to the file or directory to be deleted.
Returns
-------
None
Raises
------
Exception
If an error occurs while deleting the path.
Notes
-----
- If the specified path does not exist, a message indicating that will be printed.
- If the path is a directory, it will be removed recursively.
"""
try:
if os.path.isfile(path):
os.remove(path)
print(f"File '{path}' has been removed.")
elif os.path.isdir(path):
shutil.rmtree(path)
print(f"Directory '{path}' has been removed.")
else:
print(f"Path '{path}' does not exist.")
except Exception as e:
print(f"Error while deleting '{path}': {e}")