Skip to content

Commit

Permalink
Fix bug for run affable-shark model (#368)
Browse files Browse the repository at this point in the history
* fix bugs; add progress

* npm format
  • Loading branch information
Nanguage authored Dec 6, 2023
1 parent d405121 commit cd64674
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 60 deletions.
7 changes: 5 additions & 2 deletions src/components/ResourceItemInfo.vue
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@

<br />
<test-run-form
v-if="resourceItem.type == 'model' && modelAvailable"
v-if="resourceItem.type === 'model' && modelAvailable"
:resourceItem="resourceItem"
>
</test-run-form>
<div class="not-available" v-if="!modelAvailable">
<div
class="not-available"
v-if="resourceItem.type === 'model' && !modelAvailable"
>
This model is not available for testing.
</div>

Expand Down
114 changes: 84 additions & 30 deletions src/components/TestRun.vue
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,16 @@
<h3>Settings for image tiling</h3>
<div style="display: flex; gap: 30px">
<div style="width: 30%">
<b-field v-if="'x' in inputMinShape" label="Tile size(X)">
<b-field
v-if="'x' in inputMinShape && 'y' in inputMinShape"
label="Tile size(XY)"
>
<b-numberinput
v-model="tileSizes.x"
:min="inputMinShape.x"
:max="inputMaxShape.x"
></b-numberinput>
</b-field>
<b-field v-if="'y' in inputMinShape" label="Tile size(Y)">
<b-numberinput
v-model="tileSizes.y"
:min="inputMinShape.y"
:max="inputMaxShape.y"
></b-numberinput>
</b-field>
<b-field v-if="'z' in inputMinShape" label="Tile size(Z)">
<b-numberinput
v-model="tileSizes.z"
Expand All @@ -60,20 +56,16 @@
</b-field>
</div>
<div style="width: 30%">
<b-field v-if="'x' in inputMinShape" label="Tile overlap(X)">
<b-field
v-if="'x' in inputMinShape && 'y' in inputMinShape"
label="Tile overlap(XY)"
>
<b-numberinput
v-model="tileOverlap.x"
:min="0"
:max="inputMaxShape.x"
></b-numberinput>
</b-field>
<b-field v-if="'y' in inputMinShape" label="Tile overlap(Y)">
<b-numberinput
v-model="tileOverlap.y"
:min="0"
:max="inputMaxShape.y"
></b-numberinput>
</b-field>
<b-field v-if="'z' in inputMinShape" label="Tile overlap(Z)">
<b-numberinput
v-model="tileOverlap.z"
Expand Down Expand Up @@ -233,7 +225,16 @@ export default {
},
fixedTileSize() {
if (this.rdf) {
return this.rdf.inputs[0].shape instanceof Array;
const inputSpec = this.rdf.inputs[0];
const dims = this.tritonConfig.input[0]["dims"];
if (dims !== undefined && !dims.includes(-1)) {
return dims;
}
if (inputSpec.shape instanceof Array) {
return inputSpec.shape;
} else {
return false;
}
} else {
return false;
}
Expand All @@ -242,11 +243,10 @@ export default {
if (this.rdf) {
const axes = this.rdf.inputs[0].axes; // something like "zyx"
let minShape; // something like [16, 64, 64]
const shape = this.rdf.inputs[0].shape;
if (shape instanceof Array) {
minShape = shape;
if (this.fixedTileSize === false) {
minShape = this.rdf.inputs[0].shape.min;
} else {
minShape = shape.min;
minShape = this.fixedTileSize;
}
// return something like {x: 64, y: 64, z: 16}
const res = axes.split("").reduce((acc, cur, i) => {
Expand All @@ -262,12 +262,11 @@ export default {
if (this.rdf) {
const axes = this.rdf.inputs[0].axes; // something like "zyx"
let maxShape; // something like [16, 64, 64]
const shape = this.rdf.inputs[0].shape;
if (shape instanceof Array) {
maxShape = shape;
if (this.fixedTileSize !== false) {
maxShape = this.fixedTileSize;
} else {
// array of undefined
maxShape = shape.min.map(() => undefined);
maxShape = this.rdf.inputs[0].shape.min.map(() => undefined);
}
return axes.split("").reduce((acc, cur, i) => {
acc[cur] = maxShape[i];
Expand All @@ -278,13 +277,35 @@ export default {
}
}
},
watch: {
tileSizes: {
handler(oldObj, newObj) {
if (newObj.y !== newObj.x) {
this.tileSizes.y = newObj.x; // keep x and y the same
}
console.log(oldObj, newObj);
},
deep: true
},
tileOverlap: {
handler(oldObj, newObj) {
if (newObj.y !== newObj.x) {
this.tileOverlap.y = newObj.x; // keep x and y the same
}
console.log(oldObj, newObj);
},
deep: true
}
},
methods: {
async turnOn() {
this.switch = true;
this.setInfoPanel("Initializing...", true);
await this.loadImJoy();
await this.loadTritonClient();
await this.loadRdf();
await this.loadTritonConfig();
this.setDefaultTileSize();
this.setDefaultOverlap();
await this.detectInputEndianness();
Expand All @@ -308,10 +329,15 @@ export default {
setDefaultTileSize() {
const tileSizes = Object.assign({}, this.inputMinShape);
if (!this.fixedTileSize) {
const axes = this.rdf.inputs[0].axes;
if (this.fixedTileSize === false) {
const xyFactor = 4;
tileSizes.x = xyFactor * this.inputMinShape.x;
tileSizes.y = xyFactor * this.inputMinShape.y;
} else {
axes.split("").map((a, i) => {
tileSizes[a] = this.fixedTileSize[i];
});
}
this.tileSizes = tileSizes;
},
Expand All @@ -321,7 +347,7 @@ export default {
const outputSpec = this.rdf.outputs[0];
const axes = inputSpec.axes;
let overlap = {};
if (outputSpec.halo) {
if (outputSpec.halo && this.fixedTileSize === false) {
axes.split("").map((a, i) => {
if (outputSpec.axes.includes(a) && a !== "z") {
overlap[a] = 2 * outputSpec.halo[i];
Expand Down Expand Up @@ -386,12 +412,29 @@ export default {
let outImg = await this.submitTensor(paddedTensor);
await this.api.log("Output tile shape: " + outImg._rshape);
const outTensor = ImjoyToTfJs(outImg);
const cropedTensor = padder.crop(outTensor, padArr);
return cropedTensor;
const isImg2Img =
this.rdf.outputs[0].axes.includes("x") &&
this.rdf.outputs[0].axes.includes("y");
let result = outTensor;
if (isImg2Img) {
const cropedTensor = padder.crop(outTensor, padArr);
result = cropedTensor;
}
return result;
},
async runTiles(tensor, inputSpec, outputSpec) {
const padder = new ImgPadder(inputSpec, outputSpec, 0);
let padder;
if (this.fixedTileSize === false) {
padder = new ImgPadder(
undefined,
inputSpec.shape.min,
inputSpec.shape.step,
0
);
} else {
padder = new ImgPadder(this.fixedTileSize, undefined, undefined, 0);
}
const tileSize = inputSpec.axes.split("").map(a => this.tileSizes[a]);
const overlap = inputSpec.axes.split("").map(a => this.tileOverlap[a]);
console.log("tile size:", tileSize, "overlap:", overlap);
Expand All @@ -402,6 +445,10 @@ export default {
await this.api.log("Number of tiles: " + inTiles.length);
const outTiles = [];
for (let i = 0; i < inTiles.length; i++) {
this.setInfoPanel(
`Running the model... (${i + 1}/${inTiles.length})`,
true
);
const tile = inTiles[i];
console.log(tile);
tile.slice(tensor);
Expand Down Expand Up @@ -514,6 +561,13 @@ export default {
this.triton = await server.get_service("triton-client");
},
async loadTritonConfig() {
const nickname = this.resourceItem.nickname;
const url = `https://ai.imjoy.io/triton/v2/models/${nickname}/config`;
const config = await fetch(url).then(res => res.json());
this.tritonConfig = config;
},
async loadImJoy() {
function waitForImjoy(timeout = 10000) {
return new Promise((resolve, reject) => {
Expand Down
47 changes: 19 additions & 28 deletions src/imgProcess.js
Original file line number Diff line number Diff line change
Expand Up @@ -368,23 +368,23 @@ export async function getNpyEndianness(url) {
}

export class ImgPadder {
constructor(inputSpec, outputSpec, padValue = 0) {
this.inputSpec = inputSpec;
this.outputSpec = outputSpec;
constructor(fixedPaddedShape, padMin, padStep, padValue = 0) {
this.fixedPaddedShape = fixedPaddedShape;
this.padMin = padMin;
this.padStep = padStep;
this.padValue = padValue;
}

getPaddedShape(shape) {
const specShape = this.inputSpec.shape;
let paddedShape = [];
if (specShape instanceof Array) {
if (this.fixedPaddedShape) {
// Explicit shape
paddedShape = specShape;
paddedShape = this.fixedPaddedShape;
} else {
// Implicit shape
// infer from the min and step
const min = specShape.min;
const step = specShape.step;
const min = this.padMin;
const step = this.padStep;
for (let d = 0; d < shape.length; d++) {
if (step[d] === 0) {
paddedShape.push(shape[d]);
Expand Down Expand Up @@ -427,27 +427,18 @@ export class ImgPadder {

crop(tensor, pad, halo = undefined) {
let res;
const isImg2Img =
this.outputSpec.axes.includes("x") && this.outputSpec.axes.includes("y");
if (isImg2Img) {
// img-to-img model
if (halo) {
res = tf.slice(
tensor,
pad.map((p, i) => p[0] + halo[i]),
tensor.shape.map((s, i) => s - pad[i][0] - pad[i][1] - halo[i] * 2)
);
} else {
res = tf.slice(
tensor,
pad.map(p => p[0]),
tensor.shape.map((s, i) => s - pad[i][0] - pad[i][1])
);
}
if (halo) {
res = tf.slice(
tensor,
pad.map((p, i) => p[0] + halo[i]),
tensor.shape.map((s, i) => s - pad[i][0] - pad[i][1] - halo[i] * 2)
);
} else {
// other model, e.g. classification
// no crop
res = tensor;
res = tf.slice(
tensor,
pad.map(p => p[0]),
tensor.shape.map((s, i) => s - pad[i][0] - pad[i][1])
);
}
res._rdtype = tensor._rdtype;
return res;
Expand Down

0 comments on commit cd64674

Please sign in to comment.