前言
最近在实现模型时需要用到ABAE算法提取语句的aspect,记录一下。该算法在论文《An Unsupervised Neural Attention Model for Aspect Extraction》中被提出。有一篇文章写了比较详细的中文翻译:点击直达。这里是原文作者的代码,用theano和keras写的。然后我使用pytorch复现该代码。
代码
pytorch实现ABAE模型:https://github.com/iamwinter/UMPR/blob/pre-training/pre-training/abae.py
原文作者的代码的README中提供了数据集的下载链接。
实验记录
设置超参aspect_size=14,ABAE将学习到14个aspect的embedding。对于每个aspect,计算其与任意单词的Word2vec Embedding之间的cosine距离,取前10个单词作为代表性单词,结果如下:
2021-04-28 09:57:59 INFO: train sentences: 279885
2021-04-28 09:57:59 INFO: test sentences: 1490
Loading word2vec from /home/zhaojinglong/Projects/UMPR/pretrain/dataset/restaura
2021-04-28 09:57:59 INFO: vocabulary size: 9000
2021-04-28 09:57:59 INFO: Loading training dataset
2021-04-28 09:58:40 INFO: Start to train.
ABAE training epoch 0: 100%|████████████████| 2187/2187 [00:31<00:00, 70.34it/s]
2021-04-28 09:59:12 INFO: Epoch 0; train loss 0.290156; Valid loss: 0.140530; number of per category:[61, 86, 52, 11, 2, 263, 605, 0, 3, 32, 7, 18, 335, 15]
ABAE training epoch 1: 100%|████████████████| 2187/2187 [00:31<00:00, 68.86it/s]
2021-04-28 09:59:44 INFO: Epoch 1; train loss 0.087966; Valid loss: 0.072155; number of per category:[47, 67, 86, 41, 1, 475, 438, 1, 0, 36, 7, 8, 271, 12]
ABAE training epoch 2: 100%|████████████████| 2187/2187 [00:30<00:00, 72.87it/s]
2021-04-28 10:00:14 INFO: Epoch 2; train loss 0.049858; Valid loss: 0.045330; number of per category:[53, 42, 92, 34, 2, 620, 332, 2, 2, 29, 10, 10, 247, 15]
ABAE training epoch 3: 100%|████████████████| 2187/2187 [00:31<00:00, 68.89it/s]
2021-04-28 10:00:46 INFO: Epoch 3; train loss 0.033686; Valid loss: 0.034566; number of per category:[50, 55, 72, 41, 23, 611, 379, 1, 0, 25, 13, 9, 198, 13]
ABAE training epoch 4: 100%|████████████████| 2187/2187 [00:33<00:00, 65.91it/s]
2021-04-28 10:01:19 INFO: Epoch 4; train loss 0.025960; Valid loss: 0.025766; number of per category:[60, 57, 62, 29, 33, 677, 399, 0, 1, 20, 11, 3, 128, 10]
ABAE training epoch 5: 100%|████████████████| 2187/2187 [00:28<00:00, 77.23it/s]
2021-04-28 10:01:47 INFO: Epoch 5; train loss 0.023936; Valid loss: 0.027931; number of per category:[57, 82, 56, 4, 36, 644, 469, 0, 0, 32, 11, 2, 96, 1]
ABAE training epoch 6: 100%|████████████████| 2187/2187 [00:34<00:00, 63.30it/s]
2021-04-28 10:02:22 INFO: Epoch 6; train loss 0.020974; Valid loss: 0.021833; number of per category:[46, 73, 52, 5, 44, 651, 514, 0, 0, 14, 25, 2, 61, 3]
ABAE training epoch 7: 100%|████████████████| 2187/2187 [00:32<00:00, 66.39it/s]
2021-04-28 10:02:55 INFO: Epoch 7; train loss 0.017763; Valid loss: 0.021184; number of per category:[31, 70, 45, 3, 56, 471, 547, 150, 0, 5, 15, 2, 80, 15]
ABAE training epoch 8: 100%|████████████████| 2187/2187 [00:33<00:00, 64.74it/s]
2021-04-28 10:03:29 INFO: Epoch 8; train loss 0.015974; Valid loss: 0.012010; number of per category:[41, 70, 36, 5, 39, 574, 427, 214, 0, 8, 21, 2, 36, 17]
ABAE training epoch 9: 100%|████████████████| 2187/2187 [00:29<00:00, 74.80it/s]
2021-04-28 10:03:59 INFO: Epoch 9; train loss 0.012756; Valid loss: 0.012025; number of per category:[42, 87, 43, 13, 43, 540, 482, 178, 1, 6, 15, 2, 35, 3]
ABAE training epoch 10: 100%|███████████████| 2187/2187 [00:34<00:00, 64.18it/s]
2021-04-28 10:04:33 INFO: Epoch 10; train loss 0.011524; Valid loss: 0.013158; number of per category:[29, 63, 47, 5, 21, 657, 536, 7, 1, 6, 16, 3, 65, 34]
ABAE training epoch 11: 100%|███████████████| 2187/2187 [00:31<00:00, 68.83it/s]
2021-04-28 10:05:05 INFO: Epoch 11; train loss 0.010627; Valid loss: 0.010079; number of per category:[41, 56, 47, 5, 24, 524, 566, 158, 2, 5, 10, 7, 29, 16]
ABAE training epoch 12: 100%|███████████████| 2187/2187 [00:27<00:00, 78.85it/s]
2021-04-28 10:05:32 INFO: Epoch 12; train loss 0.008745; Valid loss: 0.008166; number of per category:[50, 74, 49, 4, 55, 644, 560, 6, 1, 10, 15, 5, 12, 5]
ABAE training epoch 13: 100%|███████████████| 2187/2187 [00:31<00:00, 69.72it/s]
2021-04-28 10:06:04 INFO: Epoch 13; train loss 0.007257; Valid loss: 0.008417; number of per category:[36, 53, 66, 0, 45, 644, 601, 5, 1, 4, 9, 10, 7, 9]
ABAE training epoch 14: 100%|███████████████| 2187/2187 [00:32<00:00, 67.93it/s]
2021-04-28 10:06:36 INFO: Epoch 14; train loss 0.007551; Valid loss: 0.005902; number of per category:[34, 47, 61, 4, 52, 642, 615, 6, 8, 7, 4, 6, 1, 3]
2021-04-28 10:06:36 DEBUG: Aspect: 0: ['breadcrumb', 'accordingly', 'gum', 'papaya', 'constant', 'pastis', 'houston', 'fishing', 'maki', 'churrasco']
2021-04-28 10:06:36 DEBUG: Aspect: 1: ['restaurant', 'restuarant', 'eatery', 'resturant', 'restaraunt', 'orleans', 'deli', 'establishment', 'yorkers', 'williamsburg']
2021-04-28 10:06:36 DEBUG: Aspect: 2: ['common', 'understandable', 'reality', 'grace', 'palace', 'proper', 'deserve', 'result', 'issue', 'bizarre']
2021-04-28 10:06:36 DEBUG: Aspect: 3: ['waitstaff', 'staff', 'server', 'sommelier', 'service', 'accommodating', 'funny', 'incredibly', 'manner', 'waiter']
2021-04-28 10:06:36 DEBUG: Aspect: 4: ['elmo', 'nose', 'norma', 'slab', 'luck', 'rosa', 'dal', 'definetely', 'atomosphere', 'accept']
2021-04-28 10:06:36 DEBUG: Aspect: 5: ['sensual', 'deliciousness', 'reference', 'anti', '<PAD>', 'ony', 'chocalate', '<NUM>', 'whats', 'kim']
2021-04-28 10:06:36 DEBUG: Aspect: 6: ['latte', 'luke', 'brulee', 'beginning', 'bloody', 'crabcakes', 'baked', 'experience', 'pistachio', 'memory']
2021-04-28 10:06:36 DEBUG: Aspect: 7: ['dust', 'ventilation', 'chocalate', 'reply', 'whats', 'ony', 'mardi', '<UNK>', 'smelly', 'reference']
2021-04-28 10:06:36 DEBUG: Aspect: 8: ['greatly', 'terribly', 'smelled', 'posted', 'dissappointed', 'reminded', 'remained', 'gotten', 'sorely', 'never']
2021-04-28 10:06:36 DEBUG: Aspect: 9: ['lighting', 'design', 'music', 'indoor', 'soft', 'pink', 'decoration', 'faux', 'fireplace', 'band']
2021-04-28 10:06:36 DEBUG: Aspect: 10: ['vindaloo', 'bubbly', 'mine', 'daughter', 'yr', 'paul', 'catfish', 'mom', 'dad', 'father']
2021-04-28 10:06:36 DEBUG: Aspect: 11: ['replaced', 'poured', 'shoulder', 'literally', 'gooey', 'captain', 'stuck', 'maitre', 'foot', 'brown']
2021-04-28 10:06:36 DEBUG: Aspect: 12: ['eggplant', 'lightly', 'onion', 'deliciously', 'lettuce', 'savory', 'broth', 'risotto', 'enchilada', 'basil']
2021-04-28 10:06:36 DEBUG: Aspect: 13: ['food', 'samosa', 'usual', 'pad', 'cuisine', 'significantly', 'coupon', 'kuma', 'mix', 'adobe']
2021-04-28 10:06:37 INFO: Trained model "/home/zhaojinglong/Projects/UMPR/pretrain/model/ABAE.pt" has been saved.
2021-04-28 10:06:37 INFO: Please choose a category from following list for each aspect.
2021-04-28 10:06:37 INFO: {0: 'Food', 1: 'Staff', 2: 'Ambience', 3: 'Price', 4: 'Anecdotes', 5: 'Miscellaneous'}
Input index(0~5) to aspect 0:5
5
Input index(0~5) to aspect 1:2
2
Input index(0~5) to aspect 2:2
2
Input index(0~5) to aspect 3:1
1
Input index(0~5) to aspect 4:4
4
Input index(0~5) to aspect 5:0
0
Input index(0~5) to aspect 6:0
0
Input index(0~5) to aspect 7:2
2
Input index(0~5) to aspect 8:4
4
Input index(0~5) to aspect 9:2
2
Input index(0~5) to aspect 10:0
0
Input index(0~5) to aspect 11:4
4
Input index(0~5) to aspect 12:0
0
Input index(0~5) to aspect 13:0
0
Evaluate: 100%|███████████████████████████████████| 1/1 [00:00<00:00, 49.45it/s]
2021-04-28 10:12:29 INFO: Accuracy: 0.529530