机器学习速成大法之断点续传篇

我们在下载文件时经常会用到断点续传,以至于下载暂停再开始时不用从头来一次。而机器学习里最耗时的操作无疑是训练,有的模型数据量大,要用高档GPU服务器训练好几天才能出结果。这中途如果因为断电或故障关机等原因造成训练中断,不能像下载中断断点续传一样的话,程序员估计会气的砸电脑。那么本篇我就带着大家来学习,如何将训练中的模型保存下来,并在需要继续训练时恢复出来使用。

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

首先直接用手写数字识别的代码,构建和编译模型,然后在开始训练之前设置模型检查点回调函数,当训练开始后,tensorflow会调用这个回调函数,实时将已训练的状态信息保存到这个检查点里。

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

我们到检查点路径里,可以看到保存文件的信息。

os.listdir(checkpoint_dir)
['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']

这是我们试试不恢复已训练状态,直接构建模型,不训练,来预测结果看看,正确率只有11.5%,非常的低。

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))

32/32 - 0s - loss: 2.3609 - sparse_categorical_accuracy: 0.1150
Untrained model, accuracy: 11.50%

现在我们试试将之前训练状态恢复,不进行再次训练,来预测结果看看,正确率86.4%,非常的高。这说明直接复用了上一次的训练成果,实现断点续传了。

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

32/32 - 0s - loss: 0.4329 - sparse_categorical_accuracy: 0.8640
Restored model, accuracy: 86.40%
机器学习速成大法之断点续传篇


上述代码是通过设置检查点checkpoint回调函数实现自动保存,我们也可以在需要时手动保存模型状态和加载已有模型的状态信息。代码如下:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

有时候我们想把训练好的模型发给别人使用,而别人并不知道我们构建模型所使用的代码怎么办呢?这时候就要求我们不仅仅是保存模型权重等状态信息了,而是要把整个模型保存下来,需要使用model.save()函数,保存的文件里就包含了模型构建、编译、训练的全部数据,别人拿去后可以直接加载使用。

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')

拿到一个完整的模型数据后,直接使用load_model加载出来,然后就可以用于预测结果了,省去了模型构建和编译过程。有了这个功能,才能让我们站在巨人的肩膀人,如果大家都把自己训练好的完整模型开源共享出来,那么全世界的数据中心机房将节省几亿度电。

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
展开阅读全文

页面更新:2024-06-17

标签:检查点   正确率   大法   函数   模型   加载   机器   状态   完整   代码   数据   信息

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2008-2024 All Rights Reserved. Powered By bs178.com 闽ICP备11008920号-3
闽公网安备35020302034844号

Top