[论文笔记随手] Training with Weighted Sum of Denoising Score Matching Objectives

news/2024/5/18 16:29:33 标签: 其他, 论文

[note] Training with Weighted Sum of Denoising Score Matching Objectives

利用 去噪分数匹配目标的加权和 进行训练,去噪指的是使用sde的方法就不需要自行补充噪声了。

本文的目的是解释如何对原始数据进行扰动。 from https://yang-song.github.io/blog/2021/score/

一、理论

首先,挑选一个随机过程(SDE)对原始数据分布 p 0 p_0 p0进行扰动得到扰动后数据的概率密度分布 p t p_t pt

本文选择的随机过程为:
d x = σ t d w ,   t ∈ [ 0 , 1 ] d{\bf x} = \sigma^td{\bf w}, \ t\in[0,1] dx=σtdw, t[0,1]
在这种情况下,扰动后数据的概率密度分布 p t p_t pt,在原始数据下的条件概率分布为:
p 0 t ( x ( t ) ∣ x ( 0 ) ) = N ( x ( t ) ; x ( 0 ) , 1 2 log ⁡ σ ( σ 2 t − 1 ) I ) p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) = \mathcal{N}\bigg(\mathbf{x}(t); \mathbf{x}(0), \frac{1}{2\log \sigma}(\sigma^{2t} - 1) \mathbf{I}\bigg) p0t(x(t)x(0))=N(x(t);x(0),2logσ1(σ2t1)I)
关于这个函数的解释是,使用参数$ \frac{1}{2\log \sigma}(\sigma^{2t} - 1) 作 为 我 们 的 权 重 函 数 , 即 作为我们的权重函数,即 \lambda(t) = \frac{1}{2 \log \sigma}(\sigma^{2t} - 1)$.

当参数 σ \sigma σ变得非常大的时候,其中的先验分布 p t = 1 p_{t=1} pt=1,也就是最终扰动后的数据分布就可以变成一个正太分布:
∫ p 0 ( y ) N ( x ; y , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) d y ≈ N ( x ; 0 , 1 2 log ⁡ σ ( σ 2 − 1 ) I ) , \int p_0(\mathbf{y})\mathcal{N}\bigg(\mathbf{x}; \mathbf{y}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg) d \mathbf{y} \approx \mathbf{N}\bigg(\mathbf{x}; \mathbf{0}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg), p0(y)N(x;y,2logσ1(σ21)I)dyN(x;0,2logσ1(σ21)I),
直观地说,这个SDE通过一个变种函数 1 2   l o g   σ ( σ 2 t − 1 ) \frac1{2\ log\ \sigma}(\sigma^{2t}-1) 2 log σ1(σ2t1)帮助我们捕获了高斯扰动的数据变量集合(连续统continuum),即 x ( t ) x(t) x(t)。这个数据变量集合可以帮助我们逐渐将原始数据分布 p 0 p_0 p0变成了一个简单的高斯分布 p 1 p_1 p1,也就是t=1时候的分布。

二、代码实现

1) 对t进行连续采样

 # 对时间特征t进行均匀采样
 random_t = torch.rand(x.x.shape[0]//30, device=device) * (1. - eps) + eps # 防止采样到0

2)定义权重函数

可以看到,这里定义的权重函数就是作者在上面提到的 λ ( t ) \lambda(t) λ(t)函数。

def marginal_prob_std(t, sigma):
    # t = torch.tensor(t, device=device)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))

3)对数据进行扰动

# 表征时间的特征t, 从0到1上进行均匀采样
random_t = torch.rand(batchsize, device=device) * (1. - eps) + eps # 这里的eps是为了防止采样到t=0

# 构造一个与原始数据结构一样的向量,并在[0,1)上进行均匀采样。
z = torch.randn_like(x.x)

# 利用前面均匀采样的时间特征t,求得权重函数的值,这个权重函数的目的就是为了使得t=1时的扰动数据达到一个正太分布的结果。重复30遍的目的是因为一轮训练中设置的batch_size = 30
std = marginal_prob_std_func(random_t).repeat(1, 30).view(-1, 1)

# 这里将噪声与标准差相乘,
perturbed_x = copy.deepcopy(x)
perturbed_x.x += z * std

4)利用扰动的数据进行训练

需要补充一下,为了训练积分函数模型,目前的目标函数变成了下面这个样子:
E t ∈ u ( 0 , T ) E p t ( x ) [ λ ( t ) ∣ ∣ ∇ x l o g   p t ( x ) − s θ ( x , t ) ∣ ∣ 2 2 ] \mathbb{E}_{t\in u(0,T)}\mathbb{E}_{p_t(x)}[\lambda(t)||\nabla_xlog\ p_t(x)-s_\theta(x,t)||_2^2] Etu(0,T)Ept(x)[λ(t)xlog pt(x)sθ(x,t)22]

这里是最基本的目标函数的样子:
E p ( x ) [ ∣ ∣ ∇ x l o g   p ( x )   −   s θ ( x ) ∣ ∣ 2 2 ]   =   ∫   p ( x ) ∣ ∣ ∇ x   l o g   p ( x )   −   s θ ( x ) ∣ ∣ 2 2 d x . \mathbb{E}_{p(x)}[{||\nabla_xlog\ p(x)\ -\ s_\theta(x)||}_2^2]\ =\ \int\ p(x)||\nabla_x\ log\ p(x)\ -\ s_\theta(x)||_2^2dx. Ep(x)[xlog p(x)  sθ(x)22] =  p(x)x log p(x)  sθ(x)22dx.
为了估计这个目标函数,需要如下估计,即使用Score Matching的方法进行估计(Hyvärinen 2005):

可以看到,去估计如下的目标函数是可以达到的。

E p d a t a ( x ) [ 1 2 ∣ ∣ s θ ( x ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x ) ) ] \mathbb{E}_{p_{data}(x)}[\frac12||s_\theta(x)||_2^2+trace(\nabla_xs_\theta(x))] Epdata(x)[21sθ(x)22+trace(xsθ(x))]

具体上,体现在代码上,用的是如下的公式:
1 N ∑ i = 1 N [ 1 2 ∣ ∣ s θ ( x i ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x i ) ) ] ≈ 1 N ∑ i = 1 N [ 1 2 ∣ ∣ s θ ( x i ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x i ) ) \frac1N\sum^N_{i=1}[\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))] \\ \approx \frac1N \sum_{i=1}^N [\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i)) N1i=1N[21sθ(xi)22+trace(xsθ(xi))]N1i=1N[21sθ(xi)22+trace(xsθ(xi))

# 计算积分函数的值
output = model(perturbed_x, random_t)
# score matching的损失函数,与上式不一致的原因在于,本文的目标函数中还有一个参数\lambda(t),所以表现为如下的形式。
loss_ = torch.mean(torch.sum(((output * std + z)**2).view(batch_size, -1)), dim=-1)
# 一轮训练之后,将score matching的目标函数的结果返回
return loss_

🙋‍♂️ 我有一个问题,这个目标函数是怎么推理得到的呀? 🤔


http://www.niftyadmin.cn/n/572799.html

相关文章

RK3399平台开发系列讲解(进程调度篇)14.6、等待队列结构体的抽象与关系

=>返回专栏总目录<= 文章目录 一、数据结构定义1、列表头2、列表项3、双向链表结构二、作用三、字段详解1、spinlock_t lock;2、srtuct list_head_t task_list;四、

RK3568平台开发系列讲解(触摸屏篇)Android11 触摸芯片移植

🚀返回专栏总目录 文章目录 一、硬件原理图分析二、配置设备树三、内核配置四、触摸屏验证沉淀、分享、成长,让自己和他人都能有所收获!😄 📢 本章节我们来配置触摸,mipi 屏幕的触摸芯片是 ft5x06。 一、硬件原理图分析 瑞芯微提供的 Android11 源码里面自带 ft5x06 …

交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数

分类问题中&#xff0c;交叉熵函数是比较常用也是比较基础的损失函数&#xff0c;原来就是了解&#xff0c;但一直搞不懂他是怎么来的&#xff1f;为什么交叉熵能够表征真实样本标签和预测概率之间的差值&#xff1f;趁着这次学习把这些概念系统学习了一下。 首先说起交叉熵&am…

利用Node.js延时执行脚本

利用Node.js延时执行脚本 本文代码的目的是为了延时执行linux脚本 setTimeout(()>{var spawn require(child_process).spawn;free spawn(./go.sh);// 捕获标准输出并将其打印到控制台 free.stdout.on(data, function (data) { console.log(standard output:\n data); });…

RK3568平台开发系列讲解(驱动篇)Linux 设备和分类

🚀返回专栏总目录 文章目录 一、Linux 设备和分类二、设备节点和设备号2.1、设备节点2.2、设备编号2.3、获取和释放设备编号三、设备的注册和注销沉淀、分享、成长,让自己和他人都能有所收获!😄 一、Linux 设备和分类 Linux 系统中的设备可以分为字符设备、块设备和网络…

sourthtree与gitlab之间的连接

sourthtree与gitlab之间的连接 sourthtree是一款免费的git界面软件&#xff0c;下载安装该工具&#xff0c;安装的过程需要注册账号&#xff0c;用google邮箱账号或者其它可用账号 注册&#xff0c;然后进行与服务端gitlab项目关联之前&#xff0c;需要进行把本地创建的公钥上传…

脚踏实地的好好学习深度学习 笔记一 线性回归

李宏毅 深度学习 笔记一 from https://www.bilibili.com/video/BV15b411g7Wd?fromsearch&seid2645809784486608148 脚踏实地学习&#xff0c;安安心心科研 这里是目录呀李宏毅 深度学习 笔记一一、利用宝可梦的案例来引出回归模型二、损失函数三、Gradient Desent梯度的公…

RK3399平台开发系列讲解(中断篇)13.6、中断irq_desc_tree描述

文章目录平台内核版本安卓版本RK3399Linux4.4Android7.1