Source code for trajdl.grid.hierarchy

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import pandas as pd
from rtree import index
from tqdm.contrib import tenumerate

from .. import trajdl_cpp
from ..tokenizers import SimpleTokenizer
from .base import BaseGridSystem, SimpleGridSystem


[docs] class HierarchyGridSystem(BaseGridSystem): def __init__( self, boundary: trajdl_cpp.RectangleBoundary, steps: List[Tuple[float, float]] ): super().__init__(boundary=boundary) self.steps = steps self.all_grids = self.recursive_split_regions() self.rtree = self.build_rtree() def __len__(self) -> int: return len(self.all_grids) def __getstate__(self): state = self.__dict__.copy() # TODO: rtree使用pkl序列化会有一些问题,按理来说可以使用rtree自身的序列化,后续可以集成这个功能 # 这里暂时先删除rtree,加载的时候进行重构 if "rtree" in state: del state["rtree"] return state def __setstate__(self, state): self.__dict__.update(state) self.rtree = self.build_rtree() @property def grid_id_cols(self) -> List[str]: return [f"grid_id_level_{level}" for level in range(len(self.steps))]
[docs] def recursive_split_regions(self) -> pd.DataFrame: all_grids = [] def recursive_split_region( boundary: trajdl_cpp.RectangleBoundary, parent_grid_ids: List[str], steps: List[Tuple[float, float]], depth: int = 0, ): if depth >= len(steps): parent_grid_ids.append(boundary.to_tuple()) all_grids.append(parent_grid_ids) return None grid_system = SimpleGridSystem( boundary=boundary, step_x=steps[depth][0], step_y=steps[depth][1] ) for sub_boundary, sub_grid_id in grid_system: new_parent_grid_ids = parent_grid_ids.copy() new_parent_grid_ids.append(f"{depth}-{sub_grid_id}") recursive_split_region( sub_boundary, new_parent_grid_ids, steps, depth + 1 ) recursive_split_region(self.boundary, [], self.steps) all_grids = pd.DataFrame( all_grids, columns=self.grid_id_cols + ["boundary"], ) all_grids["grid_id"] = all_grids[self.grid_id_cols].apply( lambda x: "-".join(x), axis=1 ) return all_grids
[docs] def build_rtree(self) -> index.Index: rtree = index.Index() for idx, line in tenumerate( self.all_grids.itertuples(), total=self.all_grids.shape[0], desc="construct rtree for hierarchy grid system", ): rtree.insert(idx, line.boundary) return rtree
[docs] def locate_unsafe(self, x: float, y: float) -> str: indices = self.rtree.intersection((x, y, x, y)) df = self.all_grids.loc[indices] cols = self.grid_id_cols for line in df.itertuples(): boundary = trajdl_cpp.RectangleBoundary.from_tuple(line.boundary) if boundary.in_boundary(x, y): return line.grid_id return None
[docs] def build_simple_tokenizer(self) -> SimpleTokenizer: vocab = dict(zip(self.all_grids["grid_id"], self.all_grids.index)) return SimpleTokenizer.build(init_vocab=vocab)