当前位置:首页 > python > 正文内容

PyTorch基础入门教程

zhangsir1个月前 (06-26)python22

一、PyTorch简介

PyTorch是一个基于Python的科学计算包,主要用途包括NumPy的替代品(以使用GPU和其他加速器的强大功能)以及一个用于实现神经网络的自动微分库。它具有动态计算图的特点,可以根据计算需要实时改变计算图,与TensorFlow的静态计算图不同。由于其简洁性和符合Python风格的特点,PyTorch在机器学习领域得到了广泛应用。

二、安装PyTorch

在开始使用PyTorch之前,需要先进行安装。以下是安装PyTorch的基本步骤:

  1. 安装Conda:访问Conda官网,根据你的操作系统选择合适的安装包进行下载和安装。

  2. 创建Conda环境:在Anaconda Prompt中输入命令创建一个新的Conda环境,例如conda create -n pytorch python=3.9

  3. 激活环境:使用命令conda activate pytorch激活新创建的环境。

  4. 安装PyTorch:在激活的Conda环境中,根据PyTorch官网的指导,选择合适的安装命令进行安装,例如conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia

  5. 检验安装:进入Python环境,导入PyTorch并检验CUDA可用性,例如import torchtorch.cuda.is_available()

三、Tensor基础操作

Tensor(张量)是PyTorch中的基本操作对象,可以理解为NumPy中的ndarray。以下是Tensor的一些基础操作:

1. 创建Tensor

Tensor可以通过多种方式创建,包括直接从数据创建、从NumPy数组创建、从另一个Tensor创建,以及指定特定值或随机创建。

import torch
import numpy as np

# 直接从数据创建
data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)

# 从NumPy数组创建
np_array = np.array(data)
x_np = torch.from_numpy(np_array)

# 从另一个Tensor创建
x_ones = torch.ones_like(x_data)
x_rand = torch.rand_like(x_data, dtype=torch.float)

# 指定特定值或随机创建
shape = (2, 3)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

2. Tensor的属性

可以通过.shape().size()方法来获取Tensor的形状;通过.dtype()来查看其中存放的数据类型;通过.device()方法来查看运算所处于的设备。

pythontensor = torch.rand(3, 4)print(f"Shape of tensor: {tensor.shape}")print(f"Datatype of tensor: {tensor.dtype}")print(f"Device tensor is stored on: {tensor.device}")

3. Tensor的运算

Tensor支持多种运算,包括索引和切片、张量的连接、算数操作(如加法)、矩阵乘法以及元素级乘法等。

tensor = torch.rand(3, 4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")

四、自动微分(Autograd)

PyTorch的自动微分机制使得反向传播求梯度变得非常简单。通过设置requires_grad=True,可以使得Tensor在参与计算时自动记录梯度信息。

pythonx = torch.tensor(2.0, requires_grad=True)y = x ** 2y.backward()  # 反向传播求梯度print(x.grad)  # 输出梯度值

五、数据加载(DataLoader)

PyTorch通过Dataset和DataLoader进行构建数据管道。Dataset用于封装数据集,而DataLoader则用于将数据集分批加载到模型中进行训练。

# 索引和切片
tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")
tensor[:, 1] = 0
print(tensor)

# 张量的连接
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)

# 算数操作(以加法为例)
t = torch.randn(4, 4)
print(tensor)
print(t)
print(tensor + t)  # 直接相加
print(torch.add(tensor, t))  # 使用add
result = torch.empty(4, 4)
torch.add(tensor, t, out=result)  # 指定输出到result中
print(result)
t.add_(tensor)  # inplace操作
print(t)

# 矩阵乘法
y1 = tensor @ tensor.T  # tensor.T 是 tensor 的转置
y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

# 元素级乘法
# 计算 tensor 的元素级乘积,即每个位置上的元素相乘

六、模型构建与训练

在PyTorch中,可以通过继承Module类来构造模型。以下是一个简单的模型构建和训练示例:

import torch.nn as nn
import torch.optim as optim

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(3, 1)  # 输入特征3,输出特征1

    def forward(self, x):
        return self.fc(x)

# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 训练模型
for epoch in range(100):
    for batch_data, batch_labels in dataloader:
        # 前向传播
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels.float())  # 假设标签是浮点数

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

七、结语

通过本文的学习,你已经掌握了PyTorch的基础知识,包括Tensor的创建和操作、自动微分机制、数据加载以及模型构建和训练等。希望这些内容能够帮助你快速入门PyTorch,并开启你的深度学习之旅。


zhangsir版权f8防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://www.mianka.xyz/post/184.html

分享给朋友:

“PyTorch基础入门教程” 的相关文章

Python怎么获取命令行参数

输入:“ import sys”,导入 sys 模块。插入语句:“print(sys.argv)”,打印获取的命令行参数。...

Selenium添加Cookie来实现自动登录

Selenium添加Cookie来实现自动登录第一步获取你登录的cookie,以csdn为例from selenium import webdriver driver = webdriver.Chrome() driver.get('...

解决Django的request.POST获取不到请求参数的问题

这个是Django自身的问题:只要在请求头的添加"content-type":'application/x-www-form-urlencoded'就行。...

python 爬虫 报错:UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0x8b in position”解决方案

发现报错“UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1:invalid start byte”,方法一:根据报错提示,错误原因有一条是这样的:“'Accept-Encodi...

python 给电脑设置闹钟

python会自动触发windows桌面通知,提示重要事项,比如说:您已工作两小时,该休息了我们可以设定固定时间提示,比如隔10分钟、1小时等用到的第三方库:win10toast - 用于发送桌面通知的工具from win10toast import ToastNoti...

python selenium 使用代理ip

代码如下:from selenium import webdriver chromeOptions = webdriver.ChromeOptions() chromeOptions.add_argument("--proxy-serv...