Skip to content

Commit

Permalink
add stepper
Browse files Browse the repository at this point in the history
  • Loading branch information
Pop John committed Jan 17, 2025
1 parent c93f448 commit 49032ca
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 85 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2024 Cisco Systems, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

import { AlgorithmType } from '../../../model-compression/models/enums/algorithms.enum';
import { PageKey } from '../enums/page-key.enum';

export interface CurrentRunningPageInfo {
page: PageKey;
algKey: string;
type: AlgorithmType | null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
// SPDX-License-Identifier: Apache-2.0

import { Injectable } from '@angular/core';
import { BehaviorSubject, Observable, filter, map, skip, switchMap, take, tap } from 'rxjs';
import { BehaviorSubject, Observable, filter, skip, switchMap, take, tap } from 'rxjs';
import { ScriptDetails } from '../../../services/client/models/script/script-details.interface-dto';
import { ScriptActions } from '../../../state/core/script';
import { AlgorithmType } from '../../model-compression/models/enums/algorithms.enum';
import { ScriptStatusEnum } from '../../model-compression/models/enums/script-status.enum';
import { isNilOrEmptyString } from '../../shared/shared.utils';
import { PageKey } from '../models/enums/page-key.enum';
import { CurrentRunningPageInfo } from '../models/interfaces/current-running-page-info.interface';
import { ScriptFacadeService } from './script-facade.service';

@Injectable()
export class PageRunningScriptSpiningIndicatorService {
private _currentRunningPage: BehaviorSubject<PageKey> = new BehaviorSubject<PageKey>(PageKey.NONE);
private _currentRunningPageInfo = new BehaviorSubject<CurrentRunningPageInfo>({
page: PageKey.NONE,
algKey: '',
type: null
});

get currentRunningPage$(): Observable<PageKey> {
return this._currentRunningPage.asObservable();
get currentRunningPageInfo$(): Observable<CurrentRunningPageInfo> {
return this._currentRunningPageInfo.asObservable();
}

constructor(private scriptFacadeService: ScriptFacadeService) {}
Expand All @@ -48,45 +53,50 @@ export class PageRunningScriptSpiningIndicatorService {
take(1),
filter((scriptDetails): scriptDetails is ScriptDetails => !isNilOrEmptyString(scriptDetails?.algKey))
)
),
map((scriptDetails: ScriptDetails) => scriptDetails.type)
)
)
.subscribe((type: AlgorithmType) => {
switch (type) {
.subscribe((scriptDetails: ScriptDetails) => {
let page: PageKey;

switch (scriptDetails.type) {
case AlgorithmType.PRUNING:
case AlgorithmType.QUANTIZATION:
this._currentRunningPage.next(PageKey.MODEL_COMPRESSION);
page = PageKey.MODEL_COMPRESSION;
break;
case AlgorithmType.MACHINE_UNLEARNING: {
this._currentRunningPage.next(PageKey.MACHINE_UNLEARNING);
case AlgorithmType.MACHINE_UNLEARNING:
page = PageKey.MACHINE_UNLEARNING;
break;
}
case AlgorithmType.AWQ: {
this._currentRunningPage.next(PageKey.AWQ);
case AlgorithmType.AWQ:
page = PageKey.AWQ;
break;
}
case AlgorithmType.TRAIN: {
this._currentRunningPage.next(PageKey.MODEL_TRAINING);
case AlgorithmType.TRAIN:
page = PageKey.MODEL_TRAINING;
break;
}
case AlgorithmType.MULTIFLOW: {
this._currentRunningPage.next(PageKey.MODEL_SPECIALIZATION);
case AlgorithmType.MULTIFLOW:
page = PageKey.MODEL_SPECIALIZATION;
break;
}
case AlgorithmType.DIFFUSION_MODEL: {
this._currentRunningPage.next(PageKey.DIFFUSION_MODEL);
case AlgorithmType.DIFFUSION_MODEL:
page = PageKey.DIFFUSION_MODEL;
break;
}
default: {
this._currentRunningPage.next(PageKey.NONE);
default:
page = PageKey.NONE;
break;
}
}

this._currentRunningPageInfo.next({
page,
algKey: scriptDetails.algKey ?? '',
type: scriptDetails.type ?? null
});
});

this.scriptFacadeService.scriptStatus$.subscribe((status: ScriptStatusEnum | null) => {
if (status !== ScriptStatusEnum.RUNNING && status !== ScriptStatusEnum.STOPPING) {
this._currentRunningPage.next(PageKey.NONE);
this._currentRunningPageInfo.next({
page: PageKey.NONE,
algKey: '',
type: null
});
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,83 @@
SPDX-License-Identifier: Apache-2.0
-->

<p class="heading-primary-title title">Diffusion Model</p>
<p class="heading-primary-title title">Diffusion Model (PTQ4DiT)</p>

<div class="page-wrapper" [formGroup]="form">
<div class="left">
<ng-container [formGroup]="form">
<mat-card class="ms-card algorithm-card">
<p class="heading-sub-section-title">Algorithm</p>
<div class="form-field-container" formGroupName="algorithm">
<mat-form-field appearance="outline" subscriptSizing="dynamic">
<mat-select formControlName="alg">
@for (algorithm of DIFFUSION_MODEL_ALGORITHMS_LIST; track algorithm.key) {
<mat-option [value]="algorithm.key">
{{ algorithm.value }}
</mat-option>
}
</mat-select>
</mat-form-field>
</div>
</mat-card>

<ms-panel-parameters #panelParameters controlKey="params" [algorithm]="selectedAlgorithm"></ms-panel-parameters>
</ng-container>

<div>
<button mat-raised-button color="primary" (click)="submit()" [disabled]="isScriptActive || form.invalid">
Run
</button>
</div>
<mat-stepper orientation="horizontal" class="ms-stepper">
<mat-step>
<ng-template matStepLabel>
<div class="step-custom-label">
<div class="step-name">Step 1: Calibration Data</div>
@let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if
(currentInfo?.algKey === DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET) {
<div class="step-loader"><ms-spining-indicator></ms-spining-indicator></div>
}
</div>
</ng-template>

<ng-template matStepContent>
<div class="panel-parameters">
<ms-panel-parameters
#panelParametersCalibrationSet
controlKey="params_calibration_set"
[algorithm]="DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET"></ms-panel-parameters>
</div>

<div class="mt-6 flex">
<div class="mr-2">
<button mat-stroked-button matStepperNext>Next</button>
</div>
<div>
<button
mat-raised-button
color="primary"
(click)="submit(DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET, 'params_calibration_set')"
[disabled]="isScriptActive || form.invalid">
Run
</button>
</div>
</div>
</ng-template>
</mat-step>

<mat-step>
<ng-template matStepLabel>
<div class="step-custom-label">
<div class="step-name">Step 2: Quantization</div>
@let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if
(currentInfo?.algKey === DiffusionModelAlgorithmsEnum.PTQ4DIT_QUANT_SAMPLE) {
<div class="step-loader"><ms-spining-indicator></ms-spining-indicator></div>
}
</div>
</ng-template>

<ng-template matStepContent>
<div class="panel-parameters">
<ms-panel-parameters
#panelParametersQuantSample
controlKey="params_quant_sample"
[algorithm]="DiffusionModelAlgorithmsEnum.PTQ4DIT_QUANT_SAMPLE"></ms-panel-parameters>
</div>

<div class="mt-6 flex">
<div class="mr-2">
<button mat-stroked-button matStepperPrevious>Previous</button>
</div>
<div>
<button
mat-raised-button
color="primary"
(click)="submit(DiffusionModelAlgorithmsEnum.PTQ4DIT_QUANT_SAMPLE, 'params_quant_sample')"
[disabled]="isScriptActive || form.invalid">
Run
</button>
</div>
</div>
</ng-template>
</mat-step>
</mat-stepper>
</div>
<div class="right">
<ms-terminal-xterm-with-toolbar></ms-terminal-xterm-with-toolbar>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,24 @@
.title {
margin-bottom: 13px;
}

.mat-stepper-horizontal {
background: none;

.mat-horizontal-content-container {
padding-bottom: 2px;
}
}

.step-custom-label {
display: flex;
justify-content: center;
align-items: center;
.step-loader {
margin-left: 10px;
}
}

.panel-parameters {
margin-top: 20px;
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import { ScriptFacadeService } from 'src/app/modules/core/services';
import { DiffusionModelAlgorithmsEnum } from 'src/app/modules/model-compression/models/enums/algorithms.enum';
import { isScriptActive } from 'src/app/modules/model-compression/models/enums/script-status.enum';
import { MsPanelParametersComponent } from 'src/app/modules/shared/components/ms-panel-parameters/ms-panel-parameters.component';
import { ScriptConfigsDto } from 'src/app/services/client/models/script/script-configs.interface-dto';
import { ScriptActions } from 'src/app/state/core/script';
import { ScriptArguments } from '../../../../services/client/models/script/script-arguments.interface-dto';
import { ScriptConfigsDto } from '../../../../services/client/models/script/script-configs.interface-dto';
import { ScriptActions } from '../../../../state/core/script';
import { PageRunningScriptSpiningIndicatorService } from '../../../core/services/page-running-script-spinning-indicator.service';

@UntilDestroy()
@Component({
Expand All @@ -33,11 +35,12 @@ import { ScriptActions } from 'src/app/state/core/script';
export class DifussionModelComponent {
readonly DiffusionModelAlgorithmsEnum: typeof DiffusionModelAlgorithmsEnum = DiffusionModelAlgorithmsEnum;

@ViewChild('panelParameters', { static: false }) panelParametersComponent!: MsPanelParametersComponent;
@ViewChild('panelParametersCalibrationSet', { static: false })
panelParametersCalibrationSet!: MsPanelParametersComponent;
@ViewChild('panelParametersQuantSample', { static: false }) panelParametersQuantSample!: MsPanelParametersComponent;

form!: FormGroup;
isScriptActive: boolean = false;
selectedAlgorithm = DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET;

DIFFUSION_MODEL_ALGORITHMS_LIST: { key: DiffusionModelAlgorithmsEnum; value: string }[] = [
{ key: DiffusionModelAlgorithmsEnum.PTQ4DIT_GET_CALIBRATION_SET, value: 'Get calibration set' },
Expand All @@ -46,7 +49,8 @@ export class DifussionModelComponent {

constructor(
private fb: FormBuilder,
private scriptFacadeService: ScriptFacadeService
private scriptFacadeService: ScriptFacadeService,
public pageRunningScriptSpiningIndicatorService: PageRunningScriptSpiningIndicatorService
) {}

ngOnInit() {
Expand All @@ -55,44 +59,34 @@ export class DifussionModelComponent {
}

private initForm() {
this.form = this.fb.group({
algorithm: this.fb.group({
alg: [this.selectedAlgorithm]
})
});

this.form
.get('algorithm.alg')
?.valueChanges.pipe(untilDestroyed(this))
.subscribe((value) => {
this.selectedAlgorithm = value;
});
this.form = this.fb.group({});
}

private listenToScriptStateChanges(): void {
this.scriptFacadeService.scriptStatus$.pipe(untilDestroyed(this)).subscribe((state) => {
this.isScriptActive = isScriptActive(state);

if (isScriptActive(state)) {
if (this.isScriptActive) {
this.form.disable();
} else {
this.form.enable();
}
});
}

submit() {
submit(algorithm: DiffusionModelAlgorithmsEnum, type: string) {
if (this.isScriptActive) {
return;
}

const { algorithm } = this.form.getRawValue();
const parameters: ScriptArguments =
type === 'params_calibration_set'
? this.panelParametersCalibrationSet.parametersFormatted
: this.panelParametersQuantSample.parametersFormatted;

const configs: ScriptConfigsDto = {
...algorithm,
params: {
...this.panelParametersComponent.parametersFormatted
}
alg: algorithm,
params: parameters
};

this.scriptFacadeService.dispatch(ScriptActions.callScript({ configs }));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import { FormsModule, ReactiveFormsModule } from '@angular/forms';
import { MatButtonModule } from '@angular/material/button';
import { MatCardModule } from '@angular/material/card';
import { MatSelectModule } from '@angular/material/select';
import { MatStepperModule } from '@angular/material/stepper';
import { MsPanelParametersComponent } from '../shared/components/ms-panel-parameters/ms-panel-parameters.component';
import { MsSpiningIndicatorComponent } from '../shared/components/ms-spining-indicator/ms-spining-indicator.component';
import { MsTerminalXtermWithToolbarComponent } from '../shared/components/ms-terminal/components/ms-terminal-xterm-with-toolbar/ms-terminal-xterm-with-toolbar.component';
import { DifussionModelComponent } from './components/difussion-model/difussion-model.component';
import { DiffusionModelRoutingModule } from './diffusion-model-routing.module';
Expand All @@ -38,7 +40,9 @@ import { DiffusionModelRoutingModule } from './diffusion-model-routing.module';
FormsModule,
ReactiveFormsModule,
MatSelectModule,
MatCardModule
MatCardModule,
MatStepperModule,
MsSpiningIndicatorComponent
]
})
export class DiffusionModelModule {}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
@if (isTrainModelsPageRouteVisible) {
<div class="train-models">
<a [routerLink]="['/' + RoutesList.MODEL_TRAINING.ROOT]">Train models</a>
@if (pageRunningScriptSpiningIndicatorService.currentRunningPage$ | async; as currentRunningPageKey) { @if
(currentRunningPageKey === PageKey.MODEL_TRAINING) {
@let currentInfo = pageRunningScriptSpiningIndicatorService.currentRunningPageInfo$ | async; @if
(currentInfo?.page === PageKey.MODEL_TRAINING) {
<ms-spining-indicator class="ml-2"></ms-spining-indicator>
} }
}
</div>
}
</div>
Expand Down
Loading

0 comments on commit 49032ca

Please sign in to comment.