|
479 | 479 | "\n", |
480 | 480 | "# For BayesFlow devs: this ensures that the latest dev version can be found\n", |
481 | 481 | "import sys\n", |
482 | | - "sys.path.append('../')\n", |
| 482 | + "sys.path.append(\"../\")\n", |
483 | 483 | "\n", |
484 | 484 | "import bayesflow as bf" |
485 | 485 | ] |
|
513 | 513 | "source": [ |
514 | 514 | "def prior_helper():\n", |
515 | 515 | " \"\"\"The ABC prior returns a Parameter Object from pyabc which we convert to a dict.\"\"\"\n", |
516 | | - " return dict(rate=prior.rvs()['rate'])\n", |
| 516 | + " return dict(rate=prior.rvs()[\"rate\"])\n", |
517 | 517 | "\n", |
518 | 518 | "def sim_helper(rate):\n", |
519 | 519 | " \"\"\"The simulator returns a dict, we extract the output at the test times.\"\"\"\n", |
520 | | - " temp = sim({'rate': rate})\n", |
| 520 | + " temp = sim({\"rate\": rate})\n", |
521 | 521 | " xt_ind = np.searchsorted(temp[\"t\"], t_test_times) - 1\n", |
522 | 522 | " obs = temp[\"X\"][:, 1][xt_ind]\n", |
523 | 523 | " return dict(obs=obs)" |
|
568 | 568 | ], |
569 | 569 | "source": [ |
570 | 570 | "adapter = bf.approximators.ContinuousApproximator.build_adapter(\n", |
571 | | - " inference_variables='rate',\n", |
572 | | - " inference_conditions='obs',\n", |
| 571 | + " inference_variables=\"rate\",\n", |
| 572 | + " inference_conditions=\"obs\",\n", |
573 | 573 | " summary_variables=None\n", |
574 | 574 | ")\n", |
575 | 575 | "adapter" |
|
665 | 665 | "output_type": "stream", |
666 | 666 | "text": [ |
667 | 667 | "Epoch 1/10\n", |
668 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", |
| 668 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4428 - loss/inference_loss: 0.4428 - val_loss: 0.4605 - val_loss/inference_loss: 0.4605\n", |
669 | 669 | "Epoch 2/10\n", |
670 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", |
| 670 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 64ms/step - loss: 0.3700 - loss/inference_loss: 0.3700 - val_loss: 0.4467 - val_loss/inference_loss: 0.4467\n", |
671 | 671 | "Epoch 3/10\n", |
672 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", |
| 672 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 68ms/step - loss: 0.3458 - loss/inference_loss: 0.3458 - val_loss: 0.3627 - val_loss/inference_loss: 0.3627\n", |
673 | 673 | "Epoch 4/10\n", |
674 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", |
| 674 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 70ms/step - loss: 0.3771 - loss/inference_loss: 0.3771 - val_loss: 0.3637 - val_loss/inference_loss: 0.3637\n", |
675 | 675 | "Epoch 5/10\n", |
676 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", |
| 676 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 69ms/step - loss: 0.3729 - loss/inference_loss: 0.3729 - val_loss: 0.2138 - val_loss/inference_loss: 0.2138\n", |
677 | 677 | "Epoch 6/10\n", |
678 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", |
| 678 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m7s\u001B[0m 66ms/step - loss: 0.3567 - loss/inference_loss: 0.3567 - val_loss: 0.2888 - val_loss/inference_loss: 0.2888\n", |
679 | 679 | "Epoch 7/10\n", |
680 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", |
| 680 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 62ms/step - loss: 0.4077 - loss/inference_loss: 0.4077 - val_loss: 0.3235 - val_loss/inference_loss: 0.3235\n", |
681 | 681 | "Epoch 8/10\n", |
682 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", |
| 682 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.4124 - loss/inference_loss: 0.4124 - val_loss: 0.3256 - val_loss/inference_loss: 0.3256\n", |
683 | 683 | "Epoch 9/10\n", |
684 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", |
| 684 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 61ms/step - loss: 0.3960 - loss/inference_loss: 0.3960 - val_loss: 0.2767 - val_loss/inference_loss: 0.2767\n", |
685 | 685 | "Epoch 10/10\n", |
686 | | - "\u001b[1m100/100\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" |
| 686 | + "\u001B[1m100/100\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m6s\u001B[0m 60ms/step - loss: 0.4217 - loss/inference_loss: 0.4217 - val_loss: 0.3482 - val_loss/inference_loss: 0.3482\n" |
687 | 687 | ] |
688 | 688 | } |
689 | 689 | ], |
|
829 | 829 | "obs = observations[\"X\"][:, 1][xt_ind]\n", |
830 | 830 | "\n", |
831 | 831 | "# Obtain 1000 posterior samples\n", |
832 | | - "samples = workflow.sample(conditions={'obs': [obs]}, num_samples=num_samples)" |
| 832 | + "samples = workflow.sample(conditions={\"obs\": [obs]}, num_samples=num_samples)" |
833 | 833 | ] |
834 | 834 | }, |
835 | 835 | { |
|
881 | 881 | "source": [ |
882 | 882 | "# abc gives us weighted samples, we resample them to get comparable samples\n", |
883 | 883 | "df, w = abc_history.get_distribution()\n", |
884 | | - "abc_samples = weighted_statistics.resample(df['rate'].values, w, 1000)\n", |
| 884 | + "abc_samples = weighted_statistics.resample(df[\"rate\"].values, w, 1000)\n", |
885 | 885 | "\n", |
886 | | - "f = bf.diagnostics.plots.pairs_posterior({'rate': abc_samples}, targets=np.array([true_rate]))" |
| 886 | + "f = bf.diagnostics.plots.pairs_posterior({\"rate\": abc_samples}, targets=np.array([true_rate]))" |
887 | 887 | ] |
888 | 888 | }, |
889 | 889 | { |
|
0 commit comments