diff --git a/README.md b/README.md index e232bab90..8b6c78f5b 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 | 流场高分辨率重构 | [2D 湍流流场重构](https://aistudio.baidu.com/projectdetail/4493261?contributionType=1) | 数据驱动 | cycleGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)
[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://arxiv.org/abs/2007.15324)| | 流场高分辨率重构 | [基于Voronoi嵌入辅助深度学习的稀疏传感器全局场重建](https://aistudio.baidu.com/projectdetail/5807904) | 数据驱动 | CNN | 监督学习 | [Data1](https://drive.google.com/drive/folders/1K7upSyHAIVtsyNAqe6P8TY1nS5WpxJ2c)
[Data2](https://drive.google.com/drive/folders/1pVW4epkeHkT2WHZB7Dym5IURcfOP4cXu)
[Data3](https://drive.google.com/drive/folders/1xIY_jIu-hNcRY-TTf4oYX1Xg4_fx8ZvD) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) | | 流场预测 | [Catheter](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/catheter/) | 数据驱动 | FNO | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/291940) | [Paper](https://www.science.org/doi/pdf/10.1126/sciadv.adj1741) | +| 流场预测 | [Catheter](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/confild/) | 数据驱动 | CONFILD | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/9736790) | [Paper](https://doi.org/10.1038/s41467-024-54712-1) | | 求解器耦合 | [CFD-GCN](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/cfdgcn) | 数据驱动 | GCN | 监督学习 | [Data](https://aistudio.baidu.com/aistudio/datasetdetail/184778)
[Mesh](https://paddle-org.bj.bcebos.com/paddlescience/datasets/CFDGCN/meshes.tar) | [Paper](https://arxiv.org/abs/2007.04439)| | 受力分析 | [1D 欧拉梁变形](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/euler_beam) | 机理驱动 | MLP | 无监督学习 | - | - | | 受力分析 | [2D 平板变形](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/biharmonic2d) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/2108.07243) | diff --git a/docs/index.md b/docs/index.md index c9e288ee5..ee3f0913b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -124,6 +124,7 @@ | 流场高分辨率重构 | [2D 湍流流场重构](https://aistudio.baidu.com/projectdetail/4493261?contributionType=1) | 数据驱动 | cycleGAN | 监督学习 | [Train Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat)
[Eval Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat) | [Paper](https://arxiv.org/abs/2007.15324)| | 流场高分辨率重构 | [基于Voronoi嵌入辅助深度学习的稀疏传感器全局场重建](https://aistudio.baidu.com/projectdetail/5807904) | 数据驱动 | CNN | 监督学习 | [Data1](https://drive.google.com/drive/folders/1K7upSyHAIVtsyNAqe6P8TY1nS5WpxJ2c)
[Data2](https://drive.google.com/drive/folders/1pVW4epkeHkT2WHZB7Dym5IURcfOP4cXu)
[Data3](https://drive.google.com/drive/folders/1xIY_jIu-hNcRY-TTf4oYX1Xg4_fx8ZvD) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) | | 流场预测 | [Catheter](https://aistudio.baidu.com/projectdetail/5379212) | 数据驱动 | FNO | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/291940) | [Paper](https://www.science.org/doi/pdf/10.1126/sciadv.adj1741) | + | 流场预测 | [CONFILD](https://aistudio.baidu.com/projectdetail/5379212) | 数据驱动 | CONFILD | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/9736790) | [Paper](https://doi.org/10.1038/s41467-024-54712-1) | | 求解器耦合 | [CFD-GCN](./zh/examples/cfdgcn.md) | 数据驱动 | GCN | 监督学习 | [Data](https://aistudio.baidu.com/aistudio/datasetdetail/184778)
[Mesh](https://paddle-org.bj.bcebos.com/paddlescience/datasets/CFDGCN/meshes.tar) | [Paper](https://arxiv.org/abs/2007.04439)| | 受力分析 | [1D 欧拉梁变形](./zh/examples/euler_beam.md) | 机理驱动 | MLP | 无监督学习 | - | - | | 受力分析 | [2D 平板变形](./zh/examples/biharmonic2d.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/2108.07243) | diff --git a/docs/zh/examples/confild.md b/docs/zh/examples/confild.md new file mode 100644 index 000000000..734da52db --- /dev/null +++ b/docs/zh/examples/confild.md @@ -0,0 +1,103 @@ +# AI辅助的时空湍流生成:条件神经场潜在扩散模型(CoNFILD) + +Distributed under a Creative Commons Attribution license 4.0 (CC BY). + +## 1. 背景简介 +### 1.1 论文信息 +| 年份 | 期刊 | 作者 | 引用数 | 论文PDF与补充材料 | +|----------------|---------------------|--------------------------------------------------------------------------------------------------|--------|----------------------------------------------------------------------------------------------------| +| 2024年1月3日 | Nature Communications | Pan Du, Meet Hemant Parikh, Xiantao Fan, Xin-Yang Liu, Jian-Xun Wang | 15 | [论文链接](https://doi.org/10.1038/s41467-024-54712-1)
[代码仓库](https://github.com/jx-wang-s-group/CoNFILD) | + +### 1.2 作者介绍 +- **通讯作者**:Jian-Xun Wang(王建勋)
所属机构:美国圣母大学航空航天与机械工程系、康奈尔大学机械与航空航天工程系
研究方向:湍流建模、生成式AI、物理信息机器学习
+ +- **其他作者**:
Pan Du、Meet Hemant Parikh(共同一作):圣母大学博士生,研究方向为生成式模型与计算流体力学
Xiantao Fan、Xin-Yang Liu:圣母大学研究助理,负责数值模拟与数据生成 + +### 1.3 模型&复现代码 +| 问题类型 | 在线运行 | 神经网络架构 | 评估指标 | +|------------------------|----------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------| +| 时空湍流生成 | [aistudio](https://aistudio.baidu.com/project/edit/9736790) | 条件神经场+潜在扩散模型 | MSE: 0.041(速度场) | + +=== "模型训练命令" +```bash +python confild.py mode=train +``` + +=== "预训练模型快速评估" + +``` sh +python confild.py mode=eval +``` + +## 2. 问题定义 +### 2.1 研究背景 +湍流模拟在航空航天、海洋工程等领域至关重要,但传统方法如直接数值模拟(DNS)和大涡模拟(LES)计算成本高昂,难以应用于高雷诺数或实时场景。现有深度学习模型多基于确定性框架,难以捕捉湍流的混沌特性,且在复杂几何域中表现受限。 + +### 2.2 核心挑战 +1. **高维数据**:三维时空湍流数据维度高达 \(O(10^9)\),传统生成模型内存需求巨大。 +2. **随机性建模**:需同时捕捉湍流的多尺度统计特性与瞬时动态。 +3. **几何适应性**:需支持不规则计算域与自适应网格。 + +### 2.3 创新方法 +提出**条件神经场潜在扩散模型(CoNFILD)**,通过三阶段框架解决上述挑战: +1. **神经场编码**:将高维流场压缩为低维潜在表示,压缩比达0.002%-0.017%。 +2. **潜在扩散**:在潜在空间进行概率扩散过程,学习湍流统计分布。 +3. **零样本条件生成**:结合贝叶斯推理,无需重新训练即可实现传感器重建、超分辨率等任务。 + +![图1 CoNFILD框架](./confild.png) +*框架示意图:CNF编码器将流场映射到潜在空间,扩散模型生成新潜在样本,解码器重建物理场* + +## 3. 模型构建 +### 3.1 条件神经场(CNF) +- **架构**:基于SIREN网络,采用正弦激活函数捕捉周期性特征。 +- **数学表示**: + $$ + \mathscr{E}(\mathbf{X},\mathbf{L}) = \text{SIREN}(\mathbf{x}) + \text{FILM}(\mathbf{L}) + $$ + 其中FILM(Feature-wise Linear Modulation)通过潜在向量\(\mathbf{L}\)调节每层偏置。 + +### 3.2 潜在扩散模型 +- **前向过程**:逐步添加高斯噪声,潜在表示\(\mathbf{z}_0 \rightarrow \mathbf{z}_T\)。 +- **逆向过程**:训练U-Net预测噪声,通过迭代去噪生成新样本: + $$ + \mathbf{z}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{z}_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{z}_t, t) \right) + \sigma_t \epsilon + $$ + +### 3.3 零样本条件生成 +- **贝叶斯后验采样**:基于稀疏观测\(\Psi\),通过梯度修正潜在空间采样: + $$ + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t|\Psi) \approx \nabla_{\mathbf{z}_t} \log p(\Psi|\mathbf{z}_t) + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t) + $$ + +## 4. 问题求解 +### 4.1 数据集准备 +数据文件说明如下: +``` +data # CNF的训练数据集 +| +|-- data.npy # 要拟合的数据 +| +|-- coords.npy # 查询坐标 +``` + +在加载数据之后,需要进行normalization,以便于训练。具体代码如下: +```python linenums="39" +example/confild/confild.py:59:135 +``` + +### 4.2 CoNFiLD 模型 +CoNFiLD 模型基于贝叶斯后验采样,将稀疏传感器测量数据作为条件输入。通过训练好的无条件扩散模型作为先验,在扩散后验采样过程中,考虑测量噪声引入的不确定性。利用状态到观测映射,根据条件向量与流场的关系,通过调整无条件得分函数,引导生成与传感器数据一致的全时空流场实现重构,并且能提供重构的不确定性估计。代码如下: + +```python linenums="39" +ppsci/arch/confild.py:304:420 +``` +为了在计算时,准确快速地访问具体变量的值,我们在这里指定网络模型的输入变量名是 ["confild_x", "latent_z"],输出变量名是 ["confild_output"],这些命名与后续代码保持一致。 + +4.3 模型训练、评估 +完成上述设置之后,只需要将上述实例化的对象按照文档进行组合,然后启动训练、评估。 +```python linenums="39" +examples/confild/confild.py:218:503 +``` + +## 5. 实验结果 +![](https://ai-studio-static-online.cdn.bcebos.com/1f81af1d579b4b41a525f867ac0fde19d59fb6fc44f8406aa84345c6015938c9) diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml new file mode 100644 index 000000000..0f5bbf6bf --- /dev/null +++ b/examples/confild/conf/confild_case1.yaml @@ -0,0 +1,119 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 64 + test_batch_size: 256 + epochs: 9800 + mutil_GPU: 1 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case1/confild_case1/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case1/latent_case1/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 3 + hidden_features: 128 + in_coord_features: 2 + in_latent_features: 128 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 16000 + lumped: True + N_features: 128 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case1 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case1 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [918, 2] + latents_shape: [1, 128] + log_freq: 20 + batch_size: 64 + +Uncondiction_INFER: + batch_size : 16 + test_batch_size : 16 + time_length : 128 + latent_length : 128 + image_size : 128 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + steps: 1000 + noise_schedule: "cosine" + +Data: + data_path: data/Case1/case1_data.npy + coor_path: data/Case1/case1_coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_elbow_flow diff --git a/examples/confild/conf/confild_case2.yaml b/examples/confild/conf/confild_case2.yaml new file mode 100644 index 000000000..ab6208c09 --- /dev/null +++ b/examples/confild/conf/confild_case2.yaml @@ -0,0 +1,103 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_confild_case2 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 10 + test_batch_size: 10 + epochs: 44500 + mutil_GPU: 1 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case2/confild_case2/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case2/latent_case2/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 4 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 1200 + lumped: False + N_features: 256 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case2 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case2 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [400, 100, 2] + latents_shape: [1, 1, 256] + log_freq: 20 + batch_size: 40 + +Data: + data_path: /home/aistudio/work/extracted/data/Case2/data.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_channel_flow diff --git a/examples/confild/conf/confild_case3.yaml b/examples/confild/conf/confild_case3.yaml new file mode 100644 index 000000000..140ced107 --- /dev/null +++ b/examples/confild/conf/confild_case3.yaml @@ -0,0 +1,104 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_confild_case3 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 100 + test_batch_size: 100 + epochs: 4800 + mutil_GPU: 2 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case3/confild_case3/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case3/latent_case3/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 117 + out_features: 2 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 2880 + lumped: True + N_features: 256 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case3 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case3 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [10884, 2] + latents_shape: [1, 256] + log_freq: 20 + batch_size: 100 + +Data: + data_path: /home/aistudio/work/extracted/data/Case3/data.npy + coor_path: /home/aistudio/work/extracted/data/Case3/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_periodic_hill_flow diff --git a/examples/confild/conf/confild_case4.yaml b/examples/confild/conf/confild_case4.yaml new file mode 100644 index 000000000..42dd6470d --- /dev/null +++ b/examples/confild/conf/confild_case4.yaml @@ -0,0 +1,104 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_confild_case4 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 4 + test_batch_size: 4 + epochs: 20000 + mutil_GPU: 2 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case4/confild_case4/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case4/latent_case4/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 15 + out_features: 3 + hidden_features: 384 + in_coord_features: 3 + in_latent_features: 384 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 1200 + lumped: True + N_features: 384 + dims: 3 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case4 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case4 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [58483, 3] + latents_shape: [1, 384] + log_freq: 20 + batch_size: 4 + +Data: + data_path: /home/aistudio/work/extracted/data/Case4/data.npy + coor_path: /home/aistudio/work/extracted/data/Case4/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_3d_flow diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml new file mode 100644 index 000000000..027eef920 --- /dev/null +++ b/examples/confild/conf/un_confild_case1.yaml @@ -0,0 +1,89 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: eval +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 +save_path: ${output_dir}/result.npy + +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 10000 + lr : 5.e-5 + weight_decay: 0.0 + final_lr: 1.e-5 + microbatch: -1 + max_steps: 10000 + +EVAL: + mutil_GPU: 1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + time_length : 128 + latent_length : 128 + test_batch_size: 16 + +UNET: + image_size : 128 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /home/aistudio/work/data/Case1/diffusion/ema.pdparams + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/data/Case1/data.npy + coor_path: /home/aistudio/work/data/Case1/coords.npy + load_data_fn: load_elbow_flow + normalizer: + method: "-11" + dim: 0 + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 3 + hidden_features: 128 + in_coord_features: 2 + in_latent_features: 128 + normalizer_params_path: /home/aistudio/work/data/Case1/cnf/normalizer_params.pdparams + model_path: /home/aistudio/work/PaddleScience/examples/confild/cnf_model.pdparams + +DATA: + max_val: 1.0 + min_val: -1.0 + train_data: "/home/aistudio/work/data/Case1/train_data.npy" + valid_data: "/home/aistudio/work/data/Case1/valid_data.npy" diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml new file mode 100644 index 000000000..4356abaf1 --- /dev/null +++ b/examples/confild/conf/un_confild_case2.yaml @@ -0,0 +1,88 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_un_confild_case2 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: eval +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 +save_path: ${output_dir}/result.npy + +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 10000 + lr : 5.e-5 + weight_decay: 0.0 + final_lr: 1.e-5 + microbatch: -1 + max_steps: 10000 + +EVAL: + mutil_GPU: 1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + time_length : 256 + latent_length : 256 + test_batch_size: 16 + +UNET: + image_size : 256 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /home/aistudio/work/data/Case2/diffusion/ema.pdparams + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/data/Case2/data.npy + load_data_fn: load_channel_flow + normalizer: + method: "-11" + dim: 0 + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 4 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + normalizer_params_path: /home/aistudio/work/data/Case2/cnf/normalizer_params.pdparams + model_path: /home/aistudio/work/PaddleScience/examples/confild/cnf_model.pdparams + +DATA: + max_val: 1.0 + min_val: -1.0 + train_data: "/home/aistudio/work/data/Case2/train_data.npy" + valid_data: "/home/aistudio/work/data/Case2/valid_data.npy" \ No newline at end of file diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml new file mode 100644 index 000000000..75628cdae --- /dev/null +++ b/examples/confild/conf/un_confild_case3.yaml @@ -0,0 +1,89 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_un_confild_case3 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: eval +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 +save_path: ${output_dir}/result.npy + +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 10000 + lr : 5.e-5 + weight_decay: 0.0 + final_lr: 1.e-5 + microbatch: -1 + max_steps: 10000 + +EVAL: + mutil_GPU: 2 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + time_length : 256 + latent_length : 256 + test_batch_size: 16 + +UNET: + image_size : 256 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /home/aistudio/work/data/Case3/diffusion/ema.pdparams + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/data/Case3/data.npy + coor_path: /home/aistudio/work/data/Case3/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_periodic_hill_flow + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 117 + out_features: 2 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + normalizer_params_path: /home/aistudio/work/data/Case3/cnf/normalizer_params.pdparams + model_path: /home/aistudio/work/PaddleScience/examples/confild/cnf_model.pdparams + +DATA: + min_val: -1.0 + max_val: 1.0 + train_data: "/home/aistudio/work/data/Case3/train_data.npy" + valid_data: "/home/aistudio/work/data/Case3/valid_data.npy" diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml new file mode 100644 index 000000000..f3113f37c --- /dev/null +++ b/examples/confild/conf/un_confild_case4.yaml @@ -0,0 +1,90 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + dir: ./outputs_un_confild_case4 + job: + name: ${mode} + chdir: false + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + dir: ${hydra.run.dir} + subdir: ./ + +mode: eval +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 +save_path: ${output_dir}/result.npy + +TRAIN: + batch_size : 8 + test_batch_size : 8 + ema_rate: "0.9999" + lr_anneal_steps: 10000 + lr : 5.e-5 + weight_decay: 0.0 + final_lr: 1.e-5 + microbatch: -1 + max_steps: 10000 + +EVAL: + mutil_GPU: 2 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + time_length : 256 + latent_length : 256 + test_batch_size: 8 + +UNET: + + image_size : 384 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: "1, 1, 2, 2, 4, 4" + ema_path: /home/aistudio/work/data/Case4/diffusion/ema.pdparams + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/data/Case4/data.npy + coor_path: /home/aistudio/work/data/Case4/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_3d_flow + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 15 + out_features: 3 + hidden_features: 384 + in_coord_features: 3 + in_latent_features: 384 + normalizer_params_path: /home/aistudio/work/data/Case4/cnf/normalizer_params.pdparams + model_path: /home/aistudio/work/PaddleScience/examples/confild/cnf_model.pdparams + +DATA: + min_val: -1.0 + max_val: 1.0 + train_data: "/home/aistudio/work/data/Case4/train_data.npy" + valid_data: "/home/aistudio/work/data/Case4/valid_data.npy" diff --git a/examples/confild/confild.py b/examples/confild/confild.py new file mode 100644 index 000000000..bb8896336 --- /dev/null +++ b/examples/confild/confild.py @@ -0,0 +1,631 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import paddle +from omegaconf import DictConfig +from paddle.distributed import fleet +from paddle.io import DataLoader, DistributedBatchSampler + +import ppsci +from ppsci.arch import LatentContainer, SIRENAutodecoder_film +from ppsci.utils import logger + + +def load_elbow_flow(path): + return np.load(f"{path}")[1:] + + +def load_channel_flow( + path, + t_start=0, + t_end=1200, + t_every=1, +): + return np.load(f"{path}")[t_start:t_end:t_every] + + +def load_periodic_hill_flow(path): + data = np.load(f"{path}") + return data + + +def load_3d_flow(path): + data = np.load(f"{path}") + return data + + +def rMAE(prediction, target, dims=(1, 2)): + return paddle.abs(x=prediction - target).mean(axis=dims) / paddle.abs( + x=target + ).mean(axis=dims) + + +class Normalizer_ts(object): + def __init__(self, params=[], method="-11", dim=None): + self.params = params + self.method = method + self.dim = dim + + def fit_normalize(self, data): + assert type(data) == paddle.Tensor + if len(self.params) == 0: + if self.method == "-11" or self.method == "01": + if self.dim is None: + self.params = paddle.max(x=data), paddle.min(x=data) + else: + self.params = ( + paddle.max(keepdim=True, x=data, axis=self.dim), + paddle.argmax(keepdim=True, x=data, axis=self.dim), + )[0], ( + paddle.min(keepdim=True, x=data, axis=self.dim), + paddle.argmin(keepdim=True, x=data, axis=self.dim), + )[ + 0 + ] + elif self.method == "ms": + if self.dim is None: + self.params = paddle.mean(x=data, axis=self.dim), paddle.std( + x=data, axis=self.dim + ) + else: + self.params = paddle.mean( + x=data, axis=self.dim, keepdim=True + ), paddle.std(x=data, axis=self.dim, keepdim=True) + elif self.method == "none": + self.params = None + return self.fnormalize(data, self.params, self.method) + + def normalize(self, new_data): + if not new_data.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fnormalize(new_data, self.params, self.method) + + def denormalize(self, new_data_norm): + if not new_data_norm.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fdenormalize(new_data_norm, self.params, self.method) + + def get_params(self): + if self.method == "ms": + print("returning mean and std") + elif self.method == "01": + print("returning max and min") + elif self.method == "-11": + print("returning max and min") + elif self.method == "none": + print("do nothing") + return self.params + + @staticmethod + def fnormalize(data, params, method): + if method == "-11": + return (data - params[1]) / (params[0] - params[1]) * 2 - 1 + elif method == "01": + return (data - params[1]) / (params[0] - params[1]) + elif method == "ms": + return (data - params[0]) / params[1] + elif method == "none": + return data + + @staticmethod + def fdenormalize(data_norm, params, method): + if method == "-11": + return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] + elif method == "01": + return data_norm * (params[0] - params[1]) + params[1] + elif method == "ms": + return data_norm * params[1] + params[0] + elif method == "none": + return data_norm + + +class BasicSet(paddle.io.Dataset): + def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> None: + super().__init__() + self.fois = fois.numpy() + self.total_samples = tuple(fois.shape)[0] + self.coords = coord.numpy() + self.global_indices = ( + global_indices + if global_indices is not None + else np.arange(self.total_samples) + ) + + def __len__(self): + return self.total_samples + + def __getitem__(self, idx): + global_idx = self.global_indices[idx] + if hasattr(self, "extra_in"): + extra_id = idx % tuple(self.fois.shape)[1] + idb = idx // tuple(self.fois.shape)[1] + return ( + (self.coords, self.extra_in[extra_id]), + self.fois[idb, extra_id], + global_idx, + ) + else: + return self.coords, self.fois[idx], global_idx + + +def getdata(cfg): + if cfg.Data.load_data_fn == "load_3d_flow": + fois = load_3d_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_elbow_flow": + fois = load_elbow_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_channel_flow": + fois = load_channel_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_periodic_hill_flow": + fois = load_periodic_hill_flow(cfg.Data.data_path) + else: + fois = np.load(cfg.Data.data_path) + + spatio_shape = fois.shape[1:-1] + spatio_axis = list( + range(fois.ndim if isinstance(fois, np.ndarray) else fois.dim()) + )[1:-1] + + ###### read data - coordinate ###### + if cfg.Data.get("coor_path", None) is None: + coord = [np.linspace(0, 1, i) for i in spatio_shape] + coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1) + else: + coord = np.load(cfg.Data.coor_path) + coord = coord.astype("float32") + fois = fois.astype("float32") + + ###### convert to tensor ###### + fois = paddle.to_tensor(fois) if not isinstance(fois, paddle.Tensor) else fois + coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord + N_samples = fois.shape[0] + + ###### normalizer ###### + in_normalizer = Normalizer_ts(**cfg.Data.normalizer) + in_normalizer.fit_normalize( + coord if cfg.Latent.lumped else coord.flatten(0, cfg.Latent.dims - 1) + ) + out_normalizer = Normalizer_ts(**cfg.Data.normalizer) + out_normalizer.fit_normalize( + fois if cfg.Latent.lumped else fois.flatten(0, cfg.Latent.dims) + ) + normed_coords = in_normalizer.normalize(coord) # 训练集就是测试集 + normed_fois = out_normalizer.normalize(fois) + + return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer + + +def signal_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, +): + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + latents_model = LatentContainer(**cfg.Latent) + + train_dataset = BasicSet(train_normed_fois, normed_coords, train_indices) + test_dataset = BasicSet(test_normed_fois, normed_coords, test_indices) + + criterion = paddle.nn.MSELoss() + + # set loader + train_loader = DataLoader( + dataset=train_dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True + ) + test_loader = DataLoader( + dataset=test_dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False + ) + # set optimizer + cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model) + latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)( + latents_model + ) + losses = [] + + for epoch in range(cfg.TRAIN.epochs): + cnf_model.train() + latents_model.train() + train_loss = [] + for batch_coords, batch_fois, idx in train_loader: + idx = {"latent_x": idx} + batch_latent = latents_model(idx) + if isinstance(batch_coords, list): + batch_coords = [coord for coord in batch_coords] + data = { + "confild_x": batch_coords, + "latent_z": batch_latent["latent_z"], + } + batch_output = cnf_model(data) + loss = criterion(batch_output["confild_output"], batch_fois) + + cnf_optimizer.clear_grad(set_to_zero=False) + latents_optimizer.clear_grad(set_to_zero=False) + + loss.backward() + + cnf_optimizer.step() + latents_optimizer.step() + train_loss.append(loss) + epoch_loss = paddle.stack(x=train_loss).mean().item() + losses.append(epoch_loss) + + print("epoch {}, train loss {}".format(epoch + 1, epoch_loss)) + if epoch % 100 == 0: + test_error = [] + cnf_model.eval() + latents_model.eval() + with paddle.no_grad(): + for test_coords, test_fois, idx in test_loader: + if isinstance(test_coords, list): + test_coords = [coord for coord in test_coords] + prediction = out_normalizer.denormalize( + cnf_model( + { + "confild_x": test_coords, + "latent_z": latents_model({"latent_x": idx})[ + "latent_z" + ], + } + )["confild_output"] + ) + target = out_normalizer.denormalize(test_fois) + error = rMAE(prediction=prediction, target=target, dims=spatio_axis) + test_error.append(error) + test_error = paddle.concat(x=test_error).mean(axis=0) + print("test MAE: ", test_error) + import os + + if not os.path.exists(cfg.output_dir): + os.makedirs(cfg.output_dir, exist_ok=True) + + paddle.save( + cnf_model.state_dict(), f"{cfg.output_dir}/cnf_model_{epoch}.pdparams" + ) + paddle.save( + latents_model.state_dict(), + f"{cfg.output_dir}/latents_model_{epoch}.pdparams", + ) + + plt.figure(figsize=(10, 6)) + plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") + + plt.title("Training Loss over Epochs") + plt.xlabel("Epochs") + plt.xticks(rotation=45) + plt.ylabel("Loss") + + plt.legend() + + plt.grid(True) + + plt.savefig("case.png") + + +def mutil_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, +): + fleet.init(is_collective=True) + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + cnf_model = fleet.distributed_model(cnf_model) + latents_model = LatentContainer(**cfg.Latent) + latents_model = fleet.distributed_model(latents_model) + + # set optimizer + cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model) + cnf_optimizer = fleet.distributed_optimizer(cnf_optimizer) + latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)( + latents_model + ) + latents_optimizer = fleet.distributed_optimizer(latents_optimizer) + + train_dataset = BasicSet(train_normed_fois, normed_coords, train_indices) + test_dataset = BasicSet(test_normed_fois, normed_coords, test_indices) + + train_sampler = DistributedBatchSampler( + train_dataset, cfg.TRAIN.batch_size, shuffle=True, drop_last=True + ) + train_loader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=cfg.TRAIN.mutil_GPU, + use_shared_memory=False, + ) + test_sampler = DistributedBatchSampler( + test_dataset, cfg.TRAIN.test_batch_size, drop_last=True + ) + test_loader = DataLoader( + test_dataset, + batch_sampler=test_sampler, + num_workers=cfg.TRAIN.mutil_GPU, + use_shared_memory=False, + ) + + criterion = paddle.nn.MSELoss() + losses = [] + + for epoch in range(cfg.TRAIN.epochs): + cnf_model.train() + latents_model.train() + train_loss = [] + for batch_coords, batch_fois, idx in train_loader: + idx = {"latent_x": idx} + batch_latent = latents_model(idx) + if isinstance(batch_coords, list): + batch_coords = [coord for coord in batch_coords] + data = { + "confild_x": batch_coords, + "latent_z": batch_latent["latent_z"], + } + batch_output = cnf_model(data) + loss = criterion(batch_output["confild_output"], batch_fois) + + cnf_optimizer.clear_grad(set_to_zero=False) + latents_optimizer.clear_grad(set_to_zero=False) + + loss.backward() + + cnf_optimizer.step() + latents_optimizer.step() + train_loss.append(loss) + epoch_loss = paddle.stack(x=train_loss).mean().item() + losses.append(epoch_loss) + print("epoch {}, train loss {}".format(epoch + 1, epoch_loss)) + if epoch % 100 == 0: + test_error = [] + cnf_model.eval() + latents_model.eval() + with paddle.no_grad(): + for test_coords, test_fois, idx in test_loader: + if isinstance(test_coords, list): + test_coords = [coord for coord in test_coords] + prediction = out_normalizer.denormalize( + cnf_model( + { + "confild_x": test_coords, + "latent_z": latents_model({"latent_x": idx})[ + "latent_z" + ], + } + )["confild_output"] + ) + target = out_normalizer.denormalize(test_fois) + error = rMAE(prediction=prediction, target=target, dims=spatio_axis) + test_error.append(error) + test_error = paddle.concat(x=test_error).mean(axis=0) + print("test MAE: ", test_error) + + import os + + if not os.path.exists(cfg.output_dir): + os.makedirs(cfg.output_dir, exist_ok=True) + + paddle.save( + cnf_model.state_dict(), f"{cfg.output_dir}/cnf_model_{epoch}.pdparams" + ) + paddle.save( + latents_model.state_dict(), + f"{cfg.output_dir}/latents_model_{epoch}.pdparams", + ) + + plt.figure(figsize=(10, 6)) + plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") + + plt.title("Training Loss over Epochs") + plt.xlabel("Epochs") + plt.xticks(rotation=45) + plt.ylabel("Loss") + + plt.legend() + + plt.grid(True) + + plt.savefig("case.png") + + +def train(cfg): + world_size = cfg.TRAIN.mutil_GPU + normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) + train_normed_fois = normed_fois + test_normed_fois = normed_fois + train_indices = list(range(N_samples)) + test_indices = list(range(N_samples)) + + if world_size > 1: + import paddle.distributed as dist + + dist.init_parallel_env() + mutil_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, + ) + else: + signal_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, + ) + + +def evaluate(cfg: DictConfig): + # set data + # normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) + normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg) + # [918,2] + # [16000,918,3] + print(normed_coords.shape) + print(normed_fois.shape) + t_std = cfg.get("t_std", 15698) # Use configurable parameter, default to 15698 + normed_fois = normed_fois[t_std:] + # exit() + if len(normed_coords.shape) + 1 == len(normed_fois.shape): + normed_coords = paddle.tile( + normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape) + ) + + idx = paddle.to_tensor(np.arange(normed_fois.shape[0]), dtype="int64") + # set model + confild = SIRENAutodecoder_film(**cfg.CONFILD) + latent = LatentContainer(**cfg.Latent) + logger.info( + "Loading pretrained model from {}".format( + cfg.EVAL.confild_pretrained_model_path + ) + ) + ppsci.utils.save_load.load_pretrain( + confild, + cfg.EVAL.confild_pretrained_model_path, + ) + logger.info( + "Loading pretrained model from {}".format(cfg.EVAL.latent_pretrained_model_path) + ) + ppsci.utils.save_load.load_pretrain( + latent, + cfg.EVAL.latent_pretrained_model_path, + ) + latent_test_pred = latent({"latent_x": idx}) + y_test_pred = [] + for i in range(normed_coords.shape[0]): + ouput = confild( + { + "confild_x": normed_coords[i], + "latent_z": latent_test_pred["latent_z"][i], + } + )["confild_output"].numpy() + y_test_pred.append(ouput) + + y_test_pred = paddle.to_tensor(np.array(y_test_pred)) + y_test_pred = out_normalizer.denormalize(y_test_pred) + y_test = out_normalizer.denormalize(normed_fois) + + def calc_err(): + var_name = ["u", "v", "p"] + for i, var in enumerate(var_name): + u_true_mean = paddle.mean(y_test[:, :, i], axis=0) + u_label_mean = paddle.mean(y_test_pred[:, :, i], axis=0) + u_true_std = paddle.std(y_test[:, :, i], axis=0) + u_label_std = paddle.std(y_test_pred[:, :, i], axis=0) + avg_discrepancy = paddle.mean(paddle.abs(u_true_mean - u_label_mean)) + std_discrepancy = paddle.mean(paddle.abs(u_true_std - u_label_std)) + print(f"Average Discrepancy [{var}] Value: {avg_discrepancy:.3f}") + print(f"Standard Deviation of [{var}] Magnitude: {std_discrepancy:.4f}") + + calc_err() + + +def inference(cfg): + normed_coords, normed_fois, _, _, _ = getdata(cfg) + if len(normed_coords.shape) + 1 == len(normed_fois.shape): + normed_coords = paddle.tile( + normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape) + ) + + fois_len = normed_fois.shape[0] + + idxs = np.arange(fois_len) + from deploy import python_infer + + latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent) + input_dict = {"latent_x": idxs} + output_dict = latent_predictor.predict(input_dict, cfg.INFER.batch_size) + + cnf_predictor = python_infer.GeneralPredictor(cfg.INFER.Confild) + input_dict = { + "confild_x": normed_coords.numpy(), + "latent_z": list(output_dict.values())[0], + } + output_dict = cnf_predictor.predict(input_dict, cfg.INFER.batch_size) + + logger.info("Result is {}".format(output_dict["fetch_name_0"])) + + +def export(cfg): + # set model + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + latent_model = LatentContainer(**cfg.Latent) + # initialize solver + latnet_solver = ppsci.solver.Solver( + latent_model, + pretrained_model_path=cfg.INFER.Latent.INFER.pretrained_model_path, + ) + cnf_solver = ppsci.solver.Solver( + cnf_model, + pretrained_model_path=cfg.INFER.Confild.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None], "int64", name=key) for key in latent_model.input_keys}, + ] + cnf_input_spec = [ + { + cnf_model.input_keys[0]: InputSpec( + [None] + list(cfg.INFER.Confild.INFER.coord_shape), + "float32", + name=cnf_model.input_keys[0], + ), + cnf_model.input_keys[1]: InputSpec( + [None] + list(cfg.INFER.Confild.INFER.latents_shape), + "float32", + name=cnf_model.input_keys[1], + ), + } + ] + cnf_solver.export(cnf_input_spec, cfg.INFER.Confild.INFER.export_path) + latnet_solver.export(input_spec, cfg.INFER.Latent.INFER.export_path) + + +@hydra.main(version_base=None, config_path="./conf", config_name="confild_case1.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "infer": + inference(cfg) + elif cfg.mode == "export": + export(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'infer', 'export'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/confild/resample.py b/examples/confild/resample.py new file mode 100644 index 000000000..b6bb79202 --- /dev/null +++ b/examples/confild/resample.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod + +import numpy as np +import paddle as th +import paddle.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size): + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.to_tensor(indices_np, dtype="int64") + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.to_tensor(weights_np, dtype="float32") + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + + batch_sizes = [ + th.to_tensor([0], dtype=th.int32, place=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.to_tensor([len(local_ts)], dtype=th.int32, place=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py new file mode 100644 index 000000000..eb9bd9e9a --- /dev/null +++ b/examples/confild/un_confild.py @@ -0,0 +1,726 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import enum +import functools +import math +import os +from abc import ABC, abstractmethod + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import paddle +from omegaconf import DictConfig +from resample import LossAwareSampler, UniformSampler + +from ppsci.arch import ( + LossType, + ModelMeanType, + ModelVarType, + SIRENAutodecoder_film, + SpacedDiffusion, + UNetModel, +) +from ppsci.utils import logger + + +def mean_flat(tensor): + return tensor.mean(axis=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + paddle.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2) + ) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + res = paddle.to_tensor(arr, dtype=timesteps.dtype)[timesteps] + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +train_losses = [] +valid_losses = [] + + +def create_model( + image_size, + num_channels, + num_res_blocks, + dims=2, + out_channels=1, + channel_mult=None, + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult is None: + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + if isinstance(channel_mult, str): + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=out_channels, + model_channels=num_channels, + out_channels=(out_channels if not learn_sigma else 2 * out_channels), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(1000 if class_cond else None), + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + dims=dims, + ) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + if schedule_name == "linear": + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def space_timesteps(num_timesteps, section_counts): + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = LossType.RESCALED_MSE + else: + loss_type = LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X + ), + model_var_type=( + (ModelVarType.FIXED_LARGE if not sigma_small else ModelVarType.FIXED_SMALL) + if not learn_sigma + else ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +def load_elbow_flow(path): + return np.load(f"{path}")[1:] + + +def load_channel_flow( + path, + t_start=0, + t_end=1200, + t_every=1, +): + return np.load(f"{path}")[t_start:t_end:t_every] + + +def load_periodic_hill_flow(path): + data = np.load(f"{path}") + return data + + +def load_3d_flow(path): + data = np.load(f"{path}") + return data + + +class Normalizer_ts(object): + def __init__(self, params=[], method="-11", dim=None): + self.params = params + self.method = method + self.dim = dim + + def fit_normalize(self, data): + assert type(data) == paddle.Tensor + if len(self.params) == 0: + if self.method == "-11" or self.method == "01": + if self.dim is None: + self.params = paddle.max(x=data), paddle.min(x=data) + else: + self.params = ( + paddle.max(keepdim=True, x=data, axis=self.dim), + paddle.argmax(keepdim=True, x=data, axis=self.dim), + )[0], ( + paddle.min(keepdim=True, x=data, axis=self.dim), + paddle.argmin(keepdim=True, x=data, axis=self.dim), + )[ + 0 + ] + elif self.method == "ms": + if self.dim is None: + self.params = paddle.mean(x=data, axis=self.dim), paddle.std( + x=data, axis=self.dim + ) + else: + self.params = paddle.mean( + x=data, axis=self.dim, keepdim=True + ), paddle.std(x=data, axis=self.dim, keepdim=True) + elif self.method == "none": + self.params = None + return self.fnormalize(data, self.params, self.method) + + def normalize(self, new_data): + if not new_data.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fnormalize(new_data, self.params, self.method) + + def denormalize(self, new_data_norm): + if not new_data_norm.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fdenormalize(new_data_norm, self.params, self.method) + + def get_params(self): + if self.method == "ms": + print("returning mean and std") + elif self.method == "01": + print("returning max and min") + elif self.method == "-11": + print("returning max and min") + elif self.method == "none": + print("do nothing") + return self.params + + @staticmethod + def fnormalize(data, params, method): + if method == "-11": + return (data - params[1]) / (params[0] - params[1]) * 2 - 1 + elif method == "01": + return (data - params[1]) / (params[0] - params[1]) + elif method == "ms": + return (data - params[0]) / params[1] + elif method == "none": + return data + + @staticmethod + def fdenormalize(data_norm, params, method): + if method == "-11": + return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] + elif method == "01": + return data_norm * (params[0] - params[1]) + params[1] + elif method == "ms": + return data_norm * params[1] + params[0] + elif method == "none": + return data_norm + + +def create_slim(cfg): + ###### read data - fois ###### + if cfg.CNF.load_data_fn == "load_3d_flow": + fois = load_3d_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_elbow_flow": + fois = load_elbow_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_channel_flow": + fois = load_channel_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_periodic_hill_flow": + fois = load_periodic_hill_flow(cfg.CNF.data_path) + else: + fois = np.load(cfg.CNF.data_path) + + spatio_shape = fois.shape[1:-1] + + ###### read data - coordinate ###### + if cfg.CNF.coor_path is None: + coord = [np.linspace(0, 1, i) for i in spatio_shape] + coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1) + else: + coord = np.load(cfg.CNF.coor_path) + coord = coord.astype("float32") + fois = fois.astype("float32") + + ###### convert to tensor ###### + fois = paddle.to_tensor(fois) if not isinstance(fois, paddle.Tensor) else fois + coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord + N_samples = fois.shape[0] + + ###### normalizer ###### + in_normalizer = Normalizer_ts(**cfg.CNF.normalizer) + out_normalizer = Normalizer_ts(**cfg.CNF.normalizer) + norm_params = paddle.load(cfg.CNF.normalizer_params_path) + in_normalizer.params = norm_params["x_normalizer_params"] + out_normalizer.params = norm_params["y_normalizer_params"] + + cnf_model = SIRENAutodecoder_film(**cfg.CNF.CONFILD) + + return cnf_model, in_normalizer, out_normalizer, coord + + +def dl_iter(dl): + while True: + yield from dl + + +def train(cfg): + batch_size = cfg.TRAIN.batch_size + test_batch_size = cfg.TRAIN.test_batch_size + ema_rate = cfg.TRAIN.ema_rate + ema_rate = ( + [ema_rate] + if isinstance(ema_rate, float) + else [float(x) for x in ema_rate.split(",")] + ) + + lr_anneal_steps = cfg.TRAIN.lr_anneal_steps + final_lr = cfg.TRAIN.final_lr + step = 0 + resume_step = 0 + microbatch = cfg.TRAIN.microbatch if cfg.TRAIN.microbatch > 0 else batch_size + + train_data = np.load(cfg.DATA.train_data) + valid_data = np.load(cfg.DATA.valid_data) + print( + f"Train data shape: {train_data.shape}, range: [{train_data.min():.3f}, {train_data.max():.3f}]" + ) + print( + f"Valid data shape: {valid_data.shape}, range: [{valid_data.min():.3f}, {valid_data.max():.3f}]" + ) + + max_val, min_val = np.max(train_data, keepdims=True), np.min( + train_data, keepdims=True + ) + norm_train_data = -1 + (train_data - min_val) * 2.0 / (max_val - min_val) + norm_valid_data = -1 + (valid_data - min_val) * 2.0 / (max_val - min_val) + + print( + f"After normalization: train range: [{norm_train_data.min():.3f}, {norm_train_data.max():.3f}]" + ) + + norm_train_data = paddle.to_tensor(norm_train_data[:, None, ...]) + norm_valid_data = paddle.to_tensor(norm_valid_data[:, None, ...]) + + dl_train = dl_iter( + paddle.io.DataLoader( + paddle.io.TensorDataset(norm_train_data), + batch_size=batch_size, + shuffle=True, + ) + ) + dl_valid = dl_iter( + paddle.io.DataLoader( + paddle.io.TensorDataset(norm_valid_data), + batch_size=test_batch_size, + shuffle=True, + ) + ) + + unet_model = create_model( + image_size=cfg.UNET.image_size, + num_channels=cfg.UNET.num_channels, + num_res_blocks=cfg.UNET.num_res_blocks, + num_heads=cfg.UNET.num_heads, + num_head_channels=cfg.UNET.num_head_channels, + attention_resolutions=cfg.UNET.attention_resolutions, + channel_mult=cfg.UNET.channel_mult, + ) + print( + f"Model created with {sum(p.numel() for p in unet_model.parameters()):,} parameters" + ) + + diff_model = create_gaussian_diffusion( + steps=cfg.Diff.steps, noise_schedule=cfg.Diff.noise_schedule + ) + print( + f"Diffusion model created with {cfg.Diff.steps} steps, noise schedule: {cfg.Diff.noise_schedule}" + ) + + opt = paddle.optimizer.AdamW( + parameters=unet_model.parameters(), + learning_rate=cfg.TRAIN.lr, + weight_decay=cfg.TRAIN.weight_decay, + ) + print( + f"Optimizer initialized with lr={cfg.TRAIN.lr}, weight_decay={cfg.TRAIN.weight_decay}" + ) + + schedule_sampler = UniformSampler(diff_model) + + ema_params = [] + for _ in range(len(ema_rate)): + ema_param_dict = {} + for name, param in unet_model.named_parameters(): + ema_param_dict[name] = copy.deepcopy(param.detach()) + ema_params.append(ema_param_dict) + + global train_losses, valid_losses + train_losses.clear() + valid_losses.clear() + + valid_interval = 50 + max_steps = cfg.TRAIN.max_steps if hasattr(cfg.TRAIN, "max_steps") else 10000 + print( + f"Starting training with max_steps={max_steps}, lr_anneal_steps={lr_anneal_steps}" + ) + + while step + resume_step < max_steps: + cond = {} + train_batch = next(dl_train) + + unet_model.train() + opt.clear_grad() + + step_losses = [] + + for i in range(0, len(train_batch), microbatch): + micro = train_batch[i : i + microbatch] + micro_cond = {k: v[i : i + microbatch] for k, v in cond.items()} + + t, weights = schedule_sampler.sample(len(micro)) + + compute_losses = functools.partial( + diff_model.training_losses, + unet_model, + paddle.stack(micro), + t, + model_kwargs=micro_cond, + ) + losses = compute_losses() + + if isinstance(schedule_sampler, LossAwareSampler): + schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) + + loss = (losses["loss"] * weights).mean() + + if step == 0 and i == 0: + print( + f"First loss computation - loss: {loss.item():.6f}, losses keys: {list(losses.keys())}" + ) + if "mse" in losses: + print(f"MSE loss: {losses['mse'].mean().item():.6f}") + if "vb" in losses: + print(f"VB loss: {losses['vb'].mean().item():.6f}") + + step_losses.append(loss.item()) + + if i == 0: + log_loss_dict( + diff_model, + t, + { + k: v * weights + for k, v in losses.items() + if isinstance(v, paddle.Tensor) + }, + is_valid=False, + ) + + loss.backward() + + paddle.nn.utils.clip_grad_norm_(unet_model.parameters(), max_norm=1.0) + + grad_norm, param_norm = _compute_norms(unet_model) + opt.step() + + if step_losses: + avg_step_loss = sum(step_losses) / len(step_losses) + train_losses.append(avg_step_loss) + + if step % 50 == 0: + current_lr = opt.get_lr() + print( + f"Step {step}: Loss={avg_step_loss:.6f}, GradNorm={grad_norm:.6f}, ParamNorm={param_norm:.6f}, LR={current_lr:.2e}" + ) + + _update_ema(ema_rate, ema_params, unet_model) + + if lr_anneal_steps is not None and lr_anneal_steps != 0: + _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr) + + step += 1 + + if step % valid_interval == 0: + unet_model.eval() + with paddle.no_grad(): + valid_batch = next(dl_valid) + all_valid_losses = [] + + for i in range(0, len(valid_batch), microbatch): + micro = valid_batch[i : i + microbatch] + micro_cond = {k: v[i : i + microbatch] for k, v in cond.items()} + + t, weights = schedule_sampler.sample(len(micro)) + + compute_losses = functools.partial( + diff_model.training_losses, + unet_model, + paddle.stack(micro), + t, + model_kwargs=micro_cond, + valid=True, + ) + losses = compute_losses() + + if "loss" in losses: + all_valid_losses.append( + (losses["loss"] * weights).mean().item() + ) + + if len(all_valid_losses) > 0: + avg_valid_loss = sum(all_valid_losses) / len(all_valid_losses) + valid_losses.append(avg_valid_loss) + print( + f"Step {step}: Train Loss: {train_losses[-1]:.6f}, Valid Loss: {avg_valid_loss:.6f}" + ) + + unet_model.train() + + paddle.save(unet_model.state_dict(), "unet.pdparams") + + plot_losses() + + +def plot_losses(): + if len(train_losses) == 0 or len(valid_losses) == 0: + print("没有足够的数据来绘制损失曲线") + return + + plt.figure(figsize=(10, 6)) + + # 绘制训练损失 + plt.plot(train_losses, label="Training Loss", alpha=0.8) + + valid_interval = 100 + valid_steps = [(i + 1) * valid_interval for i in range(len(valid_losses))] + plt.plot(valid_steps, valid_losses, label="Validation Loss", alpha=0.8, marker="o") + + plt.xlabel("Training Steps") + plt.ylabel("Loss") + plt.title("Training and Validation Loss") + plt.legend() + plt.grid(True) + plt.tight_layout() + + # 保存图像 + plt.savefig("loss_curve.png", dpi=300, bbox_inches="tight") + + # 显示图像 + plt.show() + + +def _compute_norms(model, grad_scale=1.0): + grad_norm = 0.0 + param_norm = 0.0 + for p in model.parameters(): + with paddle.no_grad(): + param_norm += paddle.norm(p, p=2, dtype=paddle.float32).item() ** 2 + if p.grad is not None: + grad_norm += paddle.norm(p.grad, p=2, dtype=paddle.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + +def _update_ema(ema_rate, ema_params, source_model): + for rate, target_params_dict in zip(ema_rate, ema_params): + for name, target_param in target_params_dict.items(): + source_param = dict(source_model.named_parameters())[name] + updated = target_param.detach() * rate + source_param.detach() * (1 - rate) + target_param.set_value(updated) + + +def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr): + if not lr_anneal_steps: + return + frac_done = (step + resume_step) / lr_anneal_steps + new_lr = final_lr * (frac_done) + lr * (1 - frac_done) + opt.set_lr(new_lr) + + +def log_loss_dict(diffusion, ts, losses, is_valid=False, add_to_list=True): + for key, values in losses.items(): + mean_loss = values.mean().item() + logger.info(f"{key}: {mean_loss:.6f}") + + if key == "loss" and add_to_list: + if is_valid: + valid_losses.append(mean_loss) + else: + train_losses.append(mean_loss) + + +def evaluate(cfg): + unet_model = create_model( + image_size=cfg.UNET.image_size, + num_channels=cfg.UNET.num_channels, + num_res_blocks=cfg.UNET.num_res_blocks, + num_heads=cfg.UNET.num_heads, + num_head_channels=cfg.UNET.num_head_channels, + attention_resolutions=cfg.UNET.attention_resolutions, + ) + + unet_model.set_state_dict(paddle.load(cfg.UNET.ema_path)) + + diff_model = create_gaussian_diffusion( + steps=cfg.Diff.steps, noise_schedule=cfg.Diff.noise_schedule + ) + + sample_fn = diff_model.p_sample_loop + gen_latents = sample_fn( + unet_model, + (cfg.EVAL.test_batch_size, 1, cfg.EVAL.time_length, cfg.EVAL.latent_length), + )[:, 0] + + max_val, min_val = ( + cfg.DATA.max_val, + cfg.DATA.min_val, + ) # np.load(cfg.DATA.max_val), np.load(cfg.DATA.min_val) + max_val, min_val = paddle.to_tensor(max_val), paddle.to_tensor(min_val) + gen_latents = (gen_latents + 1) * (max_val - min_val) / 2.0 + min_val + + nf, in_normalizer, out_normalizer, coord = create_slim(cfg) + nf.set_state_dict(paddle.load(cfg.CNF.model_path)) + coord = in_normalizer.normalize(coord) + + batch_size = 1 + n_samples = gen_latents.shape[0] + gen_fields = [] + + for sample_index in range(n_samples): + for i in range(gen_latents.shape[1] // batch_size): + new_latents = gen_latents[ + sample_index, i * batch_size : (i + 1) * batch_size + ] + if len(coord.shape) > 2: + new_latents = new_latents[:, None, None] + else: + new_latents = new_latents[:, None] + input_data = {"confild_x": coord, "latent_z": new_latents} + out = nf(input_data)["confild_output"] + out = out_normalizer.denormalize(out) + gen_fields.append(out.detach().cpu().numpy()) + + gen_fields = np.concatenate(gen_fields) + + np.save(cfg.save_path, gen_fields) + + +@hydra.main( + version_base=None, config_path="./conf", config_name="un_confild_case1.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 098893c38..81c0b2222 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - Aneurysm: zh/examples/aneurysm.md - BubbleNet: zh/examples/bubble.md - CFDGCN: zh/examples/cfdgcn.md + - CONFILD: zh/examples/confild.md - CVit(Advection): zh/examples/adv_cvit.md - CVit(NS): zh/examples/ns_cvit.md - Cylinder2D_unsteady: zh/examples/cylinder2d_unsteady.md diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 498a380cc..4fae03916 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -24,6 +24,7 @@ from ppsci.arch.amgnet import AMGNet # isort:skip from ppsci.arch.base import Arch # isort:skip from ppsci.arch.cfdgcn import CFDGCN # isort:skip +from ppsci.arch.confild import LatentContainer, LossType, SIRENAutodecoder_film, SpacedDiffusion, UNetModel, ModelVarType, ModelMeanType # isort:skip from ppsci.arch.smc_reac import SuzukiMiyauraModel # isort:skip from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip @@ -107,11 +108,15 @@ "GraphCastNet", "HEDeepONets", "LorenzEmbedding", + "LatentContainer", "LatentNO", "LatentNO_time", "LNO", + "LossType", "MLP", "ModelList", + "ModelVarType", + "ModelMeanType", "ModifiedMLP", "NowcastNet", "PhyCRNet", @@ -120,6 +125,8 @@ "PrecipNet", "RosslerEmbedding", "SFNONet", + "SIRENAutodecoder_film", + "SpacedDiffusion", "SPINN", "TADF", "TFNO1dNet", @@ -127,6 +134,7 @@ "TFNO3dNet", "Transformer", "UNetEx", + "UNetModel", "UNONet", "USCNN", "VelocityDiscriminator", diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py new file mode 100644 index 000000000..b5a808909 --- /dev/null +++ b/ppsci/arch/confild.py @@ -0,0 +1,1871 @@ +import enum +import math +from abc import abstractmethod +from collections import OrderedDict + +import numpy as np +import paddle + +DEFAULT_W0 = 30.0 + + +###################### ConFILD Model ####################### +class Swish(paddle.nn.Layer): + """ + Swish activation function: f(x) = x * sigmoid(x). + + A smooth, non-monotonic activation function that has been shown to work + better than ReLU on deeper models across a number of challenging datasets. + """ + + def __init__(self): + super().__init__() + self.Sigmoid = paddle.nn.Sigmoid() + + def forward(self, x): + """ + Apply Swish activation. + + Args: + x (paddle.Tensor): Input tensor. + + Returns: + paddle.Tensor: Output tensor with same shape as input. + """ + return x * self.Sigmoid(x) + + +class Sine(paddle.nn.Layer): + """ + Sine activation function for SIREN (Sinusoidal Representation Networks). + + Args: + w0 (float, optional): Frequency parameter for sine activation. Defaults to DEFAULT_W0 (30.0). + """ + + def __init__(self, w0=DEFAULT_W0): + self.w0 = w0 + super().__init__() + + def forward(self, input): + """ + Apply sine activation with frequency modulation. + + Args: + input (paddle.Tensor): Input tensor. + + Returns: + paddle.Tensor: sin(w0 * input). + """ + return paddle.sin(x=self.w0 * input) + + +def sine_init(m, w0=DEFAULT_W0): + """ + Weight initialization for SIREN hidden layers. + + Initializes weights uniformly in [-√(6/n)/w0, √(6/n)/w0] where n is input dimension. + This initialization is critical for maintaining stable signal propagation in SIREN networks. + + Args: + m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute). + w0 (float, optional): Frequency parameter. Defaults to DEFAULT_W0. + """ + with paddle.no_grad(): + if hasattr(m, "weight"): + num_input = m.weight.shape[-1] + m.weight.uniform_( + min=-math.sqrt(6 / num_input) / w0, max=math.sqrt(6 / num_input) / w0 + ) + + +def first_layer_sine_init(m): + """ + Weight initialization for SIREN first layer. + + Initializes weights uniformly in [-1/n, 1/n] where n is input dimension. + Different from hidden layers to handle raw coordinate inputs properly. + + Args: + m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute). + """ + with paddle.no_grad(): + if hasattr(m, "weight"): + num_input = m.weight.shape[-1] + m.weight.uniform_(min=-1 / num_input, max=1 / num_input) + + +def __check_Linear_weight(m): + if isinstance(m, paddle.nn.Linear): + if hasattr(m, "weight"): + return True + return False + + +def init_weights_normal(m): + if __check_Linear_weight(m): + init_KaimingNormal = paddle.nn.initializer.KaimingNormal( + nonlinearity="relu", negative_slope=0.0 + ) + init_KaimingNormal(m.weight) + + +def init_weights_selu(m): + if __check_Linear_weight(m): + num_input = m.weight.shape[-1] + init_Normal = paddle.nn.initializer.Normal(std=1 / math.sqrt(num_input)) + init_Normal(m.weight) + + +def init_weights_elu(m): + if __check_Linear_weight(m): + num_input = m.weight.shape[-1] + init_Normal = paddle.nn.initializer.Normal( + std=math.sqrt(1.5505188080679277) / math.sqrt(num_input) + ) + init_Normal(m.weight) + + +def init_weights_xavier(m): + if __check_Linear_weight(m): + init_XavierNormal = paddle.nn.initializer.XavierNormal() + init_XavierNormal(m.weight) + + +NLS_AND_INITS = { + "sine": (Sine(), sine_init, first_layer_sine_init), + "relu": (paddle.nn.ReLU(), init_weights_normal, None), + "sigmoid": (paddle.nn.Sigmoid(), init_weights_xavier, None), + "tanh": (paddle.nn.Tanh(), init_weights_xavier, None), + "selu": (paddle.nn.SELU(), init_weights_selu, None), + "softplus": (paddle.nn.Softplus(), init_weights_normal, None), + "elu": (paddle.nn.ELU(), init_weights_elu, None), + "swish": (Swish(), init_weights_xavier, None), +} + + +class BatchLinear(paddle.nn.Linear): + """ + Batch-wise linear transformation layer that supports manual parameter injection. + + This layer extends paddle.nn.Linear to allow passing parameters explicitly, + which is useful for meta-learning and hypernetwork applications. + + Args: + in_features (int): Size of input features. + out_features (int): Size of output features. + + Note: + - Weight shape: (out_features, in_features) + - Bias shape: (out_features,) + """ + + __doc__ = paddle.nn.Linear.__doc__ + + def forward(self, input, params=None): + """ + Forward pass with optional external parameters. + + Args: + input (paddle.Tensor): Input tensor of shape (..., in_features). + params (OrderedDict, optional): External parameters dict containing 'weight' and optionally 'bias'. + If None, uses internal parameters. Defaults to None. + + Returns: + paddle.Tensor: Output tensor of shape (..., out_features). + """ + if params is None: + params = OrderedDict(self.named_parameters()) + bias = params.get("bias", None) + weight = params["weight"] + + output = paddle.matmul(x=input, y=weight) + if bias is not None: + output += bias.unsqueeze(axis=-2) + return output + + +class FeatureMapping: + """ + Feature mapping class for Fourier Feature Networks. + + Supports multiple mapping strategies including Gaussian random Fourier features, + positional encoding, and radial basis functions (RBF) for improving coordinate-based + neural network representations. + + Reference: + Tancik et al. "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains" + """ + + def __init__( + self, + in_features, + mode="basic", + gaussian_mapping_size=256, + gaussian_rand_key=0, + gaussian_tau=1.0, + pe_num_freqs=4, + pe_scale=2, + pe_init_scale=1, + pe_use_nyquist=True, + pe_lowest_dim=None, + rbf_out_features=None, + rbf_range=1.0, + rbf_std=0.5, + ): + """ + Initialize feature mapping. + + Args: + in_features (int): Number of input features. + mode (str, optional): Mapping mode. Options: "basic", "gaussian", "positional", "rbf". Defaults to "basic". + gaussian_mapping_size (int, optional): Output dimension for Gaussian mapping. Defaults to 256. + gaussian_rand_key (int, optional): Random seed for Gaussian mapping. Defaults to 0. + gaussian_tau (float, optional): Standard deviation for Gaussian mapping. Defaults to 1.0. + pe_num_freqs (int, optional): Number of frequency bands for positional encoding. Defaults to 4. + pe_scale (int, optional): Base scale for frequencies in positional encoding. Defaults to 2. + pe_init_scale (int, optional): Initial scale multiplier for positional encoding. Defaults to 1. + pe_use_nyquist (bool, optional): Use Nyquist frequency to determine num_freqs. Defaults to True. + pe_lowest_dim (int, optional): Lowest dimension for Nyquist calculation. Defaults to None. + rbf_out_features (int, optional): Number of RBF centers. Defaults to None. + rbf_range (float, optional): Range for RBF center initialization. Defaults to 1.0. + rbf_std (float, optional): Standard deviation for RBF kernels. Defaults to 0.5. + """ + self.mode = mode + if mode == "basic": + self.B = np.eye(in_features) + elif mode == "gaussian": + rng = np.random.default_rng(gaussian_rand_key) + self.B = rng.normal( + loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features) + ) + elif mode == "positional": + if pe_use_nyquist and pe_lowest_dim: + pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim) + self.B = pe_init_scale * np.vstack( + [(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)] + ) + self.dim = tuple(self.B.shape)[0] * 2 + elif mode == "rbf": + self.centers = paddle.nn.Parameter( + paddle.empty(shape=(rbf_out_features, in_features), dtype="float32") + ) + self.sigmas = paddle.nn.Parameter( + paddle.empty(shape=rbf_out_features, dtype="float32") + ) + init_Uniform = paddle.nn.initializer.Uniform( + low=-1 * rbf_range, high=rbf_range + ) + init_Uniform(self.centers) + init_Constant = paddle.nn.initializer.Constant(value=rbf_std) + init_Constant(self.sigmas) + + def __call__(self, input): + if self.mode in ["basic", "gaussian", "positional"]: + return self.fourier_mapping(input, self.B) + elif self.mode == "rbf": + return self.rbf_mapping(input) + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + @staticmethod + def fourier_mapping(x, B): + """ + Apply Fourier feature mapping: [sin(2πxB^T), cos(2πxB^T)]. + + Args: + x (paddle.Tensor): Input coordinates of shape (..., in_features). + B (np.ndarray): Frequency matrix of shape (mapping_size, in_features). + + Returns: + paddle.Tensor: Fourier features of shape (..., 2 * mapping_size). + """ + if B is None: + return x + else: + B = paddle.to_tensor(data=B, dtype="float32", place=x.place) + x_proj = 2.0 * np.pi * x @ B.T + return paddle.concat( + x=[paddle.sin(x=x_proj), paddle.cos(x=x_proj)], axis=-1 + ) + + def rbf_mapping(self, x): + size = tuple(x.shape)[:-1] + tuple(self.centers.shape) + x = x.unsqueeze(axis=-2).expand(shape=size) + distances = paddle.pow(x - self.centers, 2).sum(axis=-1) * self.sigmas + return self.gaussian(distances) + + @staticmethod + def gaussian(alpha): + phi = paddle.exp(x=-1 * paddle.pow(alpha, 2)) + return phi + + +class SIRENAutodecoder_film(paddle.nn.Layer): + """ + SIREN (Sinusoidal Representation Networks) with FiLM conditioning for autodecoding. + + This architecture uses sine activations and latent code modulation (FiLM) for + implicit neural representations. It takes both coordinate inputs and latent codes, + making it suitable for learning multiple shapes/scenes with a single network. + + Reference: + Sitzmann et al. "Implicit Neural Representations with Periodic Activation Functions" (NeurIPS 2020) + + Args: + input_keys (Tuple[str, ...], optional): Keys to get input tensors from dict. First key for coordinates, second for latents. + output_keys (Tuple[str, ...], optional): Keys to save output tensors into dict. + in_coord_features (int, optional): Number of input coordinate features (e.g., 2 for 2D, 3 for 3D). + in_latent_features (int, optional): Number of latent features for conditioning. + out_features (int, optional): Number of output features (e.g., 3 for RGB). + num_hidden_layers (int, optional): Number of hidden layers. + hidden_features (int, optional): Number of hidden layer features. + outermost_linear (bool, optional): Whether to use linear layer at output. Defaults to False. + nonlinearity (str, optional): Activation function. Options: "sine", "relu", "tanh", etc. Defaults to "sine". + weight_init (Callable, optional): Custom weight initialization function. Defaults to None. + bias_init (Callable, optional): Custom bias initialization function. Defaults to None. + premap_mode (str, optional): Feature mapping mode before network. Options: "gaussian", "positional", "rbf". Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.SIRENAutodecoder_film( + ... input_keys=["coords", "latents"], + ... output_keys=("output",), + ... in_coord_features=2, + ... in_latent_features=128, + ... out_features=3, + ... num_hidden_layers=10, + ... hidden_features=128, + ... ) + >>> input_data = { + ... "coords": paddle.randn([1000, 2]), + ... "latents": paddle.randn([1000, 128]) + ... } + >>> out_dict = model(input_data) + >>> print(out_dict["output"].shape) + [1000, 3] + """ + + def __init__( + self, + input_keys, + output_keys, + in_coord_features, + in_latent_features, + out_features, + num_hidden_layers, + hidden_features, + outermost_linear=False, + nonlinearity="sine", + weight_init=None, + bias_init=None, + premap_mode=None, + **kwargs, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.premap_mode = premap_mode + if self.premap_mode is not None: + self.premap_layer = FeatureMapping( + in_coord_features, mode=premap_mode, **kwargs + ) + in_coord_features = self.premap_layer.dim + self.first_layer_init = None + self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity] + if weight_init is not None: + self.weight_init = weight_init + else: + self.weight_init = nl_weight_init + self.net1 = paddle.nn.LayerList( + sublayers=[BatchLinear(in_coord_features, hidden_features)] + + [ + BatchLinear(hidden_features, hidden_features) + for i in range(num_hidden_layers) + ] + + [BatchLinear(hidden_features, out_features)] + ) + self.net2 = paddle.nn.LayerList( + sublayers=[ + BatchLinear(in_latent_features, hidden_features, bias_attr=False) + for i in range(num_hidden_layers + 1) + ] + ) + if self.weight_init is not None: + self.net1.apply(self.weight_init) + self.net2.apply(self.weight_init) + if first_layer_init is not None: + self.net1[0].apply(first_layer_init) + self.net2[0].apply(first_layer_init) + if bias_init is not None: + self.net2.apply(bias_init) + + def forward(self, input_data): + coords = input_data[self.input_keys[0]] + latents = input_data[self.input_keys[1]] + if self.premap_mode is not None: + x = self.premap_layer(coords) + else: + x = coords + + for i in range(len(self.net1) - 1): + x = self.net1[i](x) + self.net2[i](latents) + x = self.nl(x) + x = self.net1[-1](x) + return {self.output_keys[0]: x} + + def disable_gradient(self): + for param in self.parameters(): + param.stop_gradient = True + + +class LatentContainer(paddle.nn.Layer): + """ + Learnable latent code container for autodecoding applications. + + This module stores and retrieves per-sample latent codes, which can be used + for representing multiple instances (shapes, scenes) with a single decoder network. + Supports multi-GPU training and different dimensional arrangements. + + Reference: + Park et al. "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" (CVPR 2019) + + Args: + input_keys (Tuple[str, ...], optional): Key to get batch indices from dict. Defaults to ("input",). + output_keys (Tuple[str, ...], optional): Key to save latent codes into dict. Defaults to ("output",). + N_samples (int, optional): Total number of samples/instances in dataset. Defaults to None. + N_features (int, optional): Dimension of latent codes. Defaults to None. + dims (int, optional): Number of spatial dimensions (for proper broadcasting). Defaults to None. + lumped (bool, optional): If True, adds single dimension; if False, adds dims dimensions. Defaults to False. + + Examples: + >>> import ppsci + >>> import paddle + >>> model = ppsci.arch.LatentContainer( + ... N_samples=1600, + ... N_features=128, + ... dims=2, + ... lumped=True + ... ) + >>> batch_indices = paddle.randint(0, 1600, [32], dtype='int64') + >>> input_dict = {"input": batch_indices} + >>> out_dict = model(input_dict) + >>> print(out_dict["output"].shape) + [32, 1, 128] + """ + + def __init__( + self, + input_keys=("input",), + output_keys=("output",), + N_samples=None, + N_features=None, + dims=None, + lumped=False, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.dims = [1] * dims if not lumped else [1] + self.expand_dims = " ".join(["1" for _ in range(dims)]) if not lumped else "1" + self.expand_dims = f"N f -> N {self.expand_dims} f" + self.latents = self.create_parameter( + shape=(N_samples, N_features), + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0.0), + ) + + def forward(self, batch_ids): + x = batch_ids[self.input_keys[0]] + selected_latents = paddle.gather(self.latents, x) + if len(selected_latents.shape) > 1: + getShape = ( + [tuple(selected_latents.shape)[0]] + + self.dims + + [tuple(selected_latents.shape)[1]] + ) + else: + getShape = [-1] + self.dims + expanded_latents = selected_latents.reshape(getShape) + return {self.output_keys[0]: expanded_latents} + + +###################### GaussianDiffusion Model ####################### +class ModelVarType(enum.Enum): + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + res = paddle.to_tensor(data=arr, dtype=paddle.float32)[timesteps] + while len(tuple(res.shape)) < len(broadcast_shape): + res = res[..., None] + return res.expand(shape=broadcast_shape) + + +def split(x, num_or_sections, axis=0): + if isinstance(num_or_sections, int): + return paddle.split(x, x.shape[axis] // num_or_sections, axis) + else: + return paddle.split(x, num_or_sections, axis) + + +class ModelMeanType(enum.Enum): + PREVIOUS_X = enum.auto() + START_X = enum.auto() + EPSILON = enum.auto() + + +def mean_flat(tensor): + return paddle.mean(tensor, axis=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, paddle.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + ( + x + if isinstance(x, paddle.Tensor) + else paddle.to_tensor(x, dtype=tensor.dtype, place=tensor.place) + ) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + paddle.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2) + ) + + +class GaussianDiffusion: + """ + Gaussian diffusion process for denoising diffusion probabilistic models (DDPM). + + Implements the forward diffusion process q(x_t|x_0) and reverse denoising process p(x_{t-1}|x_t). + Supports various parameterizations (epsilon, x_0, x_{t-1}) and variance schedules. + + Reference: + Ho et al. "Denoising Diffusion Probabilistic Models" (NeurIPS 2020) + Nichol & Dhariwal "Improved Denoising Diffusion Probabilistic Models" (ICML 2021) + + Args: + betas (np.ndarray): Noise schedule β_t for t=0,...,T-1. + model_mean_type (ModelMeanType): Parameterization of model output. + model_var_type (ModelVarType): Variance parameterization (fixed or learned). + loss_type (LossType): Loss function type (MSE, KL, etc.). + rescale_timesteps (bool, optional): Rescale timesteps to [0, 1000]. Defaults to False. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(tuple(betas.shape)) == 1, "betas must be 1-D" + assert (betas > 0).astype("bool").all() and (betas <= 1).astype("bool").all() + + self.num_timesteps = int(tuple(betas.shape)[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert tuple(self.alphas_cumprod_prev.shape) == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = paddle.randn(x_start.shape) + + sqrt_alpha_cumprod_t = _extract_into_tensor( + self.sqrt_alphas_cumprod, t, x_start.shape + ) + sqrt_one_minus_alpha_cumprod_t = _extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + + return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert tuple(x_t.shape) == tuple(xprev.shape) + return ( + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, tuple(x_t.shape)) + * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, + t, + tuple(x_t.shape), + ) + * x_t + ) + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert tuple(x_t.shape) == tuple(eps.shape) + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, tuple(x_t.shape)) + * x_t + - _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, tuple(x_t.shape) + ) + * eps + ) + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + if model_kwargs is None: + model_kwargs = {} + B, C = tuple(x.shape)[:2] + assert tuple(t.shape) == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert tuple(model_output.shape) == (B, C * 2, *tuple(x.shape)[2:]) + model_output, model_var_values = split( + x=model_output, num_or_sections=C, axis=1 + ) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = paddle.exp(x=model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, tuple(x.shape) + ) + max_log = _extract_into_tensor(np.log(self.betas), t, tuple(x.shape)) + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = paddle.exp(x=model_log_variance) + else: + model_variance, model_log_variance = { + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, tuple(x.shape)) + model_log_variance = _extract_into_tensor( + model_log_variance, t, tuple(x.shape) + ) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clip(min=-1, max=1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + assert ( + tuple(model_mean.shape) + == tuple(model_log_variance.shape) + == tuple(pred_xstart.shape) + == tuple(x.shape) + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def q_posterior_mean_variance(self, x_start, x_t, t): + assert tuple(x_start.shape) == tuple(x_t.shape) + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, tuple(x_t.shape)) + * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, tuple(x_t.shape)) * x_t + ) + posterior_variance = _extract_into_tensor( + self.posterior_variance, t, tuple(x_t.shape) + ) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, tuple(x_t.shape) + ) + assert ( + tuple(posterior_mean.shape)[0] + == tuple(posterior_variance.shape)[0] + == tuple(posterior_log_variance_clipped.shape)[0] + == tuple(x_start.shape)[0] + ) + return (posterior_mean, posterior_variance, posterior_log_variance_clipped) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.astype(dtype="float32") * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].astype(dtype="float32") + p_mean_var[ + "variance" + ] * gradient.astype(dtype="float32") + return new_mean + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - paddle.sqrt(1 - alpha_bar) * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = paddle.randn(shape=x.shape, dtype=x.dtype) + nonzero_mask = ( + (t != 0) + .astype(dtype="float32") + .reshape([-1, *([1] * (len(tuple(x.shape)) - 1))]) + ) + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = ( + out["mean"] + nonzero_mask * paddle.exp(x=0.5 * out["log_variance"]) * noise + ) + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = paddle.randn(shape=shape) + indices = list(range(self.num_timesteps))[::-1] + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + for i in indices: + t = paddle.to_tensor(data=[i] * shape[0]) + with paddle.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def training_losses( + self, model, x_start, t, model_kwargs=None, noise=None, valid=False + ): + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = paddle.randn(x_start.shape) + + x_t = self.q_sample(x_start=x_start, t=t, noise=noise) + # terms = {} + # model_output = model(x_t, t) + + # # Handle different model outputs + # if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + # assert model_output.shape[1] == 2 * x_start.shape[1], "Output channels must be 2x input channels" + # model_output, model_var_values = split(model_output, 2, axis=1) + + # Calculate the MSE loss for epsilon prediction + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=True, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = split(model_output, C, axis=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = paddle.concat( + [model_output.detach(), model_var_values], axis=1 + ) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=True, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + + if valid == False: + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + terms["valid_mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["valid_mse"] + terms["vb"] + else: + terms["loss"] = terms["valid_mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = paddle.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = paddle.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = paddle.log(cdf_plus.clip(min=1e-12)) + log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clip(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = paddle.where( + x < -0.999, + log_cdf_plus, + paddle.where( + x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clip(min=1e-12)) + ), + ) + assert log_probs.shape == x.shape + return log_probs + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * ( + 1.0 + paddle.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * paddle.pow(x, 3))) + ) + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class SpacedDiffusion(GaussianDiffusion): + """ + Accelerated diffusion process that skips timesteps for faster sampling. + + Implements DDIM-style sampling by using a subset of timesteps from the original + diffusion process, enabling faster inference without retraining the model. + + Reference: + Song et al. "Denoising Diffusion Implicit Models" (ICLR 2021) + + Args: + use_timesteps (Sequence[int]): Collection of timesteps to retain from original process + (e.g., [0, 10, 20, ..., 1000] for 100-step sampling). + **kwargs: Additional arguments for base GaussianDiffusion (betas, model_mean_type, etc.). + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = paddle.to_tensor( + data=self.timestep_map, dtype=ts.dtype # , place=ts.place + ) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.astype(dtype="float32") * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +###################### UNET Model ####################### +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return paddle.nn.Conv1D(*args, **kwargs) + elif dims == 2: + return paddle.nn.Conv2D(*args, **kwargs) + elif dims == 3: + return paddle.nn.Conv3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + return paddle.nn.Linear(*args, **kwargs) + + +class TimestepBlock(paddle.nn.Layer): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + pass + + +class ResBlock(TimestepBlock): + """ + Residual block with timestep embedding for diffusion models. + + Implements a residual connection with two convolutional layers, timestep conditioning, + and optional up/downsampling. Supports FiLM-style adaptive normalization. + + Args: + channels (int): Number of input channels. + emb_channels (int): Number of timestep embedding channels. + dropout (float): Dropout probability. + out_channels (int, optional): Number of output channels. Defaults to channels. + use_conv (bool, optional): Use conv for skip connection if channels differ. Defaults to False. + use_scale_shift_norm (bool, optional): Use FiLM-style conditioning. Defaults to False. + dims (int, optional): Spatial dimensions (1D/2D/3D). Defaults to 2. + use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False. + up (bool, optional): Apply upsampling. Defaults to False. + down (bool, optional): Apply downsampling. Defaults to False. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.in_layers = paddle.nn.Sequential( + normalization(channels), + paddle.nn.Silu(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.updown = up or down + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = paddle.nn.Identity() + self.emb_layers = paddle.nn.Sequential( + paddle.nn.Silu(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = paddle.nn.Sequential( + normalization(self.out_channels), + paddle.nn.Silu(), + paddle.nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + if self.out_channels == channels: + self.skip_connection = paddle.nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).astype(h.dtype) + while len(tuple(emb_out.shape)) < len(tuple(h.shape)): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + (scale, shift) = paddle.chunk(x=emb_out, chunks=2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class TimestepEmbedSequential(paddle.nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +NUM_CLASSES = 1000 + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False) + elif dims == 2: + return paddle.nn.AvgPool2D(*args, **kwargs, exclusive=False) + elif dims == 3: + return paddle.nn.AvgPool3D(*args, **kwargs, exclusive=False) + raise ValueError(f"unsupported dimensions: {dims}") + + +class Downsample(paddle.nn.Layer): + """ + Spatial downsampling layer (2x reduction). + + Can use either strided convolution or average pooling for downsampling. + + Args: + channels (int): Number of input channels. + use_conv (bool): Use strided conv (True) or avg pooling (False). + dims (int, optional): Spatial dimensions. Defaults to 2. + out_channels (int, optional): Number of output channels. Defaults to channels. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + """Apply downsampling.""" + assert tuple(x.shape)[1] == self.channels + return self.op(x) + + +class Upsample(paddle.nn.Layer): + """ + Spatial upsampling layer (2x expansion). + + Uses nearest-neighbor interpolation followed by optional convolution. + + Args: + channels (int): Number of input channels. + use_conv (bool): Apply convolution after upsampling. + dims (int, optional): Spatial dimensions. Defaults to 2. + out_channels (int, optional): Number of output channels. Defaults to channels. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + """Apply upsampling.""" + assert tuple(x.shape)[1] == self.channels + if self.dims == 3: + x = paddle.nn.functional.interpolate( + x=x, + size=(tuple(x.shape)[2], tuple(x.shape)[3] * 2, tuple(x.shape)[4] * 2), + mode="nearest", + ) + else: + x = paddle.nn.functional.interpolate(x=x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +def count_flops_attn(model, _x, y): + b, c, *spatial = tuple(y[0].shape) + num_spatial = int(np.prod(spatial)) + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += paddle.to_tensor(data=[matmul_ops], dtype="float64") + + +class QKVAttentionLegacy(paddle.nn.Layer): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + bs, width, length = tuple(qkv.shape) + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + # split_size: 为 int 时 torch 表示块的大小,paddle 表示块的个数 + (q, k, v) = split(qkv.reshape((bs * self.n_heads, ch * 3, length)), ch, 1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale) + weight = paddle.nn.functional.softmax( + x=weight.astype(dtype="float32"), axis=-1 + ).astype(weight.dtype) + a = paddle.einsum("bts,bcs->bct", weight, v) + return a.reshape((bs, -1, length)) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(paddle.nn.Layer): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + bs, width, length = tuple(qkv.shape) + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + (q, k, v) = qkv.chunk(chunks=3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( # 非复数 + "bct,bcs->bts", + (q * scale).reshape([bs * self.n_heads, ch, length]), + (k * scale).reshape([bs * self.n_heads, ch, length]), + ) + weight = paddle.nn.functional.softmax( + x=weight.astype(dtype="float32"), axis=-1 + ).astype(weight.dtype) + a = paddle.einsum( + "bts,bcs->bct", weight, v.reshape((bs * self.n_heads, ch, length)) + ) + return a.reshape((bs, -1, length)) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class GroupNorm32(paddle.nn.GroupNorm): + def forward(self, x): + return super().forward(x.astype(dtype="float32")).astype(x.dtype) + + +def normalization(channels): + return GroupNorm32(32, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +def checkpoint(func, inputs, params, flag): + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with paddle.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [stop_gradient(x, stop=False) for x in ctx.input_tensors] + with paddle.enable_grad(): + shallow_copies = [x.reshape(x.shape) for x in ctx.input_tensors] + # print(shallow_copies) + output_tensors = ctx.run_function(*shallow_copies) + input_grads = paddle.grad( + outputs=output_tensors, + inputs=ctx.input_tensors + ctx.input_params, + grad_outputs=output_grads, + allow_unused=True, + # retain_graph=True, create_graph=False + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + + # 确保将input_grads转换为元组,然后与(None, None)连接 + # PyLayer要求backward方法返回元组类型 + # if input_grads: + return tuple(input_grads) + # else: + # return (None, None) + + +def stop_gradient(input, stop): + input.stop_gradient = stop + return input + + +class AttentionBlock(paddle.nn.Layer): + """ + Self-attention block for spatial feature maps. + + Applies multi-head self-attention over spatial locations in feature maps, + allowing the model to capture long-range dependencies. + + Args: + channels (int): Number of input/output channels. + num_heads (int, optional): Number of attention heads. Defaults to 1. + num_head_channels (int, optional): Channels per head (overrides num_heads). Defaults to -1. + use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False. + use_new_attention_order (bool, optional): Use optimized attention implementation. Defaults to False. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + self.attention = QKVAttention(self.num_heads) + else: + self.attention = QKVAttentionLegacy(self.num_heads) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = tuple(x.shape) + x = x.reshape((b, c, -1)) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape((b, c, *spatial)) + + +def convert_module_to_f16(l): + if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)): + l.weight.data = l.weight.data.astype(dtype="float16") + if l.bias is not None: + l.bias.data = l.bias.data.astype(dtype="float16") + + +def convert_module_to_f32(l): + if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)): + l.weight.data = l.weight.data.astype(dtype="float32") + if l.bias is not None: + l.bias.data = l.bias.data.astype(dtype="float32") + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings for diffusion models. + + Similar to positional encodings in transformers, but for continuous timesteps. + Uses sinusoids of exponentially increasing frequencies. + + Args: + timesteps (paddle.Tensor): Timestep values of shape (batch_size,). + dim (int): Embedding dimension. + max_period (int, optional): Maximum period for sinusoids. Defaults to 10000. + + Returns: + paddle.Tensor: Timestep embeddings of shape (batch_size, dim). + """ + half = dim // 2 + freqs = paddle.exp( + x=-math.log(max_period) + * paddle.arange(start=0, end=half, dtype="float32") + / half + ) # .to(paddle.CUDAPlace(0)) + args = timesteps[:, None].astype(dtype="float32") * freqs[None] + embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) + if dim % 2: + embedding = paddle.concat( + x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1 + ) + return embedding + + +class UNetModel(paddle.nn.Layer): + """ + Full UNet model with attention and timestep embedding for diffusion models. + + Implements a U-Net architecture with residual blocks, self-attention at multiple resolutions, + and timestep conditioning via adaptive normalization (FiLM). Designed for denoising diffusion + probabilistic models (DDPM) and can be conditioned on class labels. + + Reference: + Ronneberger et al. "U-Net: Convolutional Networks for Biomedical Image Segmentation" (MICCAI 2015) + Dhariwal & Nichol "Diffusion Models Beat GANs on Image Synthesis" (NeurIPS 2021) + + Args: + image_size (int): Input image size (maintained for interface compatibility). + in_channels (int): Number of channels in input tensor. + model_channels (int): Base channel count for model (multiplied by channel_mult). + out_channels (int): Number of channels in output tensor. + num_res_blocks (int): Number of residual blocks per downsampling level. + attention_resolutions (list/tuple): Downsample factors where to apply attention (e.g., [4, 8, 16]). + dropout (float, optional): Dropout probability in residual blocks. Defaults to 0.0. + channel_mult (tuple, optional): Channel multipliers per level (e.g., (1, 2, 4, 8)). Defaults to (1, 2, 4, 8). + conv_resample (bool, optional): Use learned convolutional up/downsampling. Defaults to True. + dims (int, optional): Data dimensionality (1=1D, 2=2D, 3=3D). Defaults to 2. + num_classes (int, optional): Number of classes for class-conditional generation. Defaults to None. + use_checkpoint (bool, optional): Enable gradient checkpointing to save memory. Defaults to False. + use_fp16 (bool, optional): Use float16 precision for forward pass. Defaults to False. + num_heads (int, optional): Number of attention heads in each attention block. Defaults to 1. + num_head_channels (int, optional): Fixed channels per head (overrides num_heads if set). Defaults to -1. + num_heads_upsample (int, optional): Attention heads for upsampling blocks. Defaults to -1 (use num_heads). + use_scale_shift_norm (bool, optional): Use FiLM-style conditioning in ResBlocks. Defaults to False. + resblock_updown (bool, optional): Use ResBlocks for up/downsampling instead of conv layers. Defaults to False. + use_new_attention_order (bool, optional): Use optimized QKV attention implementation. Defaults to False. + + Examples: + >>> import ppsci + >>> import paddle + >>> model = ppsci.arch.UNetModel( + ... image_size=64, + ... in_channels=3, + ... model_channels=128, + ... out_channels=3, + ... num_res_blocks=2, + ... attention_resolutions=[8, 16], + ... channel_mult=(1, 2, 4, 8), + ... num_heads=4, + ... ) + >>> x = paddle.randn([4, 3, 64, 64]) + >>> t = paddle.randint(0, 1000, [4]) + >>> out = model(x, t) + >>> print(out.shape) + [4, 3, 64, 64] + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = "float16" if use_fp16 else "float32" + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + time_embed_dim = model_channels * 4 + self.time_embed = paddle.nn.Sequential( + linear(model_channels, time_embed_dim), + paddle.nn.Silu(), + linear(time_embed_dim, time_embed_dim), + ) + if self.num_classes is not None: + self.label_emb = paddle.nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=time_embed_dim + ) + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = paddle.nn.LayerList( + sublayers=[ + TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1)) + ] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [] + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.output_blocks = paddle.nn.LayerList(sublayers=[]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [] + layers.append( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + self.out = paddle.nn.Sequential( + normalization(ch), + paddle.nn.Silu(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.num_classes is not None: + assert tuple(y.shape) == (tuple(x.shape)[0],) + emb = emb + self.label_emb(y) + h = x.astype(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = paddle.concat(x=[h, hs.pop()], axis=1) + h = module(h, emb) + h = h.astype(x.dtype) + return self.out(h)