Skip to content

Commit

Permalink
support CodeGeeX4
Browse files Browse the repository at this point in the history
  • Loading branch information
Judd committed Jul 6, 2024
1 parent 493846f commit 8173f62
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 43 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g

**What's New:**

* 2024-07-05: CodeGeeX4
* 2024-07-04: InternLM 2.5 with tool calling
* 2024-07-03: Phi3 mini (June 2024 Update)
* 2024-07-01: LLM-Compiler
Expand Down
6 changes: 5 additions & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ModelType(Enum):
CODEGEEX2 = 4
CharacterGLM = 5
CHATGLM4 = 6
CODEGEEX4 = 7

InternLM = 0x100
InternLM2 = 0x101
Expand Down Expand Up @@ -3392,7 +3393,7 @@ def main():
vocab_dir = Path(args.model_name_or_path) if args.vocab_dir == '' else Path(args.vocab_dir)
tokenizer_model_file_exists = False

if (config._name_or_path == 'THUDM/glm-4-9b-chat') or (config._name_or_path == 'THUDM/glm4-9b-chat'):
if config._name_or_path in ['THUDM/glm-4-9b-chat', 'THUDM/glm4-9b-chat', 'THUDM/codegeex4-all-9b']:
vocab = load_vocab_from_tiktok_mergeable_ranks(vocab_dir / 'tokenizer.model')
else:
tokenizer_model_file_exists = (vocab_dir / 'tokenizer.model').exists()
Expand Down Expand Up @@ -3429,6 +3430,9 @@ def main():
ChatGLM2Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
else:
ChatGLMConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
elif arch == 'codegeex4':
ChatGLM4Converter.MODEL_TYPE = ModelType.CODEGEEX4
ChatGLM4Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
elif arch == 'characterglm':
CharacterGLMConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
elif arch == 'InternLMForCausalLM':
Expand Down
2 changes: 2 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
Note: Use additional key-value pair arguments to specify characters, `--kv user_name "..." bot_name "..." user_info "..." bot_info "..."`.

* [x] GLM-4: [Chat-9B-128k](https://huggingface.co/THUDM/glm-4-9b-chat), [Chat-9B-1M](https://huggingface.co/THUDM/glm-4-9b-chat-1m)
* [x] CodeGeeX4: [9B](https://huggingface.co/THUDM/codegeex4-all-9b) (`-a CodeGeeX4`)


* InternLM (`InternLMForCausalLM`, `InternLM2ForCausalLM`)
* [x] v1: [Chat-7B](https://huggingface.co/internlm/internlm-chat-7b), [Chat-7B v1.1](https://huggingface.co/internlm/internlm-chat-7b-v1_1), [Chat-20B](https://huggingface.co/internlm/internlm-chat-20b)
Expand Down
12 changes: 10 additions & 2 deletions models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,15 @@ class GLMInterceptor : public ChunkInterceptor
}

if (find_meta)
{
oss << chunk;
if (oss.str().find(' ') != std::string::npos)
{
streamer->put_chunk(true, oss.str());
oss.str("");
find_meta = false;
}
}
else
streamer->put_chunk(first, chunk);
}
Expand All @@ -613,8 +621,8 @@ class ConditionalGeneration : public v2::ConditionalGeneration
{
public:
ConditionalGeneration() = default;
ConditionalGeneration(const Config &config)
: v2::ConditionalGeneration(config, MODEL_TYPE_GLM4)
ConditionalGeneration(const Config &config, ModelType type = MODEL_TYPE_GLM4)
: v2::ConditionalGeneration(config, type)
{
for (int i = 0; i < config.num_hidden_layers; i++)
{
Expand Down
84 changes: 84 additions & 0 deletions models/codegeex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
namespace v2
{
struct Config : public glm::v2::Config
{
};

class ChatHistoryEncoder : public BaseHistoryEncoder
{
public:
void do_append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
};

static ChatHistoryEncoder _chat_encoder;

class Tokenizer : public glm::v2::Tokenizer
{
public:
Tokenizer(const Config &config) : glm::v2::Tokenizer::Tokenizer(config, &_chat_encoder)
{
sys_prompt = "# language: Python";
}
};

class ConditionalGeneration : public glm::v2::ConditionalGeneration
{
public:
ConditionalGeneration() = default;
ConditionalGeneration(const Config &config)
: glm::v2::ConditionalGeneration(config, MODEL_TYPE_CODEGEEX2)
{
}
};

void ChatHistoryEncoder::do_append_user(int round_idx, const std::string &user, std::vector<int> &ids) const
{
std::string combined = tokenizer->get_system_prompt() + "\n" + user + "\n";
tokenizer->encode(combined, ids);
}
}

namespace v4
{
typedef glm::v4::Config Config;

class Tokenizer : public glm::v4::Tokenizer
{
public:
Tokenizer(const Config &config) : glm::v4::Tokenizer(config)
{}

size_t load(tokenizer::DataReader *buffer, int n_vocab) override
{
size_t r = glm::v4::Tokenizer::load(buffer, n_vocab);
int special_id = observation_token_id + 5;
code_prefix_token_id = special_id++;
code_middle_token_id = special_id++;
code_suffix_token_id = special_id++;
cursor_token_id = special_id++;
tp->AddAddedToken("<|code_prefix|>", code_prefix_token_id);
tp->AddAddedToken("<|code_middle|>", code_middle_token_id);
tp->AddAddedToken("<|code_suffix|>", code_suffix_token_id);
tp->AddAddedToken("<|cursor|>", cursor_token_id);
return r;
}
public:
int code_prefix_token_id;
int code_middle_token_id;
int code_suffix_token_id;
int cursor_token_id;
};

class ConditionalGeneration : public glm::v4::ConditionalGeneration
{
public:
ConditionalGeneration(const Config &config)
: glm::v4::ConditionalGeneration(config, MODEL_TYPE_CODEGEEX4)
{
}

// FIXME: this mode seems not support tool calling actually
// https://github.com/THUDM/CodeGeeX4/issues/8
ChunkInterceptor *get_interceptor(void) override { return nullptr; }
};
}
36 changes: 0 additions & 36 deletions models/codegeex_v2.cpp

This file was deleted.

9 changes: 5 additions & 4 deletions src/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace chatllm
MODEL_TYPE_CODEGEEX2 = 4,
MODEL_TYPE_CHARACTERGLM = 5,
MODEL_TYPE_GLM4 = 6,
MODEL_TYPE_CODEGEEX4 = 7,

MODEL_TYPE_INTERNLM = 0x100,
MODEL_TYPE_INTERNLM2= 0x101, // extended model, supporting 7B & 20B
Expand Down Expand Up @@ -212,6 +213,8 @@ namespace chatllm
return "GLM-4";
case MODEL_TYPE_CODEGEEX2:
return "CodeGeeX2";
case MODEL_TYPE_CODEGEEX4:
return "CodeGeeX4";
case MODEL_TYPE_CHARACTERGLM:
return "CharacterGLM";
case MODEL_TYPE_INTERNLM:
Expand Down Expand Up @@ -1146,10 +1149,7 @@ namespace chatllm

namespace codegeex
{
namespace v2
{
#include "../models/codegeex_v2.cpp"
}
#include "../models/codegeex.cpp"
}

namespace internlm
Expand Down Expand Up @@ -1504,6 +1504,7 @@ namespace chatllm
CASE(CODEGEEX2, codegeex::v2, 1) \
CASE(CHARACTERGLM, characterglm, 1) \
CASE(GLM4, glm::v4, 1) \
CASE(CODEGEEX4, codegeex::v4, 1) \
\
CASE(INTERNLM, internlm::v1, 1) \
CASE(INTERNLM2, internlm::v2, 1) \
Expand Down

0 comments on commit 8173f62

Please sign in to comment.