Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf issue: Torchsharp is slower than pytorch on cuda on some operators #1442

Open
LittleLittleCloud opened this issue Feb 7, 2025 · 3 comments
Labels
enhancement New feature or request

Comments

@LittleLittleCloud
Copy link
Contributor

LittleLittleCloud commented Feb 7, 2025

I run some benchmark tests to compare the performance difference between torchsharp and pytorch, both uses libtorch 2.2.1 + cuda 12.1. And I notice that torchsharp is slower than pytorch in most of operators. Below are benchmark result

Torchsharp

Image

Pytorch

Image

Observation

I can achieve comparable result between torchsharp and pytorch if I replace operator with in-place version. The performance also become much better if I explicitly dispose current session during each tests

For example, in adding benchmark, torchsharp runs nearly the same with pytorch if I use tensor.add_ instead of tensor.add

Considering that the major difference between the operator and the in-place operator is the in-place operator won't create a new Tensor object, it's likely that the main overhead might happen in Tensor constructor.

Source code

using TorchSharp;

// Initialize CUDA device
var device = torch.CUDA;

var repeatTime = 10000;
// Test randn
var startTime = DateTime.Now;
for (int i = 0; i < repeatTime; i++)
{
    var _ = torch.randn(new long[] { 1000, 1000 }, device: device);
}

Console.WriteLine("Time taken for randn: " + (DateTime.Now - startTime).TotalSeconds);

// Test matmul
startTime = DateTime.Now;
var a = torch.randn(new long[] { 1000, 1000 }, device: device);
var b = torch.randn(new long[] { 1000, 1000 }, device: device);

for (int i = 0; i < repeatTime; i++)
{
    var c = torch.matmul(a, b);
}

Console.WriteLine("Time taken for matmul: " + (DateTime.Now - startTime).TotalSeconds);

// Test concat
startTime = DateTime.Now;
a = torch.randn(new long[] { 1000, 1000 }, device: device);
b = torch.randn(new long[] { 1000, 1000 }, device: device);

for (int i = 0; i < repeatTime; i++)
{
    var c = torch.cat(new[] { a, b }, 0);
}

Console.WriteLine("Time taken for concat: " + (DateTime.Now - startTime).TotalSeconds);

// Test slice
startTime = DateTime.Now;
a = torch.randn(new long[] { 1000, 1000 }, device: device);

for (int i = 0; i < repeatTime; i++)
{
    var c = a[.., 0..500];
}

Console.WriteLine("Time taken for slice: " + (DateTime.Now - startTime).TotalSeconds);

// Test add
startTime = DateTime.Now;
a = torch.randn(new long[] { 1000, 1000 }, device: device);
b = torch.randn(new long[] { 1000, 1000 }, device: device);

for (int i = 0; i < repeatTime; i++)
{
    var c = a + b;
}

Console.WriteLine("Time taken for add: " + (DateTime.Now - startTime).TotalSeconds);
# create a list of benchmark for pytorch on cuda

import torch
import time
repeat = 10000
total_time = 0
start_time = time.time()
for _ in range(repeat):
    a = torch.randn(1000, 1000).cuda()
print("Time taken for randn: " , time.time()-start_time)

start_time = time.time()
# test matmul
a = torch.randn(1000, 1000).cuda()
b = torch.randn(1000, 1000).cuda()
for _ in range(repeat):
    c = torch.matmul(a, b)
    

print("Time taken for matmul: ", time.time()-start_time)

start_time = time.time()

# test concat   
a = torch.randn(1000, 1000).cuda()
b = torch.randn(1000, 1000).cuda()

for _ in range(repeat):
    c = torch.cat((a, b), 0)

print("Time taken for concat: ", time.time()-start_time)

start_time = time.time()
# test slice
a = torch.randn(1000, 1000).cuda()

for _ in range(repeat):
    c = a[:, 0:500]

print("Time taken for slice: ", time.time()-start_time)

start_time = time.time()
# test add
a = torch.randn(1000, 1000).cuda()
b = torch.randn(1000, 1000).cuda()

for _ in range(repeat):
    c = a + b

print("Time taken for add: ", time.time()-start_time)
@ds5678
Copy link

ds5678 commented Feb 8, 2025

For something like this, you should use BenchmarkDotNet. It handles all the edge cases with .NET benchmarking (like JIT warmup).

@ozanMSFT ozanMSFT added the enhancement New feature or request label Feb 11, 2025
@ozanMSFT
Copy link
Contributor

Hey @ds5678 ,

Thanks for addressing the issue. I've noticed that you've used libtorch 2.2.1 .

The latest version of torchsharp (0.105.0) is using 2.5.1 and there were some performance improvements regarding the calls and garbage collection.

Could you please check your results again with the latest version 0.105.0?

@ds5678
Copy link

ds5678 commented Feb 11, 2025

@ozanMSFT I think you meant to ping @LittleLittleCloud

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants