Source code for trajdl.metrics.tul

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains accuracy (ACC), accuracy at K (ACC_K), and Macro F1 score in TUL tasks.

These three metrics are integrated into TULMetrics.
"""

from typing import Dict, List

import numpy as np


[docs] class TULMetrics: """Class for calculating accuracy metrics in TUL tasks.""" def __init__(self, num_users: int, topk: int = 5): """ Initialize TULMetrics with the number of users and the evaluation parameter top-k. Parameters ---------- num_users : int The number of users in the TUL task. Must be greater than 1. topk : int, optional Evaluate on top-k predicted users (default is 5). Must be less than num_users. Raises ------ ValueError If `num_users` is less than or equal to 1 or less than or equal to `topk`. """ self.num_users = num_users self._k = topk if num_users <= 1: raise ValueError("`num_users` should be greater than 1.") if num_users <= topk: raise ValueError("`num_users` should be greater than `topk`.") self.num_samples = 0 self.acc = 0 self.acc_topk = 0 # eval_dict[user_idx][0] tracks correct top-1 predictions. # eval_dict[user_idx][1] tracks the number of times any user index is predicted as the top-1. # eval_dict[user_idx][2] tracks how many times each user index appears in the actual labels. # eval_dict[user_idx][3] stores the calculated Macro-F1 score for each user index. self.eval_dict = self.init_state() @property def k(self) -> int: """Return the top-k evaluation value.""" return self._k
[docs] def init_state(self) -> Dict[int, List[float]]: """ Initialize the state for evaluation metrics. Returns ------- Dict[int, List[float]] A dictionary where keys are user indices and values are lists tracking prediction statistics for each user. """ return {user_idx: [0, 0, 0, 0] for user_idx in range(self.num_users)}
[docs] def reset(self): """Reset all evaluation metrics to their initial state.""" self.num_samples = 0 self.acc = 0 self.acc_topk = 0 self.eval_dict = self.init_state()
[docs] def update(self, preds: np.ndarray, targets: np.ndarray): """ Update the metrics based on the predictions and actual targets. Parameters ---------- preds : np.ndarray Array of shape (B, num_users) containing predicted scores for each user. targets : np.ndarray Array of shape (B,) containing the actual user indices (targets). """ for idx in range(targets.shape[0]): self.num_samples += 1 vec = preds[idx] user_idx = targets[idx] self.eval_dict[user_idx][2] += 1 topk = np.argpartition(a=-vec, kth=self.k)[: self.k] top1 = np.argpartition(a=-vec, kth=1)[0] self.eval_dict[top1][1] += 1 for index in range(self.k): if topk[index] == user_idx: self.acc_topk += 1 break if top1 == user_idx: self.acc += 1 self.eval_dict[user_idx][0] += 1
[docs] def value(self) -> Dict[str, float]: """Compute accuracy, accuracy at K, and Macro-F1 score. Returns ------- Dict[str, float] A dictionary containing the computed metrics: - 'acc': Accuracy - 'acc_topk': Accuracy at K - 'macro-f1': Macro F1 score """ macro = 0 for i in self.eval_dict.keys(): if self.eval_dict[i][1] + self.eval_dict[i][2] > 0: self.eval_dict[i][3] = (2 * self.eval_dict[i][0]) / ( self.eval_dict[i][1] + self.eval_dict[i][2] ) macro += self.eval_dict[i][3] macro = macro * 100 / len(self.eval_dict) acc = self.acc * 100 / self.num_samples acc_topk = self.acc_topk * 100 / self.num_samples print( "\nacc1: {:.2f}%, acck: {:.2f}%, macro-f1: {:.2f}%".format( acc, acc_topk, macro ), flush=True, ) return {"acc": acc, "acc_topk": acc_topk, "macro-f1": macro}