I made an attempt to copy an implementation of the Mahalanobis Distance from the PyTorch library. I'm not sure it is right or if it is more complicated than it needs to be. I would like a working and simplified implementation of the Mahalanobis Distance that I can use and understand.
I know the formula for the Mahalanobis Distance doesn't look this complicated, so I'm wondering if this many reshapings, permutations, and solving triangulars is really necessary. Does it need to be this complicated because it is more than 2D? I don't know.
The shapes of both the mean and covariance tesnors are [20, 500, 17].
torch::Tensor Mahalanobis(torch::Tensor tril, torch::Tensor diff) {
auto n = diff.size(-1);
auto diff_batch_shape = std::vector<int>(diff.sizes().begin(), diff.sizes().end() -1);
auto diff_batch_dims = diff_batch_shape.size();
auto tril_batch_dims = tril.dim() - 2;
diff = diff.unsqueeze(0).unsqueeze(0);
auto flat_L = tril.reshape({-1, n, n});
auto flat_x = diff.reshape({-1, flat_L.size(0), n}); // shape = c x b x n
auto flat_x_swap = flat_x.permute({1, 2, 0}); // shape = b x n x c
auto M_swap = torch::linalg::solve_triangular(torch::nan_to_num(flat_L), torch::nan_to_num(flat_x_swap), \
false, true, false).pow(2).sum(-2);
auto M = M_swap.t(); // shape = c x b
auto permuted_M = M.reshape({1, 1, diff_batch_shape[0], diff_batch_shape[1]});
permuted_M = permuted_M.permute({0, 2, 1, 3});
auto reshaped_M = permuted_M.reshape({diff_batch_shape[0], diff_batch_shape[1]});
return torch::nan_to_num(reshaped_M);
}
I have asked in the PyTorch forums and on Stack Overflow. I haven't gotten an answer why they wrote the function like this and what most simplified function of tesnors might be.
Thank you.