Skip to content

Commit

Permalink
style: test ann2snn module rebota
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffreyWong20 committed Jan 28, 2025
1 parent f97fe2c commit a2a203e
Showing 1 changed file with 45 additions and 64 deletions.
109 changes: 45 additions & 64 deletions test/passes/module/transforms/ann2snn/test_ann2snn_module_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,71 +77,52 @@
},
}
mg, _ = ann2snn_module_transform_pass(mg, convert_pass_args)
# convert_pass_args = {
# "by": "type",
# "embedding": {
# "config": {
# "name": "zip_tf",
# },
# },
# "linear": {
# "config": {
# "name": "unfold_bias",
# "level": 32,
# "neuron_type": "ST-BIF",
# },
# },
# "conv2d": {
# "config": {
# "name": "zip_tf",
# "level": 32,
# "neuron_type": "ST-BIF",
# },
# },
# "layernorm": {
# "config": {
# "name": "zip_tf",
# },
# },
# "relu": {
# "manual_instantiate": True,
# "config": {
# "name": "identity",
# },
# },
# "lsqinteger": {
# "manual_instantiate": True,
# "config": {
# "name": "st_bif",
# # Default values. These would be replaced by the values from the LSQInteger module, so it has no effect.
# # "q_threshold": 1,
# # "level": 32,
# # "sym": True,
# },
# },
# }
# mg, _ = ann2snn_module_transform_pass(mg, convert_pass_args)

convert_pass_args = {
"by": "type",
"embedding": {
"config": {
"name": "zip_tf",
},
},
"linear": {
"config": {
"name": "unfold_bias",
"level": 32,
"neuron_type": "ST-BIF",
},
},
"conv2d": {
"config": {
"name": "zip_tf",
"level": 32,
"neuron_type": "ST-BIF",
},
},
"layernorm": {
"config": {
"name": "zip_tf",
},
},
"relu": {
"manual_instantiate": True,
"config": {
"name": "identity",
},
},
"lsqinteger": {
"manual_instantiate": True,
"config": {
"name": "st_bif",
# Default values. These would be replaced by the values from the LSQInteger module, so it has no effect.
# "q_threshold": 1,
# "level": 32,
# "sym": True,
},
},
}
mg, _ = ann2snn_module_transform_pass(mg, convert_pass_args)

# f = open(f"spiking_model_arch.txt", "w")
# f.write(str(mg))
# f.close()

# return mg


# import datasets as hf_datasets
# sst2 = hf_datasets.load_dataset("gpt3mix/sst2")
# train_df = sst2["train"]
# dev_df = sst2["validation"]
# test_df = sst2["test"]

# max_seq_len = 50
# epochs = 10
# batch_size = 32
# lr = 2e-5
# patience = 5
# max_grad_norm = 10
# if_save_model = False
# checkpoint = None

# mg = test_ann2snn_module_transform_pass()

0 comments on commit a2a203e

Please sign in to comment.