diff --git a/frontend/src/app/modules/core/models/interfaces/current-running-page-info.interface.ts b/frontend/src/app/modules/core/models/interfaces/current-running-page-info.interface.ts new file mode 100644 index 0000000..35f13c4 --- /dev/null +++ b/frontend/src/app/modules/core/models/interfaces/current-running-page-info.interface.ts @@ -0,0 +1,24 @@ +// Copyright 2024 Cisco Systems, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { AlgorithmType } from '../../../model-compression/models/enums/algorithms.enum'; +import { PageKey } from '../enums/page-key.enum'; + +export interface CurrentRunningPageInfo { + page: PageKey; + algKey: string; + type: AlgorithmType | null; +} diff --git a/frontend/src/app/modules/core/services/page-running-script-spinning-indicator.service.ts b/frontend/src/app/modules/core/services/page-running-script-spinning-indicator.service.ts index 7b8e97b..7ae9c69 100644 --- a/frontend/src/app/modules/core/services/page-running-script-spinning-indicator.service.ts +++ b/frontend/src/app/modules/core/services/page-running-script-spinning-indicator.service.ts @@ -15,21 +15,26 @@ // SPDX-License-Identifier: Apache-2.0 import { Injectable } from '@angular/core'; -import { BehaviorSubject, Observable, filter, map, skip, switchMap, take, tap } from 'rxjs'; +import { BehaviorSubject, Observable, filter, skip, switchMap, take, tap } from 'rxjs'; import { ScriptDetails } from '../../../services/client/models/script/script-details.interface-dto'; import { ScriptActions } from '../../../state/core/script'; import { AlgorithmType } from '../../model-compression/models/enums/algorithms.enum'; import { ScriptStatusEnum } from '../../model-compression/models/enums/script-status.enum'; import { isNilOrEmptyString } from '../../shared/shared.utils'; import { PageKey } from '../models/enums/page-key.enum'; +import { CurrentRunningPageInfo } from '../models/interfaces/current-running-page-info.interface'; import { ScriptFacadeService } from './script-facade.service'; @Injectable() export class PageRunningScriptSpiningIndicatorService { - private _currentRunningPage: BehaviorSubject = new BehaviorSubject(PageKey.NONE); + private _currentRunningPageInfo = new BehaviorSubject({ + page: PageKey.NONE, + algKey: '', + type: null + }); - get currentRunningPage$(): Observable { - return this._currentRunningPage.asObservable(); + get currentRunningPageInfo$(): Observable { + return this._currentRunningPageInfo.asObservable(); } constructor(private scriptFacadeService: ScriptFacadeService) {} @@ -48,45 +53,50 @@ export class PageRunningScriptSpiningIndicatorService { take(1), filter((scriptDetails): scriptDetails is ScriptDetails => !isNilOrEmptyString(scriptDetails?.algKey)) ) - ), - map((scriptDetails: ScriptDetails) => scriptDetails.type) + ) ) - .subscribe((type: AlgorithmType) => { - switch (type) { + .subscribe((scriptDetails: ScriptDetails) => { + let page: PageKey; + + switch (scriptDetails.type) { case AlgorithmType.PRUNING: case AlgorithmType.QUANTIZATION: - this._currentRunningPage.next(PageKey.MODEL_COMPRESSION); + page = PageKey.MODEL_COMPRESSION; break; - case AlgorithmType.MACHINE_UNLEARNING: { - this._currentRunningPage.next(PageKey.MACHINE_UNLEARNING); + case AlgorithmType.MACHINE_UNLEARNING: + page = PageKey.MACHINE_UNLEARNING; break; - } - case AlgorithmType.AWQ: { - this._currentRunningPage.next(PageKey.AWQ); + case AlgorithmType.AWQ: + page = PageKey.AWQ; break; - } - case AlgorithmType.TRAIN: { - this._currentRunningPage.next(PageKey.MODEL_TRAINING); + case AlgorithmType.TRAIN: + page = PageKey.MODEL_TRAINING; break; - } - case AlgorithmType.MULTIFLOW: { - this._currentRunningPage.next(PageKey.MODEL_SPECIALIZATION); + case AlgorithmType.MULTIFLOW: + page = PageKey.MODEL_SPECIALIZATION; break; - } - case AlgorithmType.DIFFUSION_MODEL: { - this._currentRunningPage.next(PageKey.DIFFUSION_MODEL); + case AlgorithmType.DIFFUSION_MODEL: + page = PageKey.DIFFUSION_MODEL; break; - } - default: { - this._currentRunningPage.next(PageKey.NONE); + default: + page = PageKey.NONE; break; - } } + + this._currentRunningPageInfo.next({ + page, + algKey: scriptDetails.algKey ?? '', + type: scriptDetails.type ?? null + }); }); this.scriptFacadeService.scriptStatus$.subscribe((status: ScriptStatusEnum | null) => { if (status !== ScriptStatusEnum.RUNNING && status !== ScriptStatusEnum.STOPPING) { - this._currentRunningPage.next(PageKey.NONE); + this._currentRunningPageInfo.next({ + page: PageKey.NONE, + algKey: '', + type: null + }); } }); } diff --git a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.html b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.html index 02e6694..cc34058 100644 --- a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.html +++ b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.html @@ -15,34 +15,83 @@ SPDX-License-Identifier: Apache-2.0 --> -

Diffusion Model

+

Diffusion Model (PTQ4DiT)

- - -

Algorithm

-
- - - @for (algorithm of DIFFUSION_MODEL_ALGORITHMS_LIST; track algorithm.key) { - - {{ algorithm.value }} - - } - - -
-
- - -
- -
- -
+ + + +
+
Step 1: Calibration Data
+ @let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if + (currentInfo?.algKey === DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET) { +
+ } +
+
+ + +
+ +
+ +
+
+ +
+
+ +
+
+
+
+ + + +
+
Step 2: Quantization
+ @let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if + (currentInfo?.algKey === DiffusionModelAlgorithmsEnum.PTQ4DIT_QUANT_SAMPLE) { +
+ } +
+
+ + +
+ +
+ +
+
+ +
+
+ +
+
+
+
+
diff --git a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.scss b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.scss index 8572532..6316706 100644 --- a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.scss +++ b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.scss @@ -17,3 +17,24 @@ .title { margin-bottom: 13px; } + +.mat-stepper-horizontal { + background: none; + + .mat-horizontal-content-container { + padding-bottom: 2px; + } +} + +.step-custom-label { + display: flex; + justify-content: center; + align-items: center; + .step-loader { + margin-left: 10px; + } +} + +.panel-parameters { + margin-top: 20px; +} diff --git a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.ts b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.ts index ce54b11..ea6a0d4 100644 --- a/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.ts +++ b/frontend/src/app/modules/diffusion-model/components/difussion-model/difussion-model.component.ts @@ -21,8 +21,10 @@ import { ScriptFacadeService } from 'src/app/modules/core/services'; import { DiffusionModelAlgorithmsEnum } from 'src/app/modules/model-compression/models/enums/algorithms.enum'; import { isScriptActive } from 'src/app/modules/model-compression/models/enums/script-status.enum'; import { MsPanelParametersComponent } from 'src/app/modules/shared/components/ms-panel-parameters/ms-panel-parameters.component'; -import { ScriptConfigsDto } from 'src/app/services/client/models/script/script-configs.interface-dto'; -import { ScriptActions } from 'src/app/state/core/script'; +import { ScriptArguments } from '../../../../services/client/models/script/script-arguments.interface-dto'; +import { ScriptConfigsDto } from '../../../../services/client/models/script/script-configs.interface-dto'; +import { ScriptActions } from '../../../../state/core/script'; +import { PageRunningScriptSpiningIndicatorService } from '../../../core/services/page-running-script-spinning-indicator.service'; @UntilDestroy() @Component({ @@ -33,11 +35,12 @@ import { ScriptActions } from 'src/app/state/core/script'; export class DifussionModelComponent { readonly DiffusionModelAlgorithmsEnum: typeof DiffusionModelAlgorithmsEnum = DiffusionModelAlgorithmsEnum; - @ViewChild('panelParameters', { static: false }) panelParametersComponent!: MsPanelParametersComponent; + @ViewChild('panelParametersCalibrationSet', { static: false }) + panelParametersCalibrationSet!: MsPanelParametersComponent; + @ViewChild('panelParametersQuantSample', { static: false }) panelParametersQuantSample!: MsPanelParametersComponent; form!: FormGroup; isScriptActive: boolean = false; - selectedAlgorithm = DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET; DIFFUSION_MODEL_ALGORITHMS_LIST: { key: DiffusionModelAlgorithmsEnum; value: string }[] = [ { key: DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET, value: 'Get calibration set' }, @@ -46,7 +49,8 @@ export class DifussionModelComponent { constructor( private fb: FormBuilder, - private scriptFacadeService: ScriptFacadeService + private scriptFacadeService: ScriptFacadeService, + public pageRunningScriptSpiningIndicatorService: PageRunningScriptSpiningIndicatorService ) {} ngOnInit() { @@ -55,25 +59,14 @@ export class DifussionModelComponent { } private initForm() { - this.form = this.fb.group({ - algorithm: this.fb.group({ - alg: [this.selectedAlgorithm] - }) - }); - - this.form - .get('algorithm.alg') - ?.valueChanges.pipe(untilDestroyed(this)) - .subscribe((value) => { - this.selectedAlgorithm = value; - }); + this.form = this.fb.group({}); } private listenToScriptStateChanges(): void { this.scriptFacadeService.scriptStatus$.pipe(untilDestroyed(this)).subscribe((state) => { this.isScriptActive = isScriptActive(state); - if (isScriptActive(state)) { + if (this.isScriptActive) { this.form.disable(); } else { this.form.enable(); @@ -81,18 +74,19 @@ export class DifussionModelComponent { }); } - submit() { + submit(algorithm: DiffusionModelAlgorithmsEnum, type: string) { if (this.isScriptActive) { return; } - const { algorithm } = this.form.getRawValue(); + const parameters: ScriptArguments = + type === 'params_calibration_set' + ? this.panelParametersCalibrationSet.parametersFormatted + : this.panelParametersQuantSample.parametersFormatted; const configs: ScriptConfigsDto = { - ...algorithm, - params: { - ...this.panelParametersComponent.parametersFormatted - } + alg: algorithm, + params: parameters }; this.scriptFacadeService.dispatch(ScriptActions.callScript({ configs })); diff --git a/frontend/src/app/modules/diffusion-model/diffusion-model.module.ts b/frontend/src/app/modules/diffusion-model/diffusion-model.module.ts index 49eb5a7..31cc983 100644 --- a/frontend/src/app/modules/diffusion-model/diffusion-model.module.ts +++ b/frontend/src/app/modules/diffusion-model/diffusion-model.module.ts @@ -20,7 +20,9 @@ import { FormsModule, ReactiveFormsModule } from '@angular/forms'; import { MatButtonModule } from '@angular/material/button'; import { MatCardModule } from '@angular/material/card'; import { MatSelectModule } from '@angular/material/select'; +import { MatStepperModule } from '@angular/material/stepper'; import { MsPanelParametersComponent } from '../shared/components/ms-panel-parameters/ms-panel-parameters.component'; +import { MsSpiningIndicatorComponent } from '../shared/components/ms-spining-indicator/ms-spining-indicator.component'; import { MsTerminalXtermWithToolbarComponent } from '../shared/components/ms-terminal/components/ms-terminal-xterm-with-toolbar/ms-terminal-xterm-with-toolbar.component'; import { DifussionModelComponent } from './components/difussion-model/difussion-model.component'; import { DiffusionModelRoutingModule } from './diffusion-model-routing.module'; @@ -38,7 +40,9 @@ import { DiffusionModelRoutingModule } from './diffusion-model-routing.module'; FormsModule, ReactiveFormsModule, MatSelectModule, - MatCardModule + MatCardModule, + MatStepperModule, + MsSpiningIndicatorComponent ] }) export class DiffusionModelModule {} diff --git a/frontend/src/app/modules/shared/components/ms-panel-model/ms-panel-model.component.html b/frontend/src/app/modules/shared/components/ms-panel-model/ms-panel-model.component.html index 3a535ac..3c78da6 100644 --- a/frontend/src/app/modules/shared/components/ms-panel-model/ms-panel-model.component.html +++ b/frontend/src/app/modules/shared/components/ms-panel-model/ms-panel-model.component.html @@ -20,10 +20,10 @@ @if (isTrainModelsPageRouteVisible) {
Train models - @if (pageRunningScriptSpiningIndicatorService.currentRunningPage$ | async; as currentRunningPageKey) { @if - (currentRunningPageKey === PageKey.MODEL_TRAINING) { + @let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if + (currentInfo?.page === PageKey.MODEL_TRAINING) { - } } + }
}
diff --git a/frontend/src/app/modules/shared/components/ms-sidenav/components/ms-sidenav-item/ms-sidenav-item.component.html b/frontend/src/app/modules/shared/components/ms-sidenav/components/ms-sidenav-item/ms-sidenav-item.component.html index 90697b6..3cdde29 100644 --- a/frontend/src/app/modules/shared/components/ms-sidenav/components/ms-sidenav-item/ms-sidenav-item.component.html +++ b/frontend/src/app/modules/shared/components/ms-sidenav/components/ms-sidenav-item/ms-sidenav-item.component.html @@ -37,11 +37,11 @@
- @if (pageRunningScriptSpiningIndicatorService.currentRunningPage$ | async; as currentRunningPageKey) { @if - (isExpanded && currentRunningPageKey === item.key) { + @let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if (isExpanded && + currentInfo?.page === item.key) {
- } } + }
diff --git a/frontend/src/app/modules/shared/components/ms-sidenav/ms-sidenav.component.ts b/frontend/src/app/modules/shared/components/ms-sidenav/ms-sidenav.component.ts index f1422d8..a407812 100644 --- a/frontend/src/app/modules/shared/components/ms-sidenav/ms-sidenav.component.ts +++ b/frontend/src/app/modules/shared/components/ms-sidenav/ms-sidenav.component.ts @@ -36,7 +36,7 @@ import { SidenavConstants } from './models/constants/sidenav.constants'; imports: [MsSidenavItemComponent, MatIconModule, MatDividerModule], animations: [ trigger('expandCollapse', [ - state('expanded', style({ width: '270px' })), + state('expanded', style({ width: '240px' })), state('collapsed', style({ width: '40px' })), transition('expanded <=> collapsed', animate('300ms ease-in-out')) ]) diff --git a/frontend/src/app/styles/theme/_ms-theme-common.scss b/frontend/src/app/styles/theme/_ms-theme-common.scss index 78bc081..f93e27c 100644 --- a/frontend/src/app/styles/theme/_ms-theme-common.scss +++ b/frontend/src/app/styles/theme/_ms-theme-common.scss @@ -129,6 +129,7 @@ $ms-typography: mat.m2-define-rem-typography-config( @include mat.progress-spinner-theme($theme); @include mat.slide-toggle-theme($theme); @include mat.expansion-theme($theme); + @include mat.stepper-theme($theme); @include mat.form-field-density(-4); @include mat.icon-button-density(-4); diff --git a/frontend/src/app/styles/theme/components/_ms-components.scss b/frontend/src/app/styles/theme/components/_ms-components.scss index 314d767..79c2d7a 100644 --- a/frontend/src/app/styles/theme/components/_ms-components.scss +++ b/frontend/src/app/styles/theme/components/_ms-components.scss @@ -28,3 +28,4 @@ @forward './ms-dialog'; @forward './ms-drawer'; @forward './ms-select'; +@forward './ms-stepper'; diff --git a/frontend/src/app/styles/theme/components/_ms-stepper.scss b/frontend/src/app/styles/theme/components/_ms-stepper.scss new file mode 100644 index 0000000..bdc6638 --- /dev/null +++ b/frontend/src/app/styles/theme/components/_ms-stepper.scss @@ -0,0 +1,30 @@ +// Copyright 2024 Cisco Systems, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +@use '../../base/typography'; +@use '@angular/material' as mat; + +:root { + .ms-stepper { + @include mat.stepper-overrides( + ( + header-selected-state-label-text-size: 0.875rem, + header-selected-state-label-text-weight: 500, + header-height: 60px + ) + ); + } +}