trajdl.metrics.tul module#
This module contains accuracy (ACC), accuracy at K (ACC_K), and Macro F1 score in TUL tasks.
These three metrics are integrated into TULMetrics.
- class trajdl.metrics.tul.TULMetrics(num_users: int, topk: int = 5)[source]#
Bases:
objectClass for calculating accuracy metrics in TUL tasks.
- init_state() Dict[int, List[float]][source]#
Initialize the state for evaluation metrics.
- Returns:
A dictionary where keys are user indices and values are lists tracking prediction statistics for each user.
- Return type:
Dict[int, List[float]]
- property k: int#
Return the top-k evaluation value.
- update(preds: ndarray, targets: ndarray)[source]#
Update the metrics based on the predictions and actual targets.
- Parameters:
preds (np.ndarray) β Array of shape (B, num_users) containing predicted scores for each user.
targets (np.ndarray) β Array of shape (B,) containing the actual user indices (targets).