# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Set, Tuple
import numpy as np
import torch
from ..utils import tiny_value_of_dtype
def _choice(num_words: int, num_samples: int, seed=None) -> Tuple[np.ndarray, int]:
"""
在vocab里面做不放回采样
Chooses `num_samples` samples without replacement from [0, ..., num_words).
Returns a tuple (samples, num_tries).
"""
num_tries = 0
num_chosen = 0
def get_buffer() -> np.ndarray:
r"""
实际上是$ \exp^{\alpha \log^{ \vert V \vert + 1 }} - 1$
$\alpha$是[0, 1)之间的均匀分布。
这个分布会使得采样时大部分样本集中在idx比较小的word上,idx比较大的word采样得到的概率很低
"""
log_samples = np.random.random(num_samples) * np.log(num_words + 1)
samples = np.exp(log_samples).astype("int64") - 1
return np.clip(samples, a_min=0, a_max=num_words - 1)
sample_buffer = get_buffer()
buffer_index = 0
samples: Set[int] = set()
# 做不放回采样,并且要保证一定要采集到num_samples个元素,并且要记录取了多少次才取出这些元素
while num_chosen < num_samples:
num_tries += 1
# choose sample
sample_id = sample_buffer[buffer_index]
if sample_id not in samples:
samples.add(sample_id)
num_chosen += 1
buffer_index += 1
if buffer_index == num_samples:
# Reset the buffer
sample_buffer = get_buffer()
buffer_index = 0
return np.array(list(samples)), num_tries
[docs]
class SampledSoftmaxLoss(torch.nn.Module):
"""
Based on the default log_uniform_candidate_sampler in tensorflow.
Parameters
----------
weights: torch.Tensor
bias: torch.Tensor
num_words: int
num_samples: int
reduction: str
device
use_sampled_softmax_in_eval: bool, optional
Whether or not using sampled softmax in eval mode, default False.
"""
def __init__(
self,
weights: torch.Tensor,
bias: torch.Tensor,
num_words: int,
num_samples: int,
reduction: str = "mean",
device=None,
use_sampled_softmax_in_eval: bool = False,
) -> None:
super().__init__()
assert num_samples < num_words
self.use_sampled_softmax_in_eval = use_sampled_softmax_in_eval
self.choice_func = _choice
assert len(weights.shape) == 2, "weights should be a 2-D tensor"
self.w = weights
# torch.nn.functional.embedding can only be applied on 2-D tensor
assert (
len(bias.shape) == 2 and bias.shape[1] == 1
), "bias should be a 2-D tensor, and the size of the second axis should be 1"
self.b = bias
self._num_samples = num_samples
self._num_words = num_words
self._reduction = reduction
self._device = device
self.eval_loss = torch.nn.CrossEntropyLoss(reduction="none")
self.initialize_num_words(self._num_words)
[docs]
def initialize_num_words(self, num_words: int):
self._log_num_words_p1 = np.log1p(num_words)
# compute the probability of each sampled id
words_seq = np.arange(num_words)
self._probs = (
np.log(words_seq + 2) - np.log(words_seq + 1)
) / self._log_num_words_p1
[docs]
def forward(
self,
embeddings: torch.Tensor,
targets: torch.LongTensor,
mask: torch.BoolTensor = None,
) -> torch.Tensor:
# embeddings is size (n, embedding_dim)
# targets is (n_words, ) with the index of the actual target
# when tieing weights, target_token_embedding is required.
# it is size (n_words, embedding_dim)
# returns log likelihood loss (batch_size, )
# Does not do any count normalization / divide by batch size
if embeddings.shape[0] == 0:
# empty batch
return torch.tensor(0.0).to(embeddings.device)
if self.training or self.use_sampled_softmax_in_eval:
batch_loss = self._forward_train(embeddings, targets, mask)
else:
batch_loss = self._forward_eval(embeddings, targets, mask)
if self._reduction == "mean":
return (
batch_loss.sum() / mask.sum() if mask is not None else batch_loss.sum()
)
elif self._reduction == "sum":
return batch_loss.sum()
else:
return batch_loss
def _forward_train(
self,
embeddings: torch.Tensor,
targets: torch.LongTensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
embeddings: shape is (B, C)
targets: shape is (B,)
masks: shape is (B,),表示当前的embedding是否需要纳入损失的计算,因为embeddings和targets里面可能存在padding
因为要实现负采样,所以如果一个样本不需要计算损失,那么负采样的部分也不应该计算,为了实现起来的方便,需要在计算损失的部分剔除掉这些元素
在当前实现方式里面,将这些元素的logits设置为-min,使得softmax之后的值为0,不贡献损失。
"""
# want to compute (n, n_samples + 1) array with the log
# probabilities where the first index is the true target
# and the remaining ones are the the negative samples.
# then we can just select the first column
(
sampled_ids,
target_expected_count,
sampled_expected_count,
) = self.log_uniform_candidate_sampler(targets, choice_func=self.choice_func)
# Get the softmax weights (so we can compute logits)
# shape is (B + num_samples)
all_ids = torch.cat([targets, sampled_ids], dim=0)
# (B + num_samples, C)
all_w = torch.nn.functional.embedding(all_ids, self.w)
all_b = torch.nn.functional.embedding(all_ids, self.b).squeeze(1)
batch_size = targets.shape[0]
true_w = all_w[:batch_size, :]
sampled_w = all_w[batch_size:, :]
true_b = all_b[:batch_size]
sampled_b = all_b[batch_size:]
# compute the logits and remove log expected counts
# shape is (batch_size, 1)
true_logits = (
(true_w * embeddings).sum(dim=1)
+ true_b
- torch.log(
target_expected_count + tiny_value_of_dtype(target_expected_count.dtype)
)
).unsqueeze(dim=1)
# [batch_size, n_samples]
sampled_logits = (
torch.matmul(embeddings, sampled_w.t())
+ sampled_b
- torch.log(
sampled_expected_count
+ tiny_value_of_dtype(sampled_expected_count.dtype)
)
)
# remove true labels -- we will take softmax, so set the sampled logits of true values to a large negative number
# 因为是针对一个batch的label,共用一组负样本,所以每个负样本都要去与targets里面进行对比,形成一个(batch_size, n_samples)的矩阵
true_in_sample_mask = sampled_ids == targets.unsqueeze(1)
# 将不小心采错的负样本的logits设置为最小值即可,这样会导致softmax运算之后,这一项基本为0
masked_sampled_logits = sampled_logits.masked_fill(
true_in_sample_mask, torch.finfo(sampled_logits.dtype).min
)
# now concat the true logits as index 0
# [batch_size, 1 + n_samples],第一列是正样本的logits,后面的列都是采样得到的
logits = torch.cat([true_logits, masked_sampled_logits], dim=1)
# finally take log_softmax
log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
# labels, shape is (batch_size, 1 + n_samples)
labels = torch.cat(
[
torch.ones_like(true_logits, device=self._device),
torch.zeros_like(masked_sampled_logits),
],
dim=1,
)
# (batch_size,), (batch_size, 1 + num_sampled)
batch_loss = -torch.sum(log_softmax * labels, dim=1)
if mask is not None:
batch_loss *= mask
return batch_loss
def _forward_eval(
self,
embeddings: torch.Tensor,
targets: torch.LongTensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
embeddings: shape is (batch_size, C)
targets: shape is (batch_size,)
mask: shape is (batch_size,)
"""
# evaluation mode, use full softmax
w = self.w
b = self.b
# (batch_size, num_classes)
logits = torch.matmul(embeddings, w.t()) + b.t()
# (batch_size,)
batch_loss = self.eval_loss(logits, targets)
if mask is not None:
batch_loss *= mask
return batch_loss