Metric collection is an essential part of every machine learning project, enabling us to track model performance and monitor training progress. Ideally, metrics should be collected and computed without introducing any additional overhead to the training process. However, just like other components of the training loop, inefficient metric computation can introduce unnecessary overhead, increase training-step times and inflate training costs.
In this post — the seventh in our series on performance profiling and optimization in PyTorch — we will demonstrate how a naïve implementation of metric collection can negatively impact runtime performance.
To program our metric collection, we will use TorchMetrics a popular library designed to simplify and standardize metric computation in PyTorch. Our goals will be to:
- Demonstrate the runtime overhead caused by a naïve implementation of metric collection.
- Use PyTorch Profiler to pinpoint performance bottlenecks introduced by metric computation.
- Demonstrate optimization techniques to reduce metric collection overhead.
To facilitate our discussion, we will define a toy PyTorch model and assess how metric collection can impact its runtime performance. We will run our experiments on an NVIDIA A40 GPU, with a PyTorch 2.5.1 docker image and TorchMetrics 1.6.1.
It’s important to note that metric collection behavior can vary greatly depending on the hardware, runtime environment, and model architecture. The code snippets provided in this post are intended for demonstrative purposes only. Please do not interpret our mention of any tool or technique as an endorsement for its use.
In the code block below we define a simple image classification model with a ResNet-18 backbone.
import time
import torch
import torchvisiondevice = "cuda"
model = torchvision.models.resnet18().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
We define a synthetic dataset which we will use to train our toy model.
from torch.utils.data import Dataset, DataLoader# A dataset with random images and labels
class FakeDataset(Dataset):
def __len__(self):
return 100000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=index % 1000, dtype=torch.int64)
return rand_image, label
train_set = FakeDataset()
batch_size = 128
num_workers = 12
train_loader = DataLoader(
dataset=train_set,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True
)
We define a collection of standard metrics from TorchMetrics, along with a control flag to enable or disable metric calculation.
from torchmetrics import (
MeanMetric,
Accuracy,
Precision,
Recall,
F1Score,
)# toggle to enable/disable metric collection
capture_metrics = False
if capture_metrics:
metrics = {
"avg_loss": MeanMetric(),
"accuracy": Accuracy(task="multiclass", num_classes=1000),
"precision": Precision(task="multiclass", num_classes=1000),
"recall": Recall(task="multiclass", num_classes=1000),
"f1_score": F1Score(task="multiclass", num_classes=1000),
}
# Move all metrics to the device
metrics = {name: metric.to(device) for name, metric in metrics.items()}
Next, we define a PyTorch Profiler instance, along with a control flag that allows us to enable or disable profiling. For a detailed tutorial on using PyTorch Profiler, please refer to the first post in this series.
from torch import profiler# toggle to enable/disable profiling
enable_profiler = True
if enable_profiler:
prof = profiler.profile(
schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),
on_trace_ready=profiler.tensorboard_trace_handler("./logs/"),
profile_memory=True,
with_stack=True
)
prof.start()
Lastly, we define a standard training step:
model.train()t0 = time.perf_counter()
total_time = 0
count = 0
for idx, (data, target) in enumerate(train_loader):
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if capture_metrics:
# update metrics
metrics["avg_loss"].update(loss)
for name, metric in metrics.items():
if name != "avg_loss":
metric.update(output, target)
if (idx + 1) % 100 == 0:
# compute metrics
metric_results = {
name: metric.compute().item()
for name, metric in metrics.items()
}
# print metrics
print(f"Step {idx + 1}: {metric_results}")
# reset metrics
for metric in metrics.values():
metric.reset()
elif (idx + 1) % 100 == 0:
# print last loss value
print(f"Step {idx + 1}: Loss = {loss.item():.4f}")
batch_time = time.perf_counter() - t0
t0 = time.perf_counter()
if idx > 10: # skip first steps
total_time += batch_time
count += 1
if enable_profiler:
prof.step()
if idx > 200:
break
if enable_profiler:
prof.stop()
avg_time = total_time/count
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')
Metric Collection Overhead
To measure the impact of metric collection on training step time, we ran our training script both with and without metric calculation. The results are summarized in the following table.
Our naïve metric collection resulted in a nearly 10% drop in runtime performance. While metric collection is essential for machine learning development, it usually involves relatively simple mathematical operations and hardly warrants such a significant overhead. What is going on?!!
To better understand the source of the performance degradation, we reran the training script with the PyTorch Profiler enabled. The resultant trace is shown below:
The trace reveals recurring “cudaStreamSynchronize” operations that coincide with noticeable drops in GPU utilization. These types of “CPU-GPU sync” events were discussed in detail in part two of our series. In a typical training step, the CPU and GPU work in parallel: The CPU manages tasks like data transfers to the GPU and kernel loading, and the GPU executes the model on the input data and updates its weights. Ideally, we would like to minimize the points of synchronization between the CPU and GPU in order to maximize performance. Here, however, we see that the metric collection triggers a CPU to GPU data copy. This requires the CPU to suspend its processing until the GPU catches up, which, in turn, causes the GPU to wait for the CPU to resume loading the subsequent kernel operations. The bottom line is that these synchronization points lead to inefficient utilization of both the CPU and GPU. Our metric collection adds eight sync events to each training step.
A closer examination of the trace shows that the sync events are coming from the update call of the MeanMetric TorchMetric. For the experienced profiling expert, this may be sufficient to identify the root cause, but we will go a step further and use the torch.profiler.record_function utility to identify the exact offending line of code.
To pinpoint the exact source of the sync event, we extended the MeanMetric class and overrode the update method using record_function context blocks. This approach allows us to profile individual operations within the method and identify performance bottlenecks.
class ProfileMeanMetric(MeanMetric):
def update(self, value, weight = 1.0):
# broadcast weight to value shape
with profiler.record_function("process value"):
if not isinstance(value, torch.Tensor):
value = torch.as_tensor(value, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
with profiler.record_function("broadcast weight"):
weight = torch.broadcast_to(weight, value.shape)
with profiler.record_function("cast_and_nan_check"):
value, weight = self._cast_and_nan_check_input(value, weight)if value.numel() == 0:
return
with profiler.record_function("update value"):
self.mean_value += (value * weight).sum()
with profiler.record_function("update weight"):
self.weight += weight.sum()
We then updated our avg_loss metric to use the newly created ProfileMeanMetric and reran the training script.
The updated trace reveals that the sync event originates from the following line:
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
This operation converts the default scalar value weight=1.0
into a PyTorch tensor and places it on the GPU. The sync event occurs because this action triggers a CPU-to-GPU data copy, which requires the CPU to wait for the GPU to process the copied value.
Optimization 1: Specify Weight Value
Now that we have found the source of the issue, we can overcome it easily by specifying a weight value in our update call. This prevents the runtime from converting the default scalar weight=1.0
into a tensor on the GPU, avoiding the sync event:
# update metrics
if capture_metric:
metrics["avg_loss"].update(loss, weight=torch.ones_like(loss))
Rerunning the script after applying this change reveals that we have succeeded in eliminating the initial sync event… only to have uncovered a new one, this time coming from the _cast_and_nan_check_input function:
To explore our new sync event, we extended our custom metric with additional profiling probes and reran our script.
class ProfileMeanMetric(MeanMetric):
def update(self, value, weight = 1.0):
# broadcast weight to value shape
with profiler.record_function("process value"):
if not isinstance(value, torch.Tensor):
value = torch.as_tensor(value, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
with profiler.record_function("broadcast weight"):
weight = torch.broadcast_to(weight, value.shape)
with profiler.record_function("cast_and_nan_check"):
value, weight = self._cast_and_nan_check_input(value, weight)if value.numel() == 0:
return
with profiler.record_function("update value"):
self.mean_value += (value * weight).sum()
with profiler.record_function("update weight"):
self.weight += weight.sum()
def _cast_and_nan_check_input(self, x, weight = None):
"""Convert input ``x`` to a tensor and check for Nans."""
with profiler.record_function("process x"):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x, dtype=self.dtype,
device=self.device)
with profiler.record_function("process weight"):
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
nans = torch.isnan(x)
if weight is not None:
nans_weight = torch.isnan(weight)
else:
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
with profiler.record_function("any nans"):
anynans = nans.any() or nans_weight.any()
with profiler.record_function("process nans"):
if anynans:
if self.nan_strategy == "error":
raise RuntimeError("Encountered `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
print("Encountered `nan` values in tensor."
" Will be removed.")
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float"
f" but you pass {self.nan_strategy}")
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy
with profiler.record_function("return value"):
retval = x.to(self.dtype), weight.to(self.dtype)
return retval
The resultant trace is captured below:
The trace points directly to the offending line:
anynans = nans.any() or nans_weight.any()
This operation checks for NaN
values in the input tensors, but it introduces a costly CPU-GPU synchronization event because the operation involves copying data from the GPU to the CPU.
Upon a closer inspection of the TorchMetric BaseAggregator class, we find several options for handling NAN value updates, all of which pass through the offending line of code. However, for our use case — calculating the average loss metric — this check is unnecessary and does not justify the runtime performance penalty.
Optimization 2: Disable NAN Value Checks
To eliminate the overhead, we propose disabling the NaN
value checks by overriding the _cast_and_nan_check_input
function. Instead of a static override, we implemented a dynamic solution that can be applied flexibly to any descendants of the BaseAggregator class.
from torchmetrics.aggregation import BaseAggregatordef suppress_nan_check(MetricClass):
assert issubclass(MetricClass, BaseAggregator), MetricClass
class DisableNanCheck(MetricClass):
def _cast_and_nan_check_input(self, x, weight=None):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x, dtype=self.dtype,
device=self.device)
if weight is not None and not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=self.dtype,
device=self.device)
if weight is None:
weight = torch.ones_like(x)
return x.to(self.dtype), weight.to(self.dtype)
return DisableNanCheck
NoNanMeanMetric = suppress_nan_check(MeanMetric)
metrics["avg_loss"] = NoNanMeanMetric().to(device)
Post Optimization Results: Success
After implementing the two optimizations — specifying the weight value and disabling the NaN
checks—we find the step time performance and the GPU utilization to match those of our baseline experiment. In addition, the resultant PyTorch Profiler trace shows that all of the added “cudaStreamSynchronize” events that were associated with the metric collection, have been eliminated. With a few small changes, we have reduced the cost of training by ~10% without any changes to the behavior of the metric collection.
In the next section we will explore an additional Metric collection optimization.
In the previous section, the metric values resided on the GPU, making it logical to store and compute the metrics on the GPU. However, in scenarios where the values we wish to aggregate reside on the CPU, it might be preferable to store the metrics on the CPU to avoid unnecessary device transfers.
In the code block below, we modify our script to calculate the average step time using a MeanMetric on the CPU. This change has no impact on the runtime performance of our training step:
avg_time = NoNanMeanMetric()
t0 = time.perf_counter()for idx, (data, target) in enumerate(train_loader):
# move data to device
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if capture_metrics:
metrics["avg_loss"].update(loss)
for name, metric in metrics.items():
if name != "avg_loss":
metric.update(output, target)
if (idx + 1) % 100 == 0:
# compute metrics
metric_results = {
name: metric.compute().item()
for name, metric in metrics.items()
}
# print metrics
print(f"Step {idx + 1}: {metric_results}")
# reset metrics
for metric in metrics.values():
metric.reset()
elif (idx + 1) % 100 == 0:
# print last loss value
print(f"Step {idx + 1}: Loss = {loss.item():.4f}")
batch_time = time.perf_counter() - t0
t0 = time.perf_counter()
if idx > 10: # skip first steps
avg_time.update(batch_time)
if enable_profiler:
prof.step()
if idx > 200:
break
if enable_profiler:
prof.stop()
avg_time = avg_time.compute().item()
print(f'Average step time: {avg_time}')
print(f'Throughput: {batch_size/avg_time:.2f} images/sec')
The problem arises when we attempt to extend our script to support distributed training. To demonstrate the problem, we modified our model definition to use DistributedDataParallel (DDP):
# toggle to enable/disable ddp
use_ddp = Trueif use_ddp:
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group("nccl", rank=0, world_size=1)
torch.cuda.set_device(0)
model = DDP(torchvision.models.resnet18().to(device))
else:
model = torchvision.models.resnet18().to(device)
# insert training loop
# append to end of the script:
if use_ddp:
# destroy the process group
dist.destroy_process_group()
The DDP modification results in the following error:
RuntimeError: No backend type associated with device type cpu
By default, metrics in distributed training are programmed to synchronize across all devices in use. However, the synchronization backend used by DDP does not support metrics stored on the CPU.
One way to solve this is to disable the cross-device metric synchronization:
avg_time = NoNanMeanMetric(sync_on_compute=False)
In our case, where we are measuring the average time, this solution is acceptable. However, in some cases, the metric synchronization is essential, and we have may have no choice but to move the metric onto the GPU:
avg_time = NoNanMeanMetric().to(device)
Unfortunately, this situation gives rise to a new CPU-GPU sync event coming from the update function.
This sync event should hardly come as a surprise—after all, we are updating a GPU metric with a value residing on the CPU, necessitating a memory copy. However, in the case of a scalar metric, this data transfer can be completely avoided with a simple optimization.
Optimization 3: Perform Metric Updates with Tensors instead of Scalars
The solution is straightforward: instead of updating the metric with a float scalar, we convert the value to a Tensor before calling update
.
batch_time = torch.as_tensor(batch_time)
avg_time.update(batch_time, torch.ones_like(batch_time))
This minor change eliminates the sync event and restores step time to baseline performance.
At first glance, this result may seem surprising: We would expect that updating a GPU metric with a CPU tensor should still require a memory copy. However, PyTorch optimizes scalar operations by using a dedicated kernel that efficiently performs the addition without an explicit data transfer. This avoids the expensive synchronization that would otherwise occur.
Our series on performance profiling and optimization in PyTorch has aimed to emphasize the critical role of performance analysis and optimization in machine learning development. Each post has focused on different stages of the training pipeline, demonstrating simple yet effective tools and techniques for analyzing and boosting runtime performance.
In this post, we explored how a naïve approach to TorchMetrics can introduce CPU-GPU synchronization events and significantly degrade PyTorch training performance. Using PyTorch Profiler, we identified the lines of code responsible for these sync events and applied targeted optimizations to eliminate them.
We have created a dedicated pull request on the TorchMetrics github page covering some of the optimizations discussed in this post. Please feel free to contribute your own improvements and optimizations!