From 90cde13c97393312a03d0b877aa709a77cc5d30d Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 5 May 2026 23:01:43 +0530 Subject: [PATCH 1/2] Modernize transforms tutorial to torchvision v2 API --- beginner_source/basics/transforms_tutorial.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/beginner_source/basics/transforms_tutorial.py b/beginner_source/basics/transforms_tutorial.py index 33076958bf5..320137b26e8 100644 --- a/beginner_source/basics/transforms_tutorial.py +++ b/beginner_source/basics/transforms_tutorial.py @@ -23,42 +23,42 @@ The FashionMNIST features are in PIL Image format, and the labels are integers. For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors. -To make these transformations, we use ``ToTensor`` and ``Lambda``. +To make these transformations, we use the ``torchvision.transforms.v2`` API along with ``torch.nn.functional.one_hot``. """ import torch from torchvision import datasets -from torchvision.transforms import ToTensor, Lambda +from torchvision.transforms import v2 ds = datasets.FashionMNIST( root="data", train=True, download=True, - transform=ToTensor(), - target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) + transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), + target_transform=v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float()) ) ################################################# -# ToTensor() +# ToImage() and ToDtype() # ------------------------------- # -# `ToTensor `_ -# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales -# the image's pixel intensity values in the range [0., 1.] +# The ``torchvision.transforms.v2`` API replaces the legacy ``ToTensor`` transform with a two-step pipeline. +# `v2.ToImage `_ +# converts a PIL image or NumPy ``ndarray`` into a ``torchvision.tv_tensors.Image`` tensor, and +# `v2.ToDtype `_ +# with ``scale=True`` casts it to ``float32`` and scales the pixel intensity values to the range [0., 1.]. # ############################################## # Lambda Transforms # ------------------------------- # -# Lambda transforms apply any user-defined lambda function. Here, we define a function -# to turn the integer into a one-hot encoded tensor. -# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls -# `scatter_ `_ which assigns a -# ``value=1`` on the index as given by the label ``y``. +# Lambda transforms apply any user-defined lambda function. Here, we use +# `torch.nn.functional.one_hot `_ +# to turn the integer label into a one-hot encoded tensor of size 10 (the number of labels in our dataset), +# then cast it to ``float`` to match the expected dtype. -target_transform = Lambda(lambda y: torch.zeros( - 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) +target_transform = v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float()) ###################################################################### # -------------- From a2ec6656dd29ad90f2040dc2feec454b84236855 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Wed, 6 May 2026 20:26:48 +0530 Subject: [PATCH 2/2] use `F` instead of full name and update docs --- beginner_source/basics/transforms_tutorial.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/beginner_source/basics/transforms_tutorial.py b/beginner_source/basics/transforms_tutorial.py index 320137b26e8..03ebbeee737 100644 --- a/beginner_source/basics/transforms_tutorial.py +++ b/beginner_source/basics/transforms_tutorial.py @@ -27,6 +27,7 @@ """ import torch +import torch.nn.functional as F from torchvision import datasets from torchvision.transforms import v2 @@ -35,7 +36,9 @@ train=True, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), - target_transform=v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float()) + target_transform=v2.Lambda( + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() + ), ) ################################################# @@ -58,7 +61,9 @@ # to turn the integer label into a one-hot encoded tensor of size 10 (the number of labels in our dataset), # then cast it to ``float`` to match the expected dtype. -target_transform = v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float()) +target_transform = v2.Lambda( + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() +) ###################################################################### # -------------- @@ -67,4 +72,5 @@ ################################################################# # Further Reading # ~~~~~~~~~~~~~~~~~ -# - `torchvision.transforms API `_ +# - `Getting started with transforms v2 `_ +# - `torchvision.transforms.v2 API `_