11# Copyright © 2021 United States Government as represented by the Administrator of the
22# National Aeronautics and Space Administration. All Rights Reserved.
33
4+ from collections import abc
45from itertools import chain
56import matplotlib .pyplot as plt
67from numbers import Number
@@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
476477 raise ValueError (f"layers must be greater than 0, got { params ['layers' ]} " )
477478 if np .isscalar (params ['units' ]):
478479 params ['units' ] = [params ['units' ] for _ in range (params ['layers' ])]
479- if not isinstance (params ['units' ], (list , np .ndarray )):
480- raise TypeError (f"units must be a list of integers, not { type (params ['units' ])} " )
480+ if not isinstance (params ['units' ], (abc . Sequence , np .ndarray )):
481+ raise TypeError (f"units must be a Sequence (e.g., list or tuple) of integers, not { type (params ['units' ])} " )
481482 if len (params ['units' ]) != params ['layers' ]:
482483 raise ValueError (f"units must be a list of integers of length { params ['layers' ]} , got { params ['units' ]} " )
483484 for i in range (params ['layers' ]):
@@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
487488 raise TypeError (f"dropout must be an float greater than or equal to 0, not { type (params ['dropout' ])} " )
488489 if params ['dropout' ] < 0 :
489490 raise ValueError (f"dropout must be greater than or equal to 0, got { params ['dropout' ]} " )
490- if not isinstance (params ['activation' ], (list , np .ndarray )):
491+ if not isinstance (params ['activation' ], (list , tuple , np .ndarray )):
491492 params ['activation' ] = [params ['activation' ] for _ in range (params ['layers' ])]
492493 if not np .isscalar (params ['validation_split' ]):
493494 raise TypeError (f"validation_split must be an float between 0 and 1, not { type (params ['validation_split' ])} " )
0 commit comments