定义基础模型模块的类
class DoubleConv(nn.Module):
# 实现init函数,在这里定义模块的网络结构
def __init__()
# 实现forward函数,在这里写出数据是如何在上述网络结构中流动的
def forward()
class UNET(nn.Module):
# 实现init函数,在这里定义模块的网络结构
def __init__()
# 实现forward函数,在这里写出数据是如何在上述网络结构中流动的
def forward()
训练模型代码
# Hyperparameters 定义超参数,
Learning rate
Batch size
Image dir
Val dir
Number of epoch
Image height
Image width
…
# 定义训练函数, for循环loader,获取数据,forward, backward
def train_fn(loader, model, optimizer, loss_fn, scaler)
for batch_idx, (data, targets) in enumerate(loop):
# 主函数
def main():
train_transform
val_transforms
# 定义模型
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
# loss函数
# 优化器
# 定义好数据的loader
train_loader, val_loader = get_loaders()
# 开始一个个epoch循环
for epoch in range(NUM_EPOCHS):
# 把上面四个东西,给训练函数,开始训练
Train_fn()
# 保存模型
Save_checkpoint()
#检查精度。 需要一个model和测试数据的loader
check_accuracy(loader, model, device = "cuda"):
定义数据集的类
class CarvanaDataset(Dataset):
# 需要输入获取图片和标签的文件夹路径
def __init__(self, image_dir, mask_dir, transform = None):
def __len__(self):
# 需要输入要获取的索引
def __getitem__(self, index):
定义一些工具类函数
#保存模型的函数
def save_checkpoint():
#加载模型的函数
def load_checkpoint(checkpoint, model):
# 获取数据loader的函数
# 需要知道数据文件夹的路径,还有batchsize,还有要没有数据变换
def get_loaders(
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers=4,
pin_memory=True,
):
# 训练过程中查看网络的精度
def check_accuracy(loader, model, device = "cuda"):
# 训练过程中保存数据
def save_predictions_as_imgs():