Featured Video Play Icon

Multi-output models in Deep Learning


We can use the functional API to create models with multiple outputs (or multiple heads). an easy example may be a network that attempts to simultaneously predict different properties of the infolike a network that takes as input a series of social media posts from one anonymous person and tries to predict attributes of that person, like age, gender, and income level.

Functional API implementation of a three-output model

from keras import layers
from keras import Input
from keras.models import Model
vocabulary_size = 50000
num_income_groups = 10
posts_input = Input(shape=(None,), dtype=’int32′, name=’posts’)
embedded_posts = layers.Embedding(256, vocabulary_size)(posts_input)
x = layers.Conv1D(128, 5, activation=’relu’)(embedded_posts)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation=’relu’)(x)
x = layers.Conv1D(256, 5, activation=’relu’)(x)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(256, 5, activation=’relu’)(x)
x = layers.Conv1D(256, 5, activation=’relu’)(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(128, activation=’relu’)(x)
Note that the output layers are given names.
age_prediction = layers.Dense(1, name=’age’)(x)
income_prediction = layers.Dense(num_income_groups,
gender_prediction = layers.Dense(1, activation=’sigmoid’, name=’gender’)(x)
model = Model(posts_input,
[age_prediction, income_prediction, gender_prediction])
Importantly, training such a model requires the power to specify different loss functions for various heads of the network: as an example, age prediction may be a scalar regression task, but gender prediction may be a binary classification task, requiring special training procedure. But because gradient descent requires us to attenuate a scalar, we must combine these losses into one value so as to coach the model. the only thanks to combine different losses is to sum all of them. In Keras, can we use either an inventory or a dictionary of losses in compile to specify different objects for various outputs; the resulting loss values are summed into a worldwide loss, which is minimized during training.

Compilation options of a multi-output model: multiple losses

loss=[‘mse’, ‘categorical_crossentropy’, ‘binary_crossentropy’])
loss={‘age’: ‘mse’,
‘income’: ‘categorical_crossentropy’,
‘gender’: ‘binary_crossentropy’})
Equivalent (possible
only if we give names to the output layers)
Note that very imbalanced loss contributions will cause the model representations to be optimized preferentially for the task with the most important individual loss, at the expense of the opposite tasks. To remedy this, we will assign different levels of importance to the loss values in their contribution to the ultimate loss. this is often useful especially if the losses’ values use different scales. as an example, the mean squared error ( MSE ) loss used for the age-regression task typically takes a worth around 3–5, whereas the cross-entropy loss used for the gender-classification task is often as low as 0.1. In such a situation, to balance the contribution of the various losses, we‘ll assign a weight of 10 to the cross-entropy loss and a weight of 0.25 to the MSE loss.

Compilation options of a multi-output model: loss weighting

loss=[‘mse’, ‘categorical_crossentropy’, ‘binary_crossentropy’],
loss_weights=[0.25, 1., 10.])
loss={‘age’: ‘mse’,
‘income’: ‘categorical_crossentropy’,
‘gender’: ‘binary_crossentropy’},
loss_weights={‘age’: 0.25,
‘income’: 1.,
‘gender’: 10.})
Equivalent (possible only if we give names to the output layers)
Much as within the case of multi-input models, we‘ll pass Numpy data to the model for training either via an inventory of arrays or via a dictionary of arrays.

Feeding data to a multi-output model

model.fit(posts, [age_targets, income_targets, gender_targets],
epochs=10, batch_size=64)
model.fit(posts, {‘age’: age_targets,
‘income’: income_targets,
‘gender’: gender_targets},
epochs=10, batch_size=64)
Equivalent (possible as long as we give names to the output layers)
age_targets, income_targets, and
gender_targets are assumed to be
Numpy arrays.

Mansoor Ahmed is Chemical Engineer, web developer, a writer currently living in Pakistan. My interests range from technology to web development. I am also interested in programming, writing, and reading.
Posts created 422

Related Posts

Begin typing your search term above and press enter to search. Press ESC to cancel.

Back To Top