Skip to content

Commit 7e307fb

Browse files
committed
Use numba for speedup
1 parent e3b09d8 commit 7e307fb

File tree

1 file changed

+114
-58
lines changed

1 file changed

+114
-58
lines changed

diffsims/generators/simulation_generator.py

Lines changed: 114 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
from typing import Union, Sequence, Tuple
2222
import numpy as np
23+
from numba import njit
2324

2425
from orix.quaternion import Rotation
26+
from orix.vector import Vector3d
2527
from orix.crystal_map import Phase
2628

2729
from diffsims.crystallography._diffracting_vector import DiffractingVector
@@ -199,11 +201,11 @@ def calculate_diffraction2d(
199201
debye_waller_factors=debye_waller_factors,
200202
)
201203
phase_vectors = []
202-
for rot in rotate:
204+
for rot, optical_axis in zip(rotate, rotate * Vector3d.zvector()):
203205
# Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
204206
intersection, excitation_error = get_intersection_with_ewalds_sphere(
205207
recip,
206-
rot,
208+
optical_axis,
207209
wavelength,
208210
max_excitation_error,
209211
self.precession_angle,
@@ -407,7 +409,7 @@ def get_intersecting_reflections(
407409

408410
intersection, excitation_error = get_intersection_with_ewalds_sphere(
409411
recip,
410-
rot,
412+
rot * Vector3d.zvector(),
411413
wavelength,
412414
max_excitation_error,
413415
self.precession_angle,
@@ -422,11 +424,10 @@ def get_intersecting_reflections(
422424
)
423425
return intersected_vectors, hkl, shape_factor
424426

425-
426427
# TODO consider refactoring into a seperate file
427428
def get_intersection_with_ewalds_sphere(
428429
recip: DiffractingVector,
429-
rot: Rotation,
430+
optical_axis: Vector3d,
430431
wavelength: float,
431432
max_excitation_error: float,
432433
precession_angle: float = 0,
@@ -437,8 +438,8 @@ def get_intersection_with_ewalds_sphere(
437438
----------
438439
recip
439440
The reciprocal lattice vectors to rotate.
440-
rot
441-
The rotation to apply to the reciprocal lattice vectors.
441+
optical_axis
442+
Normalised vector representing the direction of the beam
442443
wavelength
443444
The wavelength of the electrons in Angstroms.
444445
max_excitation_error
@@ -456,15 +457,38 @@ def get_intersection_with_ewalds_sphere(
456457
excitation_error
457458
Excitation error of all vectors
458459
"""
459-
# Identify the excitation errors of all points (distance from point to Ewald sphere)
460+
if precession_angle == 0:
461+
return _get_intersection_with_ewalds_sphere_without_precession(
462+
recip.data,
463+
optical_axis.data.squeeze(),
464+
wavelength,
465+
max_excitation_error
466+
)
467+
return _get_intersection_with_ewalds_sphere_with_precession(
468+
recip.data,
469+
optical_axis.data.squeeze(),
470+
wavelength,
471+
max_excitation_error,
472+
precession_angle
473+
)
460474

475+
@njit(
476+
"float64[:](float64[:, :], float64[:], float64)",
477+
fastmath=True,
478+
)
479+
def _calculate_excitation_error(
480+
recip: np.ndarray,
481+
optical_axis_vector: np.ndarray,
482+
wavelength: float,
483+
) -> np.ndarray:
461484
# Instead of rotating vectors, rotate Ewald's sphere to find intersections.
462485
# Only rotate the intersecting vectors.
463486
# Using notation from https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
464487
r = 1 / wavelength
465-
u = rot.to_matrix().squeeze() @ np.array([0, 0, 1])
488+
# u = rot @ np.array([0, 0, 1])
489+
u = optical_axis_vector
466490
c = r * u
467-
o = recip.data
491+
o = recip
468492

469493
diff = o - c
470494
dot = np.dot(u, diff.T)
@@ -475,53 +499,85 @@ def get_intersection_with_ewalds_sphere(
475499
sqrt_nabla = np.sqrt(nabla)
476500
d = -dot - sqrt_nabla
477501

478-
excitation_error = d
502+
return d
479503

480-
# determine the pre-selection reflections
481-
if precession_angle == 0:
482-
intersection = np.abs(excitation_error) < max_excitation_error
483-
else:
484-
# Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
485-
"""
486-
import sympy
487-
import numpy as np
488-
489-
a = sympy.Symbol("a") # Precession angle
490-
r = sympy.Symbol("r") # Ewald's sphere radius
491-
rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
492-
493-
rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
494-
495-
u = np.asarray([0, 1])
496-
c = r * u
497-
cl = rot(a) @ c
498-
cr = rot(-a) @ c
499-
o = np.asarray([rho, z])
500-
501-
def get_d(_c):
502-
diff = o - _c
503-
dot = np.dot(u, diff)
504-
nabla = dot**2 - sum(i**2 for i in diff) + r**2
505-
sqrt_nabla = nabla**0.5
506-
return -dot - sqrt_nabla
507-
508-
d = get_d(c)
509-
d_upper = get_d(cl)
510-
d_lower = get_d(cr)
511-
512-
print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
513-
print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
514-
print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
515-
"""
516-
# In the above script, d is the same as before.
517-
# We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
518-
# (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
519-
rho = np.linalg.norm(np.dot(o, u)[:, np.newaxis] * u - o, axis=1)
520-
a = np.deg2rad(precession_angle)
521-
first_half = r * np.cos(a) - r + (r**2 - rho**2) ** 0.5
522-
upper = first_half - (r**2 - (r * np.sin(a) + rho) ** 2) ** 0.5
523-
lower = first_half - (r**2 - (r * np.sin(a) - rho) ** 2) ** 0.5
524-
intersection = (d < (upper + max_excitation_error)) & (
525-
d > (lower - max_excitation_error)
526-
)
504+
@njit(
505+
"Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64)",
506+
fastmath=True,
507+
)
508+
def _get_intersection_with_ewalds_sphere_without_precession(
509+
recip: np.ndarray,
510+
optical_axis_vector: np.ndarray,
511+
wavelength: float,
512+
max_excitation_error: float,
513+
) -> Tuple[np.ndarray, np.ndarray]:
514+
excitation_error = _calculate_excitation_error(recip, optical_axis_vector, wavelength)
515+
intersection = np.abs(excitation_error) < max_excitation_error
527516
return intersection, excitation_error
517+
518+
519+
@njit(
520+
"Tuple((bool[:], float64[:]))(float64[:, :], float64[:], float64, float64, float64)",
521+
fastmath=True,
522+
)
523+
def _get_intersection_with_ewalds_sphere_with_precession(
524+
recip: np.ndarray,
525+
optical_axis_vector: np.ndarray,
526+
wavelength: float,
527+
max_excitation_error: float,
528+
precession_angle: float,
529+
) -> Tuple[np.ndarray, np.ndarray]:
530+
# Using the following script to find equations for upper and lower bounds for precessing Ewald's sphere
531+
# (names are same as in _get_excitation_error_no_precession)
532+
"""
533+
import sympy
534+
import numpy as np
535+
536+
a = sympy.Symbol("a") # Precession angle
537+
r = sympy.Symbol("r") # Ewald's sphere radius
538+
rho, z = sympy.symbols("rho z") # cylindrical coordinates of reflection
539+
540+
rot = lambda ang: np.asarray([[sympy.cos(ang), -sympy.sin(ang)],[sympy.sin(ang), sympy.cos(ang)]])
541+
542+
u = np.asarray([0, 1])
543+
c = r * u
544+
cl = rot(a) @ c
545+
cr = rot(-a) @ c
546+
o = np.asarray([rho, z])
547+
548+
def get_d(_c):
549+
diff = o - _c
550+
dot = np.dot(u, diff)
551+
nabla = dot**2 - sum(i**2 for i in diff) + r**2
552+
sqrt_nabla = nabla**0.5
553+
return -dot - sqrt_nabla
554+
555+
d = get_d(c)
556+
d_upper = get_d(cl)
557+
d_lower = get_d(cr)
558+
559+
print(d.simplify()) # r - z - (r**2 - rho**2)**0.5
560+
print((d_upper - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) + rho)**2)**0.5
561+
print((d_lower - d).simplify()) # r*cos(a) - r + (r**2 - rho**2)**0.5 - (r**2 - (r*sin(a) - rho)**2)**0.5
562+
"""
563+
d = _calculate_excitation_error(recip, optical_axis_vector, wavelength)
564+
565+
r = 1 / wavelength
566+
u = optical_axis_vector
567+
o = recip
568+
569+
excitation_error = d
570+
# We need the distance of the reflections from the incident beam, i.e. the cylindrical coordinate rho
571+
# (using https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Vector_formulation):
572+
573+
# Numba does not support norm with axes, implement manually
574+
rho = np.sum((np.dot(o, u)[:, np.newaxis] * u - o)**2, axis=1)**0.5
575+
a = np.deg2rad(precession_angle)
576+
first_half = r * np.cos(a) - r + (r**2 - rho**2) ** 0.5
577+
upper = first_half - (r**2 - (r * np.sin(a) + rho) ** 2) ** 0.5
578+
lower = first_half - (r**2 - (r * np.sin(a) - rho) ** 2) ** 0.5
579+
intersection = (d < (upper + max_excitation_error)) & (
580+
d > (lower - max_excitation_error)
581+
)
582+
return intersection, excitation_error
583+

0 commit comments

Comments
 (0)