diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java index 5ce0bd4..c44184a 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java @@ -100,6 +100,7 @@ public static > org.bytedeco.pytorch.Tensor build(RandomAccess */ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval tensor) { + long[] ogShape = tensor.dimensionsAsLongArray(); tensor = Utils.transpose(tensor); PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); @@ -110,7 +111,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; blocks.copy( new long[tensorShape.length], flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -124,6 +125,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI */ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval tensor) { + long[] ogShape = tensor.dimensionsAsLongArray(); tensor = Utils.transpose(tensor); PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); @@ -134,7 +136,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; blocks.copy( new long[tensorShape.length], flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -148,6 +150,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn */ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval tensor) { + long[] ogShape = tensor.dimensionsAsLongArray(); tensor = Utils.transpose(tensor); PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); @@ -158,7 +161,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; blocks.copy( new long[tensorShape.length], flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } @@ -172,6 +175,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible */ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval tensor) { + long[] ogShape = tensor.dimensionsAsLongArray(); tensor = Utils.transpose(tensor); PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); @@ -182,7 +186,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; blocks.copy( new long[tensorShape.length], flatArr, sArr ); - org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape); + org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape); return ndarray; } }