加入收藏 | 设为首页 | 会员中心 | 我要投稿 PHP编程网 - 湛江站长网 (https://www.0759zz.com/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 创业 > 正文

利用Keras中的权重约束减少深度神经网络中的过拟合

发布时间:2020-07-15 16:39:13 所属栏目:创业 来源:站长网
导读:副标题#e# 权重约束提供了一种方法,用于减少深度学习神经网络模型对训练数据的过度拟合,并改善模型对新数据(例如测试集)的性能。有多种类型的权重约束,例如最大和单位向量规范,有些需要必须配置的超参数。 在本教程中,您将发现Keras API,用于向深度

# define model  model = Sequential()  model.add(Dense(500, input_dim=2, activation='relu'))  model.add(Dense(1, activation='sigmoid'))  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 

然后,定义的模型拟合4,000个训练数据,默认批量大小为32。我们还将使用测试数据集作为验证数据集。

# fit model  history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=4000, verbose=0) 

我们可以在测试数据集上评估模型的性能并报告结果。

# evaluate the model  _, train_acc = model.evaluate(trainX, trainy, verbose=0)  _, test_acc = model.evaluate(testX, testy, verbose=0)  print('Train: %.3f, Test: %.3f' % (train_acc, test_acc)) 

最后,我们将在每个时期的训练集和测试集上绘制模型的性能。如果模型确实过度拟合训练数据集,我们将期望训练集上的准确度线图继续增加并且测试设置上升然后随着模型在训练数据集中学习统计噪声而再次下降。

# plot history  pyplot.plot(history.history['acc'], label='train')  pyplot.plot(history.history['val_acc'], label='test')  pyplot.legend()  pyplot.show() 

我们可以将所有这些部分组合在一起; 下面列出了完整的示例。

# mlp overfit on the moons dataset  from sklearn.datasets import make_moons  from keras.layers import Dense  from keras.models import Sequential  from matplotlib import pyplot  # generate 2d classification dataset  X, y = make_moons(n_samples=100, noise=0.2, random_state=1)  # split into train and test  n_train = 30  trainX, testX = X[:n_train, :], X[n_train:, :] trainy, testy = y[:n_train], y[n_train:]  # define model  model = Sequential()  model.add(Dense(500, input_dim=2, activation='relu'))  model.add(Dense(1, activation='sigmoid'))  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])  # fit model  history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=4000, verbose=0)  # evaluate the model  _, train_acc = model.evaluate(trainX, trainy, verbose=0)  _, test_acc = model.evaluate(testX, testy, verbose=0)  print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))  # plot history  pyplot.plot(history.history['acc'], label='train')  pyplot.plot(history.history['val_acc'], label='test')  pyplot.legend()  pyplot.show() 

(编辑:PHP编程网 - 湛江站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

热点阅读