2121from typing import Union , Sequence , Tuple
2222import numpy as np
2323from tqdm import tqdm
24+ from numba import njit
2425
2526from orix .quaternion import Rotation
27+ from orix .vector import Vector3d
2628from orix .crystal_map import Phase
2729
2830from diffsims .crystallography import ReciprocalLatticeVector
@@ -204,11 +206,11 @@ def calculate_diffraction2d(
204206 debye_waller_factors = debye_waller_factors ,
205207 )
206208 phase_vectors = []
207- for rot in rotate :
209+ for rot , optical_axis in zip ( rotate , rotate * Vector3d . zvector ()) :
208210 # Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
209211 intersection , excitation_error = get_intersection_with_ewalds_sphere (
210212 recip ,
211- rot ,
213+ optical_axis ,
212214 wavelength ,
213215 max_excitation_error ,
214216 self .precession_angle ,
@@ -384,7 +386,7 @@ def get_intersecting_reflections(
384386
385387 intersection , excitation_error = get_intersection_with_ewalds_sphere (
386388 recip ,
387- rot ,
389+ rot * Vector3d . zvector () ,
388390 wavelength ,
389391 max_excitation_error ,
390392 self .precession_angle ,
@@ -399,11 +401,10 @@ def get_intersecting_reflections(
399401 )
400402 return intersected_vectors , hkl , shape_factor
401403
402-
403404# TODO consider refactoring into a seperate file
404405def get_intersection_with_ewalds_sphere (
405406 recip : DiffractingVector ,
406- rot : Rotation ,
407+ optical_axis : Vector3d ,
407408 wavelength : float ,
408409 max_excitation_error : float ,
409410 precession_angle : float = 0 ,
@@ -414,8 +415,8 @@ def get_intersection_with_ewalds_sphere(
414415 ----------
415416 recip
416417 The reciprocal lattice vectors to rotate.
417- rot
418- The rotation to apply to the reciprocal lattice vectors.
418+ optical_axis
419+ Normalised vector representing the direction of the beam
419420 wavelength
420421 The wavelength of the electrons in Angstroms.
421422 max_excitation_error
@@ -433,15 +434,38 @@ def get_intersection_with_ewalds_sphere(
433434 excitation_error
434435 Excitation error of all vectors
435436 """
436- # Identify the excitation errors of all points (distance from point to Ewald sphere)
437+ if precession_angle == 0 :
438+ return _get_intersection_with_ewalds_sphere_without_precession (
439+ recip .data ,
440+ optical_axis .data .squeeze (),
441+ wavelength ,
442+ max_excitation_error
443+ )
444+ return _get_intersection_with_ewalds_sphere_with_precession (
445+ recip .data ,
446+ optical_axis .data .squeeze (),
447+ wavelength ,
448+ max_excitation_error ,
449+ precession_angle
450+ )
437451
452+ @njit (
453+ "float64[:](float64[:, :], float64[:], float64)" ,
454+ fastmath = True ,
455+ )
456+ def _calculate_excitation_error (
457+ recip : np .ndarray ,
458+ optical_axis_vector : np .ndarray ,
459+ wavelength : float ,
460+ ) -> np .ndarray :
438461 # Instead of rotating vectors, rotate Ewald's sphere to find intersections.
439462 # Only rotate the intersecting vectors.
440463 # Using notation from https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
441464 r = 1 / wavelength
442- u = rot .to_matrix ().squeeze () @ np .array ([0 , 0 , 1 ])
465+ # u = rot @ np.array([0, 0, 1])
466+ u = optical_axis_vector
443467 c = r * u
444- o = recip . data
468+ o = recip
445469
446470 diff = o - c
447471 dot = np .dot (u , diff .T )
@@ -452,53 +476,85 @@ def get_intersection_with_ewalds_sphere(
452476 sqrt_nabla = np .sqrt (nabla )
453477 d = - dot - sqrt_nabla
454478
455- excitation_error = d
479+ return d
456480
457- # determine the pre-selection reflections
458- if precession_angle == 0 :
459- intersection = np .abs (excitation_error ) < max_excitation_error
460- else :
461- # Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
462- """
463- import sympy
464- import numpy as np
465-
466- a = sympy.Symbol("a") # Precession angle
467- r = sympy.Symbol("r") # Ewald's sphere radius
468- rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
469-
470- rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
471-
472- u = np.asarray([0, 1])
473- c = r * u
474- cl = rot(a) @ c
475- cr = rot(-a) @ c
476- o = np.asarray([rho, z])
477-
478- def get_d(_c):
479- diff = o - _c
480- dot = np.dot(u, diff)
481- nabla = dot**2 - sum(i**2 for i in diff) + r**2
482- sqrt_nabla = nabla**0.5
483- return -dot - sqrt_nabla
484-
485- d = get_d(c)
486- d_upper = get_d(cl)
487- d_lower = get_d(cr)
488-
489- print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
490- print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
491- print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
492- """
493- # In the above script, d is the same as before.
494- # We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
495- # (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
496- rho = np .linalg .norm (np .dot (o , u )[:, np .newaxis ] * u - o , axis = 1 )
497- a = np .deg2rad (precession_angle )
498- first_half = r * np .cos (a ) - r + (r ** 2 - rho ** 2 ) ** 0.5
499- upper = first_half - (r ** 2 - (r * np .sin (a ) + rho ) ** 2 ) ** 0.5
500- lower = first_half - (r ** 2 - (r * np .sin (a ) - rho ) ** 2 ) ** 0.5
501- intersection = (d < (upper + max_excitation_error )) & (
502- d > (lower - max_excitation_error )
503- )
481+ @njit (
482+ "Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64)" ,
483+ fastmath = True ,
484+ )
485+ def _get_intersection_with_ewalds_sphere_without_precession (
486+ recip : np .ndarray ,
487+ optical_axis_vector : np .ndarray ,
488+ wavelength : float ,
489+ max_excitation_error : float ,
490+ ) -> Tuple [np .ndarray , np .ndarray ]:
491+ excitation_error = _calculate_excitation_error (recip , optical_axis_vector , wavelength )
492+ intersection = np .abs (excitation_error ) < max_excitation_error
504493 return intersection , excitation_error
494+
495+
496+ @njit (
497+ "Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64, float64)" ,
498+ fastmath = True ,
499+ )
500+ def _get_intersection_with_ewalds_sphere_with_precession (
501+ recip : np .ndarray ,
502+ optical_axis_vector : np .ndarray ,
503+ wavelength : float ,
504+ max_excitation_error : float ,
505+ precession_angle : float ,
506+ ) -> Tuple [np .ndarray , np .ndarray ]:
507+ # Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
508+ # (names are same as in _get_excitation_error_no_precession)
509+ """
510+ import sympy
511+ import numpy as np
512+
513+ a = sympy.Symbol("a") # Precession angle
514+ r = sympy.Symbol("r") # Ewald's sphere radius
515+ rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
516+
517+ rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
518+
519+ u = np.asarray([0, 1])
520+ c = r * u
521+ cl = rot(a) @ c
522+ cr = rot(-a) @ c
523+ o = np.asarray([rho, z])
524+
525+ def get_d(_c):
526+ diff = o - _c
527+ dot = np.dot(u, diff)
528+ nabla = dot**2 - sum(i**2 for i in diff) + r**2
529+ sqrt_nabla = nabla**0.5
530+ return -dot - sqrt_nabla
531+
532+ d = get_d(c)
533+ d_upper = get_d(cl)
534+ d_lower = get_d(cr)
535+
536+ print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
537+ print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
538+ print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
539+ """
540+ d = _calculate_excitation_error (recip , optical_axis_vector , wavelength )
541+
542+ r = 1 / wavelength
543+ u = optical_axis_vector
544+ o = recip
545+
546+ excitation_error = d
547+ # We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
548+ # (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
549+
550+ # Numba does not support norm with axes, implement manually
551+ rho = np .sum ((np .dot (o , u )[:, np .newaxis ] * u - o )** 2 , axis = 1 )** 0.5
552+ a = np .deg2rad (precession_angle )
553+ first_half = r * np .cos (a ) - r + (r ** 2 - rho ** 2 ) ** 0.5
554+ upper = first_half - (r ** 2 - (r * np .sin (a ) + rho ) ** 2 ) ** 0.5
555+ lower = first_half - (r ** 2 - (r * np .sin (a ) - rho ) ** 2 ) ** 0.5
556+ intersection = (d < (upper + max_excitation_error )) & (
557+ d > (lower - max_excitation_error )
558+ )
559+ return intersection , excitation_error
560+
0 commit comments