Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Pytorch code translation to Keras code #552

Open
albertopolito opened this issue Dec 18, 2020 · 0 comments
Open

Pytorch code translation to Keras code #552

albertopolito opened this issue Dec 18, 2020 · 0 comments

Comments

@albertopolito
Copy link

albertopolito commented Dec 18, 2020

Goodmorning,
I'm a beginner and this is the first time that I use Keras to implement a neural network.
I would write the same network of this link with the same activation function and forward mechanism.
I see that there is a tool that convert ONNX models to Keras models, but it seems that doesn't work fine with this code.
So I would translate it manually from Pytorch to Keras.
I have some questions:
-how can I write in Keras the activation function:

  class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()   

-how can I write in Keras the membrane potential update mechanism:

def mem_update(ops, x, mem, spike):
   mem = mem * decay * (1. - spike) + ops(x)
   spike = act_fun(mem) # act_fun : approximation firing function
   return mem, spike

 class SCNN(nn.Module):

 ...

 def forward(self, input, time_window = 20):
       c1_mem = c1_spike = torch.zeros(batch_size, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0], device=device)
       c2_mem = c2_spike = torch.zeros(batch_size, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1], device=device)

       h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device)
       h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device)

       for step in range(time_window): # simulation time steps
           x = input > torch.rand(input.size(), device=device) # prob. firing

           c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)

           x = F.avg_pool2d(c1_spike, 2)

           c2_mem, c2_spike = mem_update(self.conv2,x, c2_mem,c2_spike)

           x = F.avg_pool2d(c2_spike, 2)
           x = x.view(batch_size, -1)

           h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
           h1_sumspike += h1_spike
           h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)
           h2_sumspike += h2_spike

       outputs = h2_sumspike / time_window
       return outputs

Thanks in advance for your time and your help.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant