四十三-从理论到实践开发聊天机器人2
上一篇文章实现了一个能够训练1000个样本的聊天机器人,并且效果还可以,但是因为对样本全量加载,所以当用大量样本来做训练时,内存就撑不住了,总是Out of memory,本篇文章解决了这个问题,方法是把全量加载样本改成了批量加载,这样样本量再大,内存也不会无限增加了
上篇回顾¶
上篇文章《自己动手做聊天机器人 四十二-(重量级长文)从理论到实践开发自己的聊天机器人》使用带attention的seq2seq模型实现一般聊天机器人,经过10个小时对1000条样本的训练,达到了比较好的效果,代码分享在这里
但是存在一个问题,当把样本量加大的时候内存随之增长,如果样本量达到万级别,内存占用已经达到了10G,样本量如果到几十万几百万,内存已经不知道能到多少了,这个主要问题是每次迭代都是把样本全量加载到内存并一次性训练完再更新模型,另外还有一个问题就是词表是基于样本生成的,没有做任何限制,导致样本大词表就大,那么模型就很大,所以占据内存也更大,所以我做了一版优化,在自己机器上尝试训练20w的样本内存占用不到1G,希望大家能找到更大量的样本来帮我充分测试,我这里有三千万的聊天语料可以使用,欢迎大家尝试,获取方式请见《自己动手做聊天机器人 二十九-重磅:近1GB的三千万聊天语料供出》。
优化方案¶
首先我们把全量加载样本改成批量加载,修改原来的train()函数,核心部分如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | # 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止 previous_losses = [] for step in xrange(20000): sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set, 1000) input_feed = {} for l in xrange(input_seq_len): input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l] for l in xrange(output_seq_len): input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l] input_feed[target_weights[l].name] = sample_target_weights[l] input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32) [loss_ret, _] = sess.run([loss, update], input_feed) if step % 10 == 0: print 'step=', step, 'loss=', loss_ret, 'learning_rate=', learning_rate.eval() if len(previous_losses) > 5 and loss_ret > max(previous_losses[-5:]): sess.run(learning_rate_decay_op) previous_losses.append(loss_ret) # 模型持久化 saver.save(sess, './model/demo') |
这里的get_samples(train_set, 1000)是批量获取样本,其中1000是每次获取的样本量,这个函数增加了如下逻辑:
1 2 3 4 5 6 7 8 | if batch_num >= len(train_set): batch_train_set = train_set else: random_start = random.randint(0, len(train_set)-batch_num) batch_train_set = train_set[random_start:random_start+batch_num] for sample in batch_train_set: raw_encoder_input.append([PAD_ID] * (input_seq_len - len(sample[0])) + sample[0]) raw_decoder_input.append([GO_ID] + sample[1] + [PAD_ID] * (output_seq_len - len(sample[1]) - 1)) |
也就是说每次都在全量样本中随机位置抽取连续的1000条样本
另外,在加载样本词表时做了词的最小频率的限制,如下:
1 2 3 4 5 6 7 8 9 | def load_file_list(self, file_list, min_freq): ...... for index, item in enumerate(sorted_list): word = item[1] if item[0] < min_freq: break self.word2id_dict[word] = self.START_ID + index self.id2word_dict[self.START_ID + index] = word return index |
试验效果¶
经过如上改造,我把样本量扩展到21w,运行起来内存占用不到1G,最新的代码请见最新更新的: