3

I have an array of time of arrivals and I want to convert it to count data using pytorch in a differentiable way.

Example arrival times:

arrival_times = [2.1, 2.9, 5.1]

and let's say the total range is 6 seconds. What I want to have is:

counts = [0, 0, 2, 2, 2, 3]

For this task, a non-differentiable way works perfect:

x = [1, 2, 3, 4,5,6]
counts = torch.sum(torch.Tensor(arrival_times)[:, None] < torch.Tensor(x), dim=0)

It turns out the < operation here is not differentiable. I need a differentiable approximation of this operation.

What I could think of is to subtract the x from arrival_times with broadcasting which leads the following array.

[
[1.1, 0.1, -0.9, -1.9, -2.9, -3.9]
[1.9, 0.9, -0.1, -1.1, -2.1, -3.1]
[4.1, 3.1, 2.1, 1.1, 0.1, -0.9]
]

And then somehow count the number of negative (and also zero preferably) elements vertically which will give us the counts [0, 0, 2, 2, 2, 3].

Is there a way to do this or completely new idea for such approximation?

  • How many elements should the result contain? Your first example will return 3 values whereas in your second example you are showing a list containing 6 values. – Oxbowerce Nov 08 '21 at 16:12
  • My first example also shows 6 count values. The number of elements that is going to be returned is the length of array `x`. – iRestMyCaseYourHonor Nov 08 '21 at 16:20
  • 1
    When running the code from your first example I get back a tensor containing only three values (assuming that `arrival_times` is also a tensor): `tensor([4, 4, 1])`. – Oxbowerce Nov 08 '21 at 16:26
  • You are right, I added the question. The `dim` in `Torch.sum` was not correct. – iRestMyCaseYourHonor Nov 09 '21 at 09:51

1 Answers1

0

I tried a hacky way to do this. Still open for suggestions.

        diffs = arrival_times[..., None] - torch.Tensor(x)
        zeros = torch.zeros_like(diffs)
        minimums = torch.minimum(diffs, zeros)
        eps, r_eps = 0.001, 1000
        epsilons = torch.ones_like(diffs) * -eps
        maximums = torch.maximum(minimums, epsilons) * -r_eps
        counts = torch.sum(maximums, dim=1)

The hackiness comes from multiplying by eps. In this way, any arrival time that is between 0 - 0.001 away from its closest second will be polluting the count. It's still differentiable but did not give good results in training for my case.