-
Notifications
You must be signed in to change notification settings - Fork 363
Initialization of orthogonal tensors with respect to a pivot #931
base: master
Are you sure you want to change the base?
Changes from 5 commits
07f4035
a54b55f
cbb6e1d
29c730f
55e5d15
c08cd60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -795,3 +795,13 @@ def eps(self, dtype: Type[np.number]) -> float: | |
| float: Machine epsilon. | ||
| """ | ||
| return np.finfo(dtype).eps | ||
| def initialize_orthogonal_tensor_wrt_pivot(self,shape=Sequence[int],dtype:Optional[Type[np.number]]=None,pivot_axis:int=-1,seed=Optional[int]=None,backend: Optional[Union[Text, AbstractBackend]] = None,non_negative_diagonal: bool = False):->Tensor | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this function |
||
| if seed: | ||
| np.random.seed(seed) | ||
| dtype = dtype if dtype is not None else np.float64 | ||
| if ((np.dtype(dtype) is np.dtype(np.complex128)) or | ||
| (np.dtype(dtype) is np.dtype(np.complex64))): | ||
| q,r= decompositions.qr(np,np.random.randn( | ||
| *shape).astype(dtype) + 1j * np.random.randn(*shape).astype(dtype),pivot_axis,non_negative_diagonal) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is an |
||
| q,r= decompositions.qr(np,np.random.randn(*shape).astype(dtype),pivot_axis,non_negative_diagonal) | ||
| return q | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from tensornetwork import backend_contextmanager | ||
| from tensornetwork import backends | ||
| from tensornetwork.tensor import Tensor | ||
| from tensornetwork.linalg import linalg | ||
|
|
||
| AbstractBackend = abstract_backend.AbstractBackend | ||
|
|
||
|
|
@@ -200,3 +201,7 @@ def random_uniform(shape: Sequence[int], | |
| the_tensor = initialize_tensor("random_uniform", shape, backend=backend, | ||
| seed=seed, boundaries=boundaries, dtype=dtype) | ||
| return the_tensor | ||
| def initialize_orthogonal_tensor_wrt_pivot(shape=Sequence[int],dtype:Optional[Type[np.number]]=None,pivot_axis:int=-1,seed=Optional[int]=None,backend: Optional[Union[Text, AbstractBackend]] = None,non_negative_diagonal:bool=False) ->Tensor: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we could find a less clunky name. Some possibilities that come to my mind are random_orthogonal or random_isometry @alewis?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls add a docstring that explains what the function is doing, what the arguments are, and what the returned values are. |
||
| the_tensor=initialize_tensor("randn",shape,backend=backend,seed=seed,dtype=dtype) | ||
| q,r=linalg.qr(the_tensor,pivot_axis,non_negative_diagonal) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. us |
||
| return q | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -177,3 +177,18 @@ def inner_zero_test(dtype): | |
| numpyCheck = backend_obj.zeros(n.shape, dtype=dtype) | ||
| np.testing.assert_allclose(tensor.array, tensorCheck) | ||
| np.testing.assert_allclose(numpyT.array, numpyCheck) | ||
|
|
||
| def test_initialize_orthogonal_tensor_wrt_pivot(backend): | ||
| shape=(5, 10, 3, 2) | ||
| pivot_axis=1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls extend test to several values of the pivot axis |
||
| seed = int(time.time()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls use deterministic seed initialization |
||
| np.random.seed(seed=seed) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that line seems superflous |
||
| backend_obj = backends.backend_factory.get_backend(backend) | ||
| for dtype in dtypes[backend]["rand"]: | ||
| tnI = tensornetwork.initialize_orthogonal_tensor_wrt_pivot( | ||
| shape, | ||
| dtype=dtype,pivot_axis, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that line should throw a syntax error because your passing an argument between named arguments |
||
| seed=seed, | ||
| backend=backend,non_negative_diagonal) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| npI = backend_obj.initialize_orthogonal_tensor_wrt_pivot(shape, dtype=dtype, pivot_axis, seed=seed,non_negative_diagonal) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the function from the backend |
||
| np.testing.assert_allclose(tnI.array, npI) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls replace with a test that checks if the initialized tensor has the desired properties |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you add this function to the backend? I don't think we need it here