Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions beginner_source/blitz/cifar10_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"""
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2

########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
Expand All @@ -69,9 +69,10 @@
# BrokenPipeError or RuntimeError related to multiprocessing, try setting
# the num_worker of torch.utils.data.DataLoader() to 0.

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

Expand Down Expand Up @@ -191,7 +192,7 @@ def forward(self, x):
########################################################################
# Let's quickly save our trained model:

PATH = './cifar_net.pth'
PATH = './cifar_net.pt'
torch.save(net.state_dict(), PATH)

########################################################################
Expand Down Expand Up @@ -302,7 +303,7 @@ def forward(self, x):
# Let's first define our device as the first visible cuda device if we have
# CUDA available:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device(torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

Expand Down Expand Up @@ -355,9 +356,9 @@ def forward(self, x):
# - `Discuss PyTorch on the Forums`_
# - `Chat with other users on Slack`_
#
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/main/imagenet
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/main/dcgan
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/main/word_language_model
# .. _More examples: https://github.com/pytorch/examples
# .. _More tutorials: https://github.com/pytorch/tutorials
# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
Expand Down
Loading