2020
2121from typing import Union , Sequence , Tuple
2222import numpy as np
23+ from numba import njit
2324
2425from orix .quaternion import Rotation
26+ from orix .vector import Vector3d
2527from orix .crystal_map import Phase
2628
2729from diffsims .crystallography ._diffracting_vector import DiffractingVector
@@ -199,11 +201,11 @@ def calculate_diffraction2d(
199201 debye_waller_factors = debye_waller_factors ,
200202 )
201203 phase_vectors = []
202- for rot in rotate :
204+ for rot , optical_axis in zip ( rotate , rotate * Vector3d . zvector ()) :
203205 # Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
204206 intersection , excitation_error = get_intersection_with_ewalds_sphere (
205207 recip ,
206- rot ,
208+ optical_axis ,
207209 wavelength ,
208210 max_excitation_error ,
209211 self .precession_angle ,
@@ -407,7 +409,7 @@ def get_intersecting_reflections(
407409
408410 intersection , excitation_error = get_intersection_with_ewalds_sphere (
409411 recip ,
410- rot ,
412+ rot * Vector3d . zvector () ,
411413 wavelength ,
412414 max_excitation_error ,
413415 self .precession_angle ,
@@ -422,11 +424,10 @@ def get_intersecting_reflections(
422424 )
423425 return intersected_vectors , hkl , shape_factor
424426
425-
426427# TODO consider refactoring into a seperate file
427428def get_intersection_with_ewalds_sphere (
428429 recip : DiffractingVector ,
429- rot : Rotation ,
430+ optical_axis : Vector3d ,
430431 wavelength : float ,
431432 max_excitation_error : float ,
432433 precession_angle : float = 0 ,
@@ -437,8 +438,8 @@ def get_intersection_with_ewalds_sphere(
437438 ----------
438439 recip
439440 The reciprocal lattice vectors to rotate.
440- rot
441- The rotation to apply to the reciprocal lattice vectors.
441+ optical_axis
442+ Normalised vector representing the direction of the beam
442443 wavelength
443444 The wavelength of the electrons in Angstroms.
444445 max_excitation_error
@@ -456,15 +457,38 @@ def get_intersection_with_ewalds_sphere(
456457 excitation_error
457458 Excitation error of all vectors
458459 """
459- # Identify the excitation errors of all points (distance from point to Ewald sphere)
460+ if precession_angle == 0 :
461+ return _get_intersection_with_ewalds_sphere_without_precession (
462+ recip .data ,
463+ optical_axis .data .squeeze (),
464+ wavelength ,
465+ max_excitation_error
466+ )
467+ return _get_intersection_with_ewalds_sphere_with_precession (
468+ recip .data ,
469+ optical_axis .data .squeeze (),
470+ wavelength ,
471+ max_excitation_error ,
472+ precession_angle
473+ )
460474
475+ @njit (
476+ "float64[:](float64[:, :], float64[:], float64)" ,
477+ fastmath = True ,
478+ )
479+ def _calculate_excitation_error (
480+ recip : np .ndarray ,
481+ optical_axis_vector : np .ndarray ,
482+ wavelength : float ,
483+ ) -> np .ndarray :
461484 # Instead of rotating vectors, rotate Ewald's sphere to find intersections.
462485 # Only rotate the intersecting vectors.
463486 # Using notation from https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
464487 r = 1 / wavelength
465- u = rot .to_matrix ().squeeze () @ np .array ([0 , 0 , 1 ])
488+ # u = rot @ np.array([0, 0, 1])
489+ u = optical_axis_vector
466490 c = r * u
467- o = recip . data
491+ o = recip
468492
469493 diff = o - c
470494 dot = np .dot (u , diff .T )
@@ -475,53 +499,85 @@ def get_intersection_with_ewalds_sphere(
475499 sqrt_nabla = np .sqrt (nabla )
476500 d = - dot - sqrt_nabla
477501
478- excitation_error = d
502+ return d
479503
480- # determine the pre-selection reflections
481- if precession_angle == 0 :
482- intersection = np .abs (excitation_error ) < max_excitation_error
483- else :
484- # Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
485- """
486- import sympy
487- import numpy as np
488-
489- a = sympy.Symbol("a") # Precession angle
490- r = sympy.Symbol("r") # Ewald's sphere radius
491- rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
492-
493- rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
494-
495- u = np.asarray([0, 1])
496- c = r * u
497- cl = rot(a) @ c
498- cr = rot(-a) @ c
499- o = np.asarray([rho, z])
500-
501- def get_d(_c):
502- diff = o - _c
503- dot = np.dot(u, diff)
504- nabla = dot**2 - sum(i**2 for i in diff) + r**2
505- sqrt_nabla = nabla**0.5
506- return -dot - sqrt_nabla
507-
508- d = get_d(c)
509- d_upper = get_d(cl)
510- d_lower = get_d(cr)
511-
512- print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
513- print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
514- print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
515- """
516- # In the above script, d is the same as before.
517- # We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
518- # (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
519- rho = np .linalg .norm (np .dot (o , u )[:, np .newaxis ] * u - o , axis = 1 )
520- a = np .deg2rad (precession_angle )
521- first_half = r * np .cos (a ) - r + (r ** 2 - rho ** 2 ) ** 0.5
522- upper = first_half - (r ** 2 - (r * np .sin (a ) + rho ) ** 2 ) ** 0.5
523- lower = first_half - (r ** 2 - (r * np .sin (a ) - rho ) ** 2 ) ** 0.5
524- intersection = (d < (upper + max_excitation_error )) & (
525- d > (lower - max_excitation_error )
526- )
504+ @njit (
505+ "Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64)" ,
506+ fastmath = True ,
507+ )
508+ def _get_intersection_with_ewalds_sphere_without_precession (
509+ recip : np .ndarray ,
510+ optical_axis_vector : np .ndarray ,
511+ wavelength : float ,
512+ max_excitation_error : float ,
513+ ) -> Tuple [np .ndarray , np .ndarray ]:
514+ excitation_error = _calculate_excitation_error (recip , optical_axis_vector , wavelength )
515+ intersection = np .abs (excitation_error ) < max_excitation_error
527516 return intersection , excitation_error
517+
518+
519+ @njit (
520+ "Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64, float64)" ,
521+ fastmath = True ,
522+ )
523+ def _get_intersection_with_ewalds_sphere_with_precession (
524+ recip : np .ndarray ,
525+ optical_axis_vector : np .ndarray ,
526+ wavelength : float ,
527+ max_excitation_error : float ,
528+ precession_angle : float ,
529+ ) -> Tuple [np .ndarray , np .ndarray ]:
530+ # Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
531+ # (names are same as in _get_excitation_error_no_precession)
532+ """
533+ import sympy
534+ import numpy as np
535+
536+ a = sympy.Symbol("a") # Precession angle
537+ r = sympy.Symbol("r") # Ewald's sphere radius
538+ rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
539+
540+ rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
541+
542+ u = np.asarray([0, 1])
543+ c = r * u
544+ cl = rot(a) @ c
545+ cr = rot(-a) @ c
546+ o = np.asarray([rho, z])
547+
548+ def get_d(_c):
549+ diff = o - _c
550+ dot = np.dot(u, diff)
551+ nabla = dot**2 - sum(i**2 for i in diff) + r**2
552+ sqrt_nabla = nabla**0.5
553+ return -dot - sqrt_nabla
554+
555+ d = get_d(c)
556+ d_upper = get_d(cl)
557+ d_lower = get_d(cr)
558+
559+ print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
560+ print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
561+ print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
562+ """
563+ d = _calculate_excitation_error (recip , optical_axis_vector , wavelength )
564+
565+ r = 1 / wavelength
566+ u = optical_axis_vector
567+ o = recip
568+
569+ excitation_error = d
570+ # We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
571+ # (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
572+
573+ # Numba does not support norm with axes, implement manually
574+ rho = np .sum ((np .dot (o , u )[:, np .newaxis ] * u - o )** 2 , axis = 1 )** 0.5
575+ a = np .deg2rad (precession_angle )
576+ first_half = r * np .cos (a ) - r + (r ** 2 - rho ** 2 ) ** 0.5
577+ upper = first_half - (r ** 2 - (r * np .sin (a ) + rho ) ** 2 ) ** 0.5
578+ lower = first_half - (r ** 2 - (r * np .sin (a ) - rho ) ** 2 ) ** 0.5
579+ intersection = (d < (upper + max_excitation_error )) & (
580+ d > (lower - max_excitation_error )
581+ )
582+ return intersection , excitation_error
583+
0 commit comments