forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
moe.py
836 lines (724 loc) · 34.6 KB
/
moe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass
from enum import IntEnum
from typing import List, Optional, Type, Union
import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
str_dtype_to_trt)
from tensorrt_llm.layers.lora import LoraParams
from .._common import default_net, default_trtnet
from .._utils import int32_array
from ..functional import (AllReduceFusionParams, _add_plugin_info,
_create_tensor, allreduce, cast, concat, constant,
div, expand, gather_nd, is_gated_activation,
non_gated_version, nonzero, repeat_interleave,
scatter_nd, shape, softmax, split, sum, topk)
from ..layers import MLP, GatedMLP
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..parameter import Parameter
from ..plugin import TRT_LLM_PLUGIN_NAMESPACE
from ..quantization import QuantMode
from ..quantization.functional import quantize
from .linear import RowLinear
activation_str_to_int_map = {
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
"gelu": 0,
"gelu_new": 0,
"relu": 1,
"silu": 2,
"swiglu": 3,
"geglu": 4,
"identity": 5,
}
@dataclass
class MoeConfig:
class ExpertScaleNormalizationMode(IntEnum):
NONE = 0
RENORMALIZE = 1
num_experts: int = 0
top_k: int = 0
normalization_mode: ExpertScaleNormalizationMode = ExpertScaleNormalizationMode.RENORMALIZE
tp_mode: int = 0
def validate(self) -> "MoeConfig":
if (self.num_experts == 0) != (self.top_k == 0):
raise ValueError(
"Both or neither MoeConfig's num_experts and top_k must be set to 0"
)
return self
def has_moe(self) -> bool:
return self.num_experts > 1
@classmethod
def from_dict(cls, config: dict):
return cls(**config)
def to_dict(self):
return asdict(self)
def _moe_plugin(moe_config,
hidden_states,
routing,
finished,
expert_weights_1,
expert_weights_2,
expert_bias_1,
expert_bias_2,
expert_scale_1,
expert_scale_2,
expert_scale_3,
expert_scale_4,
hidden_size,
ffn_hidden_size,
act_fn,
dtype,
weight_dtype,
output_dtype,
lora_params: LoraParams,
lora_max_low_rank,
quant_mode=QuantMode(0),
tp_size=1,
ep_size=1,
tp_rank=0,
ep_rank=0):
if isinstance(dtype, str):
dtype = str_dtype_to_trt(dtype)
if isinstance(weight_dtype, str):
weight_dtype = str_dtype_to_trt(weight_dtype)
if isinstance(output_dtype, str):
output_dtype = str_dtype_to_trt(output_dtype)
def from_parameter(x):
if isinstance(x, Parameter):
return x.value
return x
expert_weights_1 = from_parameter(expert_weights_1)
expert_weights_2 = from_parameter(expert_weights_2)
expert_bias_1 = from_parameter(expert_bias_1)
expert_bias_2 = from_parameter(expert_bias_2)
expert_scale_1 = from_parameter(expert_scale_1)
expert_scale_2 = from_parameter(expert_scale_2)
expert_scale_3 = from_parameter(expert_scale_3)
expert_scale_4 = from_parameter(expert_scale_4)
# Create the plugin with our required state
num_experts = moe_config.num_experts
p_remove_input_padding = trt.PluginField(
"remove_input_padding",
np.array(np.int32(default_net().plugin_config.remove_input_padding),
dtype=np.int32), trt.PluginFieldType.INT32)
# We pass the full number of experts (not divided by ep_size) even for EP mode
p_num_experts = trt.PluginField("number_of_experts",
np.array(num_experts, dtype=np.int32),
trt.PluginFieldType.INT32)
p_top_k = trt.PluginField("top_k", np.array(moe_config.top_k,
dtype=np.int32),
trt.PluginFieldType.INT32)
p_expert_hidden_size = trt.PluginField(
"expert_hidden_size", np.array(hidden_size, dtype=np.int32),
trt.PluginFieldType.INT32)
p_expert_inter_size = trt.PluginField(
"expert_inter_size", np.array(ffn_hidden_size, dtype=np.int32),
trt.PluginFieldType.INT32)
p_activation_type = trt.PluginField(
"activation_type",
np.array(activation_str_to_int_map[act_fn], dtype=np.int32),
trt.PluginFieldType.INT32)
p_type_id = trt.PluginField("type_id", np.array([int(dtype)],
dtype=np.int32),
trt.PluginFieldType.INT32)
p_weight_type_id = trt.PluginField(
"weight_type_id", np.array([int(weight_dtype)], dtype=np.int32),
trt.PluginFieldType.INT32)
p_output_type_id = trt.PluginField(
"output_type_id", np.array([int(output_dtype)], dtype=np.int32),
trt.PluginFieldType.INT32)
p_quant_mode = trt.PluginField("quant_mode",
np.array([int(quant_mode)], dtype=np.int32),
trt.PluginFieldType.INT32)
p_use_finished = trt.PluginField(
"use_finished", np.array([int(finished is not None)], dtype=np.int32),
trt.PluginFieldType.INT32)
p_use_bias = trt.PluginField(
"use_bias", np.array([int(expert_bias_1 is not None)], dtype=np.int32),
trt.PluginFieldType.INT32)
p_tp_size = trt.PluginField("tp_size", np.array(tp_size, dtype=np.int32),
trt.PluginFieldType.INT32)
p_tp_rank = trt.PluginField("tp_rank", np.array(tp_rank, dtype=np.int32),
trt.PluginFieldType.INT32)
p_ep_size = trt.PluginField("ep_size", np.array(ep_size, dtype=np.int32),
trt.PluginFieldType.INT32)
p_ep_rank = trt.PluginField("ep_rank", np.array(ep_rank, dtype=np.int32),
trt.PluginFieldType.INT32)
p_normalization_mode = trt.PluginField(
"normalization_mode",
np.array(moe_config.normalization_mode, dtype=np.int32),
trt.PluginFieldType.INT32)
p_force_determinism = trt.PluginField(
"force_determinism", np.array([int(False)], dtype=np.int32),
trt.PluginFieldType.INT32)
use_lora = default_net().plugin_config.lora_plugin is not None
p_use_lora = trt.PluginField("use_lora", np.array([int(use_lora)],
np.int32),
trt.PluginFieldType.INT32)
if use_lora:
p_lora_type_id = trt.PluginField(
"lora_type_id",
np.array([
int(str_dtype_to_trt(default_net().plugin_config.lora_plugin))
], np.int32), trt.PluginFieldType.INT32)
p_max_low_rank = trt.PluginField(
"max_low_rank", np.array(lora_max_low_rank, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc_inputs = [
p_remove_input_padding, p_num_experts, p_top_k, p_expert_hidden_size,
p_expert_inter_size, p_activation_type, p_type_id, p_weight_type_id,
p_output_type_id, p_quant_mode, p_use_finished, p_use_bias, p_tp_size,
p_tp_rank, p_ep_size, p_ep_rank, p_normalization_mode,
p_force_determinism, p_use_lora
]
if use_lora:
pfc_inputs += [p_lora_type_id, p_max_low_rank]
pfc = trt.PluginFieldCollection(pfc_inputs)
# Create the plugin with our constant inputs to the constructor
plugin_creator = trt.get_plugin_registry().get_plugin_creator(
'MixtureOfExperts', '1', TRT_LLM_PLUGIN_NAMESPACE)
assert plugin_creator is not None
moe_plugin = plugin_creator.create_plugin("mixture_of_experts", pfc)
# Instantiate the plugin with our specific inputs
plugin_inputs = [hidden_states, routing, expert_weights_1, expert_weights_2]
if expert_bias_1:
assert expert_bias_2
plugin_inputs += [expert_bias_1, expert_bias_2]
if finished is not None:
plugin_inputs += [finished]
# Add conditional inputs
if quant_mode.is_weight_only() or quant_mode.has_fp8_qdq():
assert expert_scale_1
assert expert_scale_2
plugin_inputs += [expert_scale_1, expert_scale_2]
# Add conditional inputs
if quant_mode.has_fp8_qdq():
assert expert_scale_3
plugin_inputs += [expert_scale_3]
if expert_scale_4 is not None:
assert quant_mode.has_fp8_qdq()
assert output_dtype == trt.fp8
plugin_inputs += [expert_scale_4]
if use_lora:
moe_h_4h_weight_ptrs = lora_params.get_runtime_params(
0, "moe_h_to_4h").lora_weights_pointers
moe_h_4h_lora_ranks = lora_params.get_runtime_params(
0, "moe_h_to_4h").lora_ranks
plugin_inputs += (moe_h_4h_weight_ptrs + moe_h_4h_lora_ranks)
moe_4h_h_weight_ptrs = lora_params.get_runtime_params(
0, "moe_4h_to_h").lora_weights_pointers
moe_4h_h_lora_ranks = lora_params.get_runtime_params(
0, "moe_4h_to_h").lora_ranks
plugin_inputs += (moe_4h_h_weight_ptrs + moe_4h_h_lora_ranks)
moe_gate_weight_ptrs = None
moe_gate_lora_ranks = None
if is_gated_activation(act_fn):
moe_gate_weight_ptrs = lora_params.get_runtime_params(
0, "moe_gate").lora_weights_pointers
moe_gate_lora_ranks = lora_params.get_runtime_params(
0, "moe_gate").lora_ranks
plugin_inputs += (moe_gate_weight_ptrs + moe_gate_lora_ranks)
host_request_types = lora_params.host_request_types
plugin_inputs += [host_request_types]
if default_net().plugin_config.remove_input_padding:
plugin_inputs += [lora_params.host_context_lengths]
plugin_inputs = [i.trt_tensor for i in plugin_inputs]
layer = default_trtnet().add_plugin_v2(plugin_inputs, moe_plugin)
_add_plugin_info(layer, plugin_creator, "mixture_of_experts", pfc)
if not default_net().strongly_typed:
for ii in range(layer.num_inputs):
if layer.get_input(ii).dtype == str_dtype_to_trt("int8"):
layer.get_input(ii).set_dynamic_range(-127, 127)
output = _create_tensor(layer.get_output(0), layer)
return output
# This exists so that MOE can have the same name format as a regular MLP, just with different shaped weight tensors
class MOEWeightWrapper(Module):
def __init__(self, in_features: int, out_features: int,
experts_per_node: int, quant_mode: QuantMode,
dtype: Union[str,
trt.DataType], weight_dtype: Union[str,
trt.DataType],
has_bias: bool, wrapper_tllm_to_externel_key_dict: dict,
tp_size: int, tp_dim: int):
super().__init__()
self.quant_mode = quant_mode
self.expert_shape = (experts_per_node, out_features, in_features)
self.dtype = dtype
self.weight_dtype = weight_dtype
self.has_bias = has_bias
self.tllm_to_externel_key_dict = wrapper_tllm_to_externel_key_dict
self.tp_size = tp_size
self.tp_dim = tp_dim
if quant_mode.is_weight_only():
bytes_per_col_scale = 2 if quant_mode.is_int4_weight_only() else 1
# We use a different shape here because the quantized weights have their own layout
self.expert_shape = (experts_per_node, in_features,
out_features // bytes_per_col_scale)
self.per_channel_scale = Parameter(shape=(experts_per_node,
out_features),
dtype=dtype)
else:
self.register_parameter('per_channel_scale', None)
self.weight = Parameter(shape=self.expert_shape,
dtype=weight_dtype,
prefer_managed=True)
if has_bias:
self.bias = Parameter(shape=(experts_per_node, out_features),
dtype=dtype)
else:
self.register_parameter('bias', None)
if quant_mode.has_fp8_qdq():
self.activation_scaling_factor = Parameter(shape=(1, ),
dtype=trt.float32)
self.weights_scaling_factor = Parameter(shape=(experts_per_node, 1),
dtype=trt.float32)
else:
self.register_parameter('activation_scaling_factor', None)
self.register_parameter('weights_scaling_factor', None)
def postprocess(self, tllm_key, weights, **kwargs):
if tllm_key.endswith("weight"):
if isinstance(weights, torch.Tensor):
weights = [weights]
if "fc" in tllm_key:
weights = torch.cat([
torch.stack(weights[:len(weights) // 2]),
torch.stack(weights[len(weights) // 2:])
],
dim=-2)
elif "proj" in tllm_key:
weights = torch.stack(weights)
weights = weights.to(str_dtype_to_torch(self.dtype))
if not self.quant_mode.has_any_quant():
return weights
elif self.quant_mode.is_weight_only():
if "per_channel_scale" in tllm_key:
return {}
if weights.dim() > 2:
v = weights.transpose(-1, -2)
else:
v = weights.t()
amax = v.abs().max(dim=-2)[0].to(v.dtype)
if self.quant_mode.is_int8_weight_only():
scale = amax / 128.
qweight = torch.clamp((v / scale.unsqueeze(1)).round(), -128,
127).char()
else:
scale = amax / 8.
qweight = torch.clamp((v / scale.unsqueeze(1)).round(), -8,
7).char()
qweight[qweight < 0] += 16
qweight = qweight.view(torch.uint8)
qweight = (qweight[:, :, 1::2] * 16 + qweight[:, :, ::2]).view(
torch.int8)
qweight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm(
qweight.contiguous(), torch.int8
if self.quant_mode.is_int8_weight_only() else torch.quint4x2,
torch.float16)
return {
tllm_key: qweight,
tllm_key.replace("weight", "per_channel_scale"): scale,
}
elif self.quant_mode.has_fp8_qdq():
if tllm_key.endswith("activation_scaling_factor"):
return 448.0 / weights
elif tllm_key.endswith("weights_scaling_factor"):
return 448.0 / weights
else:
return weights
class MixtureOfExperts(Module):
def __init__(self,
moe_config: MoeConfig,
hidden_size: int,
ffn_hidden_size: int,
hidden_act: str,
mapping: Mapping = Mapping(),
bias: bool = True,
dtype=None,
tp_group: List[int] = None,
tp_size: int = 1,
quant_mode=QuantMode(0)):
super().__init__()
self.moe_config = moe_config
self.num_experts = moe_config.num_experts
self.top_k = moe_config.top_k
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.expert_inter_size = ffn_hidden_size
self.dtype = dtype
self.weight_dtype = dtype
self.tp_group = tp_group
self.tp_size = tp_size
self.mapping = mapping
self.quant_mode = quant_mode
self.bias = bias
self.experts_per_node = self.num_experts
if self.mapping.has_moe_ep():
if self.num_experts % self.mapping.moe_ep_size != 0:
raise ValueError(
f"MixtureOfExperts - Number of experts {self.num_experts} is not a multiple of EP size {self.mapping.moe_ep_size}"
)
self.experts_per_node = self.experts_per_node // self.mapping.moe_ep_size
if self.mapping.has_moe_tp():
if self.ffn_hidden_size % self.mapping.moe_tp_size != 0:
raise ValueError(
f"MixtureOfExperts - FFN Hidden Size {self.ffn_hidden_size} is not a multiple of TP size {self.mapping.moe_tp_size}"
)
self.expert_inter_size = self.ffn_hidden_size // self.mapping.moe_tp_size
if quant_mode.has_fp8_qdq() and self.bias:
# TODO (dastokes) We will need to revisit this if we have a use case for it
raise ValueError(
f"MixtureOfExperts - Bias is not supported with FP8")
if quant_mode.is_weight_only():
self.weight_dtype = trt.int8
elif quant_mode.has_fp8_qdq():
self.weight_dtype = trt.fp8
rank_experts = self.mapping.ep_experts(self.num_experts)
self.wrapper_tllm_to_externel_key_dict = {
"mlp":
"block_sparse_moe",
"proj": [f"experts.{expert}.w2" for expert in rank_experts],
"fc": [f"experts.{expert}.w3" for expert in rank_experts] +
[f"experts.{expert}.w1" for expert in rank_experts]
}
# Since output dimension is usually low (in the order of 10s), no TP at
# all is more efficient as no allreduce required in the end.
# Note that if we see models that have large number of experts, we may
# need to consider add TP back here.
# TODO: Arctic has large # experts, we may need to add TP back here.
self.router = RowLinear(
hidden_size,
self.num_experts,
bias=False,
dtype=trt.
float32, # Routing is sensitive since it conditions what experts are used
tp_group=None,
tp_size=1,
strict_dtype=True)
self.router.tllm_to_externel_key_dict = {
"mlp": "block_sparse_moe",
"router": "gate"
}
self.init_experts()
self.max_low_rank = None
def init_experts(self):
# Note we use horizontal fusion for gated activation to do the operation in one GEMM invocation
# The left matrix is a linear projection (no activation applied)
# The right matrix is the gating value (activation applied)
# The naming convention is the inverse of GatedMLP, but the same as `tensorrt_llm/functional.py`
fc_out_size = self.expert_inter_size * 2 if is_gated_activation(
self.hidden_act) else self.expert_inter_size
self.fc = MOEWeightWrapper(self.hidden_size, fc_out_size,
self.experts_per_node, self.quant_mode,
self.dtype, self.weight_dtype, self.bias,
self.wrapper_tllm_to_externel_key_dict,
self.mapping.moe_tp_size, 0)
self.proj = MOEWeightWrapper(self.expert_inter_size, self.hidden_size,
self.experts_per_node, self.quant_mode,
self.dtype, self.weight_dtype, self.bias,
self.wrapper_tllm_to_externel_key_dict,
self.mapping.moe_tp_size, 1)
def forward(self,
hidden_states,
finished=None,
lora_layer_params=None,
reduce_fusion_params: Optional[AllReduceFusionParams] = None):
moe_router_lora_params = None
if lora_layer_params is not None:
moe_router_lora_params = lora_layer_params.get_runtime_params(
0, "moe_router")
routing_input = cast(hidden_states, trt.float32)
routing = self.router(routing_input, moe_router_lora_params)
return self.forward_experts(hidden_states, routing, finished,
lora_layer_params, reduce_fusion_params)
def forward_experts(self, hidden_states, routing, finished,
lora_layer_params,
reduce_fusion_params: Optional[AllReduceFusionParams]):
if self.quant_mode.has_fp8_qdq():
assert self.fc.weight.value.dtype == trt.fp8, (
"mlp fc weight dtype should be fp8 in the fp8 quantization mode."
)
assert self.proj.weight.value.dtype == trt.fp8, (
"mlp proj weight dtype should be fp8 in the fp8 quantization mode."
)
hidden_states_quant = hidden_states
if hidden_states_quant.dtype != trt.fp8:
hidden_states_quant = quantize(
hidden_states, self.fc.activation_scaling_factor.value,
'fp8')
dtype_quant = trt.fp8
weight_dtype_quant = trt.fp8
fc1_dequant = self.fc.weights_scaling_factor.value * self.fc.activation_scaling_factor.value
fc2_quant = div(1.0, self.proj.activation_scaling_factor.value)
fc2_dequant = self.proj.weights_scaling_factor.value * self.proj.activation_scaling_factor.value
scale_1 = fc1_dequant
scale_2 = fc2_quant
scale_3 = fc2_dequant
scale_4 = None
output_dtype_quant = self.dtype
if output_dtype_quant == trt.fp8 and scale_4 is None:
raise RuntimeError(
"Cannot output FP8 value without knowing quantization parameter"
)
else:
hidden_states_quant = hidden_states
dtype_quant = self.dtype
weight_dtype_quant = self.weight_dtype
output_dtype_quant = self.dtype
scale_1 = self.fc.per_channel_scale
scale_2 = self.proj.per_channel_scale
scale_3 = None
scale_4 = None
output = _moe_plugin(self.moe_config,
hidden_states_quant,
routing,
expert_weights_1=self.fc.weight.value,
expert_weights_2=self.proj.weight.value,
expert_bias_1=self.fc.bias,
expert_bias_2=self.proj.bias,
expert_scale_1=scale_1,
expert_scale_2=scale_2,
expert_scale_3=scale_3,
expert_scale_4=scale_4,
finished=finished,
hidden_size=self.hidden_size,
ffn_hidden_size=self.expert_inter_size,
act_fn=self.hidden_act,
dtype=dtype_quant,
weight_dtype=weight_dtype_quant,
output_dtype=output_dtype_quant,
lora_params=lora_layer_params,
lora_max_low_rank=self.max_low_rank,
quant_mode=self.quant_mode,
tp_size=self.mapping.moe_tp_size,
tp_rank=self.mapping.moe_tp_rank,
ep_size=self.mapping.moe_ep_size,
ep_rank=self.mapping.moe_ep_rank)
if self.tp_size > 1 and self.tp_group is not None:
output = allreduce(output,
self.tp_group,
reduce_fusion_params=reduce_fusion_params)
return output
def load_weights(self, moe: "MixtureOfExperts"):
'''
Load weights from base MOE layer
'''
raise NotImplementedError("Subclass shall override this")
def to(self,
moe_cls: Type["MixtureOfExperts"],
quant_config=None) -> "MixtureOfExperts":
from ..quantization.quantize import quantize
if isinstance(self, moe_cls):
return self
new_moe = moe_cls(**get_init_params(self))
# If config is not None, set quantization from config
if quant_config is not None:
quantize(new_moe, quant_config)
new_moe.load_weights(self)
new_moe.router = self.router
return new_moe
MOE = MixtureOfExperts
class MoeOOTB(MOE):
def init_experts(self):
if self.quant_mode.is_weight_only():
raise ValueError(
f"OOTB MOE does not support weight only quantization now, current quant mode: {self.quant_mode}"
)
ClsMLP = GatedMLP if is_gated_activation(self.hidden_act) else MLP
tp_size = 1
tp_group = None
self.experts = ModuleList([
ClsMLP(self.hidden_size, self.expert_inter_size,
non_gated_version(self.hidden_act), self.bias, self.dtype,
tp_group, tp_size, self.quant_mode)
for _ in range(self.experts_per_node)
])
def moe_to_expert_lora_params(self, lora_layer_params, expert_idx):
def get_params(module):
ranks = lora_layer_params.get_runtime_params(0,
module).lora_ranks[0]
weights_pointers = lora_layer_params.get_runtime_params(
0, module).lora_weights_pointers[0]
return ranks, weights_pointers
if lora_layer_params is None:
return None
fc_lora_ranks, fc_lora_weights_pointers = get_params("moe_h_to_4h")
proj_lora_ranks, proj_lora_weights_pointers = get_params("moe_4h_to_h")
gate_lora_ranks = None
gate_lora_weights_pointers = None
if is_gated_activation(self.hidden_act):
gate_lora_ranks, gate_lora_weights_pointers = get_params("moe_gate")
return LoraParams(
lora_ranks=[{
"mlp_h_to_4h_lora_ranks": fc_lora_ranks,
"mlp_4h_to_h_lora_ranks": proj_lora_ranks,
"mlp_gate_lora_ranks": gate_lora_ranks,
}],
lora_weights_pointers=[{
"mlp_h_to_4h_lora_weights_pointers":
fc_lora_weights_pointers,
"mlp_4h_to_h_lora_weights_pointers":
proj_lora_weights_pointers,
"mlp_gate_lora_weights_pointers":
gate_lora_weights_pointers,
}],
host_context_lengths=lora_layer_params.host_context_lengths,
max_encoder_context_length=lora_layer_params.
max_encoder_context_length,
host_request_types=lora_layer_params.host_request_types,
host_encoder_input_lengths=lora_layer_params.
host_encoder_input_lengths,
weight_index=expert_idx,
)
def forward_experts(self, hidden_states, routing, finished,
lora_layer_params,
reduce_fusion_params: Optional[AllReduceFusionParams]):
# TODO: https://nvbugspro.nvidia.com/bug/4781396 after this nvbug is fixed, we will remove this check.
if lora_layer_params is not None:
for module in ["mlp_h_to_4h", "mlp_4h_to_h", "mlp_gate"]:
if lora_layer_params.get_runtime_params(0, module) is not None:
raise RuntimeError(
f"MoE OOTB does not support {module} LoRA module, please enable MoE plugin"
)
if self.moe_config.normalization_mode == MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE:
topk_values, topk_indices = topk(routing, self.top_k, dim=-1)
topk_values = softmax(topk_values, -1)
else:
router_probs = softmax(routing, -1)
topk_values, topk_indices = topk(router_probs, self.top_k, dim=-1)
hidden_size = shape(hidden_states, -1)
# [B*sq, hidden]
inputs_merged = hidden_states.view(concat([-1, hidden_size]))
flat_topk_indices = topk_indices.view(
concat([-1, shape(topk_indices, -1)]))
flat_topk_values = topk_values.view(concat([-1,
shape(topk_values, -1)]))
# Create output space
zero_buffer = inputs_merged * 0.0
output = zero_buffer
expert_indices_stack = []
indices_stack = []
# When topk indices are equal to expert index, the expert will inference the tokens.
# Bundle all indices and experts index, then do mask once.
for i, expert in enumerate(self.experts):
if self.mapping.has_moe_ep():
index = i + self.experts_per_node * self.mapping.moe_ep_rank
else:
index = i
expert_indices_stack.append(
flat_topk_indices.view(concat([1, shape(flat_topk_indices)])))
indices_stack.append(constant(int32_array(index)))
all_expert_indices = concat(expert_indices_stack, dim=0)
indices = expand(
concat(indices_stack).view(concat([len(self.experts), 1, 1])),
shape(all_expert_indices))
# Create all experts mask
all_expert_mask = all_expert_indices == indices
experts_weights = cast(
sum(flat_topk_values *
cast(all_expert_mask, flat_topk_values.dtype),
dim=-1,
keepdim=True), self.dtype)
all_expert_mask = cast(
sum(cast(all_expert_mask, flat_topk_values.dtype),
dim=-1,
keepdim=True), 'bool')
all_expert_mask = repeat_interleave(all_expert_mask, shape(output, -1),
2)
# split the mask and weights for each expert
experts_mask = split(all_expert_mask, 1, dim=0)
expert_weights = split(experts_weights, 1, dim=0)
for i, expert in enumerate(self.experts):
if self.mapping.has_moe_ep():
index = i + self.experts_per_node * self.mapping.moe_ep_rank
else:
index = i
# get mask token index
non_zero_index = nonzero(experts_mask[i].view(
concat([-1, hidden_size])))
non_zero_index = non_zero_index.transpose(1, 0)
input_for_expert = gather_nd(inputs_merged, non_zero_index, 0)
input_for_expert = input_for_expert.view(concat([-1, hidden_size]),
zero_is_placeholder=False)
# Expert inference
expert_output = expert(
input_for_expert,
lora_layer_params=self.moe_to_expert_lora_params(
lora_layer_params, index))
# scatter expert output to real position
expert_finialized_output = zero_buffer
expert_finialized_output = scatter_nd(
expert_finialized_output, non_zero_index,
expert_output.view([-1])) * expert_weights[i]
output += expert_finialized_output
output = output.view(shape(hidden_states))
if self.tp_size > 1 and self.tp_group is not None:
output = allreduce(output,
self.mapping.tp_group,
reduce_fusion_params=reduce_fusion_params)
return output
def load_weights(self, moe: MOE):
for i, expert in enumerate(self.experts):
is_gated_act = is_gated_activation(self.hidden_act)
# Gated weight pack in expert1 weights
# expert_weights_1
experts_weight_1_raw = moe.fc.weight.raw_value
fc1_weight_scale = None
fc1_activation_scale = None
fc2_weight_scale = None
fc2_activation_scale = None
if self.quant_mode.has_fp8_qdq():
fc1_weight_scale = moe.fc.weights_scaling_factor.raw_value
fc1_activation_scale = moe.fc.activation_scaling_factor.raw_value
fc2_weight_scale = moe.proj.weights_scaling_factor.raw_value
fc2_activation_scale = moe.proj.activation_scaling_factor.raw_value
if self.quant_mode.is_weight_only():
expert.fc.weight.value = experts_weight_1_raw[
i, :, -self.expert_inter_size:]
if is_gated_act:
expert.gate.weight.value = experts_weight_1_raw[
i, :, :self.expert_inter_size]
else:
expert.fc.weight.value = experts_weight_1_raw[
i, -self.expert_inter_size:, :]
if is_gated_act:
expert.gate.weight.value = experts_weight_1_raw[
i, :self.expert_inter_size, :]
if self.quant_mode.has_fp8_qdq():
expert.fc.activation_scaling_factor.value = fc1_activation_scale
expert.fc.weights_scaling_factor.value = fc1_weight_scale[i]
expert.proj.activation_scaling_factor.value = fc2_activation_scale
expert.proj.weights_scaling_factor.value = fc2_weight_scale[i]
if is_gated_act:
expert.gate.activation_scaling_factor.value = fc1_activation_scale
expert.gate.weights_scaling_factor.value = fc1_weight_scale[
i]
# expert_weights_2
experts_weight_2_raw = moe.proj.weight.raw_value
expert.proj.weight.value = experts_weight_2_raw[i, :, :]
has_bias = self.bias
if has_bias:
experts_bias_1_raw = moe.fc.bias.raw_value
expert.fc.bias.value = experts_bias_1_raw[
i, -self.expert_inter_size:]
experts_bias_2_raw = moe.proj.bias.raw_value
expert.proj.bias.value = experts_bias_2_raw[i, :]
if is_gated_act:
expert.gate.bias.value = experts_bias_1_raw[
i, :self.expert_inter_size]