Skip to content

Commit

Permalink
reorganize pre and post processing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 9, 2024
1 parent 5f95fa9 commit f4e00b5
Showing 1 changed file with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.bioimage.modelrunner.tensor.Tensor;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;

/**
* Class that executes the pre- or post-processing associated to a given tensor.
Expand Down Expand Up @@ -132,15 +133,18 @@ List<Tensor<R>> preprocess(List<Tensor<T>> tensorList){
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<R>> preprocess(List<Tensor<T>> tensorList, boolean inplace) {
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
if (preMap.entrySet().size() == 0) return Cast.unchecked(tensorList);
for (Entry<String, List<TransformationInstance>> ee : this.preMap.entrySet()) {
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
if (tt == null)
continue;
ee.getValue().forEach(trans -> {
trans.run(tt, inplace);
});
for (TransformationInstance trans : ee.getValue()) {
List<Tensor<R>> outList = trans.run(tt, inplace);
outputs.addAll(outList);
}
}
return null;
return outputs;
}

/**
Expand Down Expand Up @@ -174,14 +178,17 @@ List<Tensor<R>> postprocess(List<Tensor<T>> tensorList){
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<R>> postprocess(List<Tensor<T>> tensorList, boolean inplace) {
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
if (postMap.entrySet().size() == 0) return Cast.unchecked(tensorList);
for (Entry<String, List<TransformationInstance>> ee : this.postMap.entrySet()) {
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
if (tt == null)
continue;
ee.getValue().forEach(trans -> {
trans.run(tt, inplace);
});
for (TransformationInstance trans : ee.getValue()) {
List<Tensor<R>> outList = trans.run(tt, inplace);
outputs.addAll(outList);
}
}
return null;
return outputs;
}
}

0 comments on commit f4e00b5

Please sign in to comment.