@@ -470,14 +470,156 @@ class LlmConfig:
470
470
backend : BackendConfig = field (default_factory = BackendConfig )
471
471
472
472
@classmethod
473
- def from_args (cls , args : argparse .Namespace ) -> "LlmConfig" :
473
+ def from_args (cls , args : argparse .Namespace ) -> "LlmConfig" : # noqa: C901
474
474
"""
475
475
To support legacy purposes, this function converts CLI args from
476
476
argparse to an LlmConfig, which is used by the LLM export process.
477
477
"""
478
478
llm_config = LlmConfig ()
479
479
480
- # TODO: conversion code.
480
+ # BaseConfig
481
+ if hasattr (args , "model" ):
482
+ llm_config .base .model_class = ModelType (args .model )
483
+ if hasattr (args , "params" ):
484
+ llm_config .base .params = args .params
485
+ if hasattr (args , "checkpoint" ):
486
+ llm_config .base .checkpoint = args .checkpoint
487
+ if hasattr (args , "checkpoint_dir" ):
488
+ llm_config .base .checkpoint_dir = args .checkpoint_dir
489
+ if hasattr (args , "tokenizer_path" ):
490
+ llm_config .base .tokenizer_path = args .tokenizer_path
491
+ if hasattr (args , "metadata" ):
492
+ llm_config .base .metadata = args .metadata
493
+ if hasattr (args , "use_lora" ):
494
+ llm_config .base .use_lora = args .use_lora
495
+ if hasattr (args , "fairseq2" ):
496
+ llm_config .base .fairseq2 = args .fairseq2
497
+
498
+ # PreqMode settings
499
+ if hasattr (args , "preq_mode" ) and args .preq_mode :
500
+ llm_config .base .preq_mode = PreqMode (args .preq_mode )
501
+ if hasattr (args , "preq_group_size" ):
502
+ llm_config .base .preq_group_size = args .preq_group_size
503
+ if hasattr (args , "preq_embedding_quantize" ):
504
+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
505
+
506
+ # ModelConfig
507
+ if hasattr (args , "dtype_override" ):
508
+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
509
+ if hasattr (args , "enable_dynamic_shape" ):
510
+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
511
+ if hasattr (args , "use_shared_embedding" ):
512
+ llm_config .model .use_shared_embedding = args .use_shared_embedding
513
+ if hasattr (args , "use_sdpa_with_kv_cache" ):
514
+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
515
+ if hasattr (args , "expand_rope_table" ):
516
+ llm_config .model .expand_rope_table = args .expand_rope_table
517
+ if hasattr (args , "use_attention_sink" ):
518
+ llm_config .model .use_attention_sink = args .use_attention_sink
519
+ if hasattr (args , "output_prune_map" ):
520
+ llm_config .model .output_prune_map = args .output_prune_map
521
+ if hasattr (args , "input_prune_map" ):
522
+ llm_config .model .input_prune_map = args .input_prune_map
523
+ if hasattr (args , "use_kv_cache" ):
524
+ llm_config .model .use_kv_cache = args .use_kv_cache
525
+ if hasattr (args , "quantize_kv_cache" ):
526
+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
527
+ if hasattr (args , "local_global_attention" ):
528
+ llm_config .model .local_global_attention = args .local_global_attention
529
+
530
+ # ExportConfig
531
+ if hasattr (args , "max_seq_length" ):
532
+ llm_config .export .max_seq_length = args .max_seq_length
533
+ if hasattr (args , "max_context_length" ):
534
+ llm_config .export .max_context_length = args .max_context_length
535
+ if hasattr (args , "output_dir" ):
536
+ llm_config .export .output_dir = args .output_dir
537
+ if hasattr (args , "output_name" ):
538
+ llm_config .export .output_name = args .output_name
539
+ if hasattr (args , "so_library" ):
540
+ llm_config .export .so_library = args .so_library
541
+ if hasattr (args , "export_only" ):
542
+ llm_config .export .export_only = args .export_only
543
+
544
+ # QuantizationConfig
545
+ if hasattr (args , "quantization_mode" ):
546
+ llm_config .quantization .qmode = args .quantization_mode
547
+ if hasattr (args , "embedding_quantize" ):
548
+ llm_config .quantization .embedding_quantize = args .embedding_quantize
549
+ if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
550
+ llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
551
+ if hasattr (args , "group_size" ):
552
+ llm_config .quantization .group_size = args .group_size
553
+ if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
554
+ llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
555
+ if hasattr (args , "use_qat" ):
556
+ llm_config .quantization .use_qat = args .use_qat
557
+ if hasattr (args , "calibration_tasks" ):
558
+ llm_config .quantization .calibration_tasks = args .calibration_tasks
559
+ if hasattr (args , "calibration_limit" ):
560
+ llm_config .quantization .calibration_limit = args .calibration_limit
561
+ if hasattr (args , "calibration_seq_length" ):
562
+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
563
+ if hasattr (args , "calibration_data" ):
564
+ llm_config .quantization .calibration_data = args .calibration_data
565
+
566
+ # BackendConfig - XNNPack
567
+ if hasattr (args , "xnnpack" ):
568
+ llm_config .backend .xnnpack .enabled = args .xnnpack
569
+ if hasattr (args , "xnnpack_extended_ops" ):
570
+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
571
+
572
+ # CoreML
573
+ if hasattr (args , "coreml" ):
574
+ llm_config .backend .coreml .enabled = args .coreml
575
+ llm_config .backend .coreml .enable_state = getattr (
576
+ args , "coreml_enable_state" , False
577
+ )
578
+ llm_config .backend .coreml .preserve_sdpa = getattr (
579
+ args , "coreml_preserve_sdpa" , False
580
+ )
581
+ if hasattr (args , "coreml_quantize" ) and args .coreml_quantize :
582
+ llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
583
+ if hasattr (args , "coreml_ios" ):
584
+ llm_config .backend .coreml .ios = args .coreml_ios
585
+ if hasattr (args , "coreml_compute_units" ):
586
+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
587
+ args .coreml_compute_units
588
+ )
589
+
590
+ # Vulkan
591
+ if hasattr (args , "vulkan" ):
592
+ llm_config .backend .vulkan .enabled = args .vulkan
593
+
594
+ # QNN
595
+ if hasattr (args , "qnn" ):
596
+ llm_config .backend .qnn .enabled = args .qnn
597
+ if hasattr (args , "use_qnn_sha" ):
598
+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
599
+ if hasattr (args , "soc_model" ):
600
+ llm_config .backend .qnn .soc_model = args .soc_model
601
+ if hasattr (args , "optimized_rotation_path" ):
602
+ llm_config .backend .qnn .optimized_rotation_path = (
603
+ args .optimized_rotation_path
604
+ )
605
+ if hasattr (args , "num_sharding" ):
606
+ llm_config .backend .qnn .num_sharding = args .num_sharding
607
+
608
+ # MPS
609
+ if hasattr (args , "mps" ):
610
+ llm_config .backend .mps .enabled = args .mps
611
+
612
+ # DebugConfig
613
+ if hasattr (args , "profile_memory" ):
614
+ llm_config .debug .profile_memory = args .profile_memory
615
+ if hasattr (args , "profile_path" ):
616
+ llm_config .debug .profile_path = args .profile_path
617
+ if hasattr (args , "generate_etrecord" ):
618
+ llm_config .debug .generate_etrecord = args .generate_etrecord
619
+ if hasattr (args , "generate_full_logits" ):
620
+ llm_config .debug .generate_full_logits = args .generate_full_logits
621
+ if hasattr (args , "verbose" ):
622
+ llm_config .debug .verbose = args .verbose
481
623
482
624
return llm_config
483
625
0 commit comments