def self_distillation_train(model, train_dataset, optimizer,
                    reg_coef=1e-4, epochs=30, teacher=None):
  for epoch in range(epochs):
    for iter, (x_batch_train, y_batch_train)
                                in enumerate(train_dataset):
      with tf.GradientTape() as tape:
        logits = model(x_batch_train, training=True)
        loss_value = self_distillation_loss(y_batch_train,logits,
                        model, reg_coef, teacher, x_batch_train)
      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(
                            zip(grads, model.trainable_weights))
  return model

distillation_steps = 10
teacher = None
for step in range(distillation_steps):
  model = get_resnet_model()
  optimizer = keras.optimizers.Adam(learning_rate=0.0001)
  model = self_distillation_train(model, train_dataset,
                          optimizer, reg_coef, epochs, teacher)
  teacher = model
