Some Pytorch layers needed for MetNet
In a GRU cell the outputs and hidden are the same, last output must be equal to last hidden.
cgru_cell = ConvGRUCell(16, 32, 3)
cgru_cell(torch.rand(1, 16, 16, 16)).shape
Let's check:
cgru = ConvGRU(16, 32, (3, 3), 2)
cgru
layer_output, last_state_list = cgru(torch.rand(1,10,16,6,6))
layer_output.shape
last_state_list.shape
layer_output, last_state_list = cgru(torch.rand(1,10,16,6,6), last_state_list)