You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
sh-3.2# llama-cli --version
version: 4758 (5fa07c2)
built with riscv64-tizen-linux-gnu-gcc (Tizen/RISC-V/imafdcv/Standalone-20230621) 13.1.0 for riscv64-tizen-linux-gnu
Errors happen not in specific models, but most of the models I tested which includes below;
DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf
llama-3.2-1b-instruct-q8_0.gguf
nano-mistral-q4_0.gguf
tiny-llm-q8_0.gguf
Problem description & steps to reproduce
When I run llama-simple (or llama-cli) build with __riscv_v_instrinsic flags(default for llama.cpp RISCV cross compile), the generated tokens are broken like these;
e.g)
llama-simple -m tiny-llm-q8_0.gguf
(output)
Hello my name is.,etsperled.raHeperrical
plantaping]
plantaping]
plantaping]
plantaping]
plantcluding")cketsming
main: decoded 32 tokens in 0.70 s, speed: 46.00 t/s
If I run non rvv version of llama-simple (or llama-cli) built without __riscv_v_instrinsic flags(with -U__riscv_v_intrinsic), the generated tokens are not broken. I suspect there might be some bug in riscv rvv intrinsic code in ggml.
e.g)
llama-simple -m tiny-llm-q8_0.gguf
(output)
Hello my name is so much more than I am, I am so happy to be able to get a new one.
I am a new one. I am a new one
main: decoded 32 tokens in 0.77 s, speed: 41.80 t/s
First Bad Commit
No response
Relevant log output
llama_model_loader: loaded meta data with 32 key-value pairs and 12 tensors from ./model.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = Tiny LLM
llama_model_loader: - kv 3: general.size_label str = 13M
llama_model_loader: - kv 4: general.license str = mit
llama_model_loader: - kv 5: general.dataset.count u32 = 1
llama_model_loader: - kv 6: general.dataset.0.name str = Fineweb
llama_model_loader: - kv 7: general.dataset.0.organization str = HuggingFaceFW
llama_model_loader: - kv 8: general.dataset.0.repo_url str = https://huggingface.co/HuggingFaceFW/...
llama_model_loader: - kv 9: general.tags arr[str,1] = ["text-generation"]
llama_model_loader: - kv 10: llama.block_count u32 = 1
llama_model_loader: - kv 11: llama.context_length u32 = 1024
llama_model_loader: - kv 12: llama.embedding_length u32 = 192
llama_model_loader: - kv 13: llama.feed_forward_length u32 = 1024
llama_model_loader: - kv 14: llama.attention.head_count u32 = 2
llama_model_loader: - kv 15: llama.attention.head_count_kv u32 = 1
llama_model_loader: - kv 16: llama.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 17: llama.vocab_size u32 = 32000
llama_model_loader: - kv 18: llama.rope.dimension_count u32 = 96
llama_model_loader: - kv 19: tokenizer.ggml.model str = llama
llama_model_loader: - kv 20: tokenizer.ggml.pre str = default
llama_model_loader: - kv 21: tokenizer.ggml.tokens arr[str,32000] = ["<unk>", "<s>", "</s>", "<0x00>", "<...llama_model_loader: - kv 22: tokenizer.ggml.scores arr[f32,32000] = [-1000.000000, -1000.000000, -1000.00...llama_model_loader: - kv 23: tokenizer.ggml.token_type arr[i32,32000] = [3, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...llama_model_loader: - kv 24: tokenizer.ggml.bos_token_id u32 = 1llama_model_loader: - kv 25: tokenizer.ggml.eos_token_id u32 = 2llama_model_loader: - kv 26: tokenizer.ggml.unknown_token_id u32 = 0llama_model_loader: - kv 27: tokenizer.ggml.add_bos_token bool = truellama_model_loader: - kv 28: tokenizer.ggml.add_eos_token bool = falsellama_model_loader: - kv 29: tokenizer.ggml.add_space_prefix bool = truellama_model_loader: - kv 30: general.quantization_version u32 = 2llama_model_loader: - kv 31: general.file_type u32 = 7llama_model_loader: - type f32: 3 tensorsllama_model_loader: - type q8_0: 9 tensorsprint_info: file format = GGUF V3 (latest)print_info: file type = Q8_0print_info: file size = 13.16 MiB (8.50 BPW)init_tokenizer: initializing tokenizer for type 1load: control token: 0 '<unk>' is not marked as EOGload: control token: 2 '</s>' is not marked as EOGload: control token: 1 '<s>' is not marked as EOGload: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrectload: special tokens cache size = 3load: token to piece cache size = 0.1684 MBprint_info: arch = llamaprint_info: vocab_only = 0print_info: n_ctx_train = 1024print_info: n_embd = 192print_info: n_layer = 1print_info: n_head = 2print_info: n_head_kv = 1print_info: n_rot = 96print_info: n_swa = 0print_info: n_embd_head_k = 96print_info: n_embd_head_v = 96print_info: n_gqa = 2print_info: n_embd_k_gqa = 96print_info: n_embd_v_gqa = 96print_info: f_norm_eps = 0.0e+00print_info: f_norm_rms_eps = 1.0e-05print_info: f_clamp_kqv = 0.0e+00print_info: f_max_alibi_bias = 0.0e+00print_info: f_logit_scale = 0.0e+00print_info: n_ff = 1024print_info: n_expert = 0print_info: n_expert_used = 0print_info: causal attn = 1print_info: pooling type = 0print_info: rope type = 0print_info: rope scaling = linearprint_info: freq_base_train = 10000.0print_info: freq_scale_train = 1print_info: n_ctx_orig_yarn = 1024print_info: rope_finetuned = unknownprint_info: ssm_d_conv = 0print_info: ssm_d_inner = 0print_info: ssm_d_state = 0print_info: ssm_dt_rank = 0print_info: ssm_dt_b_c_rms = 0print_info: model type = ?Bprint_info: model params = 12.99 Mprint_info: general.name = Tiny LLMprint_info: vocab type = SPMprint_info: n_vocab = 32000print_info: n_merges = 0print_info: BOS token = 1 '<s>'print_info: EOS token = 2 '</s>'print_info: UNK token = 0 '<unk>'print_info: LF token = 13 '<0x0A>'print_info: EOG token = 2 '</s>'print_info: max token length = 48load_tensors: loading model tensors, this can take a while... (mmap = true)load_tensors: layer 0 assigned to device CPUload_tensors: layer 1 assigned to device CPUload_tensors: tensor 'token_embd.weight' (q8_0) (and 11 others) cannot be used with preferred buffer type CPU_AARCH64, using CPU insteadload_tensors: CPU_Mapped model buffer size = 13.16 MiB......llama_init_from_model: n_batch is less than GGML_KQ_MASK_PAD - increasing to 64llama_init_from_model: n_seq_max = 1llama_init_from_model: n_ctx = 64llama_init_from_model: n_ctx_per_seq = 64llama_init_from_model: n_batch = 64llama_init_from_model: n_ubatch = 64llama_init_from_model: flash_attn = 0llama_init_from_model: freq_base = 10000.0llama_init_from_model: freq_scale = 1llama_init_from_model: n_ctx_per_seq (64) < n_ctx_train (1024) -- the full capacity of the model will not be utilizedllama_kv_cache_init: kv_size = 64, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 1, can_shift = 1llama_kv_cache_init: layer 0: n_embd_k_gqa = 96, n_embd_v_gqa = 96llama_kv_cache_init: CPU KV buffer size = 0.02 MiBllama_init_from_model: KV self size = 0.02 MiB, K (f16): 0.01 MiB, V (f16): 0.01 MiBllama_init_from_model: CPU output buffer size = 0.12 MiBllama_init_from_model: CPU compute buffer size = 7.86 MiBllama_init_from_model: graph nodes = 38llama_init_from_model: graph splits = 1<s> Hello my name is.,etsperled.raHeperricalplantaping]plantaping]plantaping]plantaping]plantcluding")cketsming
main: decoded 32 tokens in 0.70 s, speed: 46.00 t/s
llama_perf_sampler_print: sampling time= 1.82 ms / 32 runs ( 0.06 ms per token, 17621.15 tokens per second)
llama_perf_context_print: load time= 227.70 ms
llama_perf_context_print: prompt evaltime= 22.88 ms / 5 tokens ( 4.58 ms per token, 218.52 tokens per second)
llama_perf_context_print: evaltime= 641.11 ms / 31 runs ( 20.68 ms per token, 48.35 tokens per second)
llama_perf_context_print: total time= 900.52 ms / 36 tokens
[VD-XS] src/profiling/profiling_control.c reset_inst_counters 50 Have taken checkpoint at 3659994548 guest instructions (abs_inst_count 3659994760)[VD-XS] /home/jongchul/VD-XS/NEMU/src/isa/riscv64/include/../instr/special.h execute 53 trap 0x0 cpu.pc = 0x104a4 s.pc = 0x104b6
real 0m59.748s
user 0m58.636s
sys 0m1.104s
The text was updated successfully, but these errors were encountered:
Name and Version
sh-3.2# llama-cli --version
version: 4758 (5fa07c2)
built with riscv64-tizen-linux-gnu-gcc (Tizen/RISC-V/imafdcv/Standalone-20230621) 13.1.0 for riscv64-tizen-linux-gnu
Operating systems
Linux
GGML backends
CPU
Hardware
RISC-V ISA Simulator (https://github.com/OpenXiangShan/NEMU)
Models
Errors happen not in specific models, but most of the models I tested which includes below;
DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf
llama-3.2-1b-instruct-q8_0.gguf
nano-mistral-q4_0.gguf
tiny-llm-q8_0.gguf
Problem description & steps to reproduce
When I run llama-simple (or llama-cli) build with __riscv_v_instrinsic flags(default for llama.cpp RISCV cross compile), the generated tokens are broken like these;
e.g)
llama-simple -m tiny-llm-q8_0.gguf
(output)
Hello my name is.,etsperled.raHeperrical
plantaping]
plantaping]
plantaping]
plantaping]
plantcluding")cketsming
main: decoded 32 tokens in 0.70 s, speed: 46.00 t/s
If I run non rvv version of llama-simple (or llama-cli) built without __riscv_v_instrinsic flags(with -U__riscv_v_intrinsic), the generated tokens are not broken. I suspect there might be some bug in riscv rvv intrinsic code in ggml.
e.g)
llama-simple -m tiny-llm-q8_0.gguf
(output)
Hello my name is so much more than I am, I am so happy to be able to get a new one.
I am a new one. I am a new one
main: decoded 32 tokens in 0.77 s, speed: 41.80 t/s
First Bad Commit
No response
Relevant log output
The text was updated successfully, but these errors were encountered: