LoRA 的概念和背景

LoRA(Low Rank Adaption)是一种改变模型有效权重的方法,通常用来对一个较大的、已经训练好的模型进行低成本微调。在大预训练模型时代,一个“小”模型运行起来动辄都要消耗数 10G 的显存,在 LoRA 出现之前,普通人想对其进行定制几乎是不可能的。但是如果我们考虑到以下事实,这件事就变得看起来有点可行了:

  • 当一个模型的权重变化很小时,模型作为一个参数化的函数的行为变化一般都很小(比如我们可以对 DQL 里面的 target network 进行 soft update)。
  • 任意一个 \(M\times N\) 的矩阵,我们都可以用一系列 \(M \times K \cdot K \times N\) 的矩阵乘积来拟合。
  • 在 neuroscience 和模型的可解释性研究里,有大量结果表明不少的网络的表征实际上可以很好地在低维空间内被刻画。
    从上面几点可以看出,如果把 \(M \times K \cdot K \times N\) 中的 \(K\) 设置为 1,然后把这样的输入输出符合结构的“挂件”附着在一些已经训练好的模型上,我们也许可以在对原模型行为伤害很小的情况下对其进行微调。同时,由于 \(M \times 1 \cdot 1 \times N\) 实际上的独立参数远比完整的模型小,这种微调所需要的资源也小得多。这就是 LoRA 的思想。实践证明这个思路确实非常有效。

最小 LoRA 实现

由于 LoRA 的概念非常简单,对模型的结构几乎没有任何要求,而在绝大多数模型中,参数最后都可以归结为一个个的矩阵,所以我们应该是可以写几个通用的方法对任意模型进行 LoRA 的。这里我们以 PyTorch 为例看看这样的东西应该怎么写。
对一个模型的参数通常只需要考虑两件事:前向计算和反向传播。在 PyTorch 的 Module 中,前者由 forward 后者由 backward 执行。要对已经存在的模型进行这两个函数的修改是不现实的,更好的方法是我们改变他们的参数。所以添加 LoRA 的第一步应该是遍历整个模型,找出所有可以被加挂件的参数。PyTorch 中 nn.Module 已经自带了一个这种方法 apply(), 一般用于模型的参数初始化,会递归地遍历当前 module 和所有的子 module。它在这里刚好可以被我们所用。因此我们写一个传入 apply 的函数,检测其所有参数,找到其中包含 nn.Linear 的部分。

import torch.nn.utils.parametrize as P

def inject_lora(layer):
    if type(layer) == nn.Linear:
        # 对模型的 weight 加 LoRA 挂件

接下来我们只要对这个 nn.Linear 的部分做修改即可。torch.nn.utils.parameterize 刚好提供了这样的工具:register_parametrizationregister_parametrization 第一个参数是用于接受修改的模型,第二个是被修改的模型的参数属性名,第三个则是对这个模型中指定的参数进行修改的部分了,通常是另一个 nn.Module。在 LoRA 中,可以让这第三个参数变成一个接受当前模型的 nn.Module 的子类。在它检查到 full rank 的参数以后,构造出一个 low rank 但是输入输出符合的 LoRA nn.Module 出来。注意这个 nn.Moduleforward 方法是用于输出原来的 nn.Module 的参数,而不是参与原来的 nn.Moduleforward 计算的 (这里 PyTorch 还允许我们用 right_inverse 方法来预处理原来的 nn.Module 的参数,但是 LoRA 用不到就不展开了).
显然在 LoRA 里这个挂件 nn.Modulenn.Linear 原模块 的 forward 的方法就是 \(W' = W+ AB\) 了。

import torch.nn.utils.parametrize as P

class LoRA(nn.Module):
    def __init__(self, layer, rank=1):
        super().__init__()
        self.fan_out, self.fan_in = layer.weight.shape
        self.rank = rank
        self.A = nn.Parameter(torch.zeros(self.fan_out, self.rank))
        self.B = nn.Parameter(torch.zeros(self.rank, self.fan_in))
    def __repr__(self):
        return f'LoRA({self.fan_in}->{self.rank}->{self.fan_out})'

    def forward(self, W):
        # 这里的 W 就是原来的 weight
        W_p =  W + torch.matmul(self.A, self.B).view(W.shape) # 修改 W
        return W_p
      

def inject_lora(layer):
    if type(layer) == nn.Linear:
    # 对模型的 weight 加 LoRA 挂件
        P.register_parametrization(layer, "weight", LoRA(layer))

用上面这个代码对一个普通的 nn.Linear 进行修改:

m2 = nn.Linear(8,10)
m2.apply(inject_lora)

得到

ParametrizedLinear(
  in_features=8, out_features=10, bias=True
  (parametrizations): ModuleDict((weight): ParametrizationList((0): LoRA(8->1->10)
    )))

由于 PyTorch 保证了会在前向计算中自动考虑这种参数化层的参数(参数的参数)的梯度,所以我们不用管其中的梯度计算问题。

LoRA 的融合和去除

在上面的实现中,如果训练以后我们希望将 LoRA 的权重融合进原始模型,可以利用 P.remove_parametrizations 来修改原始权重,只要在调用 remove_parametrizations 设置 leave_parametrized=True 即可。此时 LoRA.forward 的输出将变成原始权重的参数。LoRA 的去除也一样简单, 将这个选项设置为 False 即可。

跨参数的 LoRA

由于上面的参数化方法仅适用于单个参数。所以如果要 LoRA 的部分跨了单个操作就行不通了。由于不能修改 forward, 而跨参数的 LoRA 实际上就是给计算图加 shortcuts, 所以应该是无法直接实现的,不过如果考虑到一些固有结构,比如 nn.Sequential, 还是可以做一些操作,在这里就不展开了。