-
Notifications
You must be signed in to change notification settings - Fork 602
Add new export LLM config #11028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new export LLM config #11028
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11028
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4372d61 with merge base c2aa614 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D75263991 |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
Pull Request resolved: #11028 @imported-using-ghimport Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991/) ghstack-source-id: 288636930
This pull request was exported from Phabricator. Differential Revision: D75263991 |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D75263991 |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D75263991 |
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D75263991 |
e46a59a
into
gh/jackzhxng/10/base
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11028 by @jackzhxng ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/10/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/10/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/10/orig @diff-train-skip-merge Co-authored-by: Jack Zhang <[email protected]>
SMOLLM2 = "smollm2" | ||
|
||
|
||
class PreqMode(str, Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth tagging as deprecated
model_class: Which model to to export. | ||
params: Model parameters, such as n_layers, hidden_size, etc. | ||
If left empty will use defaults specified in model_args.py. | ||
checkpoint: Path to the checkpoint file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be hf path as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but at the moment if you specify a non-llama model_class
and don't specify checkpoint, it will download from HF. Worth adding this comment
tokenizer_path: Path to the tokenizer file. | ||
metadata: Json string containing metadata information. | ||
e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' | ||
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lora in our case is really tied to QAT model that we released, right? It is not independently applicable to any model? If so I think we want to tied this to QAT checkpoints specifically for llama3_2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that the case? If so I'll add something to the post_init to verify and update the comment
metadata: Json string containing metadata information. | ||
e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' | ||
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. | ||
fairseq2: For legacy internal use cases, this is safe to ignore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix with _ to differentiate between this and other supported params
preq_mode: Legacy option to specify how prequantized weights are loaded. | ||
Going forward, ExecuTorch supports loading weights prequantized through | ||
TorchAo as-is, without any special handling. | ||
preq_group_size: Legacy option to specify the group size of prequantized weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thik we probably want to couple these together. If you are loading pre-quantized checkpoint that group size cannot be set independently, right? So maybe having a separate dataclass that captures all the params and maps it by name is better. ALthough I presume you have to keep this for BC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, just keeping for BC. I think ideally we are moving away from preq to tho right?
@dataclass | ||
class ModelConfig: | ||
""" | ||
Configurations not necessarily specific to the model, but are needed to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then why call it modelconfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe call LoweringConfig? or ExportConfig although not all options are export related
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more like: what other modifications do I want to make to the model in eager that aren't specific to the model itself (e.g. NOT checkpoint, model architecture, tokenizer) and can be shared across different models.
doesn't actually have anything to do with the kv_cache at the moment. | ||
expand_rope_table: Temporary workaround to expand sin/cos table in head | ||
dim to take vectorized path in optimized kernels. | ||
use_attention_sink: Whether to use attention sink to support multi-round |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also considering pruning options that we dont intend to support. attention sink at the moment does not have performant implementation and I would rather hide it somewhere than to expose it. Reduce maintenance burden
max_context_length: Maximum of context for the model to remember. | ||
output_dir: Output dir to save the exported .pte file to. | ||
output_name: File name to override the exported .pte file. | ||
so_library: Shared library to specify custom quantized operators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we need so_library anymore. Remove it in follow up
output_dir: Output dir to save the exported .pte file to. | ||
output_name: File name to override the exported .pte file. | ||
so_library: Shared library to specify custom quantized operators. | ||
export_only: Whether to stop right after torch.export() and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this for debug?
XNNPACK_DYNAMIC = "xnnpack_dynamic" | ||
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably can remove these two
pt2e_quantize: Quantization mode using pt2e, which is an alternative | ||
to TorchAo that uses backend-aware graph mode quantization rather | ||
than source transformation quantization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would want to hide these details from users
class CoreMLQuantize(str, Enum): | ||
B4W = "b4w" | ||
C4W = "c4w" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe favor this pattern over the ones in pt2e_quantizer
class TestValidConstruction(unittest.TestCase): | ||
|
||
def test_valid_llm_config(self): | ||
LlmConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these configs constructible from json? I think that would e the best
Stack from ghstack (oldest at bottom):
Differential Revision: D75263991