Skip to content

Commit 503af8b

Browse files
committed
[do not land] printing shapes for autoquant
Summary: Generate shapes for micro benchmarking `benchmarks/benchmark_aq.py` but this doesn't seem very helpful for predicting the perf for llama2: https://gist.github.com/jerryzh168/efc0cb1be0a8a29c9edcd87cc01652f6 Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0b66ff0 commit 503af8b

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

benchmarks/benchmark_aq.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,39 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
133133
WARMUP = 20
134134
RUNS = 100
135135

136+
torch._dynamo.reset()
137+
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
138+
benchmark_model(m_bf16, WARMUP, example_inputs)
139+
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
140+
141+
torch._dynamo.reset()
136142
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
137143
benchmark_model(m_ref, WARMUP, example_inputs)
138144
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
139145

146+
torch._dynamo.reset()
140147
m = torch.compile(m, mode='max-autotune', fullgraph=True)
141148
benchmark_model(m, WARMUP, example_inputs)
142149
elapsed_time = benchmark_model(m, RUNS, example_inputs)
143150

144-
145-
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
146-
benchmark_model(m_bf16, WARMUP, example_inputs)
147-
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
148-
149151
print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")
150152

151153
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
152-
all_shapes = [
153-
(20, 2048, 2048),
154-
]
154+
# all_shapes = set([
155+
# (20, 2048, 2048),
156+
# ])
157+
all_shapes = set([
158+
(6, 12288, 4096),
159+
(6, 4096, 4096),
160+
(6, 11008, 4096),
161+
(6, 4096, 11008),
162+
(6, 32000, 4096),
163+
(1, 12288, 4096),
164+
(1, 4096, 4096),
165+
(1, 11008, 4096),
166+
(1, 4096, 11008),
167+
(1, 32000, 4096),
168+
])
155169

156170
print("_int8da_int8w_api")
157171
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors

torchao/_models/llama/generate.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@ def main(
234234

235235
# do autoquantization
236236
model.finalize_autoquant()
237+
238+
from torchao.quantization.autoquant import AUTOQUANT_CACHE
239+
shapes = []
240+
for k in AUTOQUANT_CACHE.keys():
241+
act = k[1]
242+
w = k[2]
243+
M, K = act
244+
N = w[0]
245+
shapes.append((M, N, K))
246+
print("all shapes:", set(shapes))
237247
else:
238248
if not TORCH_VERSION_AT_LEAST_2_5:
239249
unwrap_tensor_subclass(model)
@@ -375,10 +385,11 @@ def callback(x):
375385
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
376386
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
377387
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
388+
parser.add_argument('--print_autoquant_m_n_k', action='store_true', help='Whether to print the M, N, K shapes in AUTOQUANT_CACHE for micro benchmarking.')
378389
parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result')
379390

380391
args = parser.parse_args()
381392
main(
382393
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
383-
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
394+
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.print_autoquant_m_n_k, args.write_result
384395
)

0 commit comments

Comments
 (0)