Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] javacpp -pytorch EmbeddingBagImpl maybe need add one override forward method ,maybe also have bug #1593

Open
mullerhai opened this issue Mar 7, 2025 · 2 comments

Comments

@mullerhai
Copy link

Hi , @saudet EmbeddingBagImpl also in javacpp can not work ,same logic parameter in python could work. please fix up,thanks

in python

   # an EmbeddingBag module containing 10 tensors of size 3
    embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
    # a batch of 2 samples of 4 indices each
    input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
    offsets = torch.tensor([0, 4], dtype=torch.long)
    out = embedding_sum(input, offsets)
    print(out)

console

tensor([[ 2.4000, -1.5095, -2.0853],
        [ 2.5163,  0.0308, -1.2409]], grad_fn=<EmbeddingBagBackward0>)

in pure javacpp

 val options = new EmbeddingBagOptions(10, 3)
    options.mode().put(new kSum)

//    val m123 = nn.EmbeddingBag(num_embeddings = val options = new EmbeddingBagOptions(numEmbeddings.toLong, embeddingDim.toLong)10, embedding_dim = 3, mode = "sum") // (5, 7), padding_mode = "reflect")
    val input23 = torch.Tensor(Seq(1, 2, 4, 5, 4, 3, 2, 9)).to(torch.float32) // 16, 10)) //
    val offsets = torch.Tensor(Seq(0, 4))
    val model = EmbeddingBagImpl(options)
    val output = model.forward(input23.native,offsets.native,torch.zeros(input23.size).native)
    println(s"output ${fromNative(output).shape}")

javacpp console

java.lang.RuntimeException: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CPUFloatType instead (while checking arguments for embedding_bag)
Exception raised from checkScalarTypes at D:\a\javacpp-presets\javacpp-presets\pytorch\cppbuild\windows-x86_64-gpu\pytorch\aten\src\ATen\TensorUtils.cpp:203 (most recent call first):
@saudet
Copy link
Member

saudet commented Mar 8, 2025

Like the error says, you need a integer tensor, not a float tensor for this to work

@mullerhai
Copy link
Author

Like the error says, you need a integer tensor, not a float tensor for this to work

HI, @saudet ,now all tensor I cast to integer tensor , but can not work

import org.bytedeco.pytorch.global.torch as tch
class EmbeddingBagRawSuite extends munit.FunSuite {
  test("ConvTranspose2d output shapes") {
    val options = new EmbeddingBagOptions(10, 3)
    options.mode().put(new kSum)
    options.embedding_dim().put(3)
    options.num_embeddings().put(10)
    options.norm_type().put(2.0)
    options.include_last_offset().put(false)
    options.scale_grad_by_freq().put(false)
    options.sparse().put(false)
    //    options.max_norm().put(2d)
//    options.padding_idx().put()
//    val m123 = nn.EmbeddingBag(num_embeddings = val options = new EmbeddingBagOptions(numEmbeddings.toLong, embeddingDim.toLong)10, embedding_dim = 3, mode = "sum") // (5, 7), padding_mode = "reflect")

    val input23 = torch.Tensor(Seq(1, 2, 4, 5, 4, 3, 2, 9)).to(torch.int32) // 16, 10)) //
    val offsets = torch.Tensor(Seq(0, 4)).to(torch.int32)
    val pre =torch.zeros(input23.size).to(torch.int32)
//    val offset = tch.new
    val model = EmbeddingBagImpl(options)
    println(s"input type ${input23.native.dtype().name().getString()} offsets ${offsets.native.dtype().name().getString()} pre ${pre.native.dtype().name().getString()}")
    val output = model.forward(input23.native,offsets.native,pre.native)
    println(s"output ${fromNative(output).shape}")

console log

C:\Users\hai71\.jdks\openjdk-23.0.2\bin\java.exe -Didea.test.cyclic.buffer.size=1048576 "-javaagent:C:\Program Files\JetBrains\IntelliJ IDEA Community Edition 2024.3.3\lib\idea_rt.jar=9708:C:\Program Files\JetBrains\IntelliJ IDEA Community Edition 2024.3.3\bin" -javaagent:C:\Users\hai71\AppData\Local\JetBrains\IdeaIC2024.3\captureAgent\debugger-agent.jar -Dkotlinx.coroutines.debug.enable.creation.stack.trace=false -Ddebugger.agent.enable.coroutines=true -Dkotlinx.coroutines.debug.enable.flows.stack.trace=true -Dkotlinx.coroutines.debug.enable.mutable.state.flows.stack.trace=true -Dfile.encoding=UTF-8 -Dsun.stdout.encoding=UTF-8 -Dsun.stderr.encoding=UTF-8 -classpath "C:\Program Files\JetBrains\IntelliJ IDEA Community Edition 2024.3.3\lib\idea_rt.jar;C:\Program Files\JetBrains\IntelliJ IDEA Community Edition 2024.3.3\plugins\junit\lib\junit-rt.jar;D:\data\storch_demo\target\scala-3.6.2\classes;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scala-lang\scala3-library_3\3.6.2\scala3-library_3-3.6.2.jar;C:\Users\hai71\.ivy2\local\dev.storch\core_3\0.1.9-2.4.3\jars\core_3.jar;C:\Users\hai71\.ivy2\local\dev.storch\vision_3\0.1.9-2.4.3\jars\vision_3.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scalameta\munit_3\0.7.29\munit_3-0.7.29.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scalameta\munit-scalacheck_3\0.7.29\munit-scalacheck_3-0.7.29.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scala-lang\scala-library\2.13.15\scala-library-2.13.15.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\javacpp\1.5.11\javacpp-1.5.11.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\javacpp\1.5.11\javacpp-1.5.11-windows-x86_64.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\pytorch\2.5.1-1.5.11\pytorch-2.5.1-1.5.11.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\pytorch\2.5.1-1.5.11\pytorch-2.5.1-1.5.11-windows-x86_64-gpu.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\openblas\0.3.28-1.5.11\openblas-0.3.28-1.5.11.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\openblas\0.3.28-1.5.11\openblas-0.3.28-1.5.11-windows-x86_64.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\cuda\12.6-9.5-1.5.11\cuda-12.6-9.5-1.5.11.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\cuda\12.6-9.5-1.5.11\cuda-12.6-9.5-1.5.11-windows-x86_64-redist.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\mkl\2025.0-1.5.11\mkl-2025.0-1.5.11.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\mkl\2025.0-1.5.11\mkl-2025.0-1.5.11-windows-x86_64.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\bytedeco\pytorch\2.5.1-1.5.11\pytorch-2.5.1-1.5.11-windows-x86_64.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\spire_3\0.18.0\spire_3-0.18.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\shapeless3-typeable_3\3.3.0\shapeless3-typeable_3-3.3.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\lihaoyi\os-lib_3\0.9.1\os-lib_3-0.9.1.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\lihaoyi\sourcecode_3\0.3.0\sourcecode_3-0.3.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\dev\dirs\directories\26\directories-26.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\sksamuel\scrimage\scrimage-core\4.3.0\scrimage-core-4.3.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\sksamuel\scrimage\scrimage-webp\4.3.0\scrimage-webp-4.3.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\sksamuel\scrimage\scrimage-scala_2.13\4.3.0\scrimage-scala_2.13-4.3.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scalameta\junit-interface\0.7.29\junit-interface-0.7.29.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\junit\junit\4.13.2\junit-4.13.2.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scalacheck\scalacheck_3\1.15.4\scalacheck_3-1.15.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\spire-macros_3\0.18.0\spire-macros_3-0.18.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\spire-platform_3\0.18.0\spire-platform_3-0.18.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\spire-util_3\0.18.0\spire-util_3-0.18.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\algebra_3\2.8.0\algebra_3-2.8.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\lihaoyi\geny_3\1.0.0\geny_3-1.0.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\imageio\imageio-core\3.9.4\imageio-core-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\imageio\imageio-jpeg\3.9.4\imageio-jpeg-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\drewnoakes\metadata-extractor\2.18.0\metadata-extractor-2.18.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\commons-io\commons-io\2.11.0\commons-io-2.11.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\ar\com\hjg\pngj\2.1.0\pngj-2.1.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\apache\commons\commons-lang3\3.12.0\commons-lang3-3.12.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\slf4j\slf4j-api\2.0.6\slf4j-api-2.0.6.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\scala-sbt\test-interface\1.0\test-interface-1.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\hamcrest\hamcrest-core\1.3\hamcrest-core-1.3.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\org\typelevel\cats-kernel_3\2.8.0\cats-kernel_3-2.8.0.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\common\common-lang\3.9.4\common-lang-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\common\common-io\3.9.4\common-io-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\common\common-image\3.9.4\common-image-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\twelvemonkeys\imageio\imageio-metadata\3.9.4\imageio-metadata-3.9.4.jar;C:\Users\hai71\AppData\Local\Coursier\cache\v1\https\repo1.maven.org\maven2\com\adobe\xmp\xmpcore\6.1.11\xmpcore-6.1.11.jar" com.intellij.rt.junit.JUnitStarter -ideVersion5 -junit4 @w@C:\Users\hai71\AppData\Local\Temp\idea_working_dirs_munit1.tmp @C:\Users\hai71\AppData\Local\Temp\idea_munit1.tmp
Testing started at 22:34 ...
input type int offsets int pre int

java.lang.RuntimeException: Expected tensor for argument #1 'weight' to have the same type as tensor for argument #1 'per_sample_weights'; but type CPUFloatType does not equal CPUIntType (while checking arguments for embedding_bag)
Exception raised from checkSameType at D:\a\javacpp-presets\javacpp-presets\pytorch\cppbuild\windows-x86_64-

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants