网站首页 > 技术文章 正文
使用Keras实现Generative Adversarial Network(GAN)模型来生成MNIST数字图像的步骤如下:
1)导入所需的库:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout
from keras.optimizers import Adam
2)加载MNIST数据集:
(X_train, _), (_, _) = mnist.load_data()
# 对数据做归一化和重新调整形状
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
3)定义生成器(Generator)和判别器(Discriminator)模型:
def build_generator():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=100))
model.add(Reshape((7, 7, 128)))
model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(Conv2DTranspose(1, kernel_size=4, strides=2, padding="same", activation="tanh"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
def build_discriminator():
model = Sequential()
model.add(Conv2D(64, kernel_size=4, strides=2, padding="same", input_shape=(28, 28, 1)))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
# 实例化生成器和判别器模型
generator = build_generator()
discriminator = build_discriminator()
4)定义GAN模型:
def build_gan(generator, discriminator):
discriminator.trainable = False
gan_input = Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return gan
# 实例化GAN模型
gan = build_gan(generator, discriminator)
5)训练GAN模型:
def train_gan(gan, generator, discriminator, X_train, epochs=50, batch_size=128, sample_interval=200):
for epoch in range(epochs):
# 随机选择一批真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# 生成一批噪声作为输入
noise = np.random.normal(0, 1, (batch_size, 100))
# 生成假图像
generated_images = generator.predict(noise)
# 训练判别器
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出训练过程中的损失
if epoch % sample_interval == 0:
print(f"Epoch {epoch}: discriminator loss = {discriminator_loss}, generator loss = {generator_loss}")
# 保存生成的图像
sample_images(generator, epoch)
# 定义保存生成图像的函数
def sample_images(generator, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig(f"images/mnist_{epoch}.png")
plt.close()
# 开始训练GAN模型
train_gan(gan, generator, discriminator, X_train)
通过以上步骤,你可以使用Keras实现一个简单的GAN模型来生成MNIST数字图像。训练过程中,生成器和判别器模型会相互竞争,生成器尝试生成接近真实图像的假图像,而判别器则尝试区分真实图像和假图像。随着训练的进行,生成器会逐渐学习生成逼真的图像。
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 LeNet-5 一个应用于图像分类问题的卷积神经网络
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 1508℃桌面软件开发新体验!用 Blazor Hybrid 打造简洁高效的视频处理工具
- 520℃Dify工具使用全场景:dify-sandbox沙盒的原理(源码篇·第2期)
- 490℃MySQL service启动脚本浅析(r12笔记第59天)
- 469℃服务器异常重启,导致mysql启动失败,问题解决过程记录
- 467℃启用MySQL查询缓存(mysql8.0查询缓存)
- 447℃「赵强老师」MySQL的闪回(赵强iso是哪个大学毕业的)
- 427℃mysql服务怎么启动和关闭?(mysql服务怎么启动和关闭)
- 424℃MySQL server PID file could not be found!失败
- 最近发表
- 标签列表
-
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- js判断是否是json字符串 (67)
- checkout-b (67)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)