Skip to content

Commit cde8647

Browse files
committed
Use numba for speedup
1 parent cba08e1 commit cde8647

File tree

1 file changed

+114
-58
lines changed

1 file changed

+114
-58
lines changed

diffsims/generators/simulation_generator.py

Lines changed: 114 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
from typing import Union, Sequence, Tuple
2222
import numpy as np
2323
from tqdm import tqdm
24+
from numba import njit
2425

2526
from orix.quaternion import Rotation
27+
from orix.vector import Vector3d
2628
from orix.crystal_map import Phase
2729

2830
from diffsims.crystallography import ReciprocalLatticeVector
@@ -204,11 +206,11 @@ def calculate_diffraction2d(
204206
debye_waller_factors=debye_waller_factors,
205207
)
206208
phase_vectors = []
207-
for rot in rotate:
209+
for rot, optical_axis in zip(rotate, rotate * Vector3d.zvector()):
208210
# Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
209211
intersection, excitation_error = get_intersection_with_ewalds_sphere(
210212
recip,
211-
rot,
213+
optical_axis,
212214
wavelength,
213215
max_excitation_error,
214216
self.precession_angle,
@@ -384,7 +386,7 @@ def get_intersecting_reflections(
384386

385387
intersection, excitation_error = get_intersection_with_ewalds_sphere(
386388
recip,
387-
rot,
389+
rot * Vector3d.zvector(),
388390
wavelength,
389391
max_excitation_error,
390392
self.precession_angle,
@@ -399,11 +401,10 @@ def get_intersecting_reflections(
399401
)
400402
return intersected_vectors, hkl, shape_factor
401403

402-
403404
# TODO consider refactoring into a seperate file
404405
def get_intersection_with_ewalds_sphere(
405406
recip: DiffractingVector,
406-
rot: Rotation,
407+
optical_axis: Vector3d,
407408
wavelength: float,
408409
max_excitation_error: float,
409410
precession_angle: float = 0,
@@ -414,8 +415,8 @@ def get_intersection_with_ewalds_sphere(
414415
----------
415416
recip
416417
The reciprocal lattice vectors to rotate.
417-
rot
418-
The rotation to apply to the reciprocal lattice vectors.
418+
optical_axis
419+
Normalised vector representing the direction of the beam
419420
wavelength
420421
The wavelength of the electrons in Angstroms.
421422
max_excitation_error
@@ -433,15 +434,38 @@ def get_intersection_with_ewalds_sphere(
433434
excitation_error
434435
Excitation error of all vectors
435436
"""
436-
# Identify the excitation errors of all points (distance from point to Ewald sphere)
437+
if precession_angle == 0:
438+
return _get_intersection_with_ewalds_sphere_without_precession(
439+
recip.data,
440+
optical_axis.data.squeeze(),
441+
wavelength,
442+
max_excitation_error
443+
)
444+
return _get_intersection_with_ewalds_sphere_with_precession(
445+
recip.data,
446+
optical_axis.data.squeeze(),
447+
wavelength,
448+
max_excitation_error,
449+
precession_angle
450+
)
437451

452+
@njit(
453+
"float64[:](float64[:, :], float64[:], float64)",
454+
fastmath=True,
455+
)
456+
def _calculate_excitation_error(
457+
recip: np.ndarray,
458+
optical_axis_vector: np.ndarray,
459+
wavelength: float,
460+
) -> np.ndarray:
438461
# Instead of rotating vectors, rotate Ewald's sphere to find intersections.
439462
# Only rotate the intersecting vectors.
440463
# Using notation from https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
441464
r = 1 / wavelength
442-
u = rot.to_matrix().squeeze() @ np.array([0, 0, 1])
465+
# u = rot @ np.array([0, 0, 1])
466+
u = optical_axis_vector
443467
c = r * u
444-
o = recip.data
468+
o = recip
445469

446470
diff = o - c
447471
dot = np.dot(u, diff.T)
@@ -452,53 +476,85 @@ def get_intersection_with_ewalds_sphere(
452476
sqrt_nabla = np.sqrt(nabla)
453477
d = -dot - sqrt_nabla
454478

455-
excitation_error = d
479+
return d
456480

457-
# determine the pre-selection reflections
458-
if precession_angle == 0:
459-
intersection = np.abs(excitation_error) < max_excitation_error
460-
else:
461-
# Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
462-
"""
463-
import sympy
464-
import numpy as np
465-
466-
a = sympy.Symbol("a") # Precession angle
467-
r = sympy.Symbol("r") # Ewald's sphere radius
468-
rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
469-
470-
rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
471-
472-
u = np.asarray([0, 1])
473-
c = r * u
474-
cl = rot(a) @ c
475-
cr = rot(-a) @ c
476-
o = np.asarray([rho, z])
477-
478-
def get_d(_c):
479-
diff = o - _c
480-
dot = np.dot(u, diff)
481-
nabla = dot**2 - sum(i**2 for i in diff) + r**2
482-
sqrt_nabla = nabla**0.5
483-
return -dot - sqrt_nabla
484-
485-
d = get_d(c)
486-
d_upper = get_d(cl)
487-
d_lower = get_d(cr)
488-
489-
print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
490-
print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
491-
print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
492-
"""
493-
# In the above script, d is the same as before.
494-
# We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
495-
# (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
496-
rho = np.linalg.norm(np.dot(o, u)[:, np.newaxis] * u - o, axis=1)
497-
a = np.deg2rad(precession_angle)
498-
first_half = r * np.cos(a) - r + (r**2 - rho**2) ** 0.5
499-
upper = first_half - (r**2 - (r * np.sin(a) + rho) ** 2) ** 0.5
500-
lower = first_half - (r**2 - (r * np.sin(a) - rho) ** 2) ** 0.5
501-
intersection = (d < (upper + max_excitation_error)) & (
502-
d > (lower - max_excitation_error)
503-
)
481+
@njit(
482+
"Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64)",
483+
fastmath=True,
484+
)
485+
def _get_intersection_with_ewalds_sphere_without_precession(
486+
recip: np.ndarray,
487+
optical_axis_vector: np.ndarray,
488+
wavelength: float,
489+
max_excitation_error: float,
490+
) -> Tuple[np.ndarray, np.ndarray]:
491+
excitation_error = _calculate_excitation_error(recip, optical_axis_vector, wavelength)
492+
intersection = np.abs(excitation_error) < max_excitation_error
504493
return intersection, excitation_error
494+
495+
496+
@njit(
497+
"Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64, float64)",
498+
fastmath=True,
499+
)
500+
def _get_intersection_with_ewalds_sphere_with_precession(
501+
recip: np.ndarray,
502+
optical_axis_vector: np.ndarray,
503+
wavelength: float,
504+
max_excitation_error: float,
505+
precession_angle: float,
506+
) -> Tuple[np.ndarray, np.ndarray]:
507+
# Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
508+
# (names are same as in _get_excitation_error_no_precession)
509+
"""
510+
import sympy
511+
import numpy as np
512+
513+
a = sympy.Symbol("a") # Precession angle
514+
r = sympy.Symbol("r") # Ewald's sphere radius
515+
rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
516+
517+
rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
518+
519+
u = np.asarray([0, 1])
520+
c = r * u
521+
cl = rot(a) @ c
522+
cr = rot(-a) @ c
523+
o = np.asarray([rho, z])
524+
525+
def get_d(_c):
526+
diff = o - _c
527+
dot = np.dot(u, diff)
528+
nabla = dot**2 - sum(i**2 for i in diff) + r**2
529+
sqrt_nabla = nabla**0.5
530+
return -dot - sqrt_nabla
531+
532+
d = get_d(c)
533+
d_upper = get_d(cl)
534+
d_lower = get_d(cr)
535+
536+
print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
537+
print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
538+
print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
539+
"""
540+
d = _calculate_excitation_error(recip, optical_axis_vector, wavelength)
541+
542+
r = 1 / wavelength
543+
u = optical_axis_vector
544+
o = recip
545+
546+
excitation_error = d
547+
# We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
548+
# (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
549+
550+
# Numba does not support norm with axes, implement manually
551+
rho = np.sum((np.dot(o, u)[:, np.newaxis] * u - o)**2, axis=1)**0.5
552+
a = np.deg2rad(precession_angle)
553+
first_half = r * np.cos(a) - r + (r**2 - rho**2) ** 0.5
554+
upper = first_half - (r**2 - (r * np.sin(a) + rho) ** 2) ** 0.5
555+
lower = first_half - (r**2 - (r * np.sin(a) - rho) ** 2) ** 0.5
556+
intersection = (d < (upper + max_excitation_error)) & (
557+
d > (lower - max_excitation_error)
558+
)
559+
return intersection, excitation_error
560+

0 commit comments

Comments
 (0)