Source code for trajdl.datasets.open_source.hasher
# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import Iterable
import mmh3
import pyarrow as pa
import pyarrow.parquet as pq
[docs]
class Hasher:
def __init__(self, hasher_type: str):
"""Initialize the Hasher with a specified hashing algorithm.
Parameters
----------
hasher_type : str
The type of hasher to use. Options are "sha256" and "mmh3".
Raises
------
ValueError
If the specified hasher_type is not supported.
"""
self.hasher_type = hasher_type
self.init_hasher()
[docs]
def init_hasher(self):
"""Initialize the hashing function based on the hasher_type."""
if self.hasher_type == "sha256":
self.hasher = hashlib.sha256()
elif self.hasher_type == "mmh3":
self.hasher = mmh3.mmh3_x64_128()
else:
raise ValueError(
f"hasher_type should be one of {'sha256', 'mmh3'}, not '{self.hasher_type}'"
)
def _get_hash(self, iterable: Iterable[bytes]) -> str:
"""Compute the hash for a given iterable of byte data.
Parameters
----------
iterable : Iterable[bytes]
An iterable yielding byte objects to hash.
Returns
-------
str
The hexadecimal representation of the hash value.
"""
for b in iterable:
self.hasher.update(b)
result = self.hasher.digest().hex()
# Re-initialize the hasher for future use
self.init_hasher()
return result
[docs]
def digest_file(self, path: str) -> str:
"""Digest a file and produce its hash.
Parameters
----------
path : str
The path to the file to be hashed.
Returns
-------
str
The hexadecimal representation of the file hash.
"""
with open(path, "rb") as f:
return self._get_hash(iter(lambda: f.read(1 << 20), b""))
[docs]
def digest_arrow(self, table: pa.Table, max_chunksize: int = 8192) -> str:
"""Digest a PyArrow Table and produce its hash.
Parameters
----------
table : pa.Table
The PyArrow table to be hashed.
max_chunksize : int, optional
The maximum size of each chunk for processing (default is 8192).
Returns
-------
str
The hexadecimal representation of the table hash.
"""
return self._get_hash(
batch.serialize().hex()
for batch in table.to_batches(max_chunksize=max_chunksize)
)
[docs]
def digest_parquet(self, path: str, max_chunksize: int = 8192) -> str:
"""Digest a Parquet file and produce its hash.
Parameters
----------
path : str
The path to the Parquet file to be hashed.
max_chunksize : int, optional
The maximum size of each chunk for processing (default is 8192).
Returns
-------
str
The hexadecimal representation of the Parquet file hash.
"""
table = pq.read_table(path)
return self.digest_arrow(table=table, max_chunksize=max_chunksize)