trajdl.metrics.tul module#

This module contains accuracy (ACC), accuracy at K (ACC_K), and Macro F1 score in TUL tasks.

These three metrics are integrated into TULMetrics.

class trajdl.metrics.tul.TULMetrics(num_users: int, topk: int = 5)[source]#

Bases: object

Class for calculating accuracy metrics in TUL tasks.

init_state() Dict[int, List[float]][source]#

Initialize the state for evaluation metrics.

Returns:

A dictionary where keys are user indices and values are lists tracking prediction statistics for each user.

Return type:

Dict[int, List[float]]

property k: int#

Return the top-k evaluation value.

reset()[source]#

Reset all evaluation metrics to their initial state.

update(preds: ndarray, targets: ndarray)[source]#

Update the metrics based on the predictions and actual targets.

Parameters:
  • preds (np.ndarray) – Array of shape (B, num_users) containing predicted scores for each user.

  • targets (np.ndarray) – Array of shape (B,) containing the actual user indices (targets).

value() Dict[str, float][source]#

Compute accuracy, accuracy at K, and Macro-F1 score.

Returns:

A dictionary containing the computed metrics:
  • ’acc’: Accuracy

  • ’acc_topk’: Accuracy at K

  • ’macro-f1’: Macro F1 score

Return type:

Dict[str, float]