Skip to content

Commit 954f0e9

Browse files
Convert args to LlmConfig (#11513)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11081 by @jackzhxng ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/12/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/12/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/11/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/12/orig @diff-train-skip-merge --------- Co-authored-by: Jack Zhang <[email protected]>
1 parent 78446d0 commit 954f0e9

File tree

2 files changed

+145
-3
lines changed

2 files changed

+145
-3
lines changed

examples/apple/mps/scripts/mps_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_model_config(args):
145145
return model_config
146146

147147

148-
if __name__ == "__main__":
148+
if __name__ == "__main__": # noqa: C901
149149
args = parse_args()
150150

151151
if args.model_name not in MODEL_NAME_TO_MODEL:

examples/models/llama/config/llm_config.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,14 +470,156 @@ class LlmConfig:
470470
backend: BackendConfig = field(default_factory=BackendConfig)
471471

472472
@classmethod
473-
def from_args(cls, args: argparse.Namespace) -> "LlmConfig":
473+
def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
474474
"""
475475
To support legacy purposes, this function converts CLI args from
476476
argparse to an LlmConfig, which is used by the LLM export process.
477477
"""
478478
llm_config = LlmConfig()
479479

480-
# TODO: conversion code.
480+
# BaseConfig
481+
if hasattr(args, "model"):
482+
llm_config.base.model_class = ModelType(args.model)
483+
if hasattr(args, "params"):
484+
llm_config.base.params = args.params
485+
if hasattr(args, "checkpoint"):
486+
llm_config.base.checkpoint = args.checkpoint
487+
if hasattr(args, "checkpoint_dir"):
488+
llm_config.base.checkpoint_dir = args.checkpoint_dir
489+
if hasattr(args, "tokenizer_path"):
490+
llm_config.base.tokenizer_path = args.tokenizer_path
491+
if hasattr(args, "metadata"):
492+
llm_config.base.metadata = args.metadata
493+
if hasattr(args, "use_lora"):
494+
llm_config.base.use_lora = args.use_lora
495+
if hasattr(args, "fairseq2"):
496+
llm_config.base.fairseq2 = args.fairseq2
497+
498+
# PreqMode settings
499+
if hasattr(args, "preq_mode") and args.preq_mode:
500+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
501+
if hasattr(args, "preq_group_size"):
502+
llm_config.base.preq_group_size = args.preq_group_size
503+
if hasattr(args, "preq_embedding_quantize"):
504+
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
505+
506+
# ModelConfig
507+
if hasattr(args, "dtype_override"):
508+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
509+
if hasattr(args, "enable_dynamic_shape"):
510+
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
511+
if hasattr(args, "use_shared_embedding"):
512+
llm_config.model.use_shared_embedding = args.use_shared_embedding
513+
if hasattr(args, "use_sdpa_with_kv_cache"):
514+
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
515+
if hasattr(args, "expand_rope_table"):
516+
llm_config.model.expand_rope_table = args.expand_rope_table
517+
if hasattr(args, "use_attention_sink"):
518+
llm_config.model.use_attention_sink = args.use_attention_sink
519+
if hasattr(args, "output_prune_map"):
520+
llm_config.model.output_prune_map = args.output_prune_map
521+
if hasattr(args, "input_prune_map"):
522+
llm_config.model.input_prune_map = args.input_prune_map
523+
if hasattr(args, "use_kv_cache"):
524+
llm_config.model.use_kv_cache = args.use_kv_cache
525+
if hasattr(args, "quantize_kv_cache"):
526+
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
527+
if hasattr(args, "local_global_attention"):
528+
llm_config.model.local_global_attention = args.local_global_attention
529+
530+
# ExportConfig
531+
if hasattr(args, "max_seq_length"):
532+
llm_config.export.max_seq_length = args.max_seq_length
533+
if hasattr(args, "max_context_length"):
534+
llm_config.export.max_context_length = args.max_context_length
535+
if hasattr(args, "output_dir"):
536+
llm_config.export.output_dir = args.output_dir
537+
if hasattr(args, "output_name"):
538+
llm_config.export.output_name = args.output_name
539+
if hasattr(args, "so_library"):
540+
llm_config.export.so_library = args.so_library
541+
if hasattr(args, "export_only"):
542+
llm_config.export.export_only = args.export_only
543+
544+
# QuantizationConfig
545+
if hasattr(args, "quantization_mode"):
546+
llm_config.quantization.qmode = args.quantization_mode
547+
if hasattr(args, "embedding_quantize"):
548+
llm_config.quantization.embedding_quantize = args.embedding_quantize
549+
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
550+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
551+
if hasattr(args, "group_size"):
552+
llm_config.quantization.group_size = args.group_size
553+
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
554+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
555+
if hasattr(args, "use_qat"):
556+
llm_config.quantization.use_qat = args.use_qat
557+
if hasattr(args, "calibration_tasks"):
558+
llm_config.quantization.calibration_tasks = args.calibration_tasks
559+
if hasattr(args, "calibration_limit"):
560+
llm_config.quantization.calibration_limit = args.calibration_limit
561+
if hasattr(args, "calibration_seq_length"):
562+
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
563+
if hasattr(args, "calibration_data"):
564+
llm_config.quantization.calibration_data = args.calibration_data
565+
566+
# BackendConfig - XNNPack
567+
if hasattr(args, "xnnpack"):
568+
llm_config.backend.xnnpack.enabled = args.xnnpack
569+
if hasattr(args, "xnnpack_extended_ops"):
570+
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
571+
572+
# CoreML
573+
if hasattr(args, "coreml"):
574+
llm_config.backend.coreml.enabled = args.coreml
575+
llm_config.backend.coreml.enable_state = getattr(
576+
args, "coreml_enable_state", False
577+
)
578+
llm_config.backend.coreml.preserve_sdpa = getattr(
579+
args, "coreml_preserve_sdpa", False
580+
)
581+
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
582+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
583+
if hasattr(args, "coreml_ios"):
584+
llm_config.backend.coreml.ios = args.coreml_ios
585+
if hasattr(args, "coreml_compute_units"):
586+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
587+
args.coreml_compute_units
588+
)
589+
590+
# Vulkan
591+
if hasattr(args, "vulkan"):
592+
llm_config.backend.vulkan.enabled = args.vulkan
593+
594+
# QNN
595+
if hasattr(args, "qnn"):
596+
llm_config.backend.qnn.enabled = args.qnn
597+
if hasattr(args, "use_qnn_sha"):
598+
llm_config.backend.qnn.use_sha = args.use_qnn_sha
599+
if hasattr(args, "soc_model"):
600+
llm_config.backend.qnn.soc_model = args.soc_model
601+
if hasattr(args, "optimized_rotation_path"):
602+
llm_config.backend.qnn.optimized_rotation_path = (
603+
args.optimized_rotation_path
604+
)
605+
if hasattr(args, "num_sharding"):
606+
llm_config.backend.qnn.num_sharding = args.num_sharding
607+
608+
# MPS
609+
if hasattr(args, "mps"):
610+
llm_config.backend.mps.enabled = args.mps
611+
612+
# DebugConfig
613+
if hasattr(args, "profile_memory"):
614+
llm_config.debug.profile_memory = args.profile_memory
615+
if hasattr(args, "profile_path"):
616+
llm_config.debug.profile_path = args.profile_path
617+
if hasattr(args, "generate_etrecord"):
618+
llm_config.debug.generate_etrecord = args.generate_etrecord
619+
if hasattr(args, "generate_full_logits"):
620+
llm_config.debug.generate_full_logits = args.generate_full_logits
621+
if hasattr(args, "verbose"):
622+
llm_config.debug.verbose = args.verbose
481623

482624
return llm_config
483625

0 commit comments

Comments
 (0)