Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 45 additions & 41 deletions lyncs_quda/clover_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import numpy
from cppyy.gbl.std import vector
from functools import cache

from lyncs_cppyy import make_shared, to_pointer
from .lib import lib, cupy
from .lattice_field import LatticeField
from .gauge_field import GaugeField
from .enums import QudaParity
from .enums import *

# TODO list
# We want dimension of (cu/num)py array to reflect parity and order
Expand All @@ -32,12 +34,12 @@ class CloverField(LatticeField):
* Only rho is mutable. To change other params, a new instance should be created
* QUDA convention for clover field := 1+i ( kappa csw )/4 sigma_mu,nu F_mu,nu (<-sigma_mu,nu: spinor tensor)
* so that sigma_mu,nu = i[g_mu, g_nu], F_mu,nu = (Q_mu,nu - Q_nu,mu)/8 (1/2 is missing from sigma_mu,nu)
* Apparently, an input to QUDA clover object, coeff = kappa*csw
* wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp)
* Apparently, an input to QUDA clover object, coeff = kappa*csw
* wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp)
"""

def __new__(cls, fmunu, **kwargs):
#TODO: get dofs and local dims from kwargs, instead of getting them
# TODO: get dofs and local dims from kwargs, instead of getting them
# from self.shape assuming that it has the form (dofs, local_dims)
if isinstance(fmunu, CloverField):
return fmunu
Expand All @@ -52,14 +54,14 @@ def __new__(cls, fmunu, **kwargs):
field = fmunu
else:
fmunu = GaugeField(fmunu)
if not is_clover: # not copying from a clover-field array

if not is_clover: # not copying from a clover-field array
idof = int((fmunu.ncol * fmunu.ndims) ** 2 / 2)
prec = fmunu.dtype
field = fmunu.backend.empty((idof,) + fmunu.dims, dtype=prec)

return super().__new__(cls, field, **kwargs)

def __init__(
self,
obj,
Expand All @@ -70,7 +72,7 @@ def __init__(
eps2=0,
rho=0,
computeTrLog=False,
**kwargs
**kwargs,
):
# WARNING: ndarray object is not supposed to be view-casted to CloverField object
# except in __new__, for which __init__ will be called subsequently,
Expand All @@ -88,12 +90,8 @@ def __init__(
empty=True,
)
self._fmunu = obj.compute_fmunu()
self._direct = (
False # Here, it is a flag to indicate whether the field has been computed
)
self._inverse = (
False # Here, it is a flag to indicate whether the field has been computed
)
self._direct = False # Here, it is a flag to indicate whether the field has been computed
self._inverse = False # Here, it is a flag to indicate whether the field has been computed
self.coeff = coeff
self._twisted = twisted
self._twist_flavor = tf
Expand All @@ -107,12 +105,14 @@ def __init__(
elif isinstance(obj, self.backend.ndarray):
pass
else:
raise ValueError("The input is expected to be ndarray or LatticeField object")
raise ValueError(
"The input is expected to be ndarray or LatticeField object"
)

def _prepare(self, field, copy=False, check=False, **kwargs):
# When CloverField object prepares its input, the input is assumed to be of CloverField
return super()._prepare(field, copy=copy, check=check, is_clover=True, **kwargs)

# naming suggestion: native_view? default_* saved for dofs+lattice?
def default_view(self):
N = 1 if self.order == "FLOAT2" else 4
Expand All @@ -126,6 +126,7 @@ def twisted(self):
return self._twisted

@property
@QudaTwistFlavorType
def twist_flavor(self):
return self._twist_flavor

Expand Down Expand Up @@ -153,16 +154,22 @@ def rho(self, val):
self._rho = val

@property
@QudaCloverFieldOrder
def order(self):
"Data order of the field"
if self.precision == "double":
return "FLOAT2"
return "FLOAT4"

@property
def quda_order(self):
"Quda enum for data order of the field"
return getattr(lib, f"QUDA_{self.order}_CLOVER_ORDER")
@staticmethod
@cache
def _clv_params(param, **kwargs):
"Call wrapper to cache param structures"
params = lib.CloverFieldParam()
lib.copy_struct(params, param)
for key, val in kwargs.items():
setattr(params, key, val)
return params

@property
def quda_params(self):
Expand All @@ -178,20 +185,21 @@ def quda_params(self):
an alias to inverse. not really sure what this is, but does
not work properly when reconstruct==True
"""
params = lib.CloverFieldParam()
lib.copy_struct(params, super().quda_params)
params.inverse = True
params.clover = to_pointer(self.ptr)
params.cloverInv = to_pointer(self._cloverInv.ptr)
params.coeff = self.coeff
params.twisted = self.twisted
params.twist_flavor = getattr(lib, f"QUDA_TWIST_{self.twist_flavor}")
params.mu2 = self.mu2
params.epsilon2 = self.eps2
params.rho = self.rho
params.order = self.quda_order
params.create = lib.QUDA_REFERENCE_FIELD_CREATE
params.location = self.quda_location
params = self._clv_params(
super().quda_params,
inverse=True,
clover=to_pointer(self.ptr),
cloverInv=to_pointer(self._cloverInv.ptr),
coeff=self.coeff,
twisted=self.twisted,
twist_flavor=int(self.twist_flavor),
mu2=self.mu2,
epsilon2=self.eps2,
rho=self.rho,
order=int(self.order),
create=int(QudaFieldCreate["reference"]),
location=int(self.location),
)
return params

@property
Expand Down Expand Up @@ -230,7 +238,7 @@ def trLog(self):

def is_native(self):
"Whether the field is native for Quda"
return lib.clover.isNative(self.quda_order, self.quda_precision)
return lib.clover.isNative(int(self.order), self.quda_precision)

@property
def ncol(self):
Expand Down Expand Up @@ -354,11 +362,7 @@ def computeCloverForce(self, gauge, force, D, vxs, vps, mult=2, coeffs=None):
u = gauge.extended_field(sites=R)
if gauge.precision == "double":
u = gauge.prepare_in(gauge, reconstruct="NO").extended_field(sites=R)
lib.cloverDerivative(
force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_ODD_PARITY")
)
lib.cloverDerivative(
force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_EVEN_PARITY")
)
lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["ODD"]))
lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["EVEN"]))

return force
37 changes: 13 additions & 24 deletions lyncs_quda/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .clover_field import CloverField
from .spinor_field import spinor
from .lib import lib
from .enums import QudaPrecision
from .enums import *


@dataclass(frozen=True)
Expand Down Expand Up @@ -48,6 +48,7 @@ def __post_init__(self):
# TODO: Support more Dirac types
# Unsupported: DomainWall(4D/PC), Mobius(PC/Eofa), (Improved)Staggered(KD/PC), GaugeLaplace(PC), GaugeCovDev
@property
@QudaDiracType
def type(self):
"Type of the operator"
PC = "PC" if not self.full else ""
Expand All @@ -62,22 +63,14 @@ def type(self):
return "TWISTED_CLOVER" + PC

@property
def quda_type(self):
"Quda enum for quda dslash type"
return getattr(lib, f"QUDA_{self.type}_DIRAC")

@property
@QudaMatPCType
def matPCtype(self):
if self.full:
return "INVALID"
parity = "EVEN" if self.even else "ODD"
symm = "_ASYMMETRIC" if not self.symm else ""
return f"{parity}_{parity}{symm}"

@property
def quda_matPCtype(self):
return getattr(lib, f"QUDA_MATPC_{self.matPCtype}")

@property
def is_coarse(self):
"Whether is a coarse operator"
Expand All @@ -88,26 +81,22 @@ def precision(self):
return self.gauge.precision

@property
@QudaDagType
def dagger(self):
"If the operator is daggered"
return "NO"

@property
def quda_dagger(self):
"Quda enum for if the operator is dagger"
return getattr(lib, f"QUDA_DAG_{self.dagger}")

@property
def quda_params(self):
params = lib.DiracParam()
params.type = self.quda_type
params.type = int(self.type)
params.kappa = self.kappa
params.m5 = self.m5
params.Ls = self.Ls
params.mu = self.mu
params.epsilon = self.epsilon
params.dagger = self.quda_dagger
params.matpcType = self.quda_matPCtype
params.dagger = int(self.dagger)
params.matpcType = int(self.matPCtype)

# Needs to prevent the gauge field to get destroyed
# now we store QUDA gauge object in _quda, but it
Expand Down Expand Up @@ -299,16 +288,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params):
D.quda_dirac.Dslash(
xs[-1].quda_field.Odd(),
xs[-1].quda_field.Even(),
getattr(lib, "QUDA_ODD_PARITY"),
int(QudaParity["ODD"]),
)
D.quda_dirac.M(ps[i].quda_field.Even(), xs[-1].quda_field.Even())
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES"))
D.quda_dirac.Dslash(
ps[i].quda_field.Odd(),
ps[i].quda_field.Even(),
getattr(lib, "QUDA_ODD_PARITY"),
int(QudaParity["ODD"]),
)
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO"))
D.quda_dirac.Dagger(int(QudaDagType["NO"]))
else:
# Even-odd preconditioned case (i.e., PC in Dirac.type):
# use only odd part of phi
Expand All @@ -317,16 +306,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params):
D.quda_dirac.Dslash(
xs[-1].quda_field.Even(),
xs[-1].quda_field.Odd(),
getattr(lib, "QUDA_EVEN_PARITY"),
int(QudaParity["EVEN"]),
)
D.quda_dirac.M(ps[i].quda_field.Odd(), xs[-1].quda_field.Odd())
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES"))
D.quda_dirac.Dslash(
ps[i].quda_field.Even(),
ps[i].quda_field.Odd(),
getattr(lib, "QUDA_EVEN_PARITY"),
int(QudaParity["EVEN"]),
)
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO"))
D.quda_dirac.Dagger((int(QudaDagType["NO"])))

for i in range(n):
xs[i].apply_gamma5()
Expand Down
52 changes: 37 additions & 15 deletions lyncs_quda/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def __eq__(self, other):
return self.cls is other.cls and int(self) == int(other)
return False

def __ne__(self, other):
return not (
self == other
) # TODO: perhaps better to insert if-cond for NotImpelented

def __contains__(self, other):
if isinstance(other, str):
return self.cls.clean(other) in str(self)
if isinstance(other, int):
return self.to_string(other) in str(self)
return False


class EnumMeta(type):
"Metaclass for enum types"
Expand All @@ -42,18 +54,19 @@ def items(cls):
"List of enum items"
return cls._values.items()

def clean(cls, key):
def clean(cls, rep):
# should turn everything into upper for consistency
"Strips away prefix and suffix from key"
"See enums.py to find what is prefix and suffix for a given enum value"
if isinstance(key, EnumValue):
key = str(key)
if isinstance(key, str):
key = key.lower()
if cls._prefix and key.startswith(cls._prefix):
key = key[len(cls._prefix) :]
if cls._suffix and key.endswith(cls._suffix):
key = key[: -len(cls._suffix)]
return key
if isinstance(rep, EnumValue):
rep = str(rep)
if isinstance(rep, str):
rep = rep.lower()
if cls._prefix and rep.startswith(cls._prefix):
rep = rep[len(cls._prefix) :]
if cls._suffix and rep.endswith(cls._suffix):
rep = rep[: -len(cls._suffix)]
return rep

def to_string(cls, rep):
"Returns the key representative of the given enum value"
Expand Down Expand Up @@ -103,17 +116,26 @@ class Enum(metaclass=EnumMeta):
_suffix = ""
_values = {}

def __init__(self, key, default=None, callback=None):
self.key = key
def __init__(self, fnc, lpath=None, default=None, callback=None):
# fnc is supposed to return either a stripped key name or value of
# the corresponding QUDA enum type
self.fnc = fnc
self.lpath = lpath
self.default = default
self.callback = callback

def __call__(self, instance):
# intended for property.fget, which then invokes
# property.__get__(self, obj, objtype=None)
return EnumValue(type(self), self.fnc(instance))

# not meant to be a stnadard descriptor, c.f., solver.py
def __get__(self, instance, owner):
if instance is None:
raise AttributeError

out = instance
for key in self.key.split("."):
for key in self.lpath.split("."):
out = getattr(out, key)
return type(self)[out]

Expand All @@ -126,9 +148,9 @@ def __set__(self, instance, new):
new = int(type(self)[new])

out = instance
for key in self.key.split(".")[:-1]:
for key in self.lpath.split(".")[:-1]:
out = getattr(out, key)
key = self.key.split(".")[-1]
key = self.lpath.split(".")[-1]
old = int(getattr(out, key))

setattr(out, key, new)
Expand Down
Loading