Open In Colab

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)
25000 Training sequences
25000 Validation sequences
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = student
        self.student = teacher

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
# Create the teacher
# Input for variable-length sequences of integers
inputs = keras.Input(shape=(maxlen,), dtype="int32")
# Embed each integer in a 128-dimensional vector
x = layers.Embedding(vocab_size, 128)(inputs)
# Add 2 bidirectional LSTMs and GRUs
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
x = layers.Bidirectional(layers.GRU(128))(x)
# Add a classifier
outputs = layers.Dense(2)(x)
teacher = keras.Model(inputs, outputs, name='teacher')
teacher.summary()


# Create the student
# Input for variable-length sequences of integers
inputs = keras.Input(shape=(maxlen,), dtype="int32")
# Embed each integer in a 128-dimensional vector
x = layers.Embedding(vocab_size, 64)(inputs)
# Add 2 bidirectional LSTMs and GRUs
x = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(x)
x = layers.Bidirectional(layers.GRU(32))(x)
# Add a classifier
outputs = layers.Dense(2)(x)
student = keras.Model(inputs, outputs, name='student')
student.summary()

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
Model: "teacher"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 200)]             0         
_________________________________________________________________
embedding (Embedding)        (None, 200, 128)          2560000   
_________________________________________________________________
bidirectional (Bidirectional (None, 200, 128)          98816     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 256)               198144    
_________________________________________________________________
dense (Dense)                (None, 2)                 514       
=================================================================
Total params: 2,857,474
Trainable params: 2,857,474
Non-trainable params: 0
_________________________________________________________________
Model: "student"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 200)]             0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 200, 64)           1280000   
_________________________________________________________________
bidirectional_2 (Bidirection (None, 200, 64)           24832     
_________________________________________________________________
bidirectional_3 (Bidirection (None, 64)                18816     
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 130       
=================================================================
Total params: 1,323,778
Trainable params: 1,323,778
Non-trainable params: 0
_________________________________________________________________
# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=3)
teacher.evaluate(x_val, y_val)
Epoch 1/3
782/782 [==============================] - 43s 55ms/step - loss: 0.3714 - sparse_categorical_accuracy: 0.8351
Epoch 2/3
782/782 [==============================] - 43s 55ms/step - loss: 0.2004 - sparse_categorical_accuracy: 0.9232
Epoch 3/3
782/782 [==============================] - 42s 54ms/step - loss: 0.1297 - sparse_categorical_accuracy: 0.9519
782/782 [==============================] - 12s 15ms/step - loss: 0.4922 - sparse_categorical_accuracy: 0.8590





[0.49218621850013733, 0.8590400218963623]
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)

# Evaluate student on test dataset
distiller.evaluate(x_val, y_val)
Epoch 1/3
782/782 [==============================] - 53s 68ms/step - sparse_categorical_accuracy: 0.9664 - student_loss: 0.1382 - distillation_loss: 0.0084
Epoch 2/3
782/782 [==============================] - 54s 69ms/step - sparse_categorical_accuracy: 0.9781 - student_loss: 0.1119 - distillation_loss: 0.0085
Epoch 3/3
782/782 [==============================] - 54s 69ms/step - sparse_categorical_accuracy: 0.9868 - student_loss: 0.0937 - distillation_loss: 0.0088
782/782 [==============================] - 12s 15ms/step - sparse_categorical_accuracy: 0.8476 - student_loss: 0.3908





0.8475599884986877
# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)

Epoch 1/3
782/782 [==============================] - 31s 39ms/step - loss: 0.3796 - sparse_categorical_accuracy: 0.8246
Epoch 2/3
782/782 [==============================] - 30s 39ms/step - loss: 0.1918 - sparse_categorical_accuracy: 0.9283
Epoch 3/3
782/782 [==============================] - 30s 39ms/step - loss: 0.1123 - sparse_categorical_accuracy: 0.9602



---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-7-2e22fcfd29d9> in <module>()
      8 # Train and evaluate student trained from scratch.
      9 student_scratch.fit(x_train, y_train, epochs=3)
---> 10 student_scratch.evaluate(x_test, y_test)


NameError: name 'x_test' is not defined
student_scratch.evaluate(x_val, y_val)
782/782 [==============================] - 12s 15ms/step - loss: 0.3720 - sparse_categorical_accuracy: 0.8596





[0.37198370695114136, 0.8595600128173828]