From 4b3b65e3d8ff1eebd077661c20b0377e1b39ba87 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Fri, 27 Oct 2023 18:31:56 +0200 Subject: [PATCH] correct error creating tensors --- .../pytorch/javacpp/tensor/JavaCPPTensorBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java index c44184a..d964d85 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java @@ -110,7 +110,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -135,7 +135,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -160,7 +160,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -185,7 +185,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; }