Skip to content

Commit

Permalink
add concatenate operation 3d (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
syt123450 committed Sep 15, 2018
1 parent 81eaedb commit c2590c9
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/assets/image/Concatenate.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { CloseButton } from "../../elements/CloseButton";
import { LineGroupGeometry } from "../../elements/LineGroupGeometry";
import { BasicMaterialOpacity } from "../../utils/Constant";
import { MergeLineGroupController } from "./MergeLineGroupController";
import { MergeLineGroupController } from "./MergedLineGroupController";

function MergedLayer(config) {

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { MapTransitionFactory } from "../../animation/MapTransitionTween";
import { CloseButtonRatio } from "../../utils/Constant";
import { MergedAggregation } from "../../elements/MergedAggregation";
import { MergedFeatureMap } from "../../elements/MergedFeatureMap";
import {StrategyFactory} from "./strategy/StrategyFactory";
import {StrategyFactory} from "../merge/strategy/StrategyFactory";

function MergedLayer3d(config) {

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/layer/merge/Add.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {MergedLayer3d} from "./MergedLayer3d";
import {MergedLayer3d} from "../abstract/MergedLayer3d";

function Add(layerList) {

Expand Down
47 changes: 44 additions & 3 deletions src/layer/merge/Concatenate.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,52 @@
import { MergedLayer3d } from "../abstract/MergedLayer3d";

function Concatenate(layerList) {

let operatorType = "concatenate";

validate(layerList);

}
return createMergedLayer(layerList);

function validate(layerList) {

let depth;

if (layerList.length > 0) {
depth = layerList[0].layerDimension;
} else {
console.error("Merge Layer missing elements.");
}

for (let i = 0; i < layerList.length; i++) {

if (layerList[i].layerDimension !== depth) {
console.error("Can not add layer with different depth.");
}

}

Concatenate.prototype = {
}

};
function createMergedLayer(layerList) {

if (layerList[0].layerDimension === 1) {

} else if (layerList[0].layerDimension === 2) {

} else if (layerList[0].layerDimension === 3) {

return new MergedLayer3d({
operator: operatorType,
mergedElements: layerList
});

} else {
console.error("Do not support layer concatenate operation more than 4 dimension.");
}

}

}

export { Concatenate };
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
function AddStrategy3d(mergedElements) {
function Add3d(mergedElements) {

this.mergedElements = mergedElements;
this.layerIndex = undefined;

}

AddStrategy3d.prototype = {
Add3d.prototype = {

setLayerIndex: function(layerIndex) {
this.layerIndex = layerIndex;
},

validate: function() {

let inputShape;

if (this.mergedElements.length > 0) {
inputShape = this.mergedElements[0].outputShape;
} else {
console.error("Merge Layer missing elements.");
}
let inputShape = this.mergedElements[0].outputShape;

for (let i = 0; i < this.mergedElements.length; i++) {

Expand Down Expand Up @@ -124,4 +118,4 @@ AddStrategy3d.prototype = {

};

export { AddStrategy3d };
export { Add3d };
137 changes: 137 additions & 0 deletions src/layer/merge/strategy/Concatenate3d.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
function Concatenate3d(mergedElements) {

this.mergedElements = mergedElements;
this.layerIndex = undefined;

}

Concatenate3d.prototype = {

setLayerIndex: function(layerIndex) {
this.layerIndex = layerIndex;
},

validate: function() {

let inputShape = this.mergedElements[0].outputShape;

for (let i = 0; i < this.mergedElements.length; i++) {

let layerShape = this.mergedElements[i].outputShape;
if (layerShape[0] !== inputShape[0] || layerShape[1] !== inputShape[1]) {
return false;
}

}

return true;

},

getShape: function() {

let width = this.mergedElements[0].outputShape[0];
let height = this.mergedElements[0].outputShape[1];

let depth = 0;
for (let i = 0; i < this.mergedElements.length; i++) {

depth += this.mergedElements[i].outputShape[2];

}

return [width, height, depth];

},

getRelativeElements: function(selectedElement) {

let curveElements = [];
let straightElements = [];

if (selectedElement.elementType === "aggregationElement") {

let request = {
all: true
};

for (let i = 0; i < this.mergedElements.length; i++) {
let relativeResult = this.mergedElements[i].provideRelativeElements(request);
let relativeElements = relativeResult.elementList;
if (this.mergedElements[i].layerIndex === this.layerIndex - 1) {

for (let j = 0; j < relativeElements.length; j++) {
straightElements.push(relativeElements[j]);
}

} else {

if (relativeResult.isOpen) {
for (let j = 0; j < relativeElements.length; j++) {
straightElements.push(relativeElements[j]);
}
} else {
for (let j = 0; j < relativeElements.length; j++) {
curveElements.push(relativeElements[j]);
}
}

}
}

} else if (selectedElement.elementType === "featureMap") {

let fmIndex = selectedElement.fmIndex;

let relativeLayer;

for (let i = 0; i < this.mergedElements.length; i++) {

let layerDepth = this.mergedElements[i].outputShape[2];
if (layerDepth >= fmIndex) {
relativeLayer = this.mergedElements[i];
break;
} else {
fmIndex -= layerDepth;
}

}

let request = {
index: fmIndex
};

let relativeResult = relativeLayer.provideRelativeElements(request);
let relativeElements = relativeResult.elementList;
if (relativeLayer.layerIndex === this.layerIndex - 1) {

for (let i = 0; i < relativeElements.length; i++) {
straightElements.push(relativeElements[i]);
}

} else {

if (relativeResult.isOpen) {
for (let i = 0; i < relativeElements.length; i++) {
straightElements.push(relativeElements[i]);
}
} else {
for (let i = 0; i < relativeElements.length; i++) {
curveElements.push(relativeElements[i]);
}
}

}

}

return {
straight: straightElements,
curve: curveElements
};

}

};

export { Concatenate3d };
Empty file.
7 changes: 5 additions & 2 deletions src/layer/merge/strategy/StrategyFactory.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {AddStrategy3d} from "./AddStrategy3d";
import {Add3d} from "./Add3d";
import {Concatenate3d} from "./Concatenate3d";

let StrategyFactory = (function() {

Expand All @@ -7,7 +8,9 @@ let StrategyFactory = (function() {
if (dimension === 3) {

if (operator === "add") {
return new AddStrategy3d(mergedElements);
return new Add3d(mergedElements);
} else if (operator === "concatenate") {
return new Concatenate3d(mergedElements);
}

}
Expand Down
3 changes: 2 additions & 1 deletion src/tensorspace.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import { PixelReshape } from "./layer/pixel/PixelReshape";
import { PixelOutput } from "./layer/pixel/PixelOutput";

import { Add } from "./layer/merge/Add";
import { Concatenate } from "./layer/merge/Concatenate";

let layers = {
Input1d: Input1d,
Expand Down Expand Up @@ -84,4 +85,4 @@ let model = {
PixelSequential: PixelSequential
};

export {model, layers, Add};
export {model, layers, Add, Concatenate};
4 changes: 2 additions & 2 deletions test/test.html
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@

let layer2 = new TSP.layers.Conv2d({
kernelSize: 2,
filters: 3,
filters: 5,
strides: 1,
padding: "same"
});

let addLayer = TSP.Add([layer1, layer2]);
let addLayer = TSP.Concatenate([layer1, layer2]);

// console.log(addLayer);
//
Expand Down

0 comments on commit c2590c9

Please sign in to comment.