feat: extract the common module of Transformer#115
feat: extract the common module of Transformer#115JYMiracle305 wants to merge 1 commit intomasterfrom
Conversation
2ab6ca5 to
e0504d9
Compare
dfdd913 to
d833ec2
Compare
| first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(gpt2_config)) | ||
| .with_submodule(TransformerFirstStage::kWPELayerName, | ||
| BuildPositionEmbeddingSpec(gpt2_config.block_size, gpt2_config.n_embd)); | ||
| spec.with_submodule("first_stage", first_stage); |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| namespace infini_train::nn { | ||
|
|
||
| void ModuleRegistry::Register(std::type_index type, ModuleCreator creator) { registry_[type] = std::move(creator); } |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
| auto tok_emb = (*modules_[kWTELayerName])({x1}); | ||
|
|
||
| // Add position embedding only for models that use absolute position encoding | ||
| if (config_.attention_type == AttentionType::kStandard) { |
There was a problem hiding this comment.
这里是否有 WTE 绑定在 attention_type 上。
目前可以维持现状不修改,但语义上似乎是把 spec 负责的任务放到了通用层。目前两个模型可以支持,后续如果遇到例外情况不能在这里加分支,而是要下沉到 spec 里
example/llama3/main.cc
Outdated
| // ManualSeed(42); | ||
|
|
||
| LLaMA3Config model_config = LLaMA3Config(); | ||
| nn::TransformerConfig model_config; |
There was a problem hiding this comment.
nn::TransformerConfig model_config; 声明的默认值都是沿用 gpt2 的架构的,use_bias/use_rope 啥的都是按照 gpt2 来的,导致下面 else 分支实际上构造的是个 gpt2 model。
There was a problem hiding this comment.
新增静态初始化方法,在各自main.cc调用对应的初始化方法
|
|
||
| // ========== GPT2 Model Definition ========== | ||
| // Uses LayerNorm, GELU activation, standard multi-head attention | ||
| class GPT2 : public nn::TransformerLayer { |
There was a problem hiding this comment.
Layer/Block 的名字似乎和 megatron 是反着来的,megatron 里面 Layer 代表一个 transformer block,Block 代表一串 transformer blocks;不过这个正确性上没影响,就看怎么称呼了…
另外得考虑下 GPT2 Model 这里直接继承自 TransformerLayer 合适吗,megatron 里面 GPTModel 应该还是一种直接继承自 nn.Module 的存在,然后其类内成员有个 self.decoder 构建为对应的 TransformerBlock(megatron 里的名称,对应这个 PR 里面的 TransformerLayer)对象。
| Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override; | ||
|
|
||
| private: | ||
| AttentionType attention_type_; |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| // Architecture choices | ||
| AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type | ||
| MLPType mlp_type = MLPType::kGELU; // MLP activation type |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| namespace infini_train::nn { | ||
|
|
||
| class RMSNorm : public infini_train::nn::CloneableModule<RMSNorm> { |
There was a problem hiding this comment.
可以讨论下像这种融合算子的 module 有无单独拆出来做单文件的必要;一是可以像 megatron 一样把 rmsnorm/layernorm 的选择逻辑包在一个统一的 norm module 里;二是考虑到之后 flash attn 可能要接进来的话,也会存在算法选择之类的逻辑。我感觉可以拎出来
There was a problem hiding this comment.
这个最初设计是分出来的,现在比较少就先没分,可以先改其他地方,最后拆这些文件
| modules_[kCFcLayerName] = build_module(config, spec.submodules_.at(kCFcLayerName)); | ||
|
|
||
| // For SwiGLU, add second projection | ||
| if (spec.submodules_.count(kCFc2LayerName) > 0) { |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| // ========== LLaMA3 Model Definition ========== | ||
| // Uses RMSNorm, SwiGLU activation, GQA attention, RoPE positional encoding | ||
| class LLaMA3 : public nn::TransformerLayer { |
There was a problem hiding this comment.
如果对齐 megatron 的话,gpt2/llama3 本质上都用的是 GPTModel,我感觉这里似乎也不需要额外拆成两个类定义,可以就叫 DecoderOnlyTransformer 之类的?
|
我先提了几个,突然来了好多事情感觉这两天来不及看了。 |
d833ec2 to
9efb498
Compare
9efb498 to
77030ca
Compare




本次PR主要内容为抽象出Transformer类模型的构建架构,将GPT2和LLaMA3构建过程统一为一个流程实现。
目录结构
…/core/
├── models/decode_only_transformer/
│ ├── layer_specs.h/.cc # 模型构建函数声明与实现
│ └── model.h # RMSNorm/NewGELU/SwiGLU 等组件声明
└── transformer/
├── spec_utils.h/.cc # ModuleSpec 构建工具函数与模块注册宏
├── transformer_block.h/.cc # TransformerBlock 等基础组件,注册实现
├── transformer_builders.h/.cc # 规格构建器声明与实现 (BuildNormSpec, BuildMLPSpec 等)
├── transformer_config.h # TransformerConfig 配置结构体,替代原GPT2Config和LLaMA3Config
└── transformer_layer.h/.cc # TransformerFirstStage/Chunk/LastStage,替代原GPT2FirstStage/Chunk/LastStage,LLaMA3FirstStage/Chunk/LastStage
核心机制
ModuleSpec数据结构用于声明模块的类型和参数,模块具体实现通过 ModuleRegistry 统一注册,在构建模型时通过build_module() 动态实例化,根据spec关联已注册的实现。