Skip to content

Commit

Permalink
correct error creating tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 27, 2023
1 parent 57f3077 commit ac9f5c6
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -110,7 +111,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}

Expand All @@ -124,6 +125,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -134,7 +136,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}

Expand All @@ -148,6 +150,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -158,7 +161,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}

Expand All @@ -172,6 +175,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
*/
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -182,7 +186,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
return ndarray;
}
}

0 comments on commit ac9f5c6

Please sign in to comment.