【pytorch】Animatable 3D Gaussian+源码解读(一)_animatable gaussians-程序员宅基地

技术标签: 3d  3DGS  pytorch  windows  

概述

创新点:

  1. 多人场景 无遮挡处理
  2. 以3DGS进行表达

方法:
在这里插入图片描述

环境配置

基本和3DGS的配置差不多…

pip install torch==1.13.1+cu117 torchvision --index-url https://download.pytorch.org/whl/cu117
pip install hydra-core==1.3.2
pip install pytorch-lightning==2.1.2
pip install imageio
pip install ./submodules/diff-gaussian-rasterization
pip install ./submodules/simple-knn
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

子模块记得 git clone --recursive 多次踩坑…
tinycudann建议先git到本地…

数据准备

数据集下载

|---data
|   |---Gala
|   |---PeopleSnapshot
|   |---smpl

下载以后挨个解压,整理成格式

单个数据集结构:
在这里插入图片描述

0-6代表不同相机序号——同一视频的不同视角

每个文件夹均有对应的300帧+300mask。

在这里插入图片描述

bg定义了7个相机视角下的背景,目前还不知道有什么用,后续看看…

在这里插入图片描述

camera文件夹下是7个文本文件,描述相机参数。具体的表示在【pytorch】Animatable 3D Gaussian+源码解读(二)中分析。

。

pose文件夹下是300个文本文件,分别描述300帧中300个pose。

此外,Gala数据集中还有一个名为model的文件夹需要注意:

在这里插入图片描述
应该是定义了人体标准模板,包括网格、tpose、蒙皮权重等…

model_path (str) : The path to the folder that holds the vertices, tpose matrix, binding weights and indexes.

数据集具体是怎么利用的我们再结合代码来看看…

代码解读

先跟着debug走一遍流程… 然后再以面向对象的思路把握全局

Define Gaussian

首先定义场景表达元素:高斯球

train.py:

model = NeRFModel(opt)

nerf_model.py:

class NeRFModel(pl.LightningModule):
    def __init__(self, opt):
        super(NeRFModel, self).__init__()
        self.save_hyperparameters() # 储存init中输入的所有超参
        self.model = hydra.utils.instantiate(opt.deformer)

此处,opt.deformer==Gala模型:

Since the public dataset [1] contains few pose and shadow changes, we create a new dataset named GalaBasketball in order to show the
performance of our method under complex motion and dynamic
shadows
.

正文开始:

    """
    Attributes:
        parents (list[J]) : Indicate the parent joint for each joint, -1 for root joint.
        bone_count (int) : The count of joints including root joint, i.e. J.
        joints (torch.Tensor[J-1, 3]) : The translations of each joint relative to the parent joint, except for root joint.
        tpose_w2l_mats (torch.Tensor[J, 3]) : The skeleton to local transform matrixes for each joint.
    """

初始化函数:

        """
        Init joints and offset matrixes from files.

        Args:
            model_path (str) : The path to the folder that holds the vertices, tpose matrix, binding weights and indexes.
            num_players (int) : Number of players.  # 多人场景
        """

由此模型超参数为:

在这里插入图片描述

首先是一些人体基本操作——读取相关数据:

        model_file = os.path.join(model_path, "mesh.txt")
        vertices, normals, uvs, bone_weights, bone_indices = read_skinned_mesh_data(
            model_file) 

        tpose_file = os.path.join(model_path, "tpose.txt")
        tpose_positions, tpose_rotations, tpose_scales = read_bone_joints(
            tpose_file)

        tpose_mat_file = os.path.join(model_path, "tpose_mat.txt")
        tpose_w2l_mats = read_bone_joint_mats(tpose_mat_file)

        joint_parent_file = os.path.join(model_path, "jointParent.txt")
        self.joint_parent_idx = read_bone_parent_indices(joint_parent_file)

        self.bone_count = tpose_positions.shape[0]
        self.vertex_count = vertices.shape[0]

        print("mesh loaded:")
        print("total vertices: " + str(vertices.shape[0]))
        print("num of joints: " + str(self.bone_count))

read_skinned_mesh_data(“mesh.txt”)函数读取顶点、蒙皮权重、UV坐标;
read_bone_joints(“tpose.txt”)函数读取关节数据;
read_bone_joint_mats(“tpose_mat.txt”)读取world-to-local转化矩阵;
read_bone_parent_indices(“jointParent.txt”)读取关节父子关系。

多人扩维模板复制:

        self.register_buffer('v_template', vertices[None, ...].repeat(
            [self.num_players, 1, 1]))
        uvs = uvs * 2. - 1.
        self.register_buffer('uvs', uvs[None, ...].repeat(
            [self.num_players, 1, 1])) 
        bone_weights = torch.Tensor(
            np.load(os.path.join(model_path, "weights.npy")))[None, ...].repeat([self.num_players, 1, 1])
        self.register_buffer("bone_weights", bone_weights)

1.register_buffer:定义一组参数,该组参数的特别之处在于:模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
2.[None,…]:多一维

		
		# 关节位置申请加入训练
        self.J = nn.Parameter(
            tpose_positions[None, ...].repeat([self.num_players, 1, 1]))
        self.tpose_w2l_mats = tpose_w2l_mats

        # 顶点归一化
        minmax = [self.v_template[0].min(
            dim=0).values * 1.05,  self.v_template[0].max(dim=0).values * 1.05]
        self.register_buffer('normalized_vertices',
                             (self.v_template - minmax[0]) / (minmax[1] - minmax[0]))

        # distCUDA2 from simple_knn 计算点云中的每个点到与其最近的K个点的平均距离的平方
        dist2 = torch.clamp_min(
            distCUDA2(vertices.float().cuda()), 0.0000001)[..., None].repeat([1, 3])

然后开始处理要训练的高斯:

定义顶点偏移:

using unconstrained per-vertex displacement can easily cause the optimization process to diverge in dynamic scenes.Therefore, we also model a parameter field for vertex displacement. F.

	# x0 →  δx
    if use_point_displacement:
        self.displacements = nn.Parameter(
            torch.zeros_like(self.v_template))
    else:
        # 使用encoder
        self.displacementEncoder = DisplacementEncoder(
            encoder=encoder_type, num_players=num_players)

多种编码方式:uv encoder、hash encoder…

Since our animatable 3D Gaussian representation is initialized by a standard human body model, the centers of 3D Gaussians are uniformly distributed near the human surface. We only need to sample at fixed positions near the surface of the human body in the parameter fields. This allows for significant compression of the hash table for the hash encoding [36]. Thus, we choose the hash encoding to model our parameter field to reduce the time and storage consumption.

class DisplacementEncoder(nn.Module):
    def __init__(self, encoder="uv", num_players=1):
        super().__init__()
        self.num_players = num_players
        if encoder == "uv":
            self.input_channels = 2
            self.encoder = UVEncoder(
                3, num_players=num_players)
        elif encoder == "hash":
            self.input_channels = 3
            self.encoder = HashEncoder(
                3, num_players=num_players)
        elif encoder == "triplane":
            self.input_channels = 3
            self.encoder = TriPlaneEncoder(
                3, num_players=num_players)
        else:
            raise Exception("encoder does not exist!")

这里先选择hash编码,使用tcnn

class HashEncoder(nn.Module):
    def __init__(self, num_channels, num_players=1):
        super().__init__()
        self.networks = []
        self.num_players = num_players
        for i in range(num_players):
            self.networks.append(tcnn.NetworkWithInputEncoding(
                n_input_dims=3,
                n_output_dims=num_channels,
                encoding_config={
                    "otype": "HashGrid",
                    "n_levels": 16,
                    "n_features_per_level": 4,
                    "log2_hashmap_size": 17,
                    "base_resolution": 4,
                    "per_level_scale": 1.5,
                },
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                }
            ))
        self.networks = nn.ModuleList(self.networks)

定义颜色、透明度、缩放、旋转:

rendering based on 3D Gaussian rasterization can only backpropagate the gradient to a finite number of Gaussians in a single iteration, which leads to a slow or even divergent optimization process for dynamic scenes. To address this issue, we suggest sampling spherical harmonic coefficients SH for each vertex from a continuous parameter field, which is able to affect all neighboring Gaussians in a single optimization.

Optionally, we provide UV-encoded spherical harmonic coefficients, allowing fast processing of custom human models with UV coordinate mappings. UV encoding potentially achieves higher reconstruction quality compared to hash encoding.

        n = self.v_template.shape[1] * num_players # 总顶点数
        # x0 → SH
        if use_point_color:
            self.shs_dc = nn.Parameter(torch.zeros(
                [n, 1, 3]))
            self.shs_rest = nn.Parameter(torch.zeros(
                [n, (max_sh_degree + 1) ** 2 - 1, 3]))
        else:
        	# 使用encoder
            self.shEncoder = SHEncoder(max_sh_degree=max_sh_degree,
                                       encoder=encoder_type, num_players=num_players)
        self.opacity = nn.Parameter(inverse_sigmoid(
            0.2 * torch.ones((n, 1), dtype=torch.float)))
        self.scales = nn.Parameter(
            torch.log(torch.sqrt(dist2)).repeat([num_players, 1]))
        rotations = torch.zeros([n, 4])
        rotations[:, 0] = 1
        self.rotations = nn.Parameter(rotations)

遮挡处理:x0, γ(t) → ao

We propose a time-dependent ambient occlusion module to address the issue of dynamic shadows in specific scenes.

        if enable_ambient_occlusion:
            self.aoEncoder = AOEncoder(
                encoder=encoder_type, max_freq=max_freq, num_players=num_players)
        self.register_buffer("aos", torch.ones_like(self.opacity))

在这里插入图片描述

we also employ hash encoding for the ambient occlusion ao, since shadows should be continuously modeled in space

class AOEncoder(nn.Module):
    def __init__(self, encoder="uv", num_players=1, max_freq=4):
        super().__init__()
        self.num_players = num_players
        self.max_freq = max_freq
        if encoder == "uv":
            self.input_channels = 2
            self.encoder = UVTimeEncoder(
                1, num_players=num_players, time_dim=max_freq*2 + 1)
        elif encoder == "hash":
            self.input_channels = 3
            self.encoder = HashTimeEncoder(
                1, num_players=num_players, time_dim=max_freq*2 + 1)
        else:
            raise Exception("encoder does not exist!")
class HashTimeEncoder(nn.Module):
    def __init__(self, num_channels, time_dim=9, num_players=1):
        super().__init__()
        self.networks = []
        self.time_nets = []
        self.num_players = num_players
        for i in range(num_players):
            self.networks.append(tcnn.Encoding(
                n_input_dims=3,
                encoding_config={
                    "otype": "HashGrid",
                    "n_levels": 16,
                    "n_features_per_level": 4,
                    "log2_hashmap_size": 19,
                    "base_resolution": 4,
                    "per_level_scale": 1.5,
                },
            ))
            self.time_nets.append(tcnn.Network(
                n_input_dims=self.networks[i].n_output_dims + time_dim,
                n_output_dims=num_channels,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                }
            ))
        self.networks = nn.ModuleList(self.networks)
        self.time_nets = nn.ModuleList(self.time_nets)

至此,高斯球定义完成。
在这里插入图片描述
在这里插入图片描述

【pytorch】Animatable 3D Gaussian+源码解读(二)将进一步介绍数据集的处理细节。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XYQjiayou/article/details/136271978

智能推荐

从零开始搭建Hadoop_创建一个hadoop项目-程序员宅基地

文章浏览阅读331次。第一部分:准备工作1 安装虚拟机2 安装centos73 安装JDK以上三步是准备工作,至此已经完成一台已安装JDK的主机第二部分:准备3台虚拟机以下所有工作最好都在root权限下操作1 克隆上面已经有一台虚拟机了,现在对master进行克隆,克隆出另外2台子机;1.1 进行克隆21.2 下一步1.3 下一步1.4 下一步1.5 根据子机需要,命名和安装路径1.6 ..._创建一个hadoop项目

心脏滴血漏洞HeartBleed CVE-2014-0160深入代码层面的分析_heartbleed代码分析-程序员宅基地

文章浏览阅读1.7k次。心脏滴血漏洞HeartBleed CVE-2014-0160 是由heartbeat功能引入的,本文从深入码层面的分析该漏洞产生的原因_heartbleed代码分析

java读取ofd文档内容_ofd电子文档内容分析工具(分析文档、签章和证书)-程序员宅基地

文章浏览阅读1.4k次。前言ofd是国家文档标准,其对标的文档格式是pdf。ofd文档是容器格式文件,ofd其实就是压缩包。将ofd文件后缀改为.zip,解压后可看到文件包含的内容。ofd文件分析工具下载:点我下载。ofd文件解压后,可以看到如下内容: 对于xml文件,可以用文本工具查看。但是对于印章文件(Seal.esl)、签名文件(SignedValue.dat)就无法查看其内容了。本人开发一款ofd内容查看器,..._signedvalue.dat

基于FPGA的数据采集系统(一)_基于fpga的信息采集-程序员宅基地

文章浏览阅读1.8w次,点赞29次,收藏313次。整体系统设计本设计主要是对ADC和DAC的使用,主要实现功能流程为:首先通过串口向FPGA发送控制信号,控制DAC芯片tlv5618进行DA装换,转换的数据存在ROM中,转换开始时读取ROM中数据进行读取转换。其次用按键控制adc128s052进行模数转换100次,模数转换数据存储到FIFO中,再从FIFO中读取数据通过串口输出显示在pc上。其整体系统框图如下:图1:FPGA数据采集系统框图从图中可以看出,该系统主要包括9个模块:串口接收模块、按键消抖模块、按键控制模块、ROM模块、D.._基于fpga的信息采集

微服务 spring cloud zuul com.netflix.zuul.exception.ZuulException GENERAL-程序员宅基地

文章浏览阅读2.5w次。1.背景错误信息:-- [http-nio-9904-exec-5] o.s.c.n.z.filters.post.SendErrorFilter : Error during filteringcom.netflix.zuul.exception.ZuulException: Forwarding error at org.springframework.cloud..._com.netflix.zuul.exception.zuulexception

邻接矩阵-建立图-程序员宅基地

文章浏览阅读358次。1.介绍图的相关概念  图是由顶点的有穷非空集和一个描述顶点之间关系-边(或者弧)的集合组成。通常,图中的数据元素被称为顶点,顶点间的关系用边表示,图通常用字母G表示,图的顶点通常用字母V表示,所以图可以定义为:  G=(V,E)其中,V(G)是图中顶点的有穷非空集合,E(G)是V(G)中顶点的边的有穷集合1.1 无向图:图中任意两个顶点构成的边是没有方向的1.2 有向图:图中..._给定一个邻接矩阵未必能够造出一个图

随便推点

MDT2012部署系列之11 WDS安装与配置-程序员宅基地

文章浏览阅读321次。(十二)、WDS服务器安装通过前面的测试我们会发现,每次安装的时候需要加域光盘映像,这是一个比较麻烦的事情,试想一个上万个的公司,你天天带着一个光盘与光驱去给别人装系统,这将是一个多么痛苦的事情啊,有什么方法可以解决这个问题了?答案是肯定的,下面我们就来简单说一下。WDS服务器,它是Windows自带的一个免费的基于系统本身角色的一个功能,它主要提供一种简单、安全的通过网络快速、远程将Window..._doc server2012上通过wds+mdt无人值守部署win11系统.doc

python--xlrd/xlwt/xlutils_xlutils模块可以读xlsx吗-程序员宅基地

文章浏览阅读219次。python–xlrd/xlwt/xlutilsxlrd只能读取,不能改,支持 xlsx和xls 格式xlwt只能改,不能读xlwt只能保存为.xls格式xlutils能将xlrd.Book转为xlwt.Workbook,从而得以在现有xls的基础上修改数据,并创建一个新的xls,实现修改xlrd打开文件import xlrdexcel=xlrd.open_workbook('E:/test.xlsx') 返回值为xlrd.book.Book对象,不能修改获取sheett_xlutils模块可以读xlsx吗

关于新版本selenium定位元素报错:‘WebDriver‘ object has no attribute ‘find_element_by_id‘等问题_unresolved attribute reference 'find_element_by_id-程序员宅基地

文章浏览阅读8.2w次,点赞267次,收藏656次。运行Selenium出现'WebDriver' object has no attribute 'find_element_by_id'或AttributeError: 'WebDriver' object has no attribute 'find_element_by_xpath'等定位元素代码错误,是因为selenium更新到了新的版本,以前的一些语法经过改动。..............._unresolved attribute reference 'find_element_by_id' for class 'webdriver

DOM对象转换成jQuery对象转换与子页面获取父页面DOM对象-程序员宅基地

文章浏览阅读198次。一:模态窗口//父页面JSwindow.showModalDialog(ifrmehref, window, 'dialogWidth:550px;dialogHeight:150px;help:no;resizable:no;status:no');//子页面获取父页面DOM对象//window.showModalDialog的DOM对象var v=parentWin..._jquery获取父window下的dom对象

什么是算法?-程序员宅基地

文章浏览阅读1.7w次,点赞15次,收藏129次。算法(algorithm)是解决一系列问题的清晰指令,也就是,能对一定规范的输入,在有限的时间内获得所要求的输出。 简单来说,算法就是解决一个问题的具体方法和步骤。算法是程序的灵 魂。二、算法的特征1.可行性 算法中执行的任何计算步骤都可以分解为基本可执行的操作步,即每个计算步都可以在有限时间里完成(也称之为有效性) 算法的每一步都要有确切的意义,不能有二义性。例如“增加x的值”,并没有说增加多少,计算机就无法执行明确的运算。 _算法

【网络安全】网络安全的标准和规范_网络安全标准规范-程序员宅基地

文章浏览阅读1.5k次,点赞18次,收藏26次。网络安全的标准和规范是网络安全领域的重要组成部分。它们为网络安全提供了技术依据,规定了网络安全的技术要求和操作方式,帮助我们构建安全的网络环境。下面,我们将详细介绍一些主要的网络安全标准和规范,以及它们在实际操作中的应用。_网络安全标准规范