网站首页 > 技术文章 正文
使用Keras实现Generative Adversarial Network(GAN)模型来生成MNIST数字图像的步骤如下:
1)导入所需的库:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout
from keras.optimizers import Adam
2)加载MNIST数据集:
(X_train, _), (_, _) = mnist.load_data()
# 对数据做归一化和重新调整形状
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
3)定义生成器(Generator)和判别器(Discriminator)模型:
def build_generator():
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=100))
model.add(Reshape((7, 7, 128)))
model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(Conv2DTranspose(1, kernel_size=4, strides=2, padding="same", activation="tanh"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
def build_discriminator():
model = Sequential()
model.add(Conv2D(64, kernel_size=4, strides=2, padding="same", input_shape=(28, 28, 1)))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return model
# 实例化生成器和判别器模型
generator = build_generator()
discriminator = build_discriminator()
4)定义GAN模型:
def build_gan(generator, discriminator):
discriminator.trainable = False
gan_input = Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
return gan
# 实例化GAN模型
gan = build_gan(generator, discriminator)
5)训练GAN模型:
def train_gan(gan, generator, discriminator, X_train, epochs=50, batch_size=128, sample_interval=200):
for epoch in range(epochs):
# 随机选择一批真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# 生成一批噪声作为输入
noise = np.random.normal(0, 1, (batch_size, 100))
# 生成假图像
generated_images = generator.predict(noise)
# 训练判别器
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出训练过程中的损失
if epoch % sample_interval == 0:
print(f"Epoch {epoch}: discriminator loss = {discriminator_loss}, generator loss = {generator_loss}")
# 保存生成的图像
sample_images(generator, epoch)
# 定义保存生成图像的函数
def sample_images(generator, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig(f"images/mnist_{epoch}.png")
plt.close()
# 开始训练GAN模型
train_gan(gan, generator, discriminator, X_train)
通过以上步骤,你可以使用Keras实现一个简单的GAN模型来生成MNIST数字图像。训练过程中,生成器和判别器模型会相互竞争,生成器尝试生成接近真实图像的假图像,而判别器则尝试区分真实图像和假图像。随着训练的进行,生成器会逐渐学习生成逼真的图像。
猜你喜欢
- 2024-11-23 太强了,竟然可以根据指纹图像预测性别
- 2024-11-23 深度残差网络+自适应参数化ReLU(调参记录24)Cifar10~95.80%
- 2024-11-23 从零开始构建:使用CNN和TensorFlow进行人脸特征检测
- 2024-11-23 每个ML从业人员都必须知道的10个TensorFlow技巧
- 2024-11-23 基于OpencvCV的情绪检测
- 2024-11-23 LeNet-5 一个应用于图像分类问题的卷积神经网络
- 2024-11-23 使用TensorBoard进行超参数优化
- 2024-11-23 如何实现CNN特征层可视化?终于懂了....
- 2024-11-23 计算卷积神经网络参数总数和输出形状
- 2024-11-23 使用卷积神经网络和 Python 进行图像分类
- 08-06中等生如何学好初二数学函数篇
- 08-06C#构造函数
- 08-06初中数学:一次函数学习要点和方法
- 08-06仓颉编程语言基础-数据类型—结构类型
- 08-06C++实现委托机制
- 08-06初中VS高中三角函数:从"固定镜头"到"360°全景",数学视野升级
- 08-06一文讲透PLC中Static和Temp变量的区别
- 08-06类三剑客:一招修改所有对象!类方法与静态方法的核心区别!
- 最近发表
- 标签列表
-
- cmd/c (90)
- c++中::是什么意思 (84)
- 标签用于 (71)
- 主键只能有一个吗 (77)
- c#console.writeline不显示 (95)
- pythoncase语句 (88)
- es6includes (74)
- sqlset (76)
- windowsscripthost (69)
- apt-getinstall-y (100)
- node_modules怎么生成 (87)
- chromepost (71)
- flexdirection (73)
- c++int转char (80)
- mysqlany_value (79)
- static函数和普通函数 (84)
- el-date-picker开始日期早于结束日期 (70)
- asynccallback (71)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- java (73)
- js数组插入 (83)
- mac安装java (72)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)