the-algorithm/twml/twml/array.py
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

102 lines
2.0 KiB
Python

"""Module containing wrapper class to allow numpy arrays to work with twml functions"""
import ctypes as ct
from absl import logging
from libtwml import CLIB
import numpy as np
_NP_TO_TWML_TYPE = {
'float32': ct.c_int(1),
'float64': ct.c_int(2),
'int32': ct.c_int(3),
'int64': ct.c_int(4),
'int8': ct.c_int(5),
'uint8': ct.c_int(6),
}
class Array(object):
"""
Wrapper class to allow numpy arrays to work with twml functions.
"""
def __init__(self, array):
"""
Wraps numpy array and creates a handle that can be passed to C functions from libtwml.
array: Numpy array
"""
if not isinstance(array, np.ndarray):
raise TypeError("Input must be a numpy array")
try:
ttype = _NP_TO_TWML_TYPE[array.dtype.name]
except KeyError as err:
logging.error("Unsupported numpy type")
raise err
handle = ct.c_void_p(0)
ndim = ct.c_int(array.ndim)
dims = array.ctypes.get_shape()
isize = array.dtype.itemsize
strides_t = ct.c_size_t * array.ndim
strides = strides_t(*[n // isize for n in array.strides])
err = CLIB.twml_tensor_create(ct.pointer(handle),
array.ctypes.get_as_parameter(),
ndim, dims, strides, ttype)
if err != 1000:
raise RuntimeError("Error from libtwml")
# Store the numpy array to ensure it isn't deleted before self
self._array = array
self._handle = handle
self._type = ttype
@property
def handle(self):
"""
Return the twml handle
"""
return self._handle
@property
def shape(self):
"""
Return the shape
"""
return self._array.shape
@property
def ndim(self):
"""
Return the shape
"""
return self._array.ndim
@property
def array(self):
"""
Return the numpy array
"""
return self._array
@property
def dtype(self):
"""
Return numpy dtype
"""
return self._array.dtype
def __del__(self):
"""
Delete the handle
"""
CLIB.twml_tensor_delete(self._handle)