当前位置: 首页 > news >正文

綦江集团网站建设企业推广策划公司

綦江集团网站建设,企业推广策划公司,青海媒体网站建设公司,局部翻新装修公司文章目录 FashionMNIST数据集需求库导入、数据迭代器生成设备选择样例图片展示日志写入评估—计数器模型构建训练函数整体代码训练过程日志 FashionMNIST数据集 FashionMNIST(时尚 MNIST)是一个用于图像分类的数据集,旨在替代传统的手写数字…

文章目录

    • FashionMNIST数据集
    • 需求库导入、数据迭代器生成
    • 设备选择
    • 样例图片展示
    • 日志写入
    • 评估—计数器
    • 模型构建
    • 训练函数
    • 整体代码
    • 训练过程
    • 日志

FashionMNIST数据集

  • FashionMNIST(时尚 MNIST)是一个用于图像分类的数据集,旨在替代传统的手写数字MNIST数据集。它由 Zalando Research 创建,适用于深度学习和计算机视觉的实验。
    • FashionMNIST 包含 10 个类别,分别对应不同的时尚物品。这些类别包括 T恤/上衣、裤子、套头衫、裙子、外套、凉鞋、衬衫、运动鞋、包和踝靴。
    • 每个类别有 6,000 张训练图像和 1,000 张测试图像,总计 70,000 张图像。
    • 每张图像的尺寸为 28x28 像素,与MNIST数据集相同。
    • 数据集中的每个图像都是灰度图像,像素值在0到255之间。
      在这里插入图片描述

需求库导入、数据迭代器生成

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoaderimport torchvision
from torchvision import transformsimport argparse
from tqdm import tqdmimport matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriterdef _load_data():"""download the data, and generate the dataloader"""trans = transforms.Compose([transforms.ToTensor()])train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)# print(len(train_dataset), len(test_dataset))train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)return (train_loader, test_loader)

设备选择

def _device():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")return device

样例图片展示

"""display data examples"""
def _image_label(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def _show_images(imgs, rows, columns, titles=None, scale=1.5):figsize = (rows * scale, columns * 1.5)fig, axes = plt.subplots(rows, columns, figsize=figsize)axes = axes.flatten()for i, (img, ax) in enumerate(zip(imgs, axes)):ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])plt.show()return axesdef _show_examples():train_loader, test_loader = _load_data()for images, labels in train_loader:images = images.squeeze(1)_show_images(images, 3, 3, _image_label(labels))break

日志写入

class _logger():def __init__(self, log_dir, log_history=True):if log_history:log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))self.summary = SummaryWriter(log_dir)def scalar_summary(self, tag, value, step):self.summary.add_scalars(tag, value, step)def images_summary(self, tag, image_tensor, step):self.summary.add_images(tag, image_tensor, step)def figure_summary(self, tag, figure, step):self.summary.add_figure(tag, figure, step)def graph_summary(self, model):self.summary.add_graph(model)def close(self):self.summary.close()

评估—计数器

class AverageMeter():def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

模型构建

class Conv3x3(nn.Module):def __init__(self, in_channels, out_channels, down_sample=False):super(Conv3x3, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if down_sample:self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)def forward(self, x):return self.conv(x)class SimpleNet(nn.Module):def __init__(self, in_channels, out_channels):super(SimpleNet, self).__init__()self.conv1 = Conv3x3(in_channels, 32)self.conv2 = Conv3x3(32, 64, down_sample=True)self.conv3 = Conv3x3(64, 128)self.conv4 = Conv3x3(128, 256, down_sample=True)self.fc = nn.Linear(256*7*7, out_channels)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = torch.flatten(x, 1)out = self.fc(x)return out

训练函数

def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):train_loss = AverageMeter()test_loss = AverageMeter()train_precision = AverageMeter()test_precision = AverageMeter()time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")for epoch in range(epochs):print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))model.train()for input, label in tqdm(train_loader):input, label = input.to(device), label.to(device)output = model(input)# backwardloss = criterion(output, label)optimizor.zero_grad()loss.backward()optimizor.step()# loggerpredict = torch.argmax(output, dim=1)train_pre = sum(predict == label) / len(label)train_loss.update(loss.item(), input.size(0))train_precision.update(train_pre.item(), input.size(0))model.eval()with torch.no_grad():for X, y in tqdm(test_loader):X, y = X.to(device), y.to(device)y_hat = model(X)loss_te = criterion(y_hat, y)predict_ = torch.argmax(y_hat, dim=1)test_pre = sum(predict_ == y) / len(y)test_loss.update(loss_te.item(), X.size(0))test_precision.update(test_pre.item(), X.size(0))if save_weight:best_dice = args.best_diceweight_dir = os.path.join(args.weight_dir, args.model, time_tick)os.makedirs(weight_dir, exist_ok=True)monitor_dice = test_precision.avgif monitor_dice > best_dice:best_dice = max(monitor_dice, best_dice)name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \'_test_loss-' + str(round(test_loss.avg, 4)) + \'_test_dice-' + str(round(best_dice, 4)) + '.pt')torch.save(model.state_dict(), name)print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))# summarywriter.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)writer.close()

整体代码

import os
import random
import numpy as np
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoaderimport torchvision
from torchvision import transformsimport argparse
from tqdm import tqdmimport matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter"""Reproduction experiment"""
def setup_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)# torch.backends.cudnn.benchmark = False# torch.backends.cudnn.enabled = False# torch.backends.cudnn.deterministic = True"""data related"""
def _base_options():parser = argparse.ArgumentParser(description="Train setting for FashionMNIST")# about datasetparser.add_argument('--batch_size', default=8, type=int, help='the batch size of dataset')parser.add_argument('--num_works', default=4, type=int, help="the num_works used")# trainparser.add_argument('--epochs', default=100, type=int, help='train iterations')parser.add_argument('--lr', default=0.001, type=float, help='learning rate')parser.add_argument('--model', default="SimpleNet", choices=["SimpleNet"], help="the model choosed")# log dirparser.add_argument('--log_dir', default="./logger/", help='the path of log file')#parser.add_argument('--best_dice', default=-100, type=int, help='for save weight')parser.add_argument('--weight_dir', default="./weight/", help='the dir for save weight')args = parser.parse_args()return argsdef _load_data():"""download the data, and generate the dataloader"""trans = transforms.Compose([transforms.ToTensor()])train_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=True, download=True, transform=trans)test_dataset = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=trans)# print(len(train_dataset), len(test_dataset))train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.num_works)return (train_loader, test_loader)def _device():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")return device"""display data examples"""
def _image_label(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def _show_images(imgs, rows, columns, titles=None, scale=1.5):figsize = (rows * scale, columns * 1.5)fig, axes = plt.subplots(rows, columns, figsize=figsize)axes = axes.flatten()for i, (img, ax) in enumerate(zip(imgs, axes)):ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])plt.show()return axesdef _show_examples():train_loader, test_loader = _load_data()for images, labels in train_loader:images = images.squeeze(1)_show_images(images, 3, 3, _image_label(labels))break"""log"""
class _logger():def __init__(self, log_dir, log_history=True):if log_history:log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))self.summary = SummaryWriter(log_dir)def scalar_summary(self, tag, value, step):self.summary.add_scalars(tag, value, step)def images_summary(self, tag, image_tensor, step):self.summary.add_images(tag, image_tensor, step)def figure_summary(self, tag, figure, step):self.summary.add_figure(tag, figure, step)def graph_summary(self, model):self.summary.add_graph(model)def close(self):self.summary.close()"""evaluate the result"""
class AverageMeter():def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count"""define the Net"""
class Conv3x3(nn.Module):def __init__(self, in_channels, out_channels, down_sample=False):super(Conv3x3, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, 1, 1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if down_sample:self.conv[3] = nn.Conv2d(out_channels, out_channels, 2, 2, 0)def forward(self, x):return self.conv(x)class SimpleNet(nn.Module):def __init__(self, in_channels, out_channels):super(SimpleNet, self).__init__()self.conv1 = Conv3x3(in_channels, 32)self.conv2 = Conv3x3(32, 64, down_sample=True)self.conv3 = Conv3x3(64, 128)self.conv4 = Conv3x3(128, 256, down_sample=True)self.fc = nn.Linear(256*7*7, out_channels)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = torch.flatten(x, 1)out = self.fc(x)return out"""progress of train/test"""
def train(model, train_loader, test_loader, criterion, optimizor, epochs, device, writer, save_weight=False):train_loss = AverageMeter()test_loss = AverageMeter()train_precision = AverageMeter()test_precision = AverageMeter()time_tick = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")for epoch in range(epochs):print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, args.lr))model.train()for input, label in tqdm(train_loader):input, label = input.to(device), label.to(device)output = model(input)# backwardloss = criterion(output, label)optimizor.zero_grad()loss.backward()optimizor.step()# loggerpredict = torch.argmax(output, dim=1)train_pre = sum(predict == label) / len(label)train_loss.update(loss.item(), input.size(0))train_precision.update(train_pre.item(), input.size(0))model.eval()with torch.no_grad():for X, y in tqdm(test_loader):X, y = X.to(device), y.to(device)y_hat = model(X)loss_te = criterion(y_hat, y)predict_ = torch.argmax(y_hat, dim=1)test_pre = sum(predict_ == y) / len(y)test_loss.update(loss_te.item(), X.size(0))test_precision.update(test_pre.item(), X.size(0))if save_weight:best_dice = args.best_diceweight_dir = os.path.join(args.weight_dir, args.model, time_tick)os.makedirs(weight_dir, exist_ok=True)monitor_dice = test_precision.avgif monitor_dice > best_dice:best_dice = max(monitor_dice, best_dice)name = os.path.join(weight_dir, args.model + '_' + str(epoch) + \'_test_loss-' + str(round(test_loss.avg, 4)) + \'_test_dice-' + str(round(best_dice, 4)) + '.pt')torch.save(model.state_dict(), name)print("train" + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=train_loss.avg, dice=train_precision.avg))print("test " + '---Loss: {loss:.4f} | Dice: {dice:.4f}'.format(loss=test_loss.avg, dice=test_precision.avg))# summarywriter.scalar_summary("Loss/loss", {"train": train_loss.avg, "test": test_loss.avg}, epoch)writer.scalar_summary("Loss/precision", {"train": train_precision.avg, "test": test_precision.avg}, epoch)writer.close()if __name__ == "__main__":# configargs = _base_options()device = _device()# datatrain_loader, test_loader = _load_data()# loggerwriter = _logger(log_dir=os.path.join(args.log_dir, args.model))# modelmodel = SimpleNet(in_channels=1, out_channels=10).to(device)optimizor = torch.optim.Adam(model.parameters(), lr=args.lr)criterion = nn.CrossEntropyLoss()train(model, train_loader, test_loader, criterion, optimizor, args.epochs, device, writer, save_weight=True)"""    args = _base_options()_show_examples()  # ———>  样例图片显示
"""

训练过程

在这里插入图片描述

日志

在这里插入图片描述

http://www.mnyf.cn/news/50349.html

相关文章:

  • 重点建设学科网站网站建设免费
  • 网站营销队伍网络广告电话
  • 无为县建设局网站百度投放广告流程
  • 合肥网站优化价格关键词优化教程
  • 我帮诈骗团伙做诈骗网站获利怎么查找关键词排名
  • 动态网站开发实训课程标准seo排名工具给您好的建议
  • 谷歌做自己的网站seo海外推广
  • 长垣县住房和城乡建设局网站七台河网站seo
  • 主题猫-wordpress网站seo推广公司靠谱吗
  • 网站建设通网络设计
  • 吉林建设集团网站安全优化大师
  • 公司做网站是做什么账务处理手机百度网盘登录入口
  • 网站建设频教程免费推广方式有哪些
  • 网站keywords多少字外贸网站搭建推广
  • 邯郸做网站服务商免费私人网站建设平台
  • 天津建设工程信息网中标aso具体优化
  • 帮做网站设计与规划作业葫岛百度seo
  • 网站建设报道稿seo详细教程
  • 室内装修设计软件电脑版百度seo刷排名软件
  • 网站设置为默认主页论坛推广网站
  • 做门户网站代码质量方面具体需要注意什么网络兼职平台
  • 河北涿州建设局网站企业快速建站
  • 博客网站开发视频深圳网络营销和推广方案
  • 简单的公司资料网站怎么做网络宣传
  • 顺德做营销网站公司新闻类软文营销案例
  • 求网页设计网站个人博客网站模板
  • 三网合一网站源代码广告公司怎么找客户资源
  • 全网营销网站建设郑州网站推广培训
  • 三级分销网站建设近期重大新闻事件
  • wordpress最好的图片压缩seo关键词优化公司哪家好