[FIXED] Keras can't save model with CuDNNLSTM as SavedModel

Issue

I have recently encountered a problem with Keras. My model looks like:

inputs = Input(shape=(max_sequence_len,))

# Embedding layer
embedding = Embedding(
        input_length=max_sequence_len,
        input_dim=len(word_idx),
        output_dim=100,
        weights=[embedding_matrix],
        trainable=False
)(inputs)

# Recurrent layers
heart = Bidirectional(CuDNNLSTM(256))(embedding)

dense = Dense((n_of_stocks * stock_size * 4), activation='relu')(heart)

# Fully connected layer
preoutput = []
outputs = []
for i in range(n_of_stocks):
    preoutput.append(Dense(stock_size * 4, activation='linear')(dense))
    outputs.append(Reshape((stock_size, 4))(preoutput[i]))

# Compile the model
model = Model(inputs=inputs, outputs=outputs, name="the_model")

model.summary()
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

model.save("mynetwork")

When I try to save model it fails with an error:

  Traceback (most recent call last):
  File "D:\Projects\Project T\neural\network.py", line 95, in <module>
    model.save("mynetwork")
  File "C:\Users\nkart\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\nkart\AppData\Local\Programs\Python\Python310\lib\site-packages\keras\layers\rnn\base_rnn.py", line 282, in _use_input_spec_as_call_signature
    if self.unroll:
AttributeError: 'CuDNNLSTM' object has no attribute 'unroll'

Am I doing something wrong? Should I try to save it as h5?

Solution

Don’t use CuDNNLSTM, just use LSTM (which is newer) with default parameters, it will automatically use CuDNN, assuming you have CuDNN properly installed. CuDNNLSTM is for Tensorflow <=2.0.

heart = Bidirectional(LSTM(256))(embedding)

You might need to use tensorflow.keras.layers instead of keras.layers.

Answered By – Djinn

Answer Checked By – Dawn Plyler (Easybugfix Volunteer)

Leave a Reply

(*) Required, Your email will not be published