From b19d5204242d9206790331866c9c2d911a3c9a6a Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Sun, 21 Sep 2025 18:34:22 +0800 Subject: [PATCH 01/11] =?UTF-8?q?=E3=80=90Hackathon=209th=20No.105?= =?UTF-8?q?=E3=80=91CoNFiLD=20=E8=AE=BA=E6=96=87=E5=A4=8D=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/zh/examples/confild.md | 369 ++++++ examples/confild/conf/confild_case1.yaml | 124 ++ examples/confild/conf/confild_case2.yaml | 108 ++ examples/confild/conf/confild_case3.yaml | 108 ++ examples/confild/conf/confild_case4.yaml | 108 ++ examples/confild/confild.py | 862 ++++++++++++++ ppsci/arch/__init__.py | 7 + ppsci/arch/confild.py | 1385 ++++++++++++++++++++++ 8 files changed, 3071 insertions(+) create mode 100644 docs/zh/examples/confild.md create mode 100644 examples/confild/conf/confild_case1.yaml create mode 100644 examples/confild/conf/confild_case2.yaml create mode 100644 examples/confild/conf/confild_case3.yaml create mode 100644 examples/confild/conf/confild_case4.yaml create mode 100644 examples/confild/confild.py create mode 100644 ppsci/arch/confild.py diff --git a/docs/zh/examples/confild.md b/docs/zh/examples/confild.md new file mode 100644 index 0000000000..96e167182c --- /dev/null +++ b/docs/zh/examples/confild.md @@ -0,0 +1,369 @@ +# AI辅助的时空湍流生成:条件神经场潜在扩散模型(CoNFILD) + +Distributed under a Creative Commons Attribution license 4.0 (CC BY). + +## 1. 背景简介 +### 1.1 论文信息 +| 年份 | 期刊 | 作者 | 引用数 | 论文PDF与补充材料 | +|----------------|---------------------|--------------------------------------------------------------------------------------------------|--------|----------------------------------------------------------------------------------------------------| +| 2024年1月3日 | Nature Communications | Pan Du, Meet Hemant Parikh, Xiantao Fan, Xin-Yang Liu, Jian-Xun Wang | 15 | [论文链接](https://doi.org/10.1038/s41467-024-54712-1)
[代码仓库](https://github.com/jx-wang-s-group/CoNFILD) | + +### 1.2 作者介绍 +- **通讯作者**:Jian-Xun Wang(王建勋)
所属机构:美国圣母大学航空航天与机械工程系、康奈尔大学机械与航空航天工程系
研究方向:湍流建模、生成式AI、物理信息机器学习
+ +- **其他作者**:
Pan Du、Meet Hemant Parikh(共同一作):圣母大学博士生,研究方向为生成式模型与计算流体力学
Xiantao Fan、Xin-Yang Liu:圣母大学研究助理,负责数值模拟与数据生成 + +### 1.3 模型&复现代码 +| 问题类型 | 在线运行 | 神经网络架构 | 评估指标 | +|------------------------|----------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------| +| 时空湍流生成 | [aistudio](https://aistudio.baidu.com/projectdetail/8933946) | 条件神经场+潜在扩散模型 | MSE: 0.041(速度场) | + +=== "模型训练命令" +```bash +git clone https://github.com/PaddlePaddle/PaddleScience.git +cd PaddleScience/examples/confild +python confild.py mode=train +``` + +=== "预训练模型快速评估" + +``` sh +python confild.py mode=eval +``` + +## 2. 问题定义 +### 2.1 研究背景 +湍流模拟在航空航天、海洋工程等领域至关重要,但传统方法如直接数值模拟(DNS)和大涡模拟(LES)计算成本高昂,难以应用于高雷诺数或实时场景。现有深度学习模型多基于确定性框架,难以捕捉湍流的混沌特性,且在复杂几何域中表现受限。 + +### 2.2 核心挑战 +1. **高维数据**:三维时空湍流数据维度高达 \(O(10^9)\),传统生成模型内存需求巨大。 +2. **随机性建模**:需同时捕捉湍流的多尺度统计特性与瞬时动态。 +3. **几何适应性**:需支持不规则计算域与自适应网格。 + +### 2.3 创新方法 +提出**条件神经场潜在扩散模型(CoNFILD)**,通过三阶段框架解决上述挑战: +1. **神经场编码**:将高维流场压缩为低维潜在表示,压缩比达0.002%-0.017%。 +2. **潜在扩散**:在潜在空间进行概率扩散过程,学习湍流统计分布。 +3. **零样本条件生成**:结合贝叶斯推理,无需重新训练即可实现传感器重建、超分辨率等任务。 + +![图1 CoNFILD框架](./confild.png) +*框架示意图:CNF编码器将流场映射到潜在空间,扩散模型生成新潜在样本,解码器重建物理场* + +## 3. 模型构建 +### 3.1 条件神经场(CNF) +- **架构**:基于SIREN网络,采用正弦激活函数捕捉周期性特征。 +- **数学表示**: + $$ + \mathscr{E}(\mathbf{X},\mathbf{L}) = \text{SIREN}(\mathbf{x}) + \text{FILM}(\mathbf{L}) + $$ + 其中FILM(Feature-wise Linear Modulation)通过潜在向量\(\mathbf{L}\)调节每层偏置。 + +### 3.2 潜在扩散模型 +- **前向过程**:逐步添加高斯噪声,潜在表示\(\mathbf{z}_0 \rightarrow \mathbf{z}_T\)。 +- **逆向过程**:训练U-Net预测噪声,通过迭代去噪生成新样本: + $$ + \mathbf{z}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{z}_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{z}_t, t) \right) + \sigma_t \epsilon + $$ + +### 3.3 零样本条件生成 +- **贝叶斯后验采样**:基于稀疏观测\(\Psi\),通过梯度修正潜在空间采样: + $$ + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t|\Psi) \approx \nabla_{\mathbf{z}_t} \log p(\Psi|\mathbf{z}_t) + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t) + $$ + +## 4. 问题求解 +### 4.1 数据集准备 +数据文件说明如下: +``` +data # CNF的训练数据集 +| +|-- data.npy # 要拟合的数据 +| +|-- coords.npy # 查询坐标 +``` + +在加载数据之后,需要进行normalization,以便于训练。具体代码如下: +```python +class Normalizer_ts(object): + def __init__(self, params=[], method="-11", dim=None): + self.params = params + self.method = method + self.dim = dim + + def fit_normalize(self, data): + assert type(data) == paddle.Tensor + if len(self.params) == 0: + if self.method == "-11" or self.method == "01": + if self.dim is None: + self.params = paddle.max(x=data), paddle.min(x=data) + else: + self.params = ( + paddle.max(keepdim=True, x=data, axis=self.dim), + paddle.argmax(keepdim=True, x=data, axis=self.dim), + )[0], ( + paddle.min(keepdim=True, x=data, axis=self.dim), + paddle.argmin(keepdim=True, x=data, axis=self.dim), + )[ + 0 + ] + elif self.method == "ms": + if self.dim is None: + self.params = paddle.mean(x=data, axis=self.dim), paddle.std( + x=data, axis=self.dim + ) + else: + self.params = paddle.mean( + x=data, axis=self.dim, keepdim=True + ), paddle.std(x=data, axis=self.dim, keepdim=True) + elif self.method == "none": + self.params = None + return self.fnormalize(data, self.params, self.method) + + def normalize(self, new_data): + if not new_data.place == self.params[0].place: + self.params = self.params[0].to(new_data.place), self.params[1].to( + new_data.place + ) + return self.fnormalize(new_data, self.params, self.method) + + def denormalize(self, new_data_norm): + if not new_data_norm.place == self.params[0].place: + self.params = self.params[0].to(new_data_norm.place), self.params[1].to( + new_data_norm.place + ) + return self.fdenormalize(new_data_norm, self.params, self.method) + + def get_params(self): + if self.method == "ms": + print("returning mean and std") + elif self.method == "01": + print("returning max and min") + elif self.method == "-11": + print("returning max and min") + elif self.method == "none": + print("do nothing") + return self.params + + @staticmethod + def fnormalize(data, params, method): + if method == "-11": + return (data - params[1].to(data.place)) / ( + params[0].to(data.place) - params[1].to(data.place) + ) * 2 - 1 + elif method == "01": + return (data - params[1].to(data.place)) / ( + params[0].to(data.place) - params[1].to(data.place) + ) + elif method == "ms": + return (data - params[0].to(data.place)) / params[1].to(data.place) + elif method == "none": + return data + + @staticmethod + def fdenormalize(data_norm, params, method): + if method == "-11": + return (data_norm + 1) / 2 * ( + params[0].to(data_norm.place) - params[1].to(data_norm.place) + ) + params[1].to(data_norm.place) + elif method == "01": + return data_norm * ( + params[0].to(data_norm.place) - params[1].to(data_norm.place) + ) + params[1].to(data_norm.place) + elif method == "ms": + return data_norm * params[1].to(data_norm.place) + params[0].to( + data_norm.place + ) + elif method == "none": + return data_norm +``` + +### 4.2 CoNFiLD 模型 +CoNFiLD 模型基于贝叶斯后验采样,将稀疏传感器测量数据作为条件输入。通过训练好的无条件扩散模型作为先验,在扩散后验采样过程中,考虑测量噪声引入的不确定性。利用状态到观测映射,根据条件向量与流场的关系,通过调整无条件得分函数,引导生成与传感器数据一致的全时空流场实现重构,并且能提供重构的不确定性估计。代码如下: + +```python +class SIRENAutodecoder_film(paddle.nn.Layer): + """ + siren network with author decoding + + Args: + input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict. + output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict. + in_coord_features (int, optional): Number of input coordinates features + in_latent_features (int, optional): Number of input latent features + out_features (int, optional): Number of output features + num_hidden_layers (int, optional): Number of hidden layers + hidden_features (int, optional): Number of hidden features + outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False. + nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine". + weight_init (Callable, optional): Weight initialization function. Defaults to None. + bias_init (Callable, optional): Bias initialization function. Defaults to None. + premap_mode (str, optional): Feature mapping mode. Defaults to None. + + Examples: + >>> model = ppsci.arch.SIRENAutodecoder_film( + input_keys=["input1", "input2"], + output_keys=("output",), + in_coord_features=2, + in_latent_features=128, + out_features=3, + num_hidden_layers=10, + hidden_features=128, + ) + >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])} + >>> out_dict = model(input_data) + >>> for k, v in out_dict.items(): + ... print(k, v.shape) + output [22, 918, 3] + """ + + def __init__( + self, + input_keys, + output_keys, + in_coord_features, + in_latent_features, + out_features, + num_hidden_layers, + hidden_features, + outermost_linear=False, + nonlinearity="sine", + weight_init=None, + bias_init=None, + premap_mode=None, + **kwargs, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.premap_mode = premap_mode + if self.premap_mode is not None: + self.premap_layer = FeatureMapping( + in_coord_features, mode=premap_mode, **kwargs + ) + in_coord_features = self.premap_layer.dim + self.first_layer_init = None + self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity] + if weight_init is not None: + self.weight_init = weight_init + else: + self.weight_init = nl_weight_init + self.net1 = paddle.nn.LayerList( + sublayers=[BatchLinear(in_coord_features, hidden_features)] + + [ + BatchLinear(hidden_features, hidden_features) + for i in range(num_hidden_layers) + ] + + [BatchLinear(hidden_features, out_features)] + ) + self.net2 = paddle.nn.LayerList( + sublayers=[ + BatchLinear(in_latent_features, hidden_features, bias_attr=False) + for i in range(num_hidden_layers + 1) + ] + ) + if self.weight_init is not None: + self.net1.apply(self.weight_init) + self.net2.apply(self.weight_init) + if first_layer_init is not None: + self.net1[0].apply(first_layer_init) + self.net2[0].apply(first_layer_init) + if bias_init is not None: + self.net2.apply(bias_init) + + def forward(self, input_data): + coords = input_data[self.input_keys[0]] + latents = input_data[self.input_keys[1]] + if self.premap_mode is not None: + x = self.premap_layer(coords) + else: + x = coords + + for i in range(len(self.net1) - 1): + x = self.net1[i](x) + self.net2[i](latents) + x = self.nl(x) + x = self.net1[-1](x) + return {self.output_keys[0]: x} + + def disable_gradient(self): + for param in self.parameters(): + param.stop_gradient = not False +``` +为了在计算时,准确快速地访问具体变量的值,我们在这里指定网络模型的输入变量名是 ["confild_x", "latent_z"],输出变量名是 ["confild_output"],这些命名与后续代码保持一致。 + +4.3 模型训练、评估 +完成上述设置之后,只需要将上述实例化的对象按照文档进行组合,然后启动训练、评估。 +```python +def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer): + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + latents_model = LatentContainer(**cfg.Latent) + + dataset = basic_set(normed_fois, normed_coords) + criterion = paddle.nn.MSELoss() + + # set loader + train_loader = DataLoader( + dataset=dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True + ) + test_loader = DataLoader( + dataset=dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False + ) + # set optimizer + cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model) + latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)( + latents_model + ) + + for i in range(cfg.TRAIN.epochs): + cnf_model.train() + latents_model.train() + if i != 0: + cnf_optimizer.step() + cnf_optimizer.clear_grad(set_to_zero=False) + train_loss = [] + for batch_coords, batch_fois, idx in train_loader: + idx = {"latent_x": idx} + batch_latent = latents_model(idx) + if isinstance(batch_coords, list): + batch_coords = [i for i in batch_coords] + data = { + "confild_x": batch_coords, + "latent_z": batch_latent["latent_z"], + } + batch_output = cnf_model(data) + loss = criterion(batch_output["confild_output"], batch_fois) + latents_optimizer.clear_grad(set_to_zero=False) + loss.backward() + latents_optimizer.step() + train_loss.append(loss.item()) + epoch_loss = paddle.stack(x=train_loss).mean() + print("epoch {}, train loss {}".format(i + 1, epoch_loss)) + if i % 100 == 0: + test_error = [] + cnf_model.eval() + latents_model.eval() + with paddle.no_grad(): + for test_coords, test_fois, idx in test_loader: + if isinstance(test_coords, list): + test_coords = [i for i in test_coords] + prediction = out_normalizer.denormalize( + cnf_model( + { + "confild_x": test_coords, + "latent_z": latents_model({"latent_x": idx})[ + "latent_z" + ], + } + ) + ) + target = out_normalizer.denormalize(test_fois) + error = rMAE(prediction=prediction, target=target, dims=spatio_axis) + test_error.append(error) + test_error = paddle.concat(x=test_error).mean(axis=0) + print("test MAE: ", test_error) + if i % 1000 == 0: + paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") + paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") +``` + +## 5. 实验结果 diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml new file mode 100644 index 0000000000..522cda9b98 --- /dev/null +++ b/examples/confild/conf/confild_case1.yaml @@ -0,0 +1,124 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_confild_case1 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 +alis: False + +TRAIN: + batch_size: 64 + test_batch_size: 256 + epochs: 9800 + mutil_GPU: 1 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case1/confild_case1/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case1/latent_case1/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 3 + hidden_features: 128 + in_coord_features: 2 + in_latent_features: 128 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 16000 + lumped: True + N_features: 128 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case1 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case1 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [918, 2] + latents_shape: [1, 128] + log_freq: 20 + batch_size: 64 + +Uncondiction_INFER: + batch_size : 16 + test_batch_size : 16 + time_length : 128 + latent_length : 128 + image_size : 128 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + steps: 1000 + noise_schedule: "cosine" + +Data: + data_path: ../case1/data.npy + coor_path: ../case1/coor.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_elbow_flow diff --git a/examples/confild/conf/confild_case2.yaml b/examples/confild/conf/confild_case2.yaml new file mode 100644 index 0000000000..26b0f3c217 --- /dev/null +++ b/examples/confild/conf/confild_case2.yaml @@ -0,0 +1,108 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case2/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_confild_case2 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 40 + test_batch_size: 40 + epochs: 44500 + mutil_GPU: 1 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case2/confild_case2/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case2/latent_case2/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 4 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 1200 + lumped: False + N_features: 256 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case2 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case2 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [400, 100, 2] + latents_shape: [1, 1, 256] + log_freq: 20 + batch_size: 40 + +Data: + data_path: ../case2/data.npy + coor_path: ../case2/coor.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_channel_flow diff --git a/examples/confild/conf/confild_case3.yaml b/examples/confild/conf/confild_case3.yaml new file mode 100644 index 0000000000..2b7ea04ff6 --- /dev/null +++ b/examples/confild/conf/confild_case3.yaml @@ -0,0 +1,108 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case3/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_confild_case3 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 100 + test_batch_size: 100 + epochs: 4800 + mutil_GPU: 2 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case3/confild_case3/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case3/latent_case3/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 117 + out_features: 2 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 2880 + lumped: True + N_features: 256 + dims: 2 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case3 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case3 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [10884, 2] + latents_shape: [1, 256] + log_freq: 20 + batch_size: 100 + +Data: + data_path: ../case3/data.npy + coor_path: ../case3/coor.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_periodic_hill_flow diff --git a/examples/confild/conf/confild_case4.yaml b/examples/confild/conf/confild_case4.yaml new file mode 100644 index 0000000000..3e1491a3f1 --- /dev/null +++ b/examples/confild/conf/confild_case4.yaml @@ -0,0 +1,108 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case4/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_confild_case4 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +TRAIN: + batch_size: 4 + test_batch_size: 4 + epochs: 20000 + mutil_GPU: 2 + lr: + cnf: 1.e-4 + latents: 1.e-5 + +EVAL: + confild_pretrained_model_path: ./outputs_confild_case4/confild_case4/epoch_99999 + latent_pretrained_model_path: ./outputs_confild_case4/latent_case4/epoch_99999 + +CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 15 + out_features: 3 + hidden_features: 384 + in_coord_features: 3 + in_latent_features: 384 + +Latent: + input_keys: ["latent_x"] + output_keys: ["latent_z"] + N_samples: 1200 + lumped: True + N_features: 384 + dims: 3 + +INFER: + Latent: + INFER: + pretrained_model_path: null + export_path: ./inference/latent_case4 + pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams + onnx_path: ${INFER.Latent.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + log_freq: 20 + Confild: + INFER: + pretrained_model_path: null + export_path: ./inference/confild_case4 + pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel + pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams + onnx_path: ${INFER.Confild.INFER.export_path}.onnx + device: gpu + engine: native + precision: fp32 + ir_optim: true + min_subgraph_size: 5 + gpu_mem: 2000 + gpu_id: 0 + max_batch_size: 1024 + num_cpu_threads: 10 + coord_shape: [58483, 3] + latents_shape: [1, 384] + log_freq: 20 + batch_size: 4 + +Data: + data_path: ../case4/data.npy + coor_path: ../case4/coor.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_3d_flow diff --git a/examples/confild/confild.py b/examples/confild/confild.py new file mode 100644 index 0000000000..098b741d3d --- /dev/null +++ b/examples/confild/confild.py @@ -0,0 +1,862 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import math +import hydra +import matplotlib.pyplot as plt +import numpy as np +import paddle +from omegaconf import DictConfig +from paddle.distributed import fleet +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +import ppsci +from ppsci.arch import UNetModel +from ppsci.arch import LatentContainer +from ppsci.arch import SIRENAutodecoder_film +from ppsci.arch import SpacedDiffusion +from ppsci.arch import ModelVarType +from ppsci.arch import ModelMeanType +from ppsci.utils import logger + +def load_elbow_flow(path): + return np.load(f"{path}")[1:] + + +def load_channel_flow( + path, + t_start=0, + t_end=1200, + t_every=1, +): + return np.load(f"{path}")[t_start:t_end:t_every] + + +def load_periodic_hill_flow(path): + data = np.load(f"{path}") + return data + + +def load_3d_flow(path): + data = np.load(f"{path}") + return data + + +def rMAE(prediction, target, dims=(1, 2)): + return paddle.abs(x=prediction - target).mean(axis=dims) / paddle.abs( + x=target + ).mean(axis=dims) + + +class Normalizer_ts(object): + def __init__(self, params=[], method="-11", dim=None): + self.params = params + self.method = method + self.dim = dim + + def fit_normalize(self, data): + assert type(data) == paddle.Tensor + if len(self.params) == 0: + if self.method == "-11" or self.method == "01": + if self.dim is None: + self.params = paddle.max(x=data), paddle.min(x=data) + else: + self.params = ( + paddle.max(keepdim=True, x=data, axis=self.dim), + paddle.argmax(keepdim=True, x=data, axis=self.dim), + )[0], ( + paddle.min(keepdim=True, x=data, axis=self.dim), + paddle.argmin(keepdim=True, x=data, axis=self.dim), + )[ + 0 + ] + elif self.method == "ms": + if self.dim is None: + self.params = paddle.mean(x=data, axis=self.dim), paddle.std( + x=data, axis=self.dim + ) + else: + self.params = paddle.mean( + x=data, axis=self.dim, keepdim=True + ), paddle.std(x=data, axis=self.dim, keepdim=True) + elif self.method == "none": + self.params = None + return self.fnormalize(data, self.params, self.method) + + def normalize(self, new_data): + if not new_data.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fnormalize(new_data, self.params, self.method) + + def denormalize(self, new_data_norm): + if not new_data_norm.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fdenormalize(new_data_norm, self.params, self.method) + + def get_params(self): + if self.method == "ms": + print("returning mean and std") + elif self.method == "01": + print("returning max and min") + elif self.method == "-11": + print("returning max and min") + elif self.method == "none": + print("do nothing") + return self.params + + @staticmethod + def fnormalize(data, params, method): + if method == "-11": + return (data - params[1]) / ( + params[0] - params[1] + ) * 2 - 1 + elif method == "01": + return (data - params[1]) / ( + params[0] - params[1] + ) + elif method == "ms": + return (data - params[0]) / params[1] + elif method == "none": + return data + + @staticmethod + def fdenormalize(data_norm, params, method): + if method == "-11": + return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] + elif method == "01": + return data_norm * ( + params[0] - params[1] + ) + params[1] + elif method == "ms": + return data_norm * params[1] + params[0] + elif method == "none": + return data_norm + + +# build data +def getdata(cfg): + ###### read data - fois ###### + if cfg.Data.load_data_fn == "load_3d_flow": + input_data = load_3d_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_elbow_flow": + input_data = load_elbow_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_channel_flow": + input_data = load_channel_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_periodic_hill_flow": + input_data = load_periodic_hill_flow(cfg.Data.data_path) + else: + input_data = np.load(cfg.Data.data_path) + + spatio_shape = input_data.shape[1:-1] + spatio_axis = list( + range( + input_data.ndim if isinstance(input_data, np.ndarray) else input_data.dim() + ) + )[1:-1] + + ###### read data - coordinate ###### + if cfg.Data.coor_path is None: + coord = [np.linspace(0, 1, i) for i in spatio_shape] + coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1) + else: + coord = np.load(cfg.Data.coor_path) + coord = coord.astype("float32") + input_data = input_data.astype("float32") + + ###### convert to tensor ###### + input_data = ( + paddle.to_tensor(input_data) + if not isinstance(input_data, paddle.Tensor) + else input_data + ) + coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord + N_samples = input_data.shape[0] + + ###### normalizer ###### + in_normalizer = Normalizer_ts(**cfg.Data.normalizer) + in_normalizer.fit_normalize( + coord if cfg.Latent.lumped else coord.flatten(0, cfg.Latent.dims - 1) + ) + out_normalizer = Normalizer_ts(**cfg.Data.normalizer) + out_normalizer.fit_normalize( + input_data if cfg.Latent.lumped else input_data.flatten(0, cfg.Latent.dims) + ) + normed_coords = in_normalizer.normalize(coord) + normed_fois = out_normalizer.normalize(input_data) + + return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer + ###### 添加数据集划分 ###### + # split_ratio = cfg.Data.get("split_ratio", 0.8) # 默认为80%训练集 + # seed = cfg.Data.get("shuffle_seed", 42) # 随机种子 + + # # 生成随机索引并划分 + # np.random.seed(seed) + # total_samples = N_samples + # indices = np.random.permutation(total_samples) + # split_idx = int(total_samples * split_ratio) + + # # 划分训练集和测试集 + # train_indices = indices[:split_idx] + # test_indices = indices[split_idx:] + + # # 根据索引获取训练集和测试集数据 + # train_normed_fois = normed_fois[train_indices] + # test_normed_fois = normed_fois[test_indices] + + # return ( + # normed_coords, + # train_normed_fois, # 训练集数据 + # test_normed_fois, # 测试集数据 + # spatio_axis, + # out_normalizer, + # train_indices, # 训练集索引(用于latent模型) + # test_indices # 测试集索引 + # ) + + +class basic_set(paddle.io.Dataset): + def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> None: + super().__init__() + self.fois = fois.numpy() + self.total_samples = tuple(fois.shape)[0] + self.coords = coord.numpy() + # 存储全局索引 + self.global_indices = global_indices if global_indices is not None else np.arange(self.total_samples) + + def __len__(self): + return self.total_samples + + def __getitem__(self, idx): + # 使用全局索引 + global_idx = self.global_indices[idx] + if hasattr(self, "extra_in"): + extra_id = idx % tuple(self.fois.shape)[1] + idb = idx // tuple(self.fois.shape)[1] + return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], global_idx + else: + return self.coords, self.fois[idx], global_idx + + +def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices): + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + latents_model = LatentContainer(**cfg.Latent) + + # 创建训练集和测试集,传入全局索引 + train_dataset = basic_set(train_normed_fois, normed_coords, train_indices) + test_dataset = basic_set(test_normed_fois, normed_coords, test_indices) + + criterion = paddle.nn.MSELoss() + + # set loader + train_loader = DataLoader( + dataset=train_dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True + ) + test_loader = DataLoader( + dataset=test_dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False + ) + # set optimizer + cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model) + latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)( + latents_model + ) + losses = [] + + for i in range(cfg.TRAIN.epochs): + cnf_model.train() + latents_model.train() + if i != 0: + cnf_optimizer.step() + cnf_optimizer.clear_grad(set_to_zero=False) + train_loss = [] + for batch_coords, batch_fois, idx in train_loader: + idx = {"latent_x": idx} + batch_latent = latents_model(idx) + if isinstance(batch_coords, list): + batch_coords = [i for i in batch_coords] + data = { + "confild_x": batch_coords, + "latent_z": batch_latent["latent_z"], + } + batch_output = cnf_model(data) + loss = criterion(batch_output["confild_output"], batch_fois) + latents_optimizer.clear_grad(set_to_zero=False) + loss.backward() + latents_optimizer.step() + train_loss.append(loss) + epoch_loss = paddle.stack(x=train_loss).mean().item() + losses.append(epoch_loss) + print("epoch {}, train loss {}".format(i + 1, epoch_loss)) + if i % 100 == 0: + test_error = [] + cnf_model.eval() + latents_model.eval() + with paddle.no_grad(): + for test_coords, test_fois, idx in test_loader: + if isinstance(test_coords, list): + test_coords = [i for i in test_coords] + prediction = out_normalizer.denormalize( + cnf_model( + { + "confild_x": test_coords, + "latent_z": latents_model({"latent_x": idx})[ + "latent_z" + ], + } + )["confild_output"] + ) + target = out_normalizer.denormalize(test_fois) + error = rMAE(prediction=prediction, target=target, dims=spatio_axis) + test_error.append(error) + test_error = paddle.concat(x=test_error).mean(axis=0) + print("test MAE: ", test_error) + if i % 100 == 0: + paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") + paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") + # 绘制损失图 + plt.figure(figsize=(10, 6)) + plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") + + # 添加标题和标签 + plt.title("Training Loss over Epochs") + plt.xlabel("Epochs") + plt.xticks(rotation=45) + plt.ylabel("Loss") + + # 添加图例 + plt.legend() + + # 显示网格线 + plt.grid(True) + + # 保存为 PNG 格式 + plt.savefig("case.png") + + # 显示图形 + plt.show() + + +def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices): + fleet.init(is_collective=True) + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + cnf_model = fleet.distributed_model(cnf_model) + latents_model = LatentContainer(**cfg.Latent) + latents_model = fleet.distributed_model(latents_model) + + # set optimizer + cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model) + cnf_optimizer = fleet.distributed_optimizer(cnf_optimizer) + latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)( + latents_model + ) + latents_optimizer = fleet.distributed_optimizer(latents_optimizer) + + # 创建训练集和测试集,传入全局索引 + train_dataset = basic_set(train_normed_fois, normed_coords, train_indices) + test_dataset = basic_set(test_normed_fois, normed_coords, test_indices) + + train_sampler = DistributedBatchSampler( + train_dataset, cfg.TRAIN.batch_size, shuffle=True, drop_last=True + ) + train_loader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=cfg.TRAIN.mutil_GPU, + use_shared_memory=False, + ) + test_sampler = DistributedBatchSampler( + test_dataset, cfg.TRAIN.test_batch_size, drop_last=True + ) + test_loader = DataLoader( + test_dataset, + batch_sampler=test_sampler, + num_workers=cfg.TRAIN.mutil_GPU, + use_shared_memory=False, + ) + + criterion = paddle.nn.MSELoss() + losses = [] + + for i in range(cfg.TRAIN.epochs): + cnf_model.train() + latents_model.train() + if i != 0: + cnf_optimizer.step() + cnf_optimizer.clear_grad(set_to_zero=False) + train_loss = [] + for batch_coords, batch_fois, idx in train_loader: + idx = {"latent_x": idx} + batch_latent = latents_model(idx) + if isinstance(batch_coords, list): + batch_coords = [i for i in batch_coords] + data = { + "confild_x": batch_coords, + "latent_z": batch_latent["latent_z"], + } + batch_output = cnf_model(data) + loss = criterion(batch_output["confild_output"], batch_fois) + latents_optimizer.clear_grad(set_to_zero=False) + loss.backward() + latents_optimizer.step() + train_loss.append(loss) + epoch_loss = paddle.stack(x=train_loss).mean().item() + losses.append(epoch_loss) + print("epoch {}, train loss {}".format(i + 1, epoch_loss)) + if i % 100 == 0: + test_error = [] + cnf_model.eval() + latents_model.eval() + with paddle.no_grad(): + for test_coords, test_fois, idx in test_loader: + if isinstance(test_coords, list): + test_coords = [i for i in test_coords] + prediction = out_normalizer.denormalize( + cnf_model( + { + "confild_x": test_coords, + "latent_z": latents_model({"latent_x": idx})[ + "latent_z" + ], + } + )["confild_output"] + ) + target = out_normalizer.denormalize(test_fois) + error = rMAE(prediction=prediction, target=target, dims=spatio_axis) + test_error.append(error) + test_error = paddle.concat(x=test_error).mean(axis=0) + print("test MAE: ", test_error) + if i % 100 == 0: + paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") + paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") + # 绘制损失图 + plt.figure(figsize=(10, 6)) + plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") + + # 添加标题和标签 + plt.title("Training Loss over Epochs") + plt.xlabel("Epochs") + plt.xticks(rotation=45) + plt.ylabel("Loss") + + # 添加图例 + plt.legend() + + # 显示网格线 + plt.grid(True) + + # 保存为 PNG 格式 + plt.savefig("case.png") + + # 显示图形 + plt.show() + + +def train(cfg): + # 获取分割后的数据集 + # (normed_coords, + # train_normed_fois, + # test_normed_fois, + # spatio_axis, + # out_normalizer, + # train_indices, + # test_indices) = getdata(cfg) + normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) + train_normed_fois = normed_fois + test_normed_fois = normed_fois + train_indices = list(range(N_samples)) + test_indices = list(range(N_samples)) + + if cfg.TRAIN.mutil_GPU > 1: + import paddle.distributed as dist + dist.init_parallel_env() + mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, + spatio_axis, out_normalizer, train_indices, test_indices) + else: + signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, + spatio_axis, out_normalizer, train_indices, test_indices) + + +def evaluate(cfg: DictConfig): + # set data + # normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) + normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg) + + if len(normed_coords.shape) + 1 == len(normed_fois.shape): + normed_coords = paddle.tile( + normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape) + ) + + idx = paddle.to_tensor( + np.array([i for i in range(normed_fois.shape[0])]), dtype="int64" + ) + # set model + confild = SIRENAutodecoder_film(**cfg.CONFILD) + latent = LatentContainer(**cfg.Latent) + logger.info( + "Loading pretrained model from {}".format( + cfg.EVAL.confild_pretrained_model_path + ) + ) + ppsci.utils.save_load.load_pretrain( + confild, + cfg.EVAL.confild_pretrained_model_path, + ) + logger.info( + "Loading pretrained model from {}".format(cfg.EVAL.latent_pretrained_model_path) + ) + ppsci.utils.save_load.load_pretrain( + latent, + cfg.EVAL.latent_pretrained_model_path, + ) + latent_test_pred = latent({"latent_x": idx}) + y_test_pred = [] + for i in range(normed_coords.shape[0]): + y_test_pred.append( + confild( + { + "confild_x": normed_coords[i], + "latent_z": latent_test_pred["latent_z"][i], + } + )["confild_output"].numpy() + ) + y_test_pred = paddle.to_tensor(np.array(y_test_pred)) + + y_test_pred = out_normalizer.denormalize(y_test_pred) + y_test = out_normalizer.denormalize(normed_fois) + + logger.info("Result is {}".format(y_test.numpy())) + + +def inference(cfg): + # 获取分割后的数据集 + normed_coords, normed_fois, _, _, _ = getdata(cfg) + if len(normed_coords.shape) + 1 == len(normed_fois.shape): + normed_coords = paddle.tile( + normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape) + ) + + fois_len = normed_fois.shape[0] + idxs = np.array([i for i in range(fois_len)]) + from deploy import python_infer + + latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent) + input_dict = {"latent_x": idxs} + output_dict = latent_predictor.predict(input_dict, cfg.INFER.batch_size) + + cnf_predictor = python_infer.GeneralPredictor(cfg.INFER.Confild) + input_dict = { + "confild_x": normed_coords.numpy(), + "latent_z": list(output_dict.values())[0], + } + output_dict = cnf_predictor.predict(input_dict, cfg.INFER.batch_size) + + logger.info("Result is {}".format(output_dict["fetch_name_0"])) + + +def uncondiction_infer(cfg): + test_batch_size = cfg.Uncondiction_INFER.test_batch_size + time_length = cfg.Uncondiction_INFER.time_length + latent_length = cfg.Uncondiction_INFER.latent_length + image_size = cfg.Uncondiction_INFER.image_size + num_channels = cfg.Uncondiction_INFER.num_channels + num_res_blocks = cfg.Uncondiction_INFER.num_res_blocks + num_heads = cfg.Uncondiction_INFER.num_heads + num_head_channels = cfg.Uncondiction_INFER.num_head_channels + attention_resolutions = cfg.Uncondiction_INFER.attention_resolutions + steps = cfg.Uncondiction_INFER.steps + noise_schedule = cfg.Uncondiction_INFER.noise_schedule + + unet_model = create_model( + image_size=image_size, + num_channels=num_channels, + num_res_blocks=num_res_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + attention_resolutions=attention_resolutions, + ) + # ppsci.utils.save_load.load_pretrain( + # unet_model, + # # cfg.Uncondiction_INFER.ema_path, + # "/home/aistudio/ema_0.9999_550000.pdparams", + # ) + diff_model = create_gaussian_diffusion(steps=steps, noise_schedule=noise_schedule) + sample_fn = diff_model.p_sample_loop + gen_latents = sample_fn(unet_model, (test_batch_size, 1, time_length, latent_length))[ + :, 0 + ] + max_val, min_val = np.load("/home/aistudio/data_max.npy"), np.load("/home/aistudio/data_min.npy") + # max_val, min_val = np.load(cfg.Uncondiction_INFER.max_val), np.load(cfg.Uncondiction_INFER.min_val) + max_val, min_val = paddle.to_tensor(data=max_val), paddle.to_tensor(data=min_val) + gen_latents = (gen_latents + 1) * (max_val - min_val) / 2.0 + min_val + # 加载cnf模型 + print("加载cnf模型") + confild = SIRENAutodecoder_film(**cfg.CONFILD) + ppsci.utils.save_load.load_pretrain( + confild, + "https://dataset.bj.bcebos.com/PaddleScience/CoNFiLD/cnf_model_9700.pdparams",# cfg.EVAL.confild_pretrained_model_path, + ) + confild.eval() + coord = paddle.to_tensor(np.load("/home/aistudio/data/data321897/case1_coords.npy"), dtype="float32")#(np.load(f"{cfg.Data.coor_path}"), dtype='float32') + batch_size = 1 + n_samples = tuple(gen_latents.shape)[0] + out_normalizer = Normalizer_ts(**cfg.Data.normalizer) + + gen_fields = [] + print("开始生成") + for sample_index in range(n_samples): + print("第{}个样本", sample_index) + for i in range(tuple(gen_latents.shape)[1] // batch_size): + input_dict = { + "confild_x": coord, + "latent_z": gen_latents[sample_index, i * batch_size : (i + 1) * batch_size], + } + confild_output = confild(input_dict) + # print(confild_output) + gen_fields.append(out_normalizer.denormalize(confild_output["confild_output"]).detach() + .cpu() + .numpy()) + gen_fields = np.concatenate(gen_fields) + np.save("./", gen_fields)#cfg.Uncondiction_INFER.save_path + + +class LossType(enum.Enum): + MSE = enum.auto() + RESCALED_MSE = enum.auto() + KL = enum.auto() + RESCALED_KL = enum.auto() + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + if schedule_name == "linear": + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def space_timesteps(num_timesteps, section_counts): + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = LossType.RESCALED_MSE + else: + loss_type = LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=ModelMeanType.EPSILON + if not predict_xstart + else ModelMeanType.START_X, + model_var_type=( + ModelVarType.FIXED_LARGE + if not sigma_small + else ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else ModelVarType.LEARNED_RANGE, + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +NUM_CLASSES = 1000 + + +def create_model( + image_size, + num_channels, + num_res_blocks, + dims=2, + out_channels=1, + channel_mult=None, + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult is None: + if image_size == 512: + channel_mult = 0.5, 1, 1, 2, 2, 4, 4 + elif image_size == 256: + channel_mult = 1, 1, 2, 2, 4, 4 + elif image_size == 128: + channel_mult = 1, 1, 2, 3, 4 + elif image_size == 64: + channel_mult = 1, 2, 3, 4 + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + return UNetModel( + image_size=image_size, + in_channels=out_channels, + model_channels=num_channels, + out_channels=out_channels if not learn_sigma else 2 * out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=NUM_CLASSES if class_cond else None, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + dims=dims, + ) + + +def export(cfg): + # set model + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + latent_model = LatentContainer(**cfg.Latent) + # initialize solver + latnet_solver = ppsci.solver.Solver( + latent_model, + pretrained_model_path=cfg.INFER.Latent.INFER.pretrained_model_path, + ) + cnf_solver = ppsci.solver.Solver( + cnf_model, + pretrained_model_path=cfg.INFER.Confild.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None], "int64", name=key) for key in latent_model.input_keys}, + ] + cnf_input_spec = [ + { + cnf_model.input_keys[0]: InputSpec( + [None] + list(cfg.INFER.Confild.INFER.coord_shape), + "float32", + name=cnf_model.input_keys[0], + ), + cnf_model.input_keys[1]: InputSpec( + [None] + list(cfg.INFER.Confild.INFER.latents_shape), + "float32", + name=cnf_model.input_keys[1], + ), + } + ] + cnf_solver.export(cnf_input_spec, cfg.INFER.Confild.INFER.export_path) + latnet_solver.export(input_spec, cfg.INFER.Latent.INFER.export_path) + + +@hydra.main(version_base=None, config_path="./conf", config_name="confild_case1.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "infer": + if cfg.alis == False: + inference(cfg) + else: + uncondiction_infer(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "uncondition_infer": + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'infer', 'export', 'uncondition_infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 78f381b68e..346ed22b3d 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -22,6 +22,7 @@ from ppsci.arch.amgnet import AMGNet # isort:skip from ppsci.arch.base import Arch # isort:skip from ppsci.arch.cfdgcn import CFDGCN # isort:skip +from ppsci.arch.confild import LatentContainer, SIRENAutodecoder_film, SpacedDiffusion, UNetModel, ModelVarType, ModelMeanType # isort:skip from ppsci.arch.smc_reac import SuzukiMiyauraModel # isort:skip from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip @@ -98,11 +99,14 @@ "GraphCastNet", "HEDeepONets", "LorenzEmbedding", + "LatentContainer", "LatentNO", "LatentNO_time", "LNO", "MLP", "ModelList", + "ModelVarType", + "ModelMeanType", "ModifiedMLP", "NowcastNet", "PhyCRNet", @@ -111,12 +115,15 @@ "PrecipNet", "RosslerEmbedding", "SFNONet", + "SIRENAutodecoder_film", + "SpacedDiffusion", "SPINN", "TFNO1dNet", "TFNO2dNet", "TFNO3dNet", "Transformer", "UNetEx", + "UNetModel", "UNONet", "USCNN", "VelocityDiscriminator", diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py new file mode 100644 index 0000000000..3965080832 --- /dev/null +++ b/ppsci/arch/confild.py @@ -0,0 +1,1385 @@ +import math +import enum +from collections import OrderedDict +from abc import abstractmethod +import numpy as np +import paddle + +DEFAULT_W0 = 30.0 + + +class Swish(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.Sigmoid = paddle.nn.Sigmoid() + + def forward(self, x): + return x * self.Sigmoid(x) + + +class Sine(paddle.nn.Layer): + def __init__(self, w0=DEFAULT_W0): + self.w0 = w0 + super().__init__() + + def forward(self, input): + return paddle.sin(x=self.w0 * input) + + +def sine_init(m, w0=DEFAULT_W0): + with paddle.no_grad(): + if hasattr(m, "weight"): + num_input = m.weight.shape[-1] + m.weight.uniform_( + min=-math.sqrt(6 / num_input) / w0, max=math.sqrt(6 / num_input) / w0 + ) + + +def first_layer_sine_init(m): + with paddle.no_grad(): + if hasattr(m, "weight"): + num_input = m.weight.shape[-1] + m.weight.uniform_(min=-1 / num_input, max=1 / num_input) + + +def __check_Linear_weight(m): + if isinstance(m, paddle.nn.Linear): + if hasattr(m, "weight"): + return True + return False + + +def init_weights_normal(m): + if __check_Linear_weight(m): + init_KaimingNormal = paddle.nn.initializer.KaimingNormal( + nonlinearity="relu", negative_slope=0.0 + ) + init_KaimingNormal(m.weight) + + +def init_weights_selu(m): + if __check_Linear_weight(m): + num_input = m.weight.shape[-1] + init_Normal = paddle.nn.initializer.Normal(std=1 / math.sqrt(num_input)) + init_Normal(m.weight) + + +def init_weights_elu(m): + if __check_Linear_weight(m): + num_input = m.weight.shape[-1] + init_Normal = paddle.nn.initializer.Normal( + std=math.sqrt(1.5505188080679277) / math.sqrt(num_input) + ) + init_Normal(m.weight) + + +def init_weights_xavier(m): + if __check_Linear_weight(m): + init_XavierNormal = paddle.nn.initializer.XavierNormal() + init_XavierNormal(m.weight) + + +NLS_AND_INITS = { + "sine": (Sine(), sine_init, first_layer_sine_init), + "relu": (paddle.nn.ReLU(), init_weights_normal, None), + "sigmoid": (paddle.nn.Sigmoid(), init_weights_xavier, None), + "tanh": (paddle.nn.Tanh(), init_weights_xavier, None), + "selu": (paddle.nn.SELU(), init_weights_selu, None), + "softplus": (paddle.nn.Softplus(), init_weights_normal, None), + "elu": (paddle.nn.ELU(), init_weights_elu, None), + "swish": (Swish(), init_weights_xavier, None), +} + + +class BatchLinear(paddle.nn.Linear): + """ + This is a linear transformation implemented manually. It also allows maually input parameters. + for initialization, (in_features, out_features) needs to be provided. + weight is of shape (out_features*in_features) + bias is of shape (out_features) + + """ + + __doc__ = paddle.nn.Linear.__doc__ + + def forward(self, input, params=None): + if params is None: + params = OrderedDict(self.named_parameters()) + bias = params.get("bias", None) + weight = params["weight"] + + output = paddle.matmul(x=input, y=weight) + if bias is not None: + output += bias.unsqueeze(axis=-2) + return output + + +class FeatureMapping: + """ + This is feature mapping class for fourier feature networks + """ + + def __init__( + self, + in_features, + mode="basic", + gaussian_mapping_size=256, + gaussian_rand_key=0, + gaussian_tau=1.0, + pe_num_freqs=4, + pe_scale=2, + pe_init_scale=1, + pe_use_nyquist=True, + pe_lowest_dim=None, + rbf_out_features=None, + rbf_range=1.0, + rbf_std=0.5, + ): + """ + inputs: + in_freatures: number of input features + mapping_size: output features for Gaussian mapping + rand_key: random key for Gaussian mapping + tau: standard deviation for Gaussian mapping + num_freqs: number of frequencies for P.E. + scale = 2: base scale of frequencies for P.E. + init_scale: initial scale for P.E. + use_nyquist: use nyquist to calculate num_freqs or not. + + """ + self.mode = mode + if mode == "basic": + self.B = np.eye(in_features) + elif mode == "gaussian": + rng = np.random.default_rng(gaussian_rand_key) + self.B = rng.normal( + loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features) + ) + elif mode == "positional": + if pe_use_nyquist == "True" and pe_lowest_dim: + pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim) + self.B = pe_init_scale * np.vstack( + [(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)] + ) + self.dim = tuple(self.B.shape)[0] * 2 + elif mode == "rbf": + self.centers = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.empty( + shape=(rbf_out_features, in_features), dtype="float32" + ) + ) + self.sigmas = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.empty(shape=rbf_out_features, dtype="float32") + ) + init_Uniform = paddle.nn.initializer.Uniform( + low=-1 * rbf_range, high=rbf_range + ) + init_Uniform(self.centers) + init_Constant = paddle.nn.initializer.Constant(value=rbf_std) + init_Constant(self.sigmas) + + def __call__(self, input): + if self.mode in ["basic", "gaussian", "positional"]: + return self.fourier_mapping(input, self.B) + elif self.mode == "rbf": + return self.rbf_mapping(input) + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + @staticmethod + def fourier_mapping(x, B): + """ + x is the input, B is the reference information + """ + if B is None: + return x + else: + B = paddle.to_tensor(data=B, dtype="float32", place=x.place) + x_proj = 2.0 * np.pi * x @ B.T + return paddle.concat( + x=[paddle.sin(x=x_proj), paddle.cos(x=x_proj)], axis=-1 + ) + + def rbf_mapping(self, x): + size = tuple(x.shape)[:-1] + tuple(self.centers.shape) + x = x.unsqueeze(axis=-2).expand(shape=size) + distances = (x - self.centers).pow(y=2).sum(axis=-1) * self.sigmas + return self.gaussian(distances) + + @staticmethod + def gaussian(alpha): + phi = paddle.exp(x=-1 * alpha.pow(y=2)) + return phi + + +class SIRENAutodecoder_film(paddle.nn.Layer): + """ + siren network with author decoding + + Args: + input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict. + output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict. + in_coord_features (int, optional): Number of input coordinates features + in_latent_features (int, optional): Number of input latent features + out_features (int, optional): Number of output features + num_hidden_layers (int, optional): Number of hidden layers + hidden_features (int, optional): Number of hidden features + outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False. + nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine". + weight_init (Callable, optional): Weight initialization function. Defaults to None. + bias_init (Callable, optional): Bias initialization function. Defaults to None. + premap_mode (str, optional): Feature mapping mode. Defaults to None. + + Examples: + >>> model = ppsci.arch.SIRENAutodecoder_film( + input_keys=["input1", "input2"], + output_keys=("output",), + in_coord_features=2, + in_latent_features=128, + out_features=3, + num_hidden_layers=10, + hidden_features=128, + ) + >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])} + >>> out_dict = model(input_data) + >>> for k, v in out_dict.items(): + ... print(k, v.shape) + output [22, 918, 3] + """ + + def __init__( + self, + input_keys, + output_keys, + in_coord_features, + in_latent_features, + out_features, + num_hidden_layers, + hidden_features, + outermost_linear=False, + nonlinearity="sine", + weight_init=None, + bias_init=None, + premap_mode=None, + **kwargs, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.premap_mode = premap_mode + if self.premap_mode is not None: + self.premap_layer = FeatureMapping( + in_coord_features, mode=premap_mode, **kwargs + ) + in_coord_features = self.premap_layer.dim + self.first_layer_init = None + self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity] + if weight_init is not None: + self.weight_init = weight_init + else: + self.weight_init = nl_weight_init + self.net1 = paddle.nn.LayerList( + sublayers=[BatchLinear(in_coord_features, hidden_features)] + + [ + BatchLinear(hidden_features, hidden_features) + for i in range(num_hidden_layers) + ] + + [BatchLinear(hidden_features, out_features)] + ) + self.net2 = paddle.nn.LayerList( + sublayers=[ + BatchLinear(in_latent_features, hidden_features, bias_attr=False) + for i in range(num_hidden_layers + 1) + ] + ) + if self.weight_init is not None: + self.net1.apply(self.weight_init) + self.net2.apply(self.weight_init) + if first_layer_init is not None: + self.net1[0].apply(first_layer_init) + self.net2[0].apply(first_layer_init) + if bias_init is not None: + self.net2.apply(bias_init) + + def forward(self, input_data): + coords = input_data[self.input_keys[0]] + latents = input_data[self.input_keys[1]] + if self.premap_mode is not None: + x = self.premap_layer(coords) + else: + x = coords + + for i in range(len(self.net1) - 1): + x = self.net1[i](x) + self.net2[i](latents) + x = self.nl(x) + x = self.net1[-1](x) + return {self.output_keys[0]: x} + + def disable_gradient(self): + for param in self.parameters(): + param.stop_gradient = not False + + +class LatentContainer(paddle.nn.Layer): + """ + a model container that stores latents for multi GPU + + Args: + input_key (Tuple[str, ...], optional): Key to get the input tensor from the dict. Defaults to ("intput",). + output_key (Tuple[str, ...], optional): Key to save the output tensor into the dict. Defaults to ("output",). + N_samples (int, optional): Number of samples. Defaults to None. + N_features (int, optional): Number of features. Defaults to None. + dims (int, optional): Number of dimensions. Defaults to None. + lumped (bool, optional): Whether to lump the latents. Defaults to False. + + Examples: + >>> model = ppsci.arch.LatentContainer(N_samples=1600, N_features=128, dims=2, lumped=True) + >>> input_data = paddle.linspace(0, 1600, 1600, 'int64') + >>> input_dict = {"input": input_data} + >>> out_dict = model(input_dict) + >>> for k, v in out_dict.items(): + ... print(k, v.shape) + output [1600, 1, 128] + """ + + def __init__( + self, + input_keys=("input",), + output_keys=("output",), + N_samples=None, + N_features=None, + dims=None, + lumped=False, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.dims = [1] * dims if not lumped else [1] + self.expand_dims = " ".join(["1" for _ in range(dims)]) if not lumped else "1" + self.expand_dims = f"N f -> N {self.expand_dims} f" + self.latents = self.create_parameter( + shape=(N_samples, N_features), + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0.0), + ) + + def forward(self, batch_ids): + x = batch_ids[self.input_keys[0]] + selected_latents = paddle.gather(self.latents, x) + if len(selected_latents.shape) > 1: + getShape = ( + [tuple(selected_latents.shape)[0]] + + self.dims + + [tuple(selected_latents.shape)[1]] + ) + else: + getShape = [-1] + self.dims + expanded_latents = selected_latents.reshape(getShape) + return {self.output_keys[0]: expanded_latents} + + +class ModelVarType(enum.Enum): + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + res = ( + paddle.to_tensor(data=arr)[timesteps] + .astype(dtype="float32") + ) + while len(tuple(res.shape)) < len(broadcast_shape): + res = res[..., None] + return res.expand(shape=broadcast_shape) + + +def split(x, num_or_sections, axis=0): + if isinstance(num_or_sections, int): + return paddle.split(x, x.shape[axis]//num_or_sections, axis) + else: + return paddle.split(x, num_or_sections, axis) + + +class ModelMeanType(enum.Enum): + PREVIOUS_X = enum.auto() + START_X = enum.auto() + EPSILON = enum.auto() + + +class GaussianDiffusion: + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(tuple(betas.shape)) == 1, "betas must be 1-D" + assert (betas > 0).astype("bool").all() and (betas <= 1).astype("bool").all() + self.num_timesteps = int(tuple(betas.shape)[0]) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert tuple(self.alphas_cumprod_prev.shape) == (self.num_timesteps,) + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert tuple(x_t.shape) == tuple(xprev.shape) + return ( + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, tuple(x_t.shape)) + * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, + t, + tuple(x_t.shape), + ) + * x_t + ) + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert tuple(x_t.shape) == tuple(eps.shape) + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, tuple(x_t.shape)) + * x_t + - _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, tuple(x_t.shape) + ) + * eps + ) + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + if model_kwargs is None: + model_kwargs = {} + B, C = tuple(x.shape)[:2] + assert tuple(t.shape) == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert tuple(model_output.shape) == (B, C * 2, *tuple(x.shape)[2:]) + model_output, model_var_values = split( + x=model_output, num_or_sections=C, axis=1 + ) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = paddle.exp(x=model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, tuple(x.shape) + ) + max_log = _extract_into_tensor(np.log(self.betas), t, tuple(x.shape)) + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = paddle.exp(x=model_log_variance) + else: + model_variance, model_log_variance = { + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, tuple(x.shape)) + model_log_variance = _extract_into_tensor( + model_log_variance, t, tuple(x.shape) + ) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clip(min=-1, max=1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + assert ( + tuple(model_mean.shape) + == tuple(model_log_variance.shape) + == tuple(pred_xstart.shape) + == tuple(x.shape) + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def q_posterior_mean_variance(self, x_start, x_t, t): + assert tuple(x_start.shape) == tuple(x_t.shape) + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, tuple(x_t.shape)) + * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, tuple(x_t.shape)) * x_t + ) + posterior_variance = _extract_into_tensor( + self.posterior_variance, t, tuple(x_t.shape) + ) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, tuple(x_t.shape) + ) + assert ( + tuple(posterior_mean.shape)[0] + == tuple(posterior_variance.shape)[0] + == tuple(posterior_log_variance_clipped.shape)[0] + == tuple(x_start.shape)[0] + ) + return (posterior_mean, posterior_variance, posterior_log_variance_clipped) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.astype(dtype="float32") * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].astype(dtype="float32") + p_mean_var[ + "variance" + ] * gradient.astype(dtype="float32") + return new_mean + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = paddle.randn(shape=x.shape, dtype=x.dtype) + nonzero_mask = ( + (t != 0).astype(dtype="float32").reshape([-1, *([1] * (len(tuple(x.shape)) - 1))]) + ) + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = ( + out["mean"] + nonzero_mask * paddle.exp(x=0.5 * out["log_variance"]) * noise + ) + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = paddle.randn(shape=shape) + indices = list(range(self.num_timesteps))[::-1] + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + for i in indices: + t = paddle.to_tensor(data=[i] * shape[0]) + with paddle.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + Args: + use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = paddle.to_tensor( + data=self.timestep_map, dtype=ts.dtype, place=ts.place + ) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.astype(dtype="float32") * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return paddle.nn.Conv1D(*args, **kwargs) + elif dims == 2: + return paddle.nn.Conv2D(*args, **kwargs) + elif dims == 3: + return paddle.nn.Conv3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + return paddle.nn.Linear(*args, **kwargs) + + +class TimestepBlock(paddle.nn.Layer): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + pass + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.in_layers = paddle.nn.Sequential( + normalization(channels), + paddle.nn.Silu(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.updown = up or down + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = paddle.nn.Identity() + self.emb_layers = paddle.nn.Sequential( + paddle.nn.Silu(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = paddle.nn.Sequential( + normalization(self.out_channels), + paddle.nn.Silu(), + paddle.nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + if self.out_channels == channels: + self.skip_connection = paddle.nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).astype(h.dtype) + while len(tuple(emb_out.shape)) < len(tuple(h.shape)): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = paddle.chunk(x=emb_out, chunks=2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class TimestepEmbedSequential(paddle.nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +NUM_CLASSES = 1000 + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return paddle.nn.AvgPool1d(*args, **kwargs, exclusive=False) + elif dims == 2: + return paddle.nn.AvgPool2d(*args, **kwargs, exclusive=False) + elif dims == 3: + return paddle.nn.AvgPool3d(*args, **kwargs, exclusive=False) + raise ValueError(f"unsupported dimensions: {dims}") + + + +class Downsample(paddle.nn.Layer): + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert tuple(x.shape)[1] == self.channels + return self.op(x) + +class Upsample(paddle.nn.Layer): + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert tuple(x.shape)[1] == self.channels + if self.dims == 3: + x = paddle.nn.functional.interpolate( + x=x, + size=(tuple(x.shape)[2], tuple(x.shape)[3] * 2, tuple(x.shape)[4] * 2), + mode="nearest", + ) + else: + x = paddle.nn.functional.interpolate(x=x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +def count_flops_attn(model, _x, y): + b, c, *spatial = tuple(y[0].shape) + num_spatial = int(np.prod(spatial)) + matmul_ops = 2 * b * num_spatial**2 * c + model.total_ops += paddle.to_tensor(data=[matmul_ops], dtype="float64") + + +class QKVAttentionLegacy(paddle.nn.Layer): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + bs, width, length = tuple(qkv.shape) + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape((bs * self.n_heads, ch * 3, length)).split(3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale) + weight = paddle.nn.functional.softmax( + x=weight.astype(dtype="float32"), axis=-1 + ).astype(weight.dtype) + a = paddle.einsum("bts,bcs->bct", weight, v) + return a.reshape((bs, -1, length)) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(paddle.nn.Layer): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + bs, width, length = tuple(qkv.shape) + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(chunks=3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) + weight = paddle.nn.functional.softmax( + x=weight.astype(dtype="float32"), axis=-1 + ).astype(weight.dtype) + a = paddle.einsum( + "bts,bcs->bct", weight, v.reshape((bs * self.n_heads, ch, length)) + ) + return a.reshape((bs, -1, length)) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class GroupNorm32(paddle.nn.GroupNorm): + def forward(self, x): + return super().forward(x.astype(dtype="float32")).astype(x.dtype) + + +def normalization(channels): + return GroupNorm32(32, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +def checkpoint(func, inputs, params, flag): + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with paddle.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + """Class Method: *.requires_grad_, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" + ctx.input_tensors = [paddle.stop_gradient(x, stop=False) for x in ctx.input_tensors] + with paddle.enable_grad(): + shallow_copies = [x.view_as(other=x) for x in ctx.input_tensors] + # print(shallow_copies) + output_tensors = ctx.run_function(*shallow_copies) + input_grads = paddle.grad( + outputs=output_tensors, + inputs=ctx.input_tensors + ctx.input_params, + grad_outputs=output_grads, + allow_unused=True, + retain_graph=True, create_graph=False + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def stop_gradient(self, *args, **kwargs): + return self + + +class AttentionBlock(paddle.nn.Layer): + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + self.attention = QKVAttention(self.num_heads) + else: + self.attention = QKVAttentionLegacy(self.num_heads) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) + + def _forward(self, x): + b, c, *spatial = tuple(x.shape) + x = x.reshape((b, c, -1)) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape((b, c, *spatial)) + + +def convert_module_to_f16(l): + if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)): + l.weight.data = l.weight.data.astype(dtype="float16") + if l.bias is not None: + l.bias.data = l.bias.data.astype(dtype="float16") + + +def convert_module_to_f32(l): + if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)): + l.weight.data = l.weight.data.astype(dtype="float32") + if l.bias is not None: + l.bias.data = l.bias.data.astype(dtype="float32") + + +def timestep_embedding(timesteps, dim, max_period=10000): + half = dim // 2 + freqs = paddle.exp( + x=-math.log(max_period) + * paddle.arange(start=0, end=half, dtype="float32") + / half + ).to(paddle.CUDAPlace(0)) + args = timesteps[:, None].astype(dtype="float32") * freqs[None] + embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) + if dim % 2: + embedding = paddle.concat( + x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1 + ) + return embedding + + +class UNetModel(paddle.nn.Layer): + """ + The full UNet model with attention and timestep embedding. + + Args: + image_size (int): Input image size (maintained for interface compatibility) + in_channels (int): Number of channels in input tensor + model_channels (int): Base channel count for model + out_channels (int): Number of channels in output tensor + num_res_blocks (int): Residual blocks per downsampling level + attention_resolutions (list/tuple): Downsample rates to apply attention (e.g., [4, 8]) + dropout (float, optional): Dropout probability. Default: 0.0 + channel_mult (tuple, optional): Channel multipliers per level. Default: (1, 2, 4, 8) + conv_resample (bool, optional): Use convolutional resampling. Default: True + dims (int, optional): Data dimensionality (1=1D, 2=2D, 3=3D). Default: 2 + num_classes (int, optional): Number of classes for conditional generation. Default: None + use_checkpoint (bool, optional): Enable gradient checkpointing. Default: False + use_fp16 (bool, optional): Use float16 precision. Default: False + num_heads (int, optional): Number of attention heads. Default: 1 + num_head_channels (int, optional): Fixed channels per head (overrides num_heads). Default: -1 + num_heads_upsample (int, optional): Heads for upsampling blocks. Default: -1 (use num_heads) + use_scale_shift_norm (bool, optional): Use FiLM-like conditioning. Default: False + resblock_updown (bool, optional): Use residual blocks for resampling. Default: False + use_new_attention_order (bool, optional): Use optimized attention pattern. Default: False + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = "float16" if use_fp16 else "float32" + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + time_embed_dim = model_channels * 4 + self.time_embed = paddle.nn.Sequential( + linear(model_channels, time_embed_dim), + paddle.nn.Silu(), + linear(time_embed_dim, time_embed_dim), + ) + if self.num_classes is not None: + self.label_emb = paddle.nn.Embedding( + num_embeddings=num_classes, embedding_dim=time_embed_dim + ) + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = paddle.nn.LayerList( + sublayers=[ + TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1)) + ] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.output_blocks = paddle.nn.LayerList(sublayers=[]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + self.out = paddle.nn.Sequential( + normalization(ch), + paddle.nn.Silu(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.num_classes is not None: + assert tuple(y.shape) == (tuple(x.shape)[0],) + emb = emb + self.label_emb(y) + h = x.astype(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = paddle.concat(x=[h, hs.pop()], axis=1) + h = module(h, emb) + h = h.astype(x.dtype) + return self.out(h) \ No newline at end of file From ed154e57921b52bf68b1531d1d09c7cd26216586 Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Sun, 5 Oct 2025 21:31:51 +0800 Subject: [PATCH 02/11] Add UNet model implementation and Gaussian diffusion methods --- examples/confild/conf/confild_case1.yaml | 1 - examples/confild/conf/un_confild_case1.yaml | 67 +++ examples/confild/conf/un_confild_case2.yaml | 68 +++ examples/confild/conf/un_confild_case3.yaml | 64 +++ examples/confild/conf/un_confild_case4.yaml | 65 +++ examples/confild/confild.py | 369 +++------------- examples/confild/un_confild.py | 446 ++++++++++++++++++++ ppsci/arch/confild.py | 180 +++++++- 8 files changed, 918 insertions(+), 342 deletions(-) create mode 100644 examples/confild/conf/un_confild_case1.yaml create mode 100644 examples/confild/conf/un_confild_case2.yaml create mode 100644 examples/confild/conf/un_confild_case3.yaml create mode 100644 examples/confild/conf/un_confild_case4.yaml create mode 100644 examples/confild/un_confild.py diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml index 522cda9b98..f80c7bb447 100644 --- a/examples/confild/conf/confild_case1.yaml +++ b/examples/confild/conf/confild_case1.yaml @@ -29,7 +29,6 @@ mode: infer # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 -alis: False TRAIN: batch_size: 64 diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml new file mode 100644 index 0000000000..6295c04ee3 --- /dev/null +++ b/examples/confild/conf/un_confild_case1.yaml @@ -0,0 +1,67 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_un_confild_case1 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + + +EVAL: + epochs: 9800 + mutil_GPU: 1 + microbatch: -1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + batch_size : 16 + test_batch_size : 16 + time_length : 128 + latent_length : 128 + +UNET: + + image_size : 128 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /add/ema/path/here + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: "ConditionalNeuralField/training_recipes/case1.yml" + +DATA: + max_val: 1.0 + min_val: -1.0 \ No newline at end of file diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml new file mode 100644 index 0000000000..bacc83b3a8 --- /dev/null +++ b/examples/confild/conf/un_confild_case2.yaml @@ -0,0 +1,68 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_un_confild_case2 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + + + +EVAL: + epochs: 9800 + mutil_GPU: 1 + microbatch: -1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + batch_size : 16 + test_batch_size : 16 + time_length : 256 + latent_length : 256 + +UNET: + + image_size : 256 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /add/ema/path/here + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: "ConditionalNeuralField/training_recipes/case2.yml" + +DATA: + max_val: 1.0 + min_val: -1.0 \ No newline at end of file diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml new file mode 100644 index 0000000000..22d7b7cc8e --- /dev/null +++ b/examples/confild/conf/un_confild_case3.yaml @@ -0,0 +1,64 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_un_confild_case3 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + +EVAL: + epochs: 9800 + mutil_GPU: 2 + microbatch: -1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + batch_size : 16 + test_batch_size : 16 + time_length : 256 + latent_length : 256 + +UNET: + + image_size : 256 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: null + ema_path: /add/ema/path/here + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: "ConditionalNeuralField/training_recipes/case3.yml" + + min_val: -1.0 diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml new file mode 100644 index 0000000000..009f0296eb --- /dev/null +++ b/examples/confild/conf/un_confild_case4.yaml @@ -0,0 +1,65 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + dir: ./outputs_un_confild_case4 + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2025 +output_dir: ${hydra:run.dir} +log_freq: 20 + + +EVAL: + epochs: 9800 + mutil_GPU: 2 + microbatch: -1 + lr : 5.e-5 + ema_rate: "0.9999" + log_interval: 1000 + save_interval: 10000 + lr_anneal_steps: 0 + batch_size : 8 + test_batch_size : 8 + time_length : 256 + latent_length : 256 + +UNET: + + image_size : 384 + num_channels: 128 + num_res_blocks: 2 + num_heads: 4 + num_head_channels: 64 + attention_resolutions: "32,16,8" + channel_mult: "1, 1, 2, 2, 4, 4" + ema_path: /add/ema/path/here + +Diff: + steps: 1000 + noise_schedule: "cosine" + +CNF: "ConditionalNeuralField/training_recipes/case4.yml" + + min_val: -1.0 diff --git a/examples/confild/confild.py b/examples/confild/confild.py index 098b741d3d..8003010630 100644 --- a/examples/confild/confild.py +++ b/examples/confild/confild.py @@ -32,6 +32,7 @@ from ppsci.arch import ModelMeanType from ppsci.utils import logger + def load_elbow_flow(path): return np.load(f"{path}")[1:] @@ -146,24 +147,48 @@ def fdenormalize(data_norm, params, method): return data_norm +class basic_set(paddle.io.Dataset): + def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> None: + super().__init__() + self.fois = fois.numpy() + self.total_samples = tuple(fois.shape)[0] + self.coords = coord.numpy() + # 存储全局索引 + self.global_indices = global_indices if global_indices is not None else np.arange(self.total_samples) + + def __len__(self): + return self.total_samples + + def __getitem__(self, idx): + # 使用全局索引 + global_idx = self.global_indices[idx] + if hasattr(self, "extra_in"): + extra_id = idx % tuple(self.fois.shape)[1] + idb = idx // tuple(self.fois.shape)[1] + return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], global_idx + else: + return self.coords, self.fois[idx], global_idx + + # build data def getdata(cfg): ###### read data - fois ###### if cfg.Data.load_data_fn == "load_3d_flow": - input_data = load_3d_flow(cfg.Data.data_path) + fois = load_3d_flow(cfg.Data.data_path) elif cfg.Data.load_data_fn == "load_elbow_flow": - input_data = load_elbow_flow(cfg.Data.data_path) + fois = load_elbow_flow(cfg.Data.data_path) elif cfg.Data.load_data_fn == "load_channel_flow": - input_data = load_channel_flow(cfg.Data.data_path) + fois = load_channel_flow(cfg.Data.data_path) elif cfg.Data.load_data_fn == "load_periodic_hill_flow": - input_data = load_periodic_hill_flow(cfg.Data.data_path) + fois = load_periodic_hill_flow(cfg.Data.data_path) else: - input_data = np.load(cfg.Data.data_path) + fois = np.load(cfg.Data.data_path) - spatio_shape = input_data.shape[1:-1] + # 计算空间形状和轴 + spatio_shape = fois.shape[1:-1] spatio_axis = list( range( - input_data.ndim if isinstance(input_data, np.ndarray) else input_data.dim() + fois.ndim if isinstance(fois, np.ndarray) else fois.dim() ) )[1:-1] @@ -174,16 +199,16 @@ def getdata(cfg): else: coord = np.load(cfg.Data.coor_path) coord = coord.astype("float32") - input_data = input_data.astype("float32") + fois = fois.astype("float32") ###### convert to tensor ###### - input_data = ( - paddle.to_tensor(input_data) - if not isinstance(input_data, paddle.Tensor) - else input_data + fois = ( + paddle.to_tensor(fois) + if not isinstance(fois, paddle.Tensor) + else fois ) coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord - N_samples = input_data.shape[0] + N_samples = fois.shape[0] ###### normalizer ###### in_normalizer = Normalizer_ts(**cfg.Data.normalizer) @@ -192,62 +217,12 @@ def getdata(cfg): ) out_normalizer = Normalizer_ts(**cfg.Data.normalizer) out_normalizer.fit_normalize( - input_data if cfg.Latent.lumped else input_data.flatten(0, cfg.Latent.dims) + fois if cfg.Latent.lumped else fois.flatten(0, cfg.Latent.dims) ) - normed_coords = in_normalizer.normalize(coord) - normed_fois = out_normalizer.normalize(input_data) + normed_coords = in_normalizer.normalize(coord)# 训练集就是测试集 + normed_fois = out_normalizer.normalize(fois) return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer - ###### 添加数据集划分 ###### - # split_ratio = cfg.Data.get("split_ratio", 0.8) # 默认为80%训练集 - # seed = cfg.Data.get("shuffle_seed", 42) # 随机种子 - - # # 生成随机索引并划分 - # np.random.seed(seed) - # total_samples = N_samples - # indices = np.random.permutation(total_samples) - # split_idx = int(total_samples * split_ratio) - - # # 划分训练集和测试集 - # train_indices = indices[:split_idx] - # test_indices = indices[split_idx:] - - # # 根据索引获取训练集和测试集数据 - # train_normed_fois = normed_fois[train_indices] - # test_normed_fois = normed_fois[test_indices] - - # return ( - # normed_coords, - # train_normed_fois, # 训练集数据 - # test_normed_fois, # 测试集数据 - # spatio_axis, - # out_normalizer, - # train_indices, # 训练集索引(用于latent模型) - # test_indices # 测试集索引 - # ) - - -class basic_set(paddle.io.Dataset): - def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> None: - super().__init__() - self.fois = fois.numpy() - self.total_samples = tuple(fois.shape)[0] - self.coords = coord.numpy() - # 存储全局索引 - self.global_indices = global_indices if global_indices is not None else np.arange(self.total_samples) - - def __len__(self): - return self.total_samples - - def __getitem__(self, idx): - # 使用全局索引 - global_idx = self.global_indices[idx] - if hasattr(self, "extra_in"): - extra_id = idx % tuple(self.fois.shape)[1] - idb = idx // tuple(self.fois.shape)[1] - return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], global_idx - else: - return self.coords, self.fois[idx], global_idx def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices): @@ -464,28 +439,24 @@ def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_ def train(cfg): - # 获取分割后的数据集 - # (normed_coords, - # train_normed_fois, - # test_normed_fois, - # spatio_axis, - # out_normalizer, - # train_indices, - # test_indices) = getdata(cfg) + # 获取GPU数量,检查是否是多卡训练 + world_size = cfg.TRAIN.mutil_GPU + # 获取数据 normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) train_normed_fois = normed_fois test_normed_fois = normed_fois train_indices = list(range(N_samples)) test_indices = list(range(N_samples)) - if cfg.TRAIN.mutil_GPU > 1: + + if world_size > 1: import paddle.distributed as dist dist.init_parallel_env() mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices) else: signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, - spatio_axis, out_normalizer, train_indices, test_indices) + spatio_axis, out_normalizer, train_indices, test_indices) def evaluate(cfg: DictConfig): @@ -565,243 +536,6 @@ def inference(cfg): logger.info("Result is {}".format(output_dict["fetch_name_0"])) -def uncondiction_infer(cfg): - test_batch_size = cfg.Uncondiction_INFER.test_batch_size - time_length = cfg.Uncondiction_INFER.time_length - latent_length = cfg.Uncondiction_INFER.latent_length - image_size = cfg.Uncondiction_INFER.image_size - num_channels = cfg.Uncondiction_INFER.num_channels - num_res_blocks = cfg.Uncondiction_INFER.num_res_blocks - num_heads = cfg.Uncondiction_INFER.num_heads - num_head_channels = cfg.Uncondiction_INFER.num_head_channels - attention_resolutions = cfg.Uncondiction_INFER.attention_resolutions - steps = cfg.Uncondiction_INFER.steps - noise_schedule = cfg.Uncondiction_INFER.noise_schedule - - unet_model = create_model( - image_size=image_size, - num_channels=num_channels, - num_res_blocks=num_res_blocks, - num_heads=num_heads, - num_head_channels=num_head_channels, - attention_resolutions=attention_resolutions, - ) - # ppsci.utils.save_load.load_pretrain( - # unet_model, - # # cfg.Uncondiction_INFER.ema_path, - # "/home/aistudio/ema_0.9999_550000.pdparams", - # ) - diff_model = create_gaussian_diffusion(steps=steps, noise_schedule=noise_schedule) - sample_fn = diff_model.p_sample_loop - gen_latents = sample_fn(unet_model, (test_batch_size, 1, time_length, latent_length))[ - :, 0 - ] - max_val, min_val = np.load("/home/aistudio/data_max.npy"), np.load("/home/aistudio/data_min.npy") - # max_val, min_val = np.load(cfg.Uncondiction_INFER.max_val), np.load(cfg.Uncondiction_INFER.min_val) - max_val, min_val = paddle.to_tensor(data=max_val), paddle.to_tensor(data=min_val) - gen_latents = (gen_latents + 1) * (max_val - min_val) / 2.0 + min_val - # 加载cnf模型 - print("加载cnf模型") - confild = SIRENAutodecoder_film(**cfg.CONFILD) - ppsci.utils.save_load.load_pretrain( - confild, - "https://dataset.bj.bcebos.com/PaddleScience/CoNFiLD/cnf_model_9700.pdparams",# cfg.EVAL.confild_pretrained_model_path, - ) - confild.eval() - coord = paddle.to_tensor(np.load("/home/aistudio/data/data321897/case1_coords.npy"), dtype="float32")#(np.load(f"{cfg.Data.coor_path}"), dtype='float32') - batch_size = 1 - n_samples = tuple(gen_latents.shape)[0] - out_normalizer = Normalizer_ts(**cfg.Data.normalizer) - - gen_fields = [] - print("开始生成") - for sample_index in range(n_samples): - print("第{}个样本", sample_index) - for i in range(tuple(gen_latents.shape)[1] // batch_size): - input_dict = { - "confild_x": coord, - "latent_z": gen_latents[sample_index, i * batch_size : (i + 1) * batch_size], - } - confild_output = confild(input_dict) - # print(confild_output) - gen_fields.append(out_normalizer.denormalize(confild_output["confild_output"]).detach() - .cpu() - .numpy()) - gen_fields = np.concatenate(gen_fields) - np.save("./", gen_fields)#cfg.Uncondiction_INFER.save_path - - -class LossType(enum.Enum): - MSE = enum.auto() - RESCALED_MSE = enum.auto() - KL = enum.auto() - RESCALED_KL = enum.auto() - - def is_vb(self): - return self == LossType.KL or self == LossType.RESCALED_KL - - -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - if schedule_name == "linear": - scale = 1000 / num_diffusion_timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return np.linspace( - beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 - ) - elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - -def space_timesteps(num_timesteps, section_counts): - if isinstance(section_counts, str): - if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) - for i in range(1, num_timesteps): - if len(range(0, num_timesteps, i)) == desired_count: - return set(range(0, num_timesteps, i)) - raise ValueError( - f"cannot create exactly {num_timesteps} steps with an integer stride" - ) - section_counts = [int(x) for x in section_counts.split(",")] - size_per = num_timesteps // len(section_counts) - extra = num_timesteps % len(section_counts) - start_idx = 0 - all_steps = [] - for i, section_count in enumerate(section_counts): - size = size_per + (1 if i < extra else 0) - if size < section_count: - raise ValueError( - f"cannot divide section of {size} steps into {section_count}" - ) - if section_count <= 1: - frac_stride = 1 - else: - frac_stride = (size - 1) / (section_count - 1) - cur_idx = 0.0 - taken_steps = [] - for _ in range(section_count): - taken_steps.append(start_idx + round(cur_idx)) - cur_idx += frac_stride - all_steps += taken_steps - start_idx += size - return set(all_steps) - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def create_gaussian_diffusion( - *, - steps=1000, - learn_sigma=False, - sigma_small=False, - noise_schedule="linear", - use_kl=False, - predict_xstart=False, - rescale_timesteps=False, - rescale_learned_sigmas=False, - timestep_respacing="", -): - betas = get_named_beta_schedule(noise_schedule, steps) - if use_kl: - loss_type = LossType.RESCALED_KL - elif rescale_learned_sigmas: - loss_type = LossType.RESCALED_MSE - else: - loss_type = LossType.MSE - if not timestep_respacing: - timestep_respacing = [steps] - return SpacedDiffusion( - use_timesteps=space_timesteps(steps, timestep_respacing), - betas=betas, - model_mean_type=ModelMeanType.EPSILON - if not predict_xstart - else ModelMeanType.START_X, - model_var_type=( - ModelVarType.FIXED_LARGE - if not sigma_small - else ModelVarType.FIXED_SMALL - ) - if not learn_sigma - else ModelVarType.LEARNED_RANGE, - loss_type=loss_type, - rescale_timesteps=rescale_timesteps, - ) - - -NUM_CLASSES = 1000 - - -def create_model( - image_size, - num_channels, - num_res_blocks, - dims=2, - out_channels=1, - channel_mult=None, - learn_sigma=False, - class_cond=False, - use_checkpoint=False, - attention_resolutions="16", - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - dropout=0, - resblock_updown=False, - use_fp16=False, - use_new_attention_order=False, -): - if channel_mult is None: - if image_size == 512: - channel_mult = 0.5, 1, 1, 2, 2, 4, 4 - elif image_size == 256: - channel_mult = 1, 1, 2, 2, 4, 4 - elif image_size == 128: - channel_mult = 1, 1, 2, 3, 4 - elif image_size == 64: - channel_mult = 1, 2, 3, 4 - else: - raise ValueError(f"unsupported image size: {image_size}") - else: - channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) - attention_ds = [] - for res in attention_resolutions.split(","): - attention_ds.append(image_size // int(res)) - return UNetModel( - image_size=image_size, - in_channels=out_channels, - model_channels=num_channels, - out_channels=out_channels if not learn_sigma else 2 * out_channels, - num_res_blocks=num_res_blocks, - attention_resolutions=tuple(attention_ds), - dropout=dropout, - channel_mult=channel_mult, - num_classes=NUM_CLASSES if class_cond else None, - use_checkpoint=use_checkpoint, - use_fp16=use_fp16, - num_heads=num_heads, - num_head_channels=num_head_channels, - num_heads_upsample=num_heads_upsample, - use_scale_shift_norm=use_scale_shift_norm, - resblock_updown=resblock_updown, - use_new_attention_order=use_new_attention_order, - dims=dims, - ) - - def export(cfg): # set model cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) @@ -846,15 +580,12 @@ def main(cfg: DictConfig): elif cfg.mode == "eval": evaluate(cfg) elif cfg.mode == "infer": - if cfg.alis == False: - inference(cfg) - else: - uncondiction_infer(cfg) + inference(cfg) elif cfg.mode == "export": export(cfg) - elif cfg.mode == "uncondition_infer": + else: raise ValueError( - f"cfg.mode should in ['train', 'eval', 'infer', 'export', 'uncondition_infer'], but got '{cfg.mode}'" + f"cfg.mode should in ['train', 'eval', 'infer', 'export'], but got '{cfg.mode}'" ) diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py new file mode 100644 index 0000000000..84012f5422 --- /dev/null +++ b/examples/confild/un_confild.py @@ -0,0 +1,446 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import math +import hydra +import matplotlib.pyplot as plt +import numpy as np +import paddle +from omegaconf import DictConfig +from paddle.distributed import fleet +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +import ppsci +from ppsci.arch import UNetModel +from ppsci.arch import LatentContainer +from ppsci.arch import SIRENAutodecoder_film +from ppsci.arch import SpacedDiffusion +from ppsci.arch import ModelVarType +from ppsci.arch import ModelMeanType +from ppsci.utils import logger + + +def create_model( + image_size, + num_channels, + num_res_blocks, + dims=2, + out_channels=1, + channel_mult=None, + learn_sigma=False, + class_cond=False, + use_checkpoint=False, + attention_resolutions="16", + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + dropout=0, + resblock_updown=False, + use_fp16=False, + use_new_attention_order=False, +): + if channel_mult is None: + if image_size == 512: + channel_mult = (0.5, 1, 1, 2, 2, 4, 4) + elif image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + image_size=image_size, + in_channels=out_channels, + model_channels=num_channels, + out_channels=(out_channels if not learn_sigma else 2*out_channels),#(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(1000 if class_cond else None), + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + dims=dims + ) + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def space_timesteps(num_timesteps, section_counts): + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = LossType.RESCALED_MSE + else: + loss_type = LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X + ), + model_var_type=( + ( + ModelVarType.FIXED_LARGE + if not sigma_small + else ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +def load_elbow_flow(path): + return np.load(f"{path}")[1:] + + +def load_channel_flow( + path, + t_start=0, + t_end=1200, + t_every=1, +): + return np.load(f"{path}")[t_start:t_end:t_every] + + +def load_periodic_hill_flow(path): + data = np.load(f"{path}") + return data + + +def load_3d_flow(path): + data = np.load(f"{path}") + return data + + +class Normalizer_ts(object): + def __init__(self, params=[], method="-11", dim=None): + self.params = params + self.method = method + self.dim = dim + + def fit_normalize(self, data): + assert type(data) == paddle.Tensor + if len(self.params) == 0: + if self.method == "-11" or self.method == "01": + if self.dim is None: + self.params = paddle.max(x=data), paddle.min(x=data) + else: + self.params = ( + paddle.max(keepdim=True, x=data, axis=self.dim), + paddle.argmax(keepdim=True, x=data, axis=self.dim), + )[0], ( + paddle.min(keepdim=True, x=data, axis=self.dim), + paddle.argmin(keepdim=True, x=data, axis=self.dim), + )[ + 0 + ] + elif self.method == "ms": + if self.dim is None: + self.params = paddle.mean(x=data, axis=self.dim), paddle.std( + x=data, axis=self.dim + ) + else: + self.params = paddle.mean( + x=data, axis=self.dim, keepdim=True + ), paddle.std(x=data, axis=self.dim, keepdim=True) + elif self.method == "none": + self.params = None + return self.fnormalize(data, self.params, self.method) + + def normalize(self, new_data): + if not new_data.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fnormalize(new_data, self.params, self.method) + + def denormalize(self, new_data_norm): + if not new_data_norm.place == self.params[0].place: + self.params = self.params[0], self.params[1] + return self.fdenormalize(new_data_norm, self.params, self.method) + + def get_params(self): + if self.method == "ms": + print("returning mean and std") + elif self.method == "01": + print("returning max and min") + elif self.method == "-11": + print("returning max and min") + elif self.method == "none": + print("do nothing") + return self.params + + @staticmethod + def fnormalize(data, params, method): + if method == "-11": + return (data - params[1]) / ( + params[0] - params[1] + ) * 2 - 1 + elif method == "01": + return (data - params[1]) / ( + params[0] - params[1] + ) + elif method == "ms": + return (data - params[0]) / params[1] + elif method == "none": + return data + + @staticmethod + def fdenormalize(data_norm, params, method): + if method == "-11": + return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] + elif method == "01": + return data_norm * ( + params[0] - params[1] + ) + params[1] + elif method == "ms": + return data_norm * params[1] + params[0] + elif method == "none": + return data_norm + + +def create_slim(cfg): + world_size = cfg.multiGPU + ###### read data - fois ###### + if cfg.Data.load_data_fn == "load_3d_flow": + fois = load_3d_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_elbow_flow": + fois = load_elbow_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_channel_flow": + fois = load_channel_flow(cfg.Data.data_path) + elif cfg.Data.load_data_fn == "load_periodic_hill_flow": + fois = load_periodic_hill_flow(cfg.Data.data_path) + else: + fois = np.load(cfg.Data.data_path) + + # 计算空间形状和轴 + spatio_shape = fois.shape[1:-1] + spatio_axis = list( + range( + fois.ndim if isinstance(fois, np.ndarray) else fois.dim() + ) + )[1:-1] + + ###### read data - coordinate ###### + if cfg.Data.coor_path is None: + coord = [np.linspace(0, 1, i) for i in spatio_shape] + coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1) + else: + coord = np.load(cfg.Data.coor_path) + coord = coord.astype("float32") + fois = fois.astype("float32") + + ###### convert to tensor ###### + fois = ( + paddle.to_tensor(fois) + if not isinstance(fois, paddle.Tensor) + else fois + ) + coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord + N_samples = fois.shape[0] + + ###### normalizer ###### + in_normalizer = Normalizer_ts(**cfg.Data.normalizer) + out_normalizer = Normalizer_ts(**cfg.Data.normalizer) + # 使用最新的模型参数 + norm_params = paddle.load(f"{hyper_para.save_path}/normalizer_params.pt") + in_normalizer.params = norm_params["x_normalizer_params"] + out_normalizer.params = norm_params["y_normalizer_params"] + + cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) + + normed_coords = in_normalizer.normalize(coord)# 训练集就是测试集 + + return cnf_model, in_normalizer, out_normalizer, coord + + +def evaluate(cfg): + ## Create model and diffusion + unet_model = create_model(image_size=cfg.UNET.image_size, + num_channels=cfg.UNET.num_channels, + num_res_blocks=cfg.UNET.num_res_blocks, + num_heads=cfg.UNET.num_heads, + num_head_channels=cfg.UNET.num_head_channels, + attention_resolutions=cfg.UNET.attention_resolutions + ) + + unet_model.set_state_dict(paddle.load(cfg.UNET.ema_path)) + + diff_model = create_gaussian_diffusion(steps=cfg.Diff.steps, + noise_schedule=cfg.Diff.noise_schedule + ) + + sample_fn = diff_model.p_sample_loop + gen_latents = sample_fn(unet_model, (cfg.EVAL.test_batch_size, 1, cfg.EVAL.time_length, cfg.EVAL.latent_length))[:, 0] + + max_val, min_val = np.load(cfg.DATA.max_val), np.load(cfg.DATA.min_val) + max_val, min_val = paddle.to_tensor(max_val), paddle.to_tensor(min_val) + gen_latents = (gen_latents + 1)*(max_val - min_val)/2. + min_val + + # 获取模型 + nf, in_normalizer, out_normalizer, coord = create_slim(cfg) + nf.set_state_dict(paddle.load(cfg.CONFILD.ema_path)) + coord = in_normalizer.normalize(coord) + + batch_size = 1 # if you are limited by your GPU Memory, please change the batch_size variable accordingly + n_samples = gen_latents.shape[0] + gen_fields = [] + + for sample_index in range(n_samples): + for i in range(gen_latents.shape[1]//batch_size): + new_latents = gen_latents[sample_index, i*batch_size:(i+1)*batch_size] + # coord = in_normalizer.normalize(coord) + if len(coord.shape) > 2: + new_latents = new_latents[:, None, None] + else: + new_latents = new_latents[:, None] + out = nf(coord.to(new_latents.device), new_latents) + out = out_normalizer.denormalize(out) + gen_fields.append(out.detach().cpu().numpy()) + + gen_fields = np.concatenate(gen_fields) + + np.save(inp.save_path, gen_fields) + + # 绘制结果 + + +@hydra.main(version_base=None, config_path="./conf", config_name="un_confild_case1.yaml") +def main(cfg: DictConfig): + if cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError( + f"cfg.mode should in ['eval'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + # main() + # 构建create_model + my_model = create_model(image_size=128, + num_channels=128, + num_res_blocks=2, + num_heads=4, + num_head_channels=64, + attention_resolutions="32,16,8") + #保存参数 + paddle.save(my_model.state_dict(), "my_model.pdparams") \ No newline at end of file diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py index 3965080832..f50db439d3 100644 --- a/ppsci/arch/confild.py +++ b/ppsci/arch/confild.py @@ -7,7 +7,7 @@ DEFAULT_W0 = 30.0 - +###################### ConFILD Model ####################### class Swish(paddle.nn.Layer): def __init__(self): super().__init__() @@ -380,7 +380,7 @@ def forward(self, batch_ids): expanded_latents = selected_latents.reshape(getShape) return {self.output_keys[0]: expanded_latents} - +###################### GaussianDiffusion Model ####################### class ModelVarType(enum.Enum): LEARNED = enum.auto() @@ -412,6 +412,34 @@ class ModelMeanType(enum.Enum): EPSILON = enum.auto() +def mean_flat(tensor): + return paddle.mean(tensor, axis=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, paddle.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, paddle.Tensor) else paddle.to_tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + paddle.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2) + ) + + class GaussianDiffusion: def __init__( self, @@ -426,16 +454,22 @@ def __init__( self.model_var_type = model_var_type self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(tuple(betas.shape)) == 1, "betas must be 1-D" assert (betas > 0).astype("bool").all() and (betas <= 1).astype("bool").all() + self.num_timesteps = int(tuple(betas.shape)[0]) + alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert tuple(self.alphas_cumprod_prev.shape) == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) @@ -455,7 +489,22 @@ def __init__( * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) - + + def q_mean_variance(self, x_start, t): + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = paddle.randn(x_start.shape) + + sqrt_alpha_cumprod_t = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + sqrt_one_minus_alpha_cumprod_t = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + + return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise + def _predict_xstart_from_xprev(self, x_t, t, xprev): assert tuple(x_t.shape) == tuple(xprev.shape) return ( @@ -584,12 +633,37 @@ def _scale_timesteps(self, t): return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = p_mean_var["mean"].astype(dtype="float32") + p_mean_var[ "variance" ] * gradient.astype(dtype="float32") return new_mean + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + def p_sample( self, model, @@ -699,6 +773,63 @@ def p_sample_loop_progressive( yield out img = out["sample"] + def training_losses(self, model, x_start, t, noise=None): + if noise is None: + noise = paddle.randn(x_start.shape) + + x_t = self.q_sample(x_start=x_start, t=t, noise=noise) + + model_output = model(x_t, t) + + # Handle different model outputs + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape[1] == 2 * x_start.shape[1], "Output channels must be 2x input channels" + model_output, model_var_values = paddle.split(model_output, 2, axis=1) + + # Calculate the MSE loss for epsilon prediction + target = noise + mse_loss = mean_flat((target - model_output) ** 2) + + # Calculate the KL divergence loss if needed + vb_losses = 0 + if self.loss_type.is_vb(): + # Compute the KL divergence between q and p distributions + true_mean, true_log_variance_clipped, pred_xstart = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + p_mean, p_log_variance_clipped, _ = self.p_mean_variance( + model, x_t, t, clip_denoised=False, denoised_fn=None + ) + + kl = normal_kl(true_mean, true_log_variance_clipped, p_mean, p_log_variance_clipped) + vb_losses = mean_flat(kl) + + # Choose the loss based on loss_type + if self.loss_type == LossType.MSE: + losses = mse_loss + elif self.loss_type == LossType.RESCALED_MSE: + losses = mse_loss * self.num_timesteps + elif self.loss_type == LossType.KL: + losses = vb_losses + elif self.loss_type == LossType.RESCALED_KL: + losses = vb_losses * self.num_timesteps + else: + raise NotImplementedError(f"Unknown loss type: {self.loss_type}") + + return {"loss": losses.mean()} + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + class SpacedDiffusion(GaussianDiffusion): """ @@ -757,7 +888,7 @@ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): def __call__(self, x, ts, **kwargs): map_tensor = paddle.to_tensor( - data=self.timestep_map, dtype=ts.dtype, place=ts.place + data=self.timestep_map, dtype=ts.dtype#, place=ts.place ) new_ts = map_tensor[ts] if self.rescale_timesteps: @@ -765,6 +896,7 @@ def __call__(self, x, ts, **kwargs): return self.model(x, new_ts, **kwargs) +###################### UNET Model ####################### def conv_nd(dims, *args, **kwargs): if dims == 1: return paddle.nn.Conv1D(*args, **kwargs) @@ -894,15 +1026,14 @@ def avg_pool_nd(dims, *args, **kwargs): Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: - return paddle.nn.AvgPool1d(*args, **kwargs, exclusive=False) + return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False) elif dims == 2: - return paddle.nn.AvgPool2d(*args, **kwargs, exclusive=False) + return paddle.nn.AvgPool2D(*args, **kwargs, exclusive=False) elif dims == 3: - return paddle.nn.AvgPool3d(*args, **kwargs, exclusive=False) + return paddle.nn.AvgPool3D(*args, **kwargs, exclusive=False) raise ValueError(f"unsupported dimensions: {dims}") - class Downsample(paddle.nn.Layer): def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() @@ -923,6 +1054,7 @@ def forward(self, x): assert tuple(x.shape)[1] == self.channels return self.op(x) + class Upsample(paddle.nn.Layer): def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() @@ -947,10 +1079,11 @@ def forward(self, x): x = self.conv(x) return x + def count_flops_attn(model, _x, y): b, c, *spatial = tuple(y[0].shape) num_spatial = int(np.prod(spatial)) - matmul_ops = 2 * b * num_spatial**2 * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += paddle.to_tensor(data=[matmul_ops], dtype="float64") @@ -963,7 +1096,7 @@ def forward(self, qkv): bs, width, length = tuple(qkv.shape) assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape((bs * self.n_heads, ch * 3, length)).split(3, axis=1) + q, k, v = qkv.reshape((bs * self.n_heads, ch * 3, length)).split(ch, axis=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale) weight = paddle.nn.functional.softmax( @@ -988,7 +1121,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(chunks=3, axis=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = paddle.einsum( + weight = paddle.einsum(# 非复数 "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), @@ -1041,8 +1174,7 @@ def forward(ctx, run_function, length, *args): @staticmethod def backward(ctx, *output_grads): - """Class Method: *.requires_grad_, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" - ctx.input_tensors = [paddle.stop_gradient(x, stop=False) for x in ctx.input_tensors] + ctx.input_tensors = [stop_gradient(x, stop=False) for x in ctx.input_tensors] with paddle.enable_grad(): shallow_copies = [x.view_as(other=x) for x in ctx.input_tensors] # print(shallow_copies) @@ -1057,10 +1189,12 @@ def backward(ctx, *output_grads): del ctx.input_tensors del ctx.input_params del output_tensors - return (None, None) + input_grads + return [None, None] + input_grads + -def stop_gradient(self, *args, **kwargs): - return self +def stop_gradient(input, stop): + input.stop_gradient = stop + return input class AttentionBlock(paddle.nn.Layer): @@ -1122,7 +1256,7 @@ def timestep_embedding(timesteps, dim, max_period=10000): x=-math.log(max_period) * paddle.arange(start=0, end=half, dtype="float32") / half - ).to(paddle.CUDAPlace(0)) + )#.to(paddle.CUDAPlace(0)) args = timesteps[:, None].astype(dtype="float32") * freqs[None] embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) if dim % 2: @@ -1206,7 +1340,7 @@ def __init__( ) if self.num_classes is not None: self.label_emb = paddle.nn.Embedding( - num_embeddings=num_classes, embedding_dim=time_embed_dim + num_embeddings=self.num_classes, embedding_dim=time_embed_dim ) ch = input_ch = int(channel_mult[0] * model_channels) self.input_blocks = paddle.nn.LayerList( @@ -1219,7 +1353,8 @@ def __init__( ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): - layers = [ + layers = [] + layers.append( ResBlock( ch, time_embed_dim, @@ -1229,7 +1364,7 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) - ] + ) ch = int(mult * model_channels) if ds in attention_resolutions: layers.append( @@ -1298,7 +1433,8 @@ def __init__( for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() - layers = [ + layers = [] + layers.append( ResBlock( ch + ich, time_embed_dim, @@ -1308,7 +1444,7 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) - ] + ) ch = int(model_channels * mult) if ds in attention_resolutions: layers.append( From 119f02cfcedac4b1dd25235921197d283418b3d0 Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Tue, 7 Oct 2025 22:57:01 +0800 Subject: [PATCH 03/11] update unet model --- examples/confild/conf/confild_case1.yaml | 4 +- examples/confild/conf/confild_case2.yaml | 4 +- examples/confild/conf/confild_case3.yaml | 4 +- examples/confild/conf/confild_case4.yaml | 4 +- examples/confild/conf/un_confild_case1.yaml | 42 +- examples/confild/conf/un_confild_case2.yaml | 43 +- examples/confild/conf/un_confild_case3.yaml | 43 +- examples/confild/conf/un_confild_case4.yaml | 41 +- examples/confild/resample.py | 154 +++++ examples/confild/un_confild.py | 660 ++++++++++++++++++-- 10 files changed, 904 insertions(+), 95 deletions(-) create mode 100644 examples/confild/resample.py diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml index f80c7bb447..4bcec2db32 100644 --- a/examples/confild/conf/confild_case1.yaml +++ b/examples/confild/conf/confild_case1.yaml @@ -115,8 +115,8 @@ Uncondiction_INFER: noise_schedule: "cosine" Data: - data_path: ../case1/data.npy - coor_path: ../case1/coor.npy + data_path: /home/aistudio/work/extracted/data/Case1/data.npy + coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy normalizer: method: "-11" dim: 0 diff --git a/examples/confild/conf/confild_case2.yaml b/examples/confild/conf/confild_case2.yaml index 26b0f3c217..0ab1da60a4 100644 --- a/examples/confild/conf/confild_case2.yaml +++ b/examples/confild/conf/confild_case2.yaml @@ -100,8 +100,8 @@ INFER: batch_size: 40 Data: - data_path: ../case2/data.npy - coor_path: ../case2/coor.npy + data_path: /home/aistudio/work/extracted/data/Case2/data.npy + # coor_path: ../case2/coords.npy normalizer: method: "-11" dim: 0 diff --git a/examples/confild/conf/confild_case3.yaml b/examples/confild/conf/confild_case3.yaml index 2b7ea04ff6..9369a7e944 100644 --- a/examples/confild/conf/confild_case3.yaml +++ b/examples/confild/conf/confild_case3.yaml @@ -100,8 +100,8 @@ INFER: batch_size: 100 Data: - data_path: ../case3/data.npy - coor_path: ../case3/coor.npy + data_path: /home/aistudio/work/extracted/data/Case3/data.npy + coor_path: /home/aistudio/work/extracted/data/Case3/coords.npy normalizer: method: "-11" dim: 0 diff --git a/examples/confild/conf/confild_case4.yaml b/examples/confild/conf/confild_case4.yaml index 3e1491a3f1..820c5a92dc 100644 --- a/examples/confild/conf/confild_case4.yaml +++ b/examples/confild/conf/confild_case4.yaml @@ -100,8 +100,8 @@ INFER: batch_size: 4 Data: - data_path: ../case4/data.npy - coor_path: ../case4/coor.npy + data_path: /home/aistudio/work/extracted/data/Case4/data.npy + coor_path: /home/aistudio/work/extracted/data/Case4/coords.npy normalizer: method: "-11" dim: 0 diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml index 6295c04ee3..94cf05e326 100644 --- a/examples/confild/conf/un_confild_case1.yaml +++ b/examples/confild/conf/un_confild_case1.yaml @@ -25,28 +25,33 @@ hydra: subdir: ./ # general settings -mode: infer # running mode: infer +mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 0 + lr : 5.e-5 + weight_decay: 0.0 + lr_anneal_steps: 0 + final_lr: 0. + microbatch: -1 EVAL: - epochs: 9800 mutil_GPU: 1 - microbatch: -1 lr : 5.e-5 ema_rate: "0.9999" log_interval: 1000 save_interval: 10000 lr_anneal_steps: 0 - batch_size : 16 - test_batch_size : 16 time_length : 128 latent_length : 128 UNET: - image_size : 128 num_channels: 128 num_res_blocks: 2 @@ -54,14 +59,33 @@ UNET: num_head_channels: 64 attention_resolutions: "32,16,8" channel_mult: null - ema_path: /add/ema/path/here + ema_path: /home/aistudio/work/extracted/data/Case1/diffusion/ema.pdparams Diff: steps: 1000 noise_schedule: "cosine" -CNF: "ConditionalNeuralField/training_recipes/case1.yml" +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/extracted/data/Case1/data.npy + coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy + load_data_fn: load_elbow_flow + normalizer: + method: "-11" + dim: 0 + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 3 + hidden_features: 128 + in_coord_features: 2 + in_latent_features: 128 + normalizer_params_path: /home/aistudio/work/extracted/data/Case1/cnf/normalizer_params.pdparams + model_path: ./outputs_confild_case1/confild_case1/epoch_99999 DATA: max_val: 1.0 - min_val: -1.0 \ No newline at end of file + min_val: -1.0 + train_data: "/home/aistudio/work/extracted/data/Case1/train_data.npy" + valid_data: "/home/aistudio/work/extracted/data/Case1/valid_data.npy" \ No newline at end of file diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml index bacc83b3a8..3c424e82e1 100644 --- a/examples/confild/conf/un_confild_case2.yaml +++ b/examples/confild/conf/un_confild_case2.yaml @@ -25,29 +25,33 @@ hydra: subdir: ./ # general settings -mode: infer # running mode: infer +mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 - +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 0 + lr : 5.e-5 + weight_decay: 0.0 + lr_anneal_steps: 0 + final_lr: 0. + microbatch: -1 EVAL: - epochs: 9800 mutil_GPU: 1 - microbatch: -1 lr : 5.e-5 ema_rate: "0.9999" log_interval: 1000 save_interval: 10000 lr_anneal_steps: 0 - batch_size : 16 - test_batch_size : 16 time_length : 256 latent_length : 256 UNET: - image_size : 256 num_channels: 128 num_res_blocks: 2 @@ -55,14 +59,33 @@ UNET: num_head_channels: 64 attention_resolutions: "32,16,8" channel_mult: null - ema_path: /add/ema/path/here + ema_path: /home/aistudio/work/extracted/data/Case2/diffusion/ema.pdparams Diff: steps: 1000 noise_schedule: "cosine" -CNF: "ConditionalNeuralField/training_recipes/case2.yml" +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/extracted/data/Case2/data.npy + # coor_path: ../Case2/coords.npy + load_data_fn: load_channel_flow + normalizer: + method: "-11" + dim: 0 + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 10 + out_features: 4 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + normalizer_params_path: /home/aistudio/work/extracted/data/Case2/cnf/normalizer_params.pdparams + model_path: ./outputs_confild_case2/confild_case2/epoch_99999 DATA: max_val: 1.0 - min_val: -1.0 \ No newline at end of file + min_val: -1.0 + train_data: "/home/aistudio/work/extracted/data/Case2/train_data.npy" + valid_data: "/home/aistudio/work/extracted/data/Case2/valid_data.npy" \ No newline at end of file diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml index 22d7b7cc8e..14c714e937 100644 --- a/examples/confild/conf/un_confild_case3.yaml +++ b/examples/confild/conf/un_confild_case3.yaml @@ -25,27 +25,33 @@ hydra: subdir: ./ # general settings -mode: infer # running mode: infer +mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +TRAIN: + batch_size : 16 + test_batch_size : 16 + ema_rate: "0.9999" + lr_anneal_steps: 0 + lr : 5.e-5 + weight_decay: 0.0 + lr_anneal_steps: 0 + final_lr: 0. + microbatch: -1 + EVAL: - epochs: 9800 mutil_GPU: 2 - microbatch: -1 lr : 5.e-5 ema_rate: "0.9999" log_interval: 1000 save_interval: 10000 lr_anneal_steps: 0 - batch_size : 16 - test_batch_size : 16 time_length : 256 latent_length : 256 UNET: - image_size : 256 num_channels: 128 num_res_blocks: 2 @@ -53,12 +59,33 @@ UNET: num_head_channels: 64 attention_resolutions: "32,16,8" channel_mult: null - ema_path: /add/ema/path/here + ema_path: /home/aistudio/work/extracted/data/Case3/diffusion/ema.pdparams Diff: steps: 1000 noise_schedule: "cosine" -CNF: "ConditionalNeuralField/training_recipes/case3.yml" +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/extracted/data/Case3/data.npy + coor_path: /home/aistudio/work/extracted/data/Case3/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_periodic_hill_flow + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 117 + out_features: 2 + hidden_features: 256 + in_coord_features: 2 + in_latent_features: 256 + normalizer_params_path: /home/aistudio/work/extracted/data/Case3/cnf/normalizer_params.pdparams + model_path: ./outputs_confild_case3/confild_case3/epoch_99999 +DATA: min_val: -1.0 + max_val: 1.0 + train_data: "/home/aistudio/work/extracted/data/Case3/train_data.npy" + valid_data: "/home/aistudio/work/extracted/data/Case3/valid_data.npy" diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml index 009f0296eb..29e33325df 100644 --- a/examples/confild/conf/un_confild_case4.yaml +++ b/examples/confild/conf/un_confild_case4.yaml @@ -25,23 +25,29 @@ hydra: subdir: ./ # general settings -mode: infer # running mode: infer +mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +TRAIN: + batch_size : 8 + test_batch_size : 8 + ema_rate: "0.9999" + lr_anneal_steps: 0 + lr : 5.e-5 + weight_decay: 0.0 + lr_anneal_steps: 0 + final_lr: 0. + microbatch: -1 EVAL: - epochs: 9800 mutil_GPU: 2 - microbatch: -1 lr : 5.e-5 ema_rate: "0.9999" log_interval: 1000 save_interval: 10000 lr_anneal_steps: 0 - batch_size : 8 - test_batch_size : 8 time_length : 256 latent_length : 256 @@ -54,12 +60,33 @@ UNET: num_head_channels: 64 attention_resolutions: "32,16,8" channel_mult: "1, 1, 2, 2, 4, 4" - ema_path: /add/ema/path/here + ema_path: /home/aistudio/work/extracted/data/Case4/diffusion/ema.pdparams Diff: steps: 1000 noise_schedule: "cosine" -CNF: "ConditionalNeuralField/training_recipes/case4.yml" +CNF: + mutil_GPU: 1 + data_path: /home/aistudio/work/extracted/data/Case4/data.npy + coor_path: /home/aistudio/work/extracted/data/Case4/coords.npy + normalizer: + method: "-11" + dim: 0 + load_data_fn: load_3d_flow + CONFILD: + input_keys: ["confild_x", "latent_z"] + output_keys: ["confild_output"] + num_hidden_layers: 15 + out_features: 3 + hidden_features: 384 + in_coord_features: 3 + in_latent_features: 384 + normalizer_params_path: /home/aistudio/work/extracted/data/Case4/cnf/normalizer_params.pdparams + model_path: ./outputs_confild_case4/confild_case4/epoch_99999 +DATA: min_val: -1.0 + max_val: 1.0 + train_data: "/home/aistudio/work/extracted/data/Case4/train_data.npy" + valid_data: "/home/aistudio/work/extracted/data/Case4/valid_data.npy" diff --git a/examples/confild/resample.py b/examples/confild/resample.py new file mode 100644 index 0000000000..c1cd5b8eee --- /dev/null +++ b/examples/confild/resample.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import paddle as th +import paddle.distributed as dist + +# TODO +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.to_tensor(indices_np, dtype='int64') + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.to_tensor(weights_np, dtype='float32') + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + + batch_sizes = [ + th.to_tensor([0], dtype=th.int32, place=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.to_tensor([len(local_ts)], dtype=th.int32, place=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py index 84012f5422..c1c3546975 100644 --- a/examples/confild/un_confild.py +++ b/examples/confild/un_confild.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +# 导入必要的库 +from abc import ABC, abstractmethod +import copy import enum +import functools import math import hydra import matplotlib.pyplot as plt import numpy as np import paddle +import os from omegaconf import DictConfig -from paddle.distributed import fleet -from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler +from resample import UniformSampler, LossAwareSampler -import ppsci from ppsci.arch import UNetModel -from ppsci.arch import LatentContainer from ppsci.arch import SIRENAutodecoder_film from ppsci.arch import SpacedDiffusion from ppsci.arch import ModelVarType @@ -33,6 +34,65 @@ from ppsci.utils import logger +def mean_flat(tensor): + """ + 计算张量除批次维度外所有维度的平均值 + + 参数: + tensor: 输入张量 + + 返回: + 除批次维度外所有维度的平均值 + """ + return tensor.mean(axis=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + 计算两个高斯分布之间的KL散度 + + 参数: + mean1: 第一个高斯分布的均值 + logvar1: 第一个高斯分布的对数方差 + mean2: 第二个高斯分布的均值 + logvar2: 第二个高斯分布的对数方差 + + 返回: + 两个高斯分布之间的KL散度 + """ + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + paddle.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2) + ) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + 从一维numpy数组中为一批索引提取值 + + 参数: + arr: 一维numpy数组 + timesteps: 时间步索引 + broadcast_shape: 广播形状 + + 返回: + 提取并广播后的张量 + """ + # 修复变量名错误 + res = paddle.to_tensor(arr)[timesteps].astype(timesteps.dtype) + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +# 添加用于存储训练和验证损失的全局变量 +train_losses = [] # 存储训练损失 +valid_losses = [] # 存储验证损失 + + def create_model( image_size, num_channels, @@ -53,6 +113,32 @@ def create_model( use_fp16=False, use_new_attention_order=False, ): + """ + 创建UNet模型 + + 参数: + image_size: 图像尺寸 + num_channels: 模型通道数 + num_res_blocks: 每个下采样级别的残差块数 + dims: 数据维度(1=1D, 2=2D, 3=3D) + out_channels: 输出张量的通道数 + channel_mult: 每个级别的通道乘数 + learn_sigma: 是否学习方差 + class_cond: 是否使用类别条件 + use_checkpoint: 是否启用梯度检查点 + attention_resolutions: 应用注意力的下采样率 + num_heads: 注意力头数 + num_head_channels: 每个注意力头的通道数 + num_heads_upsample: 上采样块的注意力头数 + use_scale_shift_norm: 是否使用FiLM-like调节 + dropout: Dropout概率 + resblock_updown: 是否使用残差块进行重采样 + use_fp16: 是否使用float16精度 + use_new_attention_order: 是否使用优化的注意力模式 + + 返回: + UNet模型实例 + """ if channel_mult is None: if image_size == 512: channel_mult = (0.5, 1, 1, 2, 2, 4, 4) @@ -65,7 +151,9 @@ def create_model( else: raise ValueError(f"unsupported image size: {image_size}") else: - channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) + # 修复channel_mult处理逻辑,确保类型正确 + if isinstance(channel_mult, str): + channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) attention_ds = [] for res in attention_resolutions.split(","): @@ -94,21 +182,36 @@ def create_model( class LossType(enum.Enum): - MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + """ + 损失类型枚举 + """ + MSE = enum.auto() # 使用原始MSE损失(学习方差时使用KL) RESCALED_MSE = ( enum.auto() - ) # use raw MSE loss (with RESCALED_KL when learning variances) - KL = enum.auto() # use the variational lower-bound - RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + ) # 使用原始MSE损失(学习方差时使用RESCALED_KL) + KL = enum.auto() # 使用变分下界 + RESCALED_KL = enum.auto() # 类似KL,但重新缩放以估计完整的VLB def is_vb(self): + """ + 判断是否为变分下界损失 + """ return self == LossType.KL or self == LossType.RESCALED_KL def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + 获取命名的beta调度 + + 参数: + schedule_name: 调度名称("linear"或"cosine") + num_diffusion_timesteps: 扩散步骤数 + + 返回: + beta值数组 + """ if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. + # Ho等人的线性调度,扩展为适用于任何数量的扩散步骤 scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 @@ -125,6 +228,17 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + 基于alpha_bar创建betas + + 参数: + num_diffusion_timesteps: 扩散步骤数 + alpha_bar: 累积alpha值函数 + max_beta: beta的最大值 + + 返回: + beta值数组 + """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps @@ -134,6 +248,16 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): def space_timesteps(num_timesteps, section_counts): + """ + 在基础扩散过程中跳过步骤的时间步空间化 + + 参数: + num_timesteps: 原始时间步数 + section_counts: 每个部分的时间步数 + + 返回: + 保留的时间步集合 + """ if isinstance(section_counts, str): if section_counts.startswith("ddim"): desired_count = int(section_counts[len("ddim") :]) @@ -180,6 +304,23 @@ def create_gaussian_diffusion( rescale_learned_sigmas=False, timestep_respacing="", ): + """ + 创建高斯扩散过程 + + 参数: + steps: 扩散步骤数 + learn_sigma: 是否学习方差 + sigma_small: 是否使用小方差 + noise_schedule: 噪声调度("linear"或"cosine") + use_kl: 是否使用KL损失 + predict_xstart: 是否预测初始x + rescale_timesteps: 是否重新缩放时间步 + rescale_learned_sigmas: 是否重新缩放学习的sigma + timestep_respacing: 时间步重新间隔 + + 返回: + 高斯扩散过程实例 + """ betas = get_named_beta_schedule(noise_schedule, steps) if use_kl: loss_type = LossType.RESCALED_KL @@ -210,6 +351,15 @@ def create_gaussian_diffusion( def load_elbow_flow(path): + """ + 加载肘管流数据 + + 参数: + path: 数据文件路径 + + 返回: + 肘管流数据(从索引1开始) + """ return np.load(f"{path}")[1:] @@ -219,26 +369,76 @@ def load_channel_flow( t_end=1200, t_every=1, ): + """ + 加载通道流数据 + + 参数: + path: 数据文件路径 + t_start: 起始时间步 + t_end: 结束时间步 + t_every: 采样间隔 + + 返回: + 通道流数据 + """ return np.load(f"{path}")[t_start:t_end:t_every] def load_periodic_hill_flow(path): + """ + 加载周期性山丘流数据 + + 参数: + path: 数据文件路径 + + 返回: + 周期性山丘流数据 + """ data = np.load(f"{path}") return data def load_3d_flow(path): + """ + 加载3D流数据 + + 参数: + path: 数据文件路径 + + 返回: + 3D流数据 + """ data = np.load(f"{path}") return data class Normalizer_ts(object): + """ + 时间序列归一化器 + """ def __init__(self, params=[], method="-11", dim=None): + """ + 初始化归一化器 + + 参数: + params: 归一化参数 + method: 归一化方法("-11", "01", "ms", "none") + dim: 归一化维度 + """ self.params = params self.method = method self.dim = dim def fit_normalize(self, data): + """ + 拟合并归一化数据 + + 参数: + data: 输入数据 + + 返回: + 归一化后的数据 + """ assert type(data) == paddle.Tensor if len(self.params) == 0: if self.method == "-11" or self.method == "01": @@ -268,16 +468,37 @@ def fit_normalize(self, data): return self.fnormalize(data, self.params, self.method) def normalize(self, new_data): + """ + 归一化新数据 + + 参数: + new_data: 新数据 + + 返回: + 归一化后的数据 + """ if not new_data.place == self.params[0].place: self.params = self.params[0], self.params[1] return self.fnormalize(new_data, self.params, self.method) def denormalize(self, new_data_norm): + """ + 反归一化数据 + + 参数: + new_data_norm: 归一化后的数据 + + 返回: + 反归一化后的数据 + """ if not new_data_norm.place == self.params[0].place: self.params = self.params[0], self.params[1] return self.fdenormalize(new_data_norm, self.params, self.method) def get_params(self): + """ + 获取归一化参数 + """ if self.method == "ms": print("returning mean and std") elif self.method == "01": @@ -290,6 +511,17 @@ def get_params(self): @staticmethod def fnormalize(data, params, method): + """ + 执行归一化 + + 参数: + data: 输入数据 + params: 归一化参数 + method: 归一化方法 + + 返回: + 归一化后的数据 + """ if method == "-11": return (data - params[1]) / ( params[0] - params[1] @@ -305,6 +537,17 @@ def fnormalize(data, params, method): @staticmethod def fdenormalize(data_norm, params, method): + """ + 执行反归一化 + + 参数: + data_norm: 归一化后的数据 + params: 归一化参数 + method: 归一化方法 + + 返回: + 反归一化后的数据 + """ if method == "-11": return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] elif method == "01": @@ -318,33 +561,37 @@ def fdenormalize(data_norm, params, method): def create_slim(cfg): - world_size = cfg.multiGPU + """ + 创建SLIM模型 + + 参数: + cfg: 配置对象 + + 返回: + CNF模型、输入归一化器、输出归一化器和坐标 + """ + world_size = cfg.CNF.multiGPU ###### read data - fois ###### - if cfg.Data.load_data_fn == "load_3d_flow": - fois = load_3d_flow(cfg.Data.data_path) - elif cfg.Data.load_data_fn == "load_elbow_flow": - fois = load_elbow_flow(cfg.Data.data_path) - elif cfg.Data.load_data_fn == "load_channel_flow": - fois = load_channel_flow(cfg.Data.data_path) - elif cfg.Data.load_data_fn == "load_periodic_hill_flow": - fois = load_periodic_hill_flow(cfg.Data.data_path) + if cfg.CNF.load_data_fn == "load_3d_flow": + fois = load_3d_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_elbow_flow": + fois = load_elbow_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_channel_flow": + fois = load_channel_flow(cfg.CNF.data_path) + elif cfg.CNF.load_data_fn == "load_periodic_hill_flow": + fois = load_periodic_hill_flow(cfg.CNF.data_path) else: - fois = np.load(cfg.Data.data_path) + fois = np.load(cfg.CNF.data_path) # 计算空间形状和轴 spatio_shape = fois.shape[1:-1] - spatio_axis = list( - range( - fois.ndim if isinstance(fois, np.ndarray) else fois.dim() - ) - )[1:-1] ###### read data - coordinate ###### - if cfg.Data.coor_path is None: + if cfg.CNF.coor_path is None: coord = [np.linspace(0, 1, i) for i in spatio_shape] coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1) else: - coord = np.load(cfg.Data.coor_path) + coord = np.load(cfg.CNF.coor_path) coord = coord.astype("float32") fois = fois.astype("float32") @@ -358,21 +605,327 @@ def create_slim(cfg): N_samples = fois.shape[0] ###### normalizer ###### - in_normalizer = Normalizer_ts(**cfg.Data.normalizer) - out_normalizer = Normalizer_ts(**cfg.Data.normalizer) + in_normalizer = Normalizer_ts(**cfg.CNF.normalizer) + out_normalizer = Normalizer_ts(**cfg.CNF.normalizer) # 使用最新的模型参数 - norm_params = paddle.load(f"{hyper_para.save_path}/normalizer_params.pt") + norm_params = paddle.load(cfg.CNF.normalizer_params_path) in_normalizer.params = norm_params["x_normalizer_params"] out_normalizer.params = norm_params["y_normalizer_params"] - cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) - - normed_coords = in_normalizer.normalize(coord)# 训练集就是测试集 + cnf_model = SIRENAutodecoder_film(**cfg.CNF.CONFILD) return cnf_model, in_normalizer, out_normalizer, coord +def dl_iter(dl): + """ + 数据加载器迭代器 + + 参数: + dl: 数据加载器 + + 返回: + 无限迭代数据加载器 + """ + while True: + yield from dl + + +def train(cfg): + """ + 训练函数 + + 参数: + cfg: 配置对象 + """ + # create parameters + batch_size = cfg.TRAIN.batch_size + test_batch_size = cfg.TRAIN.test_batch_size + ema_rate = cfg.TRAIN.ema_rate + ema_rate = ( + [ema_rate] + if isinstance(ema_rate, float) + else [float(x) for x in ema_rate.split(",")] + ) + + lr_anneal_steps = cfg.TRAIN.lr_anneal_steps + final_lr = cfg.TRAIN.final_lr + step = 0 + resume_step = 0 + microbatch = cfg.TRAIN.microbatch if cfg.TRAIN.microbatch > 0 else batch_size + + ## Data Preprocessing + train_data = np.load(cfg.DATA.train_data) + valid_data = np.load(cfg.DATA.valid_data) + max_val, min_val = np.max(train_data, keepdims=True), np.min(train_data, keepdims=True) + norm_train_data = -1 + (train_data - min_val)*2. / (max_val - min_val) + norm_valid_data = -1 + (valid_data - min_val)*2. / (max_val - min_val) + + norm_train_data = paddle.to_tensor(norm_train_data[:, None, ...]) + norm_valid_data = paddle.to_tensor(norm_valid_data[:, None, ...]) + + dl_train = dl_iter(paddle.io.DataLoader(paddle.io.TensorDataset(norm_train_data), batch_size=batch_size, shuffle=True)) + dl_valid = dl_iter(paddle.io.DataLoader(paddle.io.TensorDataset(norm_valid_data), batch_size=test_batch_size, shuffle=True)) + + unet_model = create_model(image_size=cfg.UNET.image_size, + num_channels= cfg.UNET.num_channels, + num_res_blocks= cfg.UNET.num_res_blocks, + num_heads=cfg.UNET.num_heads, + num_head_channels=cfg.UNET.num_head_channels, + attention_resolutions=cfg.UNET.attention_resolutions, + channel_mult=cfg.UNET.channel_mult + ) + diff_model = create_gaussian_diffusion(steps=cfg.Diff.steps, + noise_schedule=cfg.Diff.noise_schedule + ) + + # 初始化AdamW优化器 + opt = paddle.optimizer.AdamW( + parameters=unet_model.parameters(), learning_rate=cfg.TRAIN.lr, weight_decay=cfg.TRAIN.weight_decay + ) + + schedule_sampler = UniformSampler(diff_model) + + # 初始化EMA参数 + ema_params = [ + copy.deepcopy(unet_model.parameters()) + for _ in range(len(ema_rate)) + ] + + # 清空损失记录 + global train_losses, valid_losses + train_losses.clear() + valid_losses.clear() + + while ( + not lr_anneal_steps + or step + resume_step < lr_anneal_steps + ): + cond = {} + # 获取下一个训练批次和验证批次的数据 + train_batch, = next(dl_train) + valid_batch, = next(dl_valid) + # 前向传播 + unet_model.train() + unet_model.clear_grad() + + for i in range(0, train_batch.shape[0], microbatch): + # 获取当前微批次数据 + micro = train_batch[i : i + microbatch] + micro_cond = { + k: v[i : i + microbatch] + for k, v in cond.items() + } + + # 从调度采样器中采样时间步 + t, weights = schedule_sampler.sample(micro.shape[0]) + + # 创建部分应用的损失计算函数 + compute_losses = functools.partial( + diff_model.training_losses, + unet_model, + micro, + t, + model_kwargs=micro_cond + ) + + # 计算损失 + losses = compute_losses() + # 添加训练标记 + losses["valid"] = False + + # 如果使用损失感知采样器,则更新本地损失 + if isinstance(schedule_sampler, LossAwareSampler): + schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + # 计算加权平均损失 + loss = (losses["loss"] * weights).mean() + + # 记录损失字典 + log_loss_dict( + diff_model, t, {k: v * weights for k, v in losses.items()}, is_valid=False + ) + + # 反向传播 + # unet_model.backward(loss) + loss.backward() + + # 不计算梯度,节省内存 + with paddle.no_grad(): + # 同样分解成微批次处理 + for i in range(0, valid_batch.shape[0], microbatch): + # 获取当前微批次数据 + micro = valid_batch[i : i + microbatch] + micro_cond = { + k: v[i : i + microbatch] + for k, v in cond.items() + } + + # 判断是否为最后一个微批次 + last_batch = (i + microbatch) >= valid_batch.shape[0] + + # 采样时间步 + t, weights = schedule_sampler.sample(micro.shape[0]) + + # 创建部分应用的损失计算函数 + compute_losses = functools.partial( + diff_model.training_losses, + unet_model, + micro, + t, + model_kwargs=micro_cond + ) + + # 计算验证损失 + losses = compute_losses() + # 添加验证标记 + losses["valid"] = True + + # 记录验证损失 + log_loss_dict( + diff_model, t, {k: v * weights for k, v in losses.items()}, is_valid=True + ) + + grad_norm, param_norm = _compute_norms(unet_model) + opt.step() + # took_step = unet_model.optimize(opt) + # 更新ema参数 + _update_ema(ema_rate, ema_params, unet_model.parameters()) + # 更新学习率 + _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr) + + step += 1 + + # 每100步打印一次训练和验证损失 + if step % 100 == 0: + if len(train_losses) > 0 and len(valid_losses) > 0: + print(f"Step {step}: Train Loss: {train_losses[-1]:.6f}, Valid Loss: {valid_losses[-1]:.6f}") + + # 保存模型 + paddle.save(unet_model.state_dict(), "unet.pdparams") + + # 绘制训练和验证损失曲线 + plot_losses() + + +def plot_losses(): + """ + 绘制训练和验证损失曲线 + """ + if len(train_losses) == 0 or len(valid_losses) == 0: + print("没有足够的数据来绘制损失曲线") + return + + plt.figure(figsize=(10, 6)) + plt.plot(train_losses, label='Training Loss', alpha=0.8) + plt.plot(valid_losses, label='Validation Loss', alpha=0.8) + plt.xlabel('Training Steps') + plt.ylabel('Loss') + plt.title('Training and Validation Loss') + plt.legend() + plt.grid(True) + plt.tight_layout() + + # 保存图像 + plt.savefig('loss_curve.png', dpi=300, bbox_inches='tight') + print("损失曲线已保存为 loss_curve.png") + + # 显示图像 + plt.show() + + +def _compute_norms(model, grad_scale=1.0): + """ + 计算模型参数和梯度的范数 + + 参数: + model: 模型 + grad_scale: 梯度缩放因子 + + 返回: + 梯度范数和参数范数 + """ + grad_norm = 0.0 + param_norm = 0.0 + for p in model.parameters(): + with paddle.no_grad(): + param_norm += paddle.norm(p, p=2, dtype=paddle.float32).item() ** 2 + if p.grad is not None: + grad_norm += paddle.norm(p.grad, p=2, dtype=paddle.float32).item() ** 2 + return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) + + +def _update_ema(ema_rate, ema_params, source_params, rate=0.99): + """ + 更新EMA(指数移动平均)参数 + EMA有助于提高生成质量,减少模型权重噪声 + + 参数: + ema_rate: EMA衰减率 + ema_params: EMA参数 + source_params: 源参数 + rate: 衰减率 + """ + for rate, target_params in zip(ema_rate, ema_params): + for targ, src in zip(target_params, source_params): + updated = targ.detach() * rate + src * (1 - rate) + targ.set_value(updated) + + +def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr): + """ + 学习率退火调整 + 根据训练进度线性降低学习率 + + 参数: + lr_anneal_steps: 学习率退火步数 + step: 当前步数 + resume_step: 恢复步数 + opt: 优化器 + final_lr: 最终学习率 + lr: 初始学习率 + """ + if not lr_anneal_steps: + return + frac_done = (step + resume_step) / lr_anneal_steps + new_lr = final_lr * (frac_done) + lr * (1 - frac_done) + opt.set_lr(new_lr) + + +def log_loss_dict(diffusion, ts, losses, is_valid=False): + """ + 记录损失字典信息 + + 参数: + diffusion: 扩散模型对象 + ts: 时间步张量 + losses: 损失字典 + is_valid: 是否为验证损失 + """ + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # 记录分位数(特别是四个四分位数) + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + + # 记录训练和验证损失 + if key == "loss": + if is_valid: + valid_losses.append(values.mean().item()) + else: + train_losses.append(values.mean().item()) + + def evaluate(cfg): + """ + 评估函数 + + 参数: + cfg: 配置对象 + """ ## Create model and diffusion unet_model = create_model(image_size=cfg.UNET.image_size, num_channels=cfg.UNET.num_channels, @@ -391,13 +944,13 @@ def evaluate(cfg): sample_fn = diff_model.p_sample_loop gen_latents = sample_fn(unet_model, (cfg.EVAL.test_batch_size, 1, cfg.EVAL.time_length, cfg.EVAL.latent_length))[:, 0] - max_val, min_val = np.load(cfg.DATA.max_val), np.load(cfg.DATA.min_val) + max_val, min_val = cfg.DATA.max_val, cfg.DATA.min_val#np.load(cfg.DATA.max_val), np.load(cfg.DATA.min_val) max_val, min_val = paddle.to_tensor(max_val), paddle.to_tensor(min_val) gen_latents = (gen_latents + 1)*(max_val - min_val)/2. + min_val # 获取模型 nf, in_normalizer, out_normalizer, coord = create_slim(cfg) - nf.set_state_dict(paddle.load(cfg.CONFILD.ema_path)) + nf.set_state_dict(paddle.load(cfg.CNF.model_path)) coord = in_normalizer.normalize(coord) batch_size = 1 # if you are limited by your GPU Memory, please change the batch_size variable accordingly @@ -412,35 +965,36 @@ def evaluate(cfg): new_latents = new_latents[:, None, None] else: new_latents = new_latents[:, None] - out = nf(coord.to(new_latents.device), new_latents) + input_data = { + "confild_x": coord, + "latent_z": new_latents + } + out = nf(input_data)["confild_output"] out = out_normalizer.denormalize(out) gen_fields.append(out.detach().cpu().numpy()) gen_fields = np.concatenate(gen_fields) - np.save(inp.save_path, gen_fields) - - # 绘制结果 + np.save(cfg.save_path, gen_fields) @hydra.main(version_base=None, config_path="./conf", config_name="un_confild_case1.yaml") def main(cfg: DictConfig): - if cfg.mode == "eval": + """ + 主函数 + + 参数: + cfg: 配置对象 + """ + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": evaluate(cfg) else: raise ValueError( - f"cfg.mode should in ['eval'], but got '{cfg.mode}'" + f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'" ) if __name__ == "__main__": - # main() - # 构建create_model - my_model = create_model(image_size=128, - num_channels=128, - num_res_blocks=2, - num_heads=4, - num_head_channels=64, - attention_resolutions="32,16,8") - #保存参数 - paddle.save(my_model.state_dict(), "my_model.pdparams") \ No newline at end of file + main() \ No newline at end of file From 411ca7c80ded968e5c33be84eacb6f1a7cd62fbc Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Tue, 7 Oct 2025 23:30:35 +0800 Subject: [PATCH 04/11] delete duplicate key --- examples/confild/conf/un_confild_case1.yaml | 1 - examples/confild/conf/un_confild_case2.yaml | 1 - examples/confild/conf/un_confild_case3.yaml | 1 - examples/confild/conf/un_confild_case4.yaml | 1 - 4 files changed, 4 deletions(-) diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml index 94cf05e326..544c6c05df 100644 --- a/examples/confild/conf/un_confild_case1.yaml +++ b/examples/confild/conf/un_confild_case1.yaml @@ -37,7 +37,6 @@ TRAIN: lr_anneal_steps: 0 lr : 5.e-5 weight_decay: 0.0 - lr_anneal_steps: 0 final_lr: 0. microbatch: -1 diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml index 3c424e82e1..007242d656 100644 --- a/examples/confild/conf/un_confild_case2.yaml +++ b/examples/confild/conf/un_confild_case2.yaml @@ -37,7 +37,6 @@ TRAIN: lr_anneal_steps: 0 lr : 5.e-5 weight_decay: 0.0 - lr_anneal_steps: 0 final_lr: 0. microbatch: -1 diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml index 14c714e937..be06be4a2b 100644 --- a/examples/confild/conf/un_confild_case3.yaml +++ b/examples/confild/conf/un_confild_case3.yaml @@ -37,7 +37,6 @@ TRAIN: lr_anneal_steps: 0 lr : 5.e-5 weight_decay: 0.0 - lr_anneal_steps: 0 final_lr: 0. microbatch: -1 diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml index 29e33325df..39a802cbfe 100644 --- a/examples/confild/conf/un_confild_case4.yaml +++ b/examples/confild/conf/un_confild_case4.yaml @@ -37,7 +37,6 @@ TRAIN: lr_anneal_steps: 0 lr : 5.e-5 weight_decay: 0.0 - lr_anneal_steps: 0 final_lr: 0. microbatch: -1 From 645cb4ccdddbfa9acac8b82f5fbbf7843dd9957f Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Wed, 8 Oct 2025 12:23:50 +0800 Subject: [PATCH 05/11] fix --- examples/confild/conf/un_confild_case1.yaml | 2 + examples/confild/conf/un_confild_case2.yaml | 2 + examples/confild/conf/un_confild_case3.yaml | 2 + examples/confild/conf/un_confild_case4.yaml | 2 + examples/confild/un_confild.py | 70 ++++---- ppsci/arch/__init__.py | 3 +- ppsci/arch/confild.py | 170 +++++++++++++++----- 7 files changed, 180 insertions(+), 71 deletions(-) diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml index 544c6c05df..76b0c8f27a 100644 --- a/examples/confild/conf/un_confild_case1.yaml +++ b/examples/confild/conf/un_confild_case1.yaml @@ -29,6 +29,7 @@ mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +save_path: ${output_dir}/result.npy TRAIN: batch_size : 16 @@ -49,6 +50,7 @@ EVAL: lr_anneal_steps: 0 time_length : 128 latent_length : 128 + test_batch_size: 16 UNET: image_size : 128 diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml index 007242d656..641a10b9d7 100644 --- a/examples/confild/conf/un_confild_case2.yaml +++ b/examples/confild/conf/un_confild_case2.yaml @@ -29,6 +29,7 @@ mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +save_path: ${output_dir}/result.npy TRAIN: batch_size : 16 @@ -49,6 +50,7 @@ EVAL: lr_anneal_steps: 0 time_length : 256 latent_length : 256 + test_batch_size: 16 UNET: image_size : 256 diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml index be06be4a2b..65be7105db 100644 --- a/examples/confild/conf/un_confild_case3.yaml +++ b/examples/confild/conf/un_confild_case3.yaml @@ -29,6 +29,7 @@ mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +save_path: ${output_dir}/result.npy TRAIN: batch_size : 16 @@ -49,6 +50,7 @@ EVAL: lr_anneal_steps: 0 time_length : 256 latent_length : 256 + test_batch_size: 16 UNET: image_size : 256 diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml index 39a802cbfe..9d5c5dd3b8 100644 --- a/examples/confild/conf/un_confild_case4.yaml +++ b/examples/confild/conf/un_confild_case4.yaml @@ -29,6 +29,7 @@ mode: eval # running mode: infer seed: 2025 output_dir: ${hydra:run.dir} log_freq: 20 +save_path: ${output_dir}/result.npy TRAIN: batch_size : 8 @@ -49,6 +50,7 @@ EVAL: lr_anneal_steps: 0 time_length : 256 latent_length : 256 + test_batch_size: 8 UNET: diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py index c1c3546975..a457cca7cc 100644 --- a/examples/confild/un_confild.py +++ b/examples/confild/un_confild.py @@ -32,6 +32,7 @@ from ppsci.arch import ModelVarType from ppsci.arch import ModelMeanType from ppsci.utils import logger +from ppsci.arch import LossType def mean_flat(tensor): @@ -181,22 +182,22 @@ def create_model( ) -class LossType(enum.Enum): - """ - 损失类型枚举 - """ - MSE = enum.auto() # 使用原始MSE损失(学习方差时使用KL) - RESCALED_MSE = ( - enum.auto() - ) # 使用原始MSE损失(学习方差时使用RESCALED_KL) - KL = enum.auto() # 使用变分下界 - RESCALED_KL = enum.auto() # 类似KL,但重新缩放以估计完整的VLB +# class LossType(enum.Enum): +# """ +# 损失类型枚举 +# """ +# MSE = enum.auto() # 使用原始MSE损失(学习方差时使用KL) +# RESCALED_MSE = ( +# enum.auto() +# ) # 使用原始MSE损失(学习方差时使用RESCALED_KL) +# KL = enum.auto() # 使用变分下界 +# RESCALED_KL = enum.auto() # 类似KL,但重新缩放以估计完整的VLB - def is_vb(self): - """ - 判断是否为变分下界损失 - """ - return self == LossType.KL or self == LossType.RESCALED_KL +# def is_vb(self): +# """ +# 判断是否为变分下界损失 +# """ +# return self == LossType.KL or self == LossType.RESCALED_KL def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): @@ -570,7 +571,6 @@ def create_slim(cfg): 返回: CNF模型、输入归一化器、输出归一化器和坐标 """ - world_size = cfg.CNF.multiGPU ###### read data - fois ###### if cfg.CNF.load_data_fn == "load_3d_flow": fois = load_3d_flow(cfg.CNF.data_path) @@ -703,13 +703,19 @@ def train(cfg): ): cond = {} # 获取下一个训练批次和验证批次的数据 - train_batch, = next(dl_train) - valid_batch, = next(dl_valid) + train_batch = next(dl_train) + valid_batch = next(dl_valid) # 前向传播 unet_model.train() - unet_model.clear_grad() + # def zero_grad(model_params): + for param in unet_model.parameters(): + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + unet_model.clear_grad() - for i in range(0, train_batch.shape[0], microbatch): + for i in range(0, len(train_batch), microbatch): # 获取当前微批次数据 micro = train_batch[i : i + microbatch] micro_cond = { @@ -718,13 +724,14 @@ def train(cfg): } # 从调度采样器中采样时间步 - t, weights = schedule_sampler.sample(micro.shape[0]) + t, weights = schedule_sampler.sample(len(micro)) # 创建部分应用的损失计算函数 + new_micro = paddle.to_tensor(micro) compute_losses = functools.partial( diff_model.training_losses, unet_model, - micro, + new_micro, t, model_kwargs=micro_cond ) @@ -755,7 +762,7 @@ def train(cfg): # 不计算梯度,节省内存 with paddle.no_grad(): # 同样分解成微批次处理 - for i in range(0, valid_batch.shape[0], microbatch): + for i in range(0, len(valid_batch), microbatch): # 获取当前微批次数据 micro = valid_batch[i : i + microbatch] micro_cond = { @@ -764,18 +771,20 @@ def train(cfg): } # 判断是否为最后一个微批次 - last_batch = (i + microbatch) >= valid_batch.shape[0] + last_batch = (i + microbatch) >= len(valid_batch) # 采样时间步 - t, weights = schedule_sampler.sample(micro.shape[0]) + t, weights = schedule_sampler.sample(len(micro)) # 创建部分应用的损失计算函数 + new_micro = paddle.to_tensor(micro) compute_losses = functools.partial( diff_model.training_losses, unet_model, - micro, + new_micro, t, - model_kwargs=micro_cond + model_kwargs=micro_cond, + valid=True ) # 计算验证损失 @@ -896,7 +905,7 @@ def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr): def log_loss_dict(diffusion, ts, losses, is_valid=False): """ - 记录损失字典信息 + 记录损失字典的日志 参数: diffusion: 扩散模型对象 @@ -905,11 +914,12 @@ def log_loss_dict(diffusion, ts, losses, is_valid=False): is_valid: 是否为验证损失 """ for key, values in losses.items(): - logger.logkv_mean(key, values.mean().item()) + # 使用logger.info替代logger.logkv_mean记录平均损失值 + logger.info(f"{key}: {values.mean().item():.6f}") # 记录分位数(特别是四个四分位数) for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) - logger.logkv_mean(f"{key}_q{quartile}", sub_loss) + logger.info(f"{key}_q{quartile}: {sub_loss:.6f}") # 记录训练和验证损失 if key == "loss": diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 346ed22b3d..8d498cb310 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -22,7 +22,7 @@ from ppsci.arch.amgnet import AMGNet # isort:skip from ppsci.arch.base import Arch # isort:skip from ppsci.arch.cfdgcn import CFDGCN # isort:skip -from ppsci.arch.confild import LatentContainer, SIRENAutodecoder_film, SpacedDiffusion, UNetModel, ModelVarType, ModelMeanType # isort:skip +from ppsci.arch.confild import LatentContainer, LossType, SIRENAutodecoder_film, SpacedDiffusion, UNetModel, ModelVarType, ModelMeanType # isort:skip from ppsci.arch.smc_reac import SuzukiMiyauraModel # isort:skip from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip @@ -103,6 +103,7 @@ "LatentNO", "LatentNO_time", "LNO", + "LossType", "MLP", "ModelList", "ModelVarType", diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py index f50db439d3..255009d899 100644 --- a/ppsci/arch/confild.py +++ b/ppsci/arch/confild.py @@ -773,50 +773,133 @@ def p_sample_loop_progressive( yield out img = out["sample"] - def training_losses(self, model, x_start, t, noise=None): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, valid=False): + if model_kwargs is None: + model_kwargs = {} if noise is None: noise = paddle.randn(x_start.shape) x_t = self.q_sample(x_start=x_start, t=t, noise=noise) + # terms = {} + # model_output = model(x_t, t) - model_output = model(x_t, t) - - # Handle different model outputs - if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: - assert model_output.shape[1] == 2 * x_start.shape[1], "Output channels must be 2x input channels" - model_output, model_var_values = paddle.split(model_output, 2, axis=1) + # # Handle different model outputs + # if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + # assert model_output.shape[1] == 2 * x_start.shape[1], "Output channels must be 2x input channels" + # model_output, model_var_values = split(model_output, 2, axis=1) # Calculate the MSE loss for epsilon prediction - target = noise - mse_loss = mean_flat((target - model_output) ** 2) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=True, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = split(model_output, C, axis=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = paddle.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=True, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape - # Calculate the KL divergence loss if needed - vb_losses = 0 - if self.loss_type.is_vb(): - # Compute the KL divergence between q and p distributions - true_mean, true_log_variance_clipped, pred_xstart = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - ) - p_mean, p_log_variance_clipped, _ = self.p_mean_variance( - model, x_t, t, clip_denoised=False, denoised_fn=None - ) - - kl = normal_kl(true_mean, true_log_variance_clipped, p_mean, p_log_variance_clipped) - vb_losses = mean_flat(kl) - - # Choose the loss based on loss_type - if self.loss_type == LossType.MSE: - losses = mse_loss - elif self.loss_type == LossType.RESCALED_MSE: - losses = mse_loss * self.num_timesteps - elif self.loss_type == LossType.KL: - losses = vb_losses - elif self.loss_type == LossType.RESCALED_KL: - losses = vb_losses * self.num_timesteps + if valid == False: + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + terms["valid_mse"] = mean_flat((target - model_output) ** 2) else: - raise NotImplementedError(f"Unknown loss type: {self.loss_type}") - - return {"loss": losses.mean()} + raise NotImplementedError(self.loss_type) + + return terms + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = paddle.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = paddle.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = paddle.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = paddle.where( + x < -0.999, + log_cdf_plus, + paddle.where(x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + paddle.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * paddle.pow(x, 3)))) class LossType(enum.Enum): @@ -999,7 +1082,7 @@ def _forward(self, x, emb): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = paddle.chunk(x=emb_out, chunks=2, axis=1) + (scale, shift) = paddle.chunk(x=emb_out, chunks=2, axis=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: @@ -1096,7 +1179,8 @@ def forward(self, qkv): bs, width, length = tuple(qkv.shape) assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape((bs * self.n_heads, ch * 3, length)).split(ch, axis=1) + # split_size: 为 int 时 torch 表示块的大小,paddle 表示块的个数 + (q, k, v) = split(qkv.reshape((bs * self.n_heads, ch * 3, length)), ch, 1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale) weight = paddle.nn.functional.softmax( @@ -1119,7 +1203,7 @@ def forward(self, qkv): bs, width, length = tuple(qkv.shape) assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(chunks=3, axis=1) + (q, k, v) = qkv.chunk(chunks=3, axis=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = paddle.einsum(# 非复数 "bct,bcs->bts", @@ -1184,12 +1268,18 @@ def backward(ctx, *output_grads): inputs=ctx.input_tensors + ctx.input_params, grad_outputs=output_grads, allow_unused=True, - retain_graph=True, create_graph=False + # retain_graph=True, create_graph=False ) del ctx.input_tensors del ctx.input_params del output_tensors - return [None, None] + input_grads + + # 确保将input_grads转换为元组,然后与(None, None)连接 + # PyLayer要求backward方法返回元组类型 + # if input_grads: + return tuple(input_grads) + # else: + # return (None, None) def stop_gradient(input, stop): From 9bf9f565cc19a2e3ee27935be7c6e994993d67b5 Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Sun, 19 Oct 2025 22:59:33 +0800 Subject: [PATCH 06/11] Add detailed docstrings for activation functions, layers, and models in confild.py, enhancing code documentation and usability. --- ppsci/arch/confild.py | 387 +++++++++++++++++++++++++++++++++--------- 1 file changed, 308 insertions(+), 79 deletions(-) diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py index 255009d899..a880e1b8a3 100644 --- a/ppsci/arch/confild.py +++ b/ppsci/arch/confild.py @@ -9,24 +9,64 @@ ###################### ConFILD Model ####################### class Swish(paddle.nn.Layer): + """ + Swish activation function: f(x) = x * sigmoid(x). + + A smooth, non-monotonic activation function that has been shown to work + better than ReLU on deeper models across a number of challenging datasets. + """ def __init__(self): super().__init__() self.Sigmoid = paddle.nn.Sigmoid() def forward(self, x): + """ + Apply Swish activation. + + Args: + x (paddle.Tensor): Input tensor. + + Returns: + paddle.Tensor: Output tensor with same shape as input. + """ return x * self.Sigmoid(x) class Sine(paddle.nn.Layer): + """ + Sine activation function for SIREN (Sinusoidal Representation Networks). + + Args: + w0 (float, optional): Frequency parameter for sine activation. Defaults to DEFAULT_W0 (30.0). + """ def __init__(self, w0=DEFAULT_W0): self.w0 = w0 super().__init__() def forward(self, input): + """ + Apply sine activation with frequency modulation. + + Args: + input (paddle.Tensor): Input tensor. + + Returns: + paddle.Tensor: sin(w0 * input). + """ return paddle.sin(x=self.w0 * input) def sine_init(m, w0=DEFAULT_W0): + """ + Weight initialization for SIREN hidden layers. + + Initializes weights uniformly in [-√(6/n)/w0, √(6/n)/w0] where n is input dimension. + This initialization is critical for maintaining stable signal propagation in SIREN networks. + + Args: + m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute). + w0 (float, optional): Frequency parameter. Defaults to DEFAULT_W0. + """ with paddle.no_grad(): if hasattr(m, "weight"): num_input = m.weight.shape[-1] @@ -36,6 +76,15 @@ def sine_init(m, w0=DEFAULT_W0): def first_layer_sine_init(m): + """ + Weight initialization for SIREN first layer. + + Initializes weights uniformly in [-1/n, 1/n] where n is input dimension. + Different from hidden layers to handle raw coordinate inputs properly. + + Args: + m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute). + """ with paddle.no_grad(): if hasattr(m, "weight"): num_input = m.weight.shape[-1] @@ -93,16 +142,34 @@ def init_weights_xavier(m): class BatchLinear(paddle.nn.Linear): """ - This is a linear transformation implemented manually. It also allows maually input parameters. - for initialization, (in_features, out_features) needs to be provided. - weight is of shape (out_features*in_features) - bias is of shape (out_features) - + Batch-wise linear transformation layer that supports manual parameter injection. + + This layer extends paddle.nn.Linear to allow passing parameters explicitly, + which is useful for meta-learning and hypernetwork applications. + + Args: + in_features (int): Size of input features. + out_features (int): Size of output features. + + Note: + - Weight shape: (out_features, in_features) + - Bias shape: (out_features,) """ __doc__ = paddle.nn.Linear.__doc__ def forward(self, input, params=None): + """ + Forward pass with optional external parameters. + + Args: + input (paddle.Tensor): Input tensor of shape (..., in_features). + params (OrderedDict, optional): External parameters dict containing 'weight' and optionally 'bias'. + If None, uses internal parameters. Defaults to None. + + Returns: + paddle.Tensor: Output tensor of shape (..., out_features). + """ if params is None: params = OrderedDict(self.named_parameters()) bias = params.get("bias", None) @@ -116,7 +183,14 @@ def forward(self, input, params=None): class FeatureMapping: """ - This is feature mapping class for fourier feature networks + Feature mapping class for Fourier Feature Networks. + + Supports multiple mapping strategies including Gaussian random Fourier features, + positional encoding, and radial basis functions (RBF) for improving coordinate-based + neural network representations. + + Reference: + Tancik et al. "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains" """ def __init__( @@ -136,16 +210,22 @@ def __init__( rbf_std=0.5, ): """ - inputs: - in_freatures: number of input features - mapping_size: output features for Gaussian mapping - rand_key: random key for Gaussian mapping - tau: standard deviation for Gaussian mapping - num_freqs: number of frequencies for P.E. - scale = 2: base scale of frequencies for P.E. - init_scale: initial scale for P.E. - use_nyquist: use nyquist to calculate num_freqs or not. - + Initialize feature mapping. + + Args: + in_features (int): Number of input features. + mode (str, optional): Mapping mode. Options: "basic", "gaussian", "positional", "rbf". Defaults to "basic". + gaussian_mapping_size (int, optional): Output dimension for Gaussian mapping. Defaults to 256. + gaussian_rand_key (int, optional): Random seed for Gaussian mapping. Defaults to 0. + gaussian_tau (float, optional): Standard deviation for Gaussian mapping. Defaults to 1.0. + pe_num_freqs (int, optional): Number of frequency bands for positional encoding. Defaults to 4. + pe_scale (int, optional): Base scale for frequencies in positional encoding. Defaults to 2. + pe_init_scale (int, optional): Initial scale multiplier for positional encoding. Defaults to 1. + pe_use_nyquist (bool, optional): Use Nyquist frequency to determine num_freqs. Defaults to True. + pe_lowest_dim (int, optional): Lowest dimension for Nyquist calculation. Defaults to None. + rbf_out_features (int, optional): Number of RBF centers. Defaults to None. + rbf_range (float, optional): Range for RBF center initialization. Defaults to 1.0. + rbf_std (float, optional): Standard deviation for RBF kernels. Defaults to 0.5. """ self.mode = mode if mode == "basic": @@ -191,7 +271,14 @@ def get_num_frequencies_nyquist(self, samples): @staticmethod def fourier_mapping(x, B): """ - x is the input, B is the reference information + Apply Fourier feature mapping: [sin(2πxB^T), cos(2πxB^T)]. + + Args: + x (paddle.Tensor): Input coordinates of shape (..., in_features). + B (np.ndarray): Frequency matrix of shape (mapping_size, in_features). + + Returns: + paddle.Tensor: Fourier features of shape (..., 2 * mapping_size). """ if B is None: return x @@ -216,37 +303,47 @@ def gaussian(alpha): class SIRENAutodecoder_film(paddle.nn.Layer): """ - siren network with author decoding + SIREN (Sinusoidal Representation Networks) with FiLM conditioning for autodecoding. + + This architecture uses sine activations and latent code modulation (FiLM) for + implicit neural representations. It takes both coordinate inputs and latent codes, + making it suitable for learning multiple shapes/scenes with a single network. + + Reference: + Sitzmann et al. "Implicit Neural Representations with Periodic Activation Functions" (NeurIPS 2020) Args: - input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict. - output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict. - in_coord_features (int, optional): Number of input coordinates features - in_latent_features (int, optional): Number of input latent features - out_features (int, optional): Number of output features - num_hidden_layers (int, optional): Number of hidden layers - hidden_features (int, optional): Number of hidden features - outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False. - nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine". - weight_init (Callable, optional): Weight initialization function. Defaults to None. - bias_init (Callable, optional): Bias initialization function. Defaults to None. - premap_mode (str, optional): Feature mapping mode. Defaults to None. + input_keys (Tuple[str, ...], optional): Keys to get input tensors from dict. First key for coordinates, second for latents. + output_keys (Tuple[str, ...], optional): Keys to save output tensors into dict. + in_coord_features (int, optional): Number of input coordinate features (e.g., 2 for 2D, 3 for 3D). + in_latent_features (int, optional): Number of latent features for conditioning. + out_features (int, optional): Number of output features (e.g., 3 for RGB). + num_hidden_layers (int, optional): Number of hidden layers. + hidden_features (int, optional): Number of hidden layer features. + outermost_linear (bool, optional): Whether to use linear layer at output. Defaults to False. + nonlinearity (str, optional): Activation function. Options: "sine", "relu", "tanh", etc. Defaults to "sine". + weight_init (Callable, optional): Custom weight initialization function. Defaults to None. + bias_init (Callable, optional): Custom bias initialization function. Defaults to None. + premap_mode (str, optional): Feature mapping mode before network. Options: "gaussian", "positional", "rbf". Defaults to None. Examples: + >>> import ppsci >>> model = ppsci.arch.SIRENAutodecoder_film( - input_keys=["input1", "input2"], - output_keys=("output",), - in_coord_features=2, - in_latent_features=128, - out_features=3, - num_hidden_layers=10, - hidden_features=128, - ) - >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])} + ... input_keys=["coords", "latents"], + ... output_keys=("output",), + ... in_coord_features=2, + ... in_latent_features=128, + ... out_features=3, + ... num_hidden_layers=10, + ... hidden_features=128, + ... ) + >>> input_data = { + ... "coords": paddle.randn([1000, 2]), + ... "latents": paddle.randn([1000, 128]) + ... } >>> out_dict = model(input_data) - >>> for k, v in out_dict.items(): - ... print(k, v.shape) - output [22, 918, 3] + >>> print(out_dict["output"].shape) + [1000, 3] """ def __init__( @@ -325,24 +422,37 @@ def disable_gradient(self): class LatentContainer(paddle.nn.Layer): """ - a model container that stores latents for multi GPU + Learnable latent code container for autodecoding applications. + + This module stores and retrieves per-sample latent codes, which can be used + for representing multiple instances (shapes, scenes) with a single decoder network. + Supports multi-GPU training and different dimensional arrangements. + + Reference: + Park et al. "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" (CVPR 2019) Args: - input_key (Tuple[str, ...], optional): Key to get the input tensor from the dict. Defaults to ("intput",). - output_key (Tuple[str, ...], optional): Key to save the output tensor into the dict. Defaults to ("output",). - N_samples (int, optional): Number of samples. Defaults to None. - N_features (int, optional): Number of features. Defaults to None. - dims (int, optional): Number of dimensions. Defaults to None. - lumped (bool, optional): Whether to lump the latents. Defaults to False. + input_keys (Tuple[str, ...], optional): Key to get batch indices from dict. Defaults to ("input",). + output_keys (Tuple[str, ...], optional): Key to save latent codes into dict. Defaults to ("output",). + N_samples (int, optional): Total number of samples/instances in dataset. Defaults to None. + N_features (int, optional): Dimension of latent codes. Defaults to None. + dims (int, optional): Number of spatial dimensions (for proper broadcasting). Defaults to None. + lumped (bool, optional): If True, adds single dimension; if False, adds dims dimensions. Defaults to False. Examples: - >>> model = ppsci.arch.LatentContainer(N_samples=1600, N_features=128, dims=2, lumped=True) - >>> input_data = paddle.linspace(0, 1600, 1600, 'int64') - >>> input_dict = {"input": input_data} + >>> import ppsci + >>> import paddle + >>> model = ppsci.arch.LatentContainer( + ... N_samples=1600, + ... N_features=128, + ... dims=2, + ... lumped=True + ... ) + >>> batch_indices = paddle.randint(0, 1600, [32], dtype='int64') + >>> input_dict = {"input": batch_indices} >>> out_dict = model(input_dict) - >>> for k, v in out_dict.items(): - ... print(k, v.shape) - output [1600, 1, 128] + >>> print(out_dict["output"].shape) + [32, 1, 128] """ def __init__( @@ -441,6 +551,23 @@ def normal_kl(mean1, logvar1, mean2, logvar2): class GaussianDiffusion: + """ + Gaussian diffusion process for denoising diffusion probabilistic models (DDPM). + + Implements the forward diffusion process q(x_t|x_0) and reverse denoising process p(x_{t-1}|x_t). + Supports various parameterizations (epsilon, x_0, x_{t-1}) and variance schedules. + + Reference: + Ho et al. "Denoising Diffusion Probabilistic Models" (NeurIPS 2020) + Nichol & Dhariwal "Improved Denoising Diffusion Probabilistic Models" (ICML 2021) + + Args: + betas (np.ndarray): Noise schedule β_t for t=0,...,T-1. + model_mean_type (ModelMeanType): Parameterization of model output. + model_var_type (ModelVarType): Variance parameterization (fixed or learned). + loss_type (LossType): Loss function type (MSE, KL, etc.). + rescale_timesteps (bool, optional): Rescale timesteps to [0, 1000]. Defaults to False. + """ def __init__( self, *, @@ -916,12 +1043,18 @@ def is_vb(self): class SpacedDiffusion(GaussianDiffusion): """ - A diffusion process which can skip steps in a base diffusion process. + Accelerated diffusion process that skips timesteps for faster sampling. + + Implements DDIM-style sampling by using a subset of timesteps from the original + diffusion process, enabling faster inference without retraining the model. + + Reference: + Song et al. "Denoising Diffusion Implicit Models" (ICLR 2021) Args: - use_timesteps: a collection (sequence or set) of timesteps from the - original diffusion process to retain. - kwargs: the kwargs to create the base diffusion process. + use_timesteps (Sequence[int]): Collection of timesteps to retain from original process + (e.g., [0, 10, 20, ..., 1000] for 100-step sampling). + **kwargs: Additional arguments for base GaussianDiffusion (betas, model_mean_type, etc.). """ def __init__(self, use_timesteps, **kwargs): @@ -1004,6 +1137,24 @@ def forward(self, x, emb): class ResBlock(TimestepBlock): + """ + Residual block with timestep embedding for diffusion models. + + Implements a residual connection with two convolutional layers, timestep conditioning, + and optional up/downsampling. Supports FiLM-style adaptive normalization. + + Args: + channels (int): Number of input channels. + emb_channels (int): Number of timestep embedding channels. + dropout (float): Dropout probability. + out_channels (int, optional): Number of output channels. Defaults to channels. + use_conv (bool, optional): Use conv for skip connection if channels differ. Defaults to False. + use_scale_shift_norm (bool, optional): Use FiLM-style conditioning. Defaults to False. + dims (int, optional): Spatial dimensions (1D/2D/3D). Defaults to 2. + use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False. + up (bool, optional): Apply upsampling. Defaults to False. + down (bool, optional): Apply downsampling. Defaults to False. + """ def __init__( self, channels, @@ -1118,6 +1269,17 @@ def avg_pool_nd(dims, *args, **kwargs): class Downsample(paddle.nn.Layer): + """ + Spatial downsampling layer (2x reduction). + + Can use either strided convolution or average pooling for downsampling. + + Args: + channels (int): Number of input channels. + use_conv (bool): Use strided conv (True) or avg pooling (False). + dims (int, optional): Spatial dimensions. Defaults to 2. + out_channels (int, optional): Number of output channels. Defaults to channels. + """ def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() self.channels = channels @@ -1134,11 +1296,23 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None): self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): + """Apply downsampling.""" assert tuple(x.shape)[1] == self.channels return self.op(x) class Upsample(paddle.nn.Layer): + """ + Spatial upsampling layer (2x expansion). + + Uses nearest-neighbor interpolation followed by optional convolution. + + Args: + channels (int): Number of input channels. + use_conv (bool): Apply convolution after upsampling. + dims (int, optional): Spatial dimensions. Defaults to 2. + out_channels (int, optional): Number of output channels. Defaults to channels. + """ def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() self.channels = channels @@ -1149,6 +1323,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None): self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) def forward(self, x): + """Apply upsampling.""" assert tuple(x.shape)[1] == self.channels if self.dims == 3: x = paddle.nn.functional.interpolate( @@ -1288,6 +1463,19 @@ def stop_gradient(input, stop): class AttentionBlock(paddle.nn.Layer): + """ + Self-attention block for spatial feature maps. + + Applies multi-head self-attention over spatial locations in feature maps, + allowing the model to capture long-range dependencies. + + Args: + channels (int): Number of input/output channels. + num_heads (int, optional): Number of attention heads. Defaults to 1. + num_head_channels (int, optional): Channels per head (overrides num_heads). Defaults to -1. + use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False. + use_new_attention_order (bool, optional): Use optimized attention implementation. Defaults to False. + """ def __init__( self, channels, @@ -1341,6 +1529,20 @@ def convert_module_to_f32(l): def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings for diffusion models. + + Similar to positional encodings in transformers, but for continuous timesteps. + Uses sinusoids of exponentially increasing frequencies. + + Args: + timesteps (paddle.Tensor): Timestep values of shape (batch_size,). + dim (int): Embedding dimension. + max_period (int, optional): Maximum period for sinusoids. Defaults to 10000. + + Returns: + paddle.Tensor: Timestep embeddings of shape (batch_size, dim). + """ half = dim // 2 freqs = paddle.exp( x=-math.log(max_period) @@ -1358,28 +1560,55 @@ def timestep_embedding(timesteps, dim, max_period=10000): class UNetModel(paddle.nn.Layer): """ - The full UNet model with attention and timestep embedding. + Full UNet model with attention and timestep embedding for diffusion models. + + Implements a U-Net architecture with residual blocks, self-attention at multiple resolutions, + and timestep conditioning via adaptive normalization (FiLM). Designed for denoising diffusion + probabilistic models (DDPM) and can be conditioned on class labels. + + Reference: + Ronneberger et al. "U-Net: Convolutional Networks for Biomedical Image Segmentation" (MICCAI 2015) + Dhariwal & Nichol "Diffusion Models Beat GANs on Image Synthesis" (NeurIPS 2021) Args: - image_size (int): Input image size (maintained for interface compatibility) - in_channels (int): Number of channels in input tensor - model_channels (int): Base channel count for model - out_channels (int): Number of channels in output tensor - num_res_blocks (int): Residual blocks per downsampling level - attention_resolutions (list/tuple): Downsample rates to apply attention (e.g., [4, 8]) - dropout (float, optional): Dropout probability. Default: 0.0 - channel_mult (tuple, optional): Channel multipliers per level. Default: (1, 2, 4, 8) - conv_resample (bool, optional): Use convolutional resampling. Default: True - dims (int, optional): Data dimensionality (1=1D, 2=2D, 3=3D). Default: 2 - num_classes (int, optional): Number of classes for conditional generation. Default: None - use_checkpoint (bool, optional): Enable gradient checkpointing. Default: False - use_fp16 (bool, optional): Use float16 precision. Default: False - num_heads (int, optional): Number of attention heads. Default: 1 - num_head_channels (int, optional): Fixed channels per head (overrides num_heads). Default: -1 - num_heads_upsample (int, optional): Heads for upsampling blocks. Default: -1 (use num_heads) - use_scale_shift_norm (bool, optional): Use FiLM-like conditioning. Default: False - resblock_updown (bool, optional): Use residual blocks for resampling. Default: False - use_new_attention_order (bool, optional): Use optimized attention pattern. Default: False + image_size (int): Input image size (maintained for interface compatibility). + in_channels (int): Number of channels in input tensor. + model_channels (int): Base channel count for model (multiplied by channel_mult). + out_channels (int): Number of channels in output tensor. + num_res_blocks (int): Number of residual blocks per downsampling level. + attention_resolutions (list/tuple): Downsample factors where to apply attention (e.g., [4, 8, 16]). + dropout (float, optional): Dropout probability in residual blocks. Defaults to 0.0. + channel_mult (tuple, optional): Channel multipliers per level (e.g., (1, 2, 4, 8)). Defaults to (1, 2, 4, 8). + conv_resample (bool, optional): Use learned convolutional up/downsampling. Defaults to True. + dims (int, optional): Data dimensionality (1=1D, 2=2D, 3=3D). Defaults to 2. + num_classes (int, optional): Number of classes for class-conditional generation. Defaults to None. + use_checkpoint (bool, optional): Enable gradient checkpointing to save memory. Defaults to False. + use_fp16 (bool, optional): Use float16 precision for forward pass. Defaults to False. + num_heads (int, optional): Number of attention heads in each attention block. Defaults to 1. + num_head_channels (int, optional): Fixed channels per head (overrides num_heads if set). Defaults to -1. + num_heads_upsample (int, optional): Attention heads for upsampling blocks. Defaults to -1 (use num_heads). + use_scale_shift_norm (bool, optional): Use FiLM-style conditioning in ResBlocks. Defaults to False. + resblock_updown (bool, optional): Use ResBlocks for up/downsampling instead of conv layers. Defaults to False. + use_new_attention_order (bool, optional): Use optimized QKV attention implementation. Defaults to False. + + Examples: + >>> import ppsci + >>> import paddle + >>> model = ppsci.arch.UNetModel( + ... image_size=64, + ... in_channels=3, + ... model_channels=128, + ... out_channels=3, + ... num_res_blocks=2, + ... attention_resolutions=[8, 16], + ... channel_mult=(1, 2, 4, 8), + ... num_heads=4, + ... ) + >>> x = paddle.randn([4, 3, 64, 64]) + >>> t = paddle.randint(0, 1000, [4]) + >>> out = model(x, t) + >>> print(out.shape) + [4, 3, 64, 64] """ def __init__( From b415fc291efd015c491c940a754b6a28e3a15677 Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Sun, 2 Nov 2025 11:00:01 +0800 Subject: [PATCH 07/11] Refactor training process in un_confild.py to improve EMA parameter handling and loss logging. Update _update_ema function to work with parameter dictionaries. Enhance log_loss_dict to control loss aggregation during validation. Modify GaussianDiffusion class to ensure loss calculation includes valid_mse and vb terms correctly. --- examples/confild/un_confild.py | 93 +++++++++++++++++++--------------- ppsci/arch/confild.py | 4 ++ 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py index a457cca7cc..7374c85a6e 100644 --- a/examples/confild/un_confild.py +++ b/examples/confild/un_confild.py @@ -687,10 +687,12 @@ def train(cfg): schedule_sampler = UniformSampler(diff_model) # 初始化EMA参数 - ema_params = [ - copy.deepcopy(unet_model.parameters()) - for _ in range(len(ema_rate)) - ] + ema_params = [] + for _ in range(len(ema_rate)): + ema_param_dict = {} + for name, param in unet_model.named_parameters(): + ema_param_dict[name] = copy.deepcopy(param.detach()) + ema_params.append(ema_param_dict) # 清空损失记录 global train_losses, valid_losses @@ -707,13 +709,8 @@ def train(cfg): valid_batch = next(dl_valid) # 前向传播 unet_model.train() - # def zero_grad(model_params): - for param in unet_model.parameters(): - # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group - if param.grad is not None: - param.grad.detach_() - param.grad.zero_() - unet_model.clear_grad() + # 清零梯度(使用clear_grad更高效) + unet_model.clear_grad() for i in range(0, len(train_batch), microbatch): # 获取当前微批次数据 @@ -727,19 +724,17 @@ def train(cfg): t, weights = schedule_sampler.sample(len(micro)) # 创建部分应用的损失计算函数 - new_micro = paddle.to_tensor(micro) + # 注意:micro已经是tensor(来自DataLoader),无需再次转换 compute_losses = functools.partial( diff_model.training_losses, unet_model, - new_micro, + micro, t, model_kwargs=micro_cond ) # 计算损失 losses = compute_losses() - # 添加训练标记 - losses["valid"] = False # 如果使用损失感知采样器,则更新本地损失 if isinstance(schedule_sampler, LossAwareSampler): @@ -750,38 +745,39 @@ def train(cfg): # 计算加权平均损失 loss = (losses["loss"] * weights).mean() - # 记录损失字典 + # 记录损失字典(排除非张量类型的键) log_loss_dict( - diff_model, t, {k: v * weights for k, v in losses.items()}, is_valid=False + diff_model, t, {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)}, is_valid=False ) # 反向传播 # unet_model.backward(loss) loss.backward() - # 不计算梯度,节省内存 + # 不计算梯度,节省内存,设置模型为评估模式 + unet_model.eval() with paddle.no_grad(): + # 聚合所有微批次的验证损失 + all_valid_losses = [] + # 同样分解成微批次处理 for i in range(0, len(valid_batch), microbatch): # 获取当前微批次数据 micro = valid_batch[i : i + microbatch] micro_cond = { - k: v[i : i + microbatch] - for k, v in cond.items() + k: v[i : i + microbatch] + for k, v in cond.items() } - # 判断是否为最后一个微批次 - last_batch = (i + microbatch) >= len(valid_batch) - # 采样时间步 t, weights = schedule_sampler.sample(len(micro)) # 创建部分应用的损失计算函数 - new_micro = paddle.to_tensor(micro) + # 注意:micro已经是tensor(来自DataLoader),无需再次转换 compute_losses = functools.partial( diff_model.training_losses, unet_model, - new_micro, + micro, t, model_kwargs=micro_cond, valid=True @@ -789,19 +785,31 @@ def train(cfg): # 计算验证损失 losses = compute_losses() - # 添加验证标记 - losses["valid"] = True - # 记录验证损失 + # 记录验证损失(排除非张量类型的键,如布尔标记等) + valid_loss_dict = {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)} + # 验证时不添加到列表,而是在外部聚合后统一添加 log_loss_dict( - diff_model, t, {k: v * weights for k, v in losses.items()}, is_valid=True + diff_model, t, valid_loss_dict, is_valid=True, add_to_list=False ) + + # 收集损失用于聚合 + if "loss" in valid_loss_dict: + all_valid_losses.append(valid_loss_dict["loss"].mean().item()) + + # 聚合整个验证批次的平均损失并添加一次 + if len(all_valid_losses) > 0: + avg_valid_loss = sum(all_valid_losses) / len(all_valid_losses) + valid_losses.append(avg_valid_loss) + # 验证结束后切换回训练模式 + unet_model.train() + grad_norm, param_norm = _compute_norms(unet_model) opt.step() # took_step = unet_model.optimize(opt) # 更新ema参数 - _update_ema(ema_rate, ema_params, unet_model.parameters()) + _update_ema(ema_rate, ema_params, unet_model) # 更新学习率 _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr) @@ -866,21 +874,21 @@ def _compute_norms(model, grad_scale=1.0): return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) -def _update_ema(ema_rate, ema_params, source_params, rate=0.99): +def _update_ema(ema_rate, ema_params, source_model): """ 更新EMA(指数移动平均)参数 EMA有助于提高生成质量,减少模型权重噪声 参数: - ema_rate: EMA衰减率 - ema_params: EMA参数 - source_params: 源参数 - rate: 衰减率 + ema_rate: EMA衰减率列表 + ema_params: EMA参数字典列表 + source_model: 源模型 """ - for rate, target_params in zip(ema_rate, ema_params): - for targ, src in zip(target_params, source_params): - updated = targ.detach() * rate + src * (1 - rate) - targ.set_value(updated) + for rate, target_params_dict in zip(ema_rate, ema_params): + for name, target_param in target_params_dict.items(): + source_param = dict(source_model.named_parameters())[name] + updated = target_param.detach() * rate + source_param.detach() * (1 - rate) + target_param.set_value(updated) def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr): @@ -903,7 +911,7 @@ def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr): opt.set_lr(new_lr) -def log_loss_dict(diffusion, ts, losses, is_valid=False): +def log_loss_dict(diffusion, ts, losses, is_valid=False, add_to_list=True): """ 记录损失字典的日志 @@ -912,6 +920,7 @@ def log_loss_dict(diffusion, ts, losses, is_valid=False): ts: 时间步张量 losses: 损失字典 is_valid: 是否为验证损失 + add_to_list: 是否将损失添加到全局列表中(用于验证时聚合控制) """ for key, values in losses.items(): # 使用logger.info替代logger.logkv_mean记录平均损失值 @@ -921,8 +930,8 @@ def log_loss_dict(diffusion, ts, losses, is_valid=False): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.info(f"{key}_q{quartile}: {sub_loss:.6f}") - # 记录训练和验证损失 - if key == "loss": + # 记录训练和验证损失到全局列表 + if key == "loss" and add_to_list: if is_valid: valid_losses.append(values.mean().item()) else: diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py index a880e1b8a3..873a73371d 100644 --- a/ppsci/arch/confild.py +++ b/ppsci/arch/confild.py @@ -970,6 +970,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, val terms["loss"] = terms["mse"] else: terms["valid_mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["valid_mse"] + terms["vb"] + else: + terms["loss"] = terms["valid_mse"] else: raise NotImplementedError(self.loss_type) From f34c982757a43f8d6e9c082184f6fb7c8826528f Mon Sep 17 00:00:00 2001 From: wangguan1995 <772359200@qq.com> Date: Mon, 10 Nov 2025 12:18:27 +0000 Subject: [PATCH 08/11] fix --- examples/confild/conf/confild_case1.yaml | 7 +- examples/confild/conf/un_confild_case1.yaml | 13 +- examples/confild/confild.py | 154 +++++++++++++------- 3 files changed, 109 insertions(+), 65 deletions(-) diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml index 4bcec2db32..9502b71a7c 100644 --- a/examples/confild/conf/confild_case1.yaml +++ b/examples/confild/conf/confild_case1.yaml @@ -11,8 +11,7 @@ defaults: hydra: run: # dynamic output directory according to running time and override name - # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} - dir: ./outputs_confild_case1 + dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} job: name: ${mode} # name of logfile chdir: false # keep current working directory unchanged @@ -115,8 +114,8 @@ Uncondiction_INFER: noise_schedule: "cosine" Data: - data_path: /home/aistudio/work/extracted/data/Case1/data.npy - coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy + data_path: data/Case1/case1_data.npy + coor_path: data/Case1/case1_coords.npy normalizer: method: "-11" dim: 0 diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml index 76b0c8f27a..1abe829b62 100644 --- a/examples/confild/conf/un_confild_case1.yaml +++ b/examples/confild/conf/un_confild_case1.yaml @@ -11,8 +11,7 @@ defaults: hydra: run: # dynamic output directory according to running time and override name - # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} - dir: ./outputs_un_confild_case1 + dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} job: name: ${mode} # name of logfile chdir: false # keep current working directory unchanged @@ -60,7 +59,7 @@ UNET: num_head_channels: 64 attention_resolutions: "32,16,8" channel_mult: null - ema_path: /home/aistudio/work/extracted/data/Case1/diffusion/ema.pdparams + ema_path: data/Case1/diffusion/ema.pdparams Diff: steps: 1000 @@ -68,8 +67,8 @@ Diff: CNF: mutil_GPU: 1 - data_path: /home/aistudio/work/extracted/data/Case1/data.npy - coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy + data_path: data/Case1/data.npy + coor_path: data/Case1/coords.npy load_data_fn: load_elbow_flow normalizer: method: "-11" @@ -88,5 +87,5 @@ CNF: DATA: max_val: 1.0 min_val: -1.0 - train_data: "/home/aistudio/work/extracted/data/Case1/train_data.npy" - valid_data: "/home/aistudio/work/extracted/data/Case1/valid_data.npy" \ No newline at end of file + train_data: "data/Case1/train_data.npy" + valid_data: "data/Case1/valid_data.npy" diff --git a/examples/confild/confild.py b/examples/confild/confild.py index 8003010630..c8e6a227e6 100644 --- a/examples/confild/confild.py +++ b/examples/confild/confild.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum -import math import hydra import matplotlib.pyplot as plt import numpy as np @@ -24,12 +22,8 @@ from paddle.io import DistributedBatchSampler import ppsci -from ppsci.arch import UNetModel from ppsci.arch import LatentContainer from ppsci.arch import SIRENAutodecoder_film -from ppsci.arch import SpacedDiffusion -from ppsci.arch import ModelVarType -from ppsci.arch import ModelMeanType from ppsci.utils import logger @@ -121,13 +115,9 @@ def get_params(self): @staticmethod def fnormalize(data, params, method): if method == "-11": - return (data - params[1]) / ( - params[0] - params[1] - ) * 2 - 1 + return (data - params[1]) / (params[0] - params[1]) * 2 - 1 elif method == "01": - return (data - params[1]) / ( - params[0] - params[1] - ) + return (data - params[1]) / (params[0] - params[1]) elif method == "ms": return (data - params[0]) / params[1] elif method == "none": @@ -138,9 +128,7 @@ def fdenormalize(data_norm, params, method): if method == "-11": return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1] elif method == "01": - return data_norm * ( - params[0] - params[1] - ) + params[1] + return data_norm * (params[0] - params[1]) + params[1] elif method == "ms": return data_norm * params[1] + params[0] elif method == "none": @@ -154,7 +142,11 @@ def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> Non self.total_samples = tuple(fois.shape)[0] self.coords = coord.numpy() # 存储全局索引 - self.global_indices = global_indices if global_indices is not None else np.arange(self.total_samples) + self.global_indices = ( + global_indices + if global_indices is not None + else np.arange(self.total_samples) + ) def __len__(self): return self.total_samples @@ -165,7 +157,11 @@ def __getitem__(self, idx): if hasattr(self, "extra_in"): extra_id = idx % tuple(self.fois.shape)[1] idb = idx // tuple(self.fois.shape)[1] - return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], global_idx + return ( + (self.coords, self.extra_in[extra_id]), + self.fois[idb, extra_id], + global_idx, + ) else: return self.coords, self.fois[idx], global_idx @@ -187,9 +183,7 @@ def getdata(cfg): # 计算空间形状和轴 spatio_shape = fois.shape[1:-1] spatio_axis = list( - range( - fois.ndim if isinstance(fois, np.ndarray) else fois.dim() - ) + range(fois.ndim if isinstance(fois, np.ndarray) else fois.dim()) )[1:-1] ###### read data - coordinate ###### @@ -202,11 +196,7 @@ def getdata(cfg): fois = fois.astype("float32") ###### convert to tensor ###### - fois = ( - paddle.to_tensor(fois) - if not isinstance(fois, paddle.Tensor) - else fois - ) + fois = paddle.to_tensor(fois) if not isinstance(fois, paddle.Tensor) else fois coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord N_samples = fois.shape[0] @@ -219,13 +209,22 @@ def getdata(cfg): out_normalizer.fit_normalize( fois if cfg.Latent.lumped else fois.flatten(0, cfg.Latent.dims) ) - normed_coords = in_normalizer.normalize(coord)# 训练集就是测试集 + normed_coords = in_normalizer.normalize(coord) # 训练集就是测试集 normed_fois = out_normalizer.normalize(fois) return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer -def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices): +def signal_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, +): cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) latents_model = LatentContainer(**cfg.Latent) @@ -274,7 +273,7 @@ def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio epoch_loss = paddle.stack(x=train_loss).mean().item() losses.append(epoch_loss) print("epoch {}, train loss {}".format(i + 1, epoch_loss)) - if i % 100 == 0: + if (i + 1) % cfg.log_freq == 0: test_error = [] cnf_model.eval() latents_model.eval() @@ -297,9 +296,14 @@ def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio test_error.append(error) test_error = paddle.concat(x=test_error).mean(axis=0) print("test MAE: ", test_error) - if i % 100 == 0: - paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") - paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") + if (i + 1) % cfg.log_freq == 0: + paddle.save( + cnf_model.state_dict(), f"{cfg.output_dir}/cnf_model_{i+1}.pdparams" + ) + paddle.save( + latents_model.state_dict(), + f"{cfg.output_dir}/latents_model_{i+1}.pdparams", + ) # 绘制损失图 plt.figure(figsize=(10, 6)) plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") @@ -323,7 +327,16 @@ def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio plt.show() -def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices): +def mutil_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, +): fleet.init(is_collective=True) cnf_model = SIRENAutodecoder_film(**cfg.CONFILD) cnf_model = fleet.distributed_model(cnf_model) @@ -447,30 +460,52 @@ def train(cfg): test_normed_fois = normed_fois train_indices = list(range(N_samples)) test_indices = list(range(N_samples)) - - + if world_size > 1: import paddle.distributed as dist + dist.init_parallel_env() - mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, - spatio_axis, out_normalizer, train_indices, test_indices) + mutil_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, + ) else: - signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, - spatio_axis, out_normalizer, train_indices, test_indices) + signal_train( + cfg, + normed_coords, + train_normed_fois, + test_normed_fois, + spatio_axis, + out_normalizer, + train_indices, + test_indices, + ) def evaluate(cfg: DictConfig): # set data # normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg) normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg) - + # [918,2] + # [16000,918,3] + print(normed_coords.shape) + print(normed_fois.shape) + t_std = 15698 + normed_fois = normed_fois[t_std:] + # exit() if len(normed_coords.shape) + 1 == len(normed_fois.shape): normed_coords = paddle.tile( normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape) ) - idx = paddle.to_tensor( - np.array([i for i in range(normed_fois.shape[0])]), dtype="int64" + time = paddle.to_tensor( + np.array([i for i in range(t_std, t_std + normed_fois.shape[0])]), dtype="int64" ) # set model confild = SIRENAutodecoder_film(**cfg.CONFILD) @@ -491,23 +526,34 @@ def evaluate(cfg: DictConfig): latent, cfg.EVAL.latent_pretrained_model_path, ) - latent_test_pred = latent({"latent_x": idx}) + latent_test_pred = latent({"latent_x": time}) y_test_pred = [] for i in range(normed_coords.shape[0]): - y_test_pred.append( - confild( - { - "confild_x": normed_coords[i], - "latent_z": latent_test_pred["latent_z"][i], - } - )["confild_output"].numpy() - ) - y_test_pred = paddle.to_tensor(np.array(y_test_pred)) + ouput = confild( + { + "confild_x": normed_coords[i], + "latent_z": latent_test_pred["latent_z"][i], + } + )["confild_output"].numpy() + y_test_pred.append(ouput) + y_test_pred = paddle.to_tensor(np.array(y_test_pred)) y_test_pred = out_normalizer.denormalize(y_test_pred) y_test = out_normalizer.denormalize(normed_fois) - logger.info("Result is {}".format(y_test.numpy())) + def calc_err(): + var_name = ["u", "v", "p"] + for i, var in enumerate(var_name): + u_true_mean = paddle.mean(y_test[:, :, i], axis=0) + u_label_mean = paddle.mean(y_test_pred[:, :, i], axis=0) + u_true_std = paddle.std(y_test[:, :, i], axis=0) + u_label_std = paddle.std(y_test_pred[:, :, i], axis=0) + avg_discrepancy = paddle.mean(paddle.abs(u_true_mean - u_label_mean)) + std_discrepancy = paddle.mean(paddle.abs(u_true_std - u_label_std)) + print(f"Average Discrepancy [{var}] Value: {avg_discrepancy:.3f}") + print(f"Standard Deviation of [{var}] Magnitude: {std_discrepancy:.4f}") + + calc_err() def inference(cfg): @@ -519,11 +565,11 @@ def inference(cfg): ) fois_len = normed_fois.shape[0] - idxs = np.array([i for i in range(fois_len)]) + times = np.array([i for i in range(fois_len)]) from deploy import python_infer latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent) - input_dict = {"latent_x": idxs} + input_dict = {"latent_x": times} output_dict = latent_predictor.predict(input_dict, cfg.INFER.batch_size) cnf_predictor = python_infer.GeneralPredictor(cfg.INFER.Confild) From 7ec07b0dd975888d4e336e63db89851e93f4e761 Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Mon, 10 Nov 2025 22:40:34 +0800 Subject: [PATCH 09/11] Refactor training loop in confild.py and un_confild.py to improve readability and consistency. Update variable names for clarity, enhance gradient handling, and streamline loss logging. Fix tensor type conversions and ensure proper handling of model parameters during training and evaluation. --- docs/zh/examples/confild.md | 1 + examples/confild/confild.py | 56 +++++----- examples/confild/un_confild.py | 182 ++++++++++++++++----------------- ppsci/arch/confild.py | 28 ++--- 4 files changed, 133 insertions(+), 134 deletions(-) diff --git a/docs/zh/examples/confild.md b/docs/zh/examples/confild.md index 96e167182c..83fd21f697 100644 --- a/docs/zh/examples/confild.md +++ b/docs/zh/examples/confild.md @@ -367,3 +367,4 @@ def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer): ``` ## 5. 实验结果 + diff --git a/examples/confild/confild.py b/examples/confild/confild.py index 8003010630..8e4d498995 100644 --- a/examples/confild/confild.py +++ b/examples/confild/confild.py @@ -249,39 +249,44 @@ def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio ) losses = [] - for i in range(cfg.TRAIN.epochs): + for epoch in range(cfg.TRAIN.epochs): cnf_model.train() latents_model.train() - if i != 0: - cnf_optimizer.step() - cnf_optimizer.clear_grad(set_to_zero=False) train_loss = [] for batch_coords, batch_fois, idx in train_loader: idx = {"latent_x": idx} batch_latent = latents_model(idx) if isinstance(batch_coords, list): - batch_coords = [i for i in batch_coords] + batch_coords = [coord for coord in batch_coords] data = { "confild_x": batch_coords, "latent_z": batch_latent["latent_z"], } batch_output = cnf_model(data) loss = criterion(batch_output["confild_output"], batch_fois) + + # 清空梯度 + cnf_optimizer.clear_grad(set_to_zero=False) latents_optimizer.clear_grad(set_to_zero=False) + + # 反向传播 loss.backward() + + # 更新参数 + cnf_optimizer.step() latents_optimizer.step() train_loss.append(loss) epoch_loss = paddle.stack(x=train_loss).mean().item() losses.append(epoch_loss) - print("epoch {}, train loss {}".format(i + 1, epoch_loss)) - if i % 100 == 0: + print("epoch {}, train loss {}".format(epoch + 1, epoch_loss)) + if epoch % 100 == 0: test_error = [] cnf_model.eval() latents_model.eval() with paddle.no_grad(): for test_coords, test_fois, idx in test_loader: if isinstance(test_coords, list): - test_coords = [i for i in test_coords] + test_coords = [coord for coord in test_coords] prediction = out_normalizer.denormalize( cnf_model( { @@ -297,9 +302,8 @@ def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio test_error.append(error) test_error = paddle.concat(x=test_error).mean(axis=0) print("test MAE: ", test_error) - if i % 100 == 0: - paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") - paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") + paddle.save(cnf_model.state_dict(), f"cnf_model_{epoch}.pdparams") + paddle.save(latents_model.state_dict(), f"latents_model_{epoch}.pdparams") # 绘制损失图 plt.figure(figsize=(10, 6)) plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") @@ -364,39 +368,44 @@ def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_ criterion = paddle.nn.MSELoss() losses = [] - for i in range(cfg.TRAIN.epochs): + for epoch in range(cfg.TRAIN.epochs): cnf_model.train() latents_model.train() - if i != 0: - cnf_optimizer.step() - cnf_optimizer.clear_grad(set_to_zero=False) train_loss = [] for batch_coords, batch_fois, idx in train_loader: idx = {"latent_x": idx} batch_latent = latents_model(idx) if isinstance(batch_coords, list): - batch_coords = [i for i in batch_coords] + batch_coords = [coord for coord in batch_coords] data = { "confild_x": batch_coords, "latent_z": batch_latent["latent_z"], } batch_output = cnf_model(data) loss = criterion(batch_output["confild_output"], batch_fois) + + # 清空梯度 + cnf_optimizer.clear_grad(set_to_zero=False) latents_optimizer.clear_grad(set_to_zero=False) + + # 反向传播 loss.backward() + + # 更新参数 + cnf_optimizer.step() latents_optimizer.step() train_loss.append(loss) epoch_loss = paddle.stack(x=train_loss).mean().item() losses.append(epoch_loss) - print("epoch {}, train loss {}".format(i + 1, epoch_loss)) - if i % 100 == 0: + print("epoch {}, train loss {}".format(epoch + 1, epoch_loss)) + if epoch % 100 == 0: test_error = [] cnf_model.eval() latents_model.eval() with paddle.no_grad(): for test_coords, test_fois, idx in test_loader: if isinstance(test_coords, list): - test_coords = [i for i in test_coords] + test_coords = [coord for coord in test_coords] prediction = out_normalizer.denormalize( cnf_model( { @@ -412,9 +421,8 @@ def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_ test_error.append(error) test_error = paddle.concat(x=test_error).mean(axis=0) print("test MAE: ", test_error) - if i % 100 == 0: - paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams") - paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams") + paddle.save(cnf_model.state_dict(), f"cnf_model_{epoch}.pdparams") + paddle.save(latents_model.state_dict(), f"latents_model_{epoch}.pdparams") # 绘制损失图 plt.figure(figsize=(10, 6)) plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") @@ -470,7 +478,7 @@ def evaluate(cfg: DictConfig): ) idx = paddle.to_tensor( - np.array([i for i in range(normed_fois.shape[0])]), dtype="int64" + np.arange(normed_fois.shape[0]), dtype="int64" ) # set model confild = SIRENAutodecoder_film(**cfg.CONFILD) @@ -519,7 +527,7 @@ def inference(cfg): ) fois_len = normed_fois.shape[0] - idxs = np.array([i for i in range(fois_len)]) + idxs = np.arange(fois_len) from deploy import python_infer latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent) diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py index 7374c85a6e..bd060c100b 100644 --- a/examples/confild/un_confild.py +++ b/examples/confild/un_confild.py @@ -82,8 +82,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): 返回: 提取并广播后的张量 """ - # 修复变量名错误 - res = paddle.to_tensor(arr)[timesteps].astype(timesteps.dtype) + # 修复类型转换:先指定dtype再索引 + res = paddle.to_tensor(arr, dtype=timesteps.dtype)[timesteps] while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) @@ -694,73 +694,84 @@ def train(cfg): ema_param_dict[name] = copy.deepcopy(param.detach()) ema_params.append(ema_param_dict) - # 清空损失记录 global train_losses, valid_losses train_losses.clear() valid_losses.clear() - while ( - not lr_anneal_steps - or step + resume_step < lr_anneal_steps - ): - cond = {} - # 获取下一个训练批次和验证批次的数据 - train_batch = next(dl_train) - valid_batch = next(dl_valid) - # 前向传播 - unet_model.train() - # 清零梯度(使用clear_grad更高效) - unet_model.clear_grad() - - for i in range(0, len(train_batch), microbatch): - # 获取当前微批次数据 - micro = train_batch[i : i + microbatch] - micro_cond = { - k: v[i : i + microbatch] - for k, v in cond.items() - } - - # 从调度采样器中采样时间步 - t, weights = schedule_sampler.sample(len(micro)) - - # 创建部分应用的损失计算函数 - # 注意:micro已经是tensor(来自DataLoader),无需再次转换 - compute_losses = functools.partial( - diff_model.training_losses, - unet_model, - micro, - t, - model_kwargs=micro_cond - ) + valid_interval = 100 + + while lr_anneal_steps and (step + resume_step < lr_anneal_steps): + cond = {} + # 获取训练批次数据 + train_batch = next(dl_train) + + # 前向传播 + unet_model.train() + # 清零梯度 + opt.clear_grad() + + for i in range(0, len(train_batch), microbatch): + # 获取当前微批次数据 + micro = train_batch[i : i + microbatch] + micro_cond = { + k: v[i : i + microbatch] + for k, v in cond.items() + } - # 计算损失 - losses = compute_losses() + t, weights = schedule_sampler.sample(len(micro)) - # 如果使用损失感知采样器,则更新本地损失 - if isinstance(schedule_sampler, LossAwareSampler): - schedule_sampler.update_with_local_losses( - t, losses["loss"].detach() - ) + compute_losses = functools.partial( + diff_model.training_losses, + unet_model, + micro, + t, + model_kwargs=micro_cond + ) - # 计算加权平均损失 - loss = (losses["loss"] * weights).mean() - - # 记录损失字典(排除非张量类型的键) + # 计算损失 + losses = compute_losses() + + if isinstance(schedule_sampler, LossAwareSampler): + schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + # 计算加权平均损失 + loss = (losses["loss"] * weights).mean() + + num_microbatches = (len(train_batch) + microbatch - 1) // microbatch + if num_microbatches > 1: + loss = loss / num_microbatches + + if i == 0: log_loss_dict( diff_model, t, {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)}, is_valid=False ) - - # 反向传播 - # unet_model.backward(loss) - loss.backward() + + # 反向传播(梯度累积) + loss.backward() - # 不计算梯度,节省内存,设置模型为评估模式 + # 更新参数 + grad_norm, param_norm = _compute_norms(unet_model) + opt.step() + + # 更新EMA参数 + _update_ema(ema_rate, ema_params, unet_model) + + # 更新学习率 + _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr) + + step += 1 + + # 定期执行验证(每valid_interval步) + if step % valid_interval == 0: unet_model.eval() with paddle.no_grad(): - # 聚合所有微批次的验证损失 + # 获取验证批次 + valid_batch = next(dl_valid) all_valid_losses = [] - # 同样分解成微批次处理 + # 分解成微批次处理 for i in range(0, len(valid_batch), microbatch): # 获取当前微批次数据 micro = valid_batch[i : i + microbatch] @@ -772,8 +783,7 @@ def train(cfg): # 采样时间步 t, weights = schedule_sampler.sample(len(micro)) - # 创建部分应用的损失计算函数 - # 注意:micro已经是tensor(来自DataLoader),无需再次转换 + # 计算验证损失 compute_losses = functools.partial( diff_model.training_losses, unet_model, @@ -782,43 +792,20 @@ def train(cfg): model_kwargs=micro_cond, valid=True ) - - # 计算验证损失 losses = compute_losses() - # 记录验证损失(排除非张量类型的键,如布尔标记等) - valid_loss_dict = {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)} - # 验证时不添加到列表,而是在外部聚合后统一添加 - log_loss_dict( - diff_model, t, valid_loss_dict, is_valid=True, add_to_list=False - ) - - # 收集损失用于聚合 - if "loss" in valid_loss_dict: - all_valid_losses.append(valid_loss_dict["loss"].mean().item()) + # 收集损失 + if "loss" in losses: + all_valid_losses.append((losses["loss"] * weights).mean().item()) - # 聚合整个验证批次的平均损失并添加一次 + # 聚合并记录验证损失 if len(all_valid_losses) > 0: avg_valid_loss = sum(all_valid_losses) / len(all_valid_losses) valid_losses.append(avg_valid_loss) - - # 验证结束后切换回训练模式 - unet_model.train() + print(f"Step {step}: Train Loss: {train_losses[-1]:.6f}, Valid Loss: {avg_valid_loss:.6f}") - grad_norm, param_norm = _compute_norms(unet_model) - opt.step() - # took_step = unet_model.optimize(opt) - # 更新ema参数 - _update_ema(ema_rate, ema_params, unet_model) - # 更新学习率 - _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr) - - step += 1 - - # 每100步打印一次训练和验证损失 - if step % 100 == 0: - if len(train_losses) > 0 and len(valid_losses) > 0: - print(f"Step {step}: Train Loss: {train_losses[-1]:.6f}, Valid Loss: {valid_losses[-1]:.6f}") + # 切换回训练模式 + unet_model.train() # 保存模型 paddle.save(unet_model.state_dict(), "unet.pdparams") @@ -923,19 +910,22 @@ def log_loss_dict(diffusion, ts, losses, is_valid=False, add_to_list=True): add_to_list: 是否将损失添加到全局列表中(用于验证时聚合控制) """ for key, values in losses.items(): - # 使用logger.info替代logger.logkv_mean记录平均损失值 - logger.info(f"{key}: {values.mean().item():.6f}") - # 记录分位数(特别是四个四分位数) - for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): - quartile = int(4 * sub_t / diffusion.num_timesteps) - logger.info(f"{key}_q{quartile}: {sub_loss:.6f}") + # 记录平均损失值 + mean_loss = values.mean().item() + logger.info(f"{key}: {mean_loss:.6f}") + + # ts_numpy = ts.cpu().numpy() if ts.place.is_gpu_place() else ts.numpy() + # values_numpy = values.detach().cpu().numpy() if values.place.is_gpu_place() else values.detach().numpy() + # for sub_t, sub_loss in zip(ts_numpy, values_numpy): + # quartile = int(4 * sub_t / diffusion.num_timesteps) + # logger.info(f"{key}_q{quartile}: {sub_loss:.6f}") # 记录训练和验证损失到全局列表 if key == "loss" and add_to_list: if is_valid: - valid_losses.append(values.mean().item()) + valid_losses.append(mean_loss) else: - train_losses.append(values.mean().item()) + train_losses.append(mean_loss) def evaluate(cfg): @@ -972,7 +962,7 @@ def evaluate(cfg): nf.set_state_dict(paddle.load(cfg.CNF.model_path)) coord = in_normalizer.normalize(coord) - batch_size = 1 # if you are limited by your GPU Memory, please change the batch_size variable accordingly + batch_size = 1 n_samples = gen_latents.shape[0] gen_fields = [] @@ -1016,4 +1006,4 @@ def main(cfg: DictConfig): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py index 873a73371d..5dbab69e7a 100644 --- a/ppsci/arch/confild.py +++ b/ppsci/arch/confild.py @@ -236,7 +236,7 @@ def __init__( loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features) ) elif mode == "positional": - if pe_use_nyquist == "True" and pe_lowest_dim: + if pe_use_nyquist and pe_lowest_dim: pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim) self.B = pe_init_scale * np.vstack( [(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)] @@ -292,12 +292,12 @@ def fourier_mapping(x, B): def rbf_mapping(self, x): size = tuple(x.shape)[:-1] + tuple(self.centers.shape) x = x.unsqueeze(axis=-2).expand(shape=size) - distances = (x - self.centers).pow(y=2).sum(axis=-1) * self.sigmas + distances = paddle.pow(x - self.centers, 2).sum(axis=-1) * self.sigmas return self.gaussian(distances) @staticmethod def gaussian(alpha): - phi = paddle.exp(x=-1 * alpha.pow(y=2)) + phi = paddle.exp(x=-1 * paddle.pow(alpha, 2)) return phi @@ -417,7 +417,7 @@ def forward(self, input_data): def disable_gradient(self): for param in self.parameters(): - param.stop_gradient = not False + param.stop_gradient = True class LatentContainer(paddle.nn.Layer): @@ -502,7 +502,7 @@ class ModelVarType(enum.Enum): def _extract_into_tensor(arr, timesteps, broadcast_shape): res = ( paddle.to_tensor(data=arr)[timesteps] - .astype(dtype="float32") + .astype(dtype=timesteps.dtype) ) while len(tuple(res.shape)) < len(broadcast_shape): res = res[..., None] @@ -537,7 +537,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). logvar1, logvar2 = [ - x if isinstance(x, paddle.Tensor) else paddle.to_tensor(x).to(tensor) + x if isinstance(x, paddle.Tensor) else paddle.to_tensor(x, dtype=tensor.dtype, place=tensor.place) for x in (logvar1, logvar2) ] @@ -780,7 +780,7 @@ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + eps = eps - paddle.sqrt(1 - alpha_bar) * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) @@ -940,7 +940,7 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, val model_output, model_var_values = split(model_output, C, axis=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. - frozen_out = paddle.cat([model_output.detach(), model_var_values], dim=1) + frozen_out = paddle.concat([model_output.detach(), model_var_values], axis=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, @@ -1013,13 +1013,13 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): cdf_plus = approx_standard_normal_cdf(plus_in) min_in = inv_stdv * (centered_x - 1.0 / 255.0) cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = paddle.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clamp(min=1e-12)) + log_cdf_plus = paddle.log(cdf_plus.clip(min=1e-12)) + log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clip(min=1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = paddle.where( x < -0.999, log_cdf_plus, - paddle.where(x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clamp(min=1e-12))), + paddle.where(x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clip(min=1e-12))), ) assert log_probs.shape == x.shape return log_probs @@ -1386,8 +1386,8 @@ def forward(self, qkv): scale = 1 / math.sqrt(math.sqrt(ch)) weight = paddle.einsum(# 非复数 "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), + (q * scale).reshape([bs * self.n_heads, ch, length]), + (k * scale).reshape([bs * self.n_heads, ch, length]), ) weight = paddle.nn.functional.softmax( x=weight.astype(dtype="float32"), axis=-1 @@ -1439,7 +1439,7 @@ def forward(ctx, run_function, length, *args): def backward(ctx, *output_grads): ctx.input_tensors = [stop_gradient(x, stop=False) for x in ctx.input_tensors] with paddle.enable_grad(): - shallow_copies = [x.view_as(other=x) for x in ctx.input_tensors] + shallow_copies = [x.reshape(x.shape) for x in ctx.input_tensors] # print(shallow_copies) output_tensors = ctx.run_function(*shallow_copies) input_grads = paddle.grad( From 931b8ee68f0b0677ae0c6ba2ba64c6f1ddaf8f8f Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Mon, 10 Nov 2025 23:48:25 +0800 Subject: [PATCH 10/11] Enhance model saving functionality in confild.py by ensuring output directory exists before saving model parameters. This change improves file management during training by organizing saved models into a specified output directory. --- examples/confild/confild.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/confild/confild.py b/examples/confild/confild.py index 5223028b06..3ea24140a4 100644 --- a/examples/confild/confild.py +++ b/examples/confild/confild.py @@ -302,9 +302,13 @@ def signal_train( test_error.append(error) test_error = paddle.concat(x=test_error).mean(axis=0) print("test MAE: ", test_error) + import os + if not os.path.exists(cfg.output_dir): + os.makedirs(cfg.output_dir, exist_ok=True) + + paddle.save(cnf_model.state_dict(), f"{cfg.output_dir}/cnf_model_{epoch}.pdparams") + paddle.save(latents_model.state_dict(), f"{cfg.output_dir}/latents_model_{epoch}.pdparams") - paddle.save(cnf_model.state_dict(), f"cnf_model_{epoch}.pdparams") - paddle.save(latents_model.state_dict(), f"latents_model_{epoch}.pdparams") # 绘制损失图 plt.figure(figsize=(10, 6)) plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") @@ -431,8 +435,14 @@ def mutil_train( test_error.append(error) test_error = paddle.concat(x=test_error).mean(axis=0) print("test MAE: ", test_error) - paddle.save(cnf_model.state_dict(), f"cnf_model_{epoch}.pdparams") - paddle.save(latents_model.state_dict(), f"latents_model_{epoch}.pdparams") + + import os + if not os.path.exists(cfg.output_dir): + os.makedirs(cfg.output_dir, exist_ok=True) + + paddle.save(cnf_model.state_dict(), f"{cfg.output_dir}/cnf_model_{epoch}.pdparams") + paddle.save(latents_model.state_dict(), f"{cfg.output_dir}/latents_model_{epoch}.pdparams") + # 绘制损失图 plt.figure(figsize=(10, 6)) plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss") From 46a52608aa658e3408ac4c436dd694d32cd5474e Mon Sep 17 00:00:00 2001 From: ADream-ki <2085127827@qq.com> Date: Wed, 12 Nov 2025 23:21:02 +0800 Subject: [PATCH 11/11] Make t_std configurable in evaluate function of confild.py, allowing for dynamic adjustment while defaulting to 15698. This enhances flexibility in model evaluation. --- examples/confild/confild.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/confild/confild.py b/examples/confild/confild.py index 3ea24140a4..aaf520ed85 100644 --- a/examples/confild/confild.py +++ b/examples/confild/confild.py @@ -511,7 +511,7 @@ def evaluate(cfg: DictConfig): # [16000,918,3] print(normed_coords.shape) print(normed_fois.shape) - t_std = 15698 + t_std = cfg.get("t_std", 15698) # Use configurable parameter, default to 15698 normed_fois = normed_fois[t_std:] # exit() if len(normed_coords.shape) + 1 == len(normed_fois.shape):