## Issue

I am trying to understand the role of the `Flatten`

function in Keras. Below is my code, which is a simple two-layer network. It takes in 2-dimensional data of shape (3, 2), and outputs 1-dimensional data of shape (1, 4):

```
model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)
print y.shape
```

This prints out that `y`

has shape (1, 4). However, if I remove the `Flatten`

line, then it prints out that `y`

has shape (1, 3, 4).

I don’t understand this. From my understanding of neural networks, the `model.add(Dense(16, input_shape=(3, 2)))`

function is creating a hidden fully-connected layer, with 16 nodes. Each of these nodes is connected to each of the 3×2 input elements. Therefore, the 16 nodes at the output of this first layer are already “flat”. So, the output shape of the first layer should be (1, 16). Then, the second layer takes this as an input, and outputs data of shape (1, 4).

So if the output of the first layer is already “flat” and of shape (1, 16), why do I need to further flatten it?

## Solution

If you read the Keras documentation entry for `Dense`

, you will see that this call:

```
Dense(16, input_shape=(5,3))
```

would result in a `Dense`

network with 3 inputs and 16 outputs which would be applied independently for each of 5 steps. So, if `D(x)`

transforms 3 dimensional vector to 16-d vector, what you’ll get as output from your layer would be a sequence of vectors: `[D(x[0,:]), D(x[1,:]),..., D(x[4,:])]`

with shape `(5, 16)`

. In order to have the behavior you specify you may first `Flatten`

your input to a 15-d vector and then apply `Dense`

:

```
model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
```

**EDIT:**

As some people struggled to understand – here you have an explaining image:

Answered By – Marcin Możejko

Answer Checked By – Gilberto Lyons (Easybugfix Admin)