Sequence-to-sequence example in Keras (character-level).
This script demonstrates how to implement a basic character-level CNN sequence-to-sequence model. We apply it to translating short English sentences into short French sentences, character-by-character. Note that it is fairly unusual to do character-level machine translation, as word-level models are much more common in this domain. This example is for demonstration purposes only.
Summary of the algorithm
- We start with input sequences from a domain (e.g. English sentences) and corresponding target sequences from another domain (e.g. French sentences).
- An encoder CNN encodes the input character sequence.
- A decoder CNN is trained to turn the target sequences into
the same sequence but offset by one timestep in the future,
a training process called "teacher forcing" in this context.
It uses the output from the encoder.
Effectively, the decoder learns to generate
targets[t+1...]
giventargets[...t]
, conditioned on the input sequence. - In inference mode, when we want to decode unknown input sequences, we:
- Encode the input sequence.
- Start with a target sequence of size 1 (just the start-of-sequence character)
- Feed the input sequence and 1-char target sequence to the decoder to produce predictions for the next character
- Sample the next character using these predictions (we simply use argmax).
- Append the sampled character to the target sequence
- Repeat until we hit the character limit.
Data download
English to French sentence pairs.
Lots of neat sentence pairs datasets.
References
- lstm_seq2seq.py
- https://wanasit.github.io/attention-based-sequence-to-sequence-in-keras.html
from __future__ import print_function
import numpy as np
from keras.layers import Input, Convolution1D, Dot, Dense, Activation, Concatenate
from keras.models import Model
batch_size = 64 # Batch size for training.
epochs = 100 # Number of epochs to train for.
num_samples = 10000 # Number of samples to train on.
# Path to the data txt file on disk.
data_path = 'fra-eng/fra.txt'
# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
input_text, target_text = line.split('\t')
# We use "tab" as the "start sequence" character
# for the targets, and "\n" as "end sequence" character.
target_text = '\t' + target_text + '\n'
input_texts.append(input_text)
target_texts.append(target_text)
for char in input_text:
if char not in input_characters:
input_characters.add(char)
for char in target_text:
if char not in target_characters:
target_characters.add(char)
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)
input_token_index = dict(
[(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
[(char, i) for i, char in enumerate(target_characters)])
encoder_input_data = np.zeros(
(len(input_texts), max_encoder_seq_length, num_encoder_tokens),
dtype='float32')
decoder_input_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype='float32')
decoder_target_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype='float32')
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
for t, char in enumerate(input_text):
encoder_input_data[i, t, input_token_index[char]] = 1.
for t, char in enumerate(target_text):
# decoder_target_data is ahead of decoder_input_data by one timestep
decoder_input_data[i, t, target_token_index[char]] = 1.
if t > 0:
# decoder_target_data will be ahead by one timestep
# and will not include the start character.
decoder_target_data[i, t - 1, target_token_index[char]] = 1.
# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
# Encoder
x_encoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal')(encoder_inputs)
x_encoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal', dilation_rate=2)(x_encoder)
x_encoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal', dilation_rate=4)(x_encoder)
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# Decoder
x_decoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal')(decoder_inputs)
x_decoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal', dilation_rate=2)(x_decoder)
x_decoder = Convolution1D(256, kernel_size=3, activation='relu',
padding='causal', dilation_rate=4)(x_decoder)
# Attention
attention = Dot(axes=[2, 2])([x_decoder, x_encoder])
attention = Activation('softmax')(attention)
context = Dot(axes=[2, 1])([attention, x_encoder])
decoder_combined_context = Concatenate(axis=-1)([context, x_decoder])
decoder_outputs = Convolution1D(64, kernel_size=3, activation='relu',
padding='causal')(decoder_combined_context)
decoder_outputs = Convolution1D(64, kernel_size=3, activation='relu',
padding='causal')(decoder_outputs)
# Output
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()
# Run training
model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
# Save model
model.save('cnn_s2s.h5')
# Next: inference mode (sampling).
# Define sampling models
reverse_input_char_index = dict(
(i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
(i, char) for char, i in target_token_index.items())
nb_examples = 100
in_encoder = encoder_input_data[:nb_examples]
in_decoder = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype='float32')
in_decoder[:, 0, target_token_index["\t"]] = 1
predict = np.zeros(
(len(input_texts), max_decoder_seq_length),
dtype='float32')
for i in range(max_decoder_seq_length - 1):
predict = model.predict([in_encoder, in_decoder])
predict = predict.argmax(axis=-1)
predict_ = predict[:, i].ravel().tolist()
for j, x in enumerate(predict_):
in_decoder[j, i + 1, x] = 1
for seq_index in range(nb_examples):
# Take one sequence (part of the training set)
# for trying out decoding.
output_seq = predict[seq_index, :].ravel().tolist()
decoded = []
for x in output_seq:
if reverse_target_char_index[x] == "\n":
break
else:
decoded.append(reverse_target_char_index[x])
decoded_sentence = "".join(decoded)
print('-')
print('Input sentence:', input_texts[seq_index])
print('Decoded sentence:', decoded_sentence)