This commit is contained in:
Sai Nishwanth 2023-07-17 21:38:58 -05:00 committed by GitHub
commit 145833ca71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,22 +45,23 @@ class _KeyRecorder(object):
# convert tensorio tensor type to numpy data type. # convert tensorio tensor type to numpy data type.
# also returns element size in bytes. # also returns element size in bytes.
def _get_data_type(data_type): def _get_data_type(data_type):
if data_type == 'Double': match data_type:
return (np.float64, 8) case 'Double':
return (np.float64, 8)
if data_type == 'Float': case 'Float':
return (np.float32, 4) return (np.float32, 4)
if data_type == 'Int': case 'Int':
return (np.int32, 4) return (np.int32, 4)
if data_type == 'Long': case 'Long':
return (np.int64, 8) return (np.int64, 8)
if data_type == 'Byte': case 'Byte':
return (np.int8, 1) return (np.int8, 1)
case _:
raise ValueError('Unexpected tensorio data type: ' + data_type) raise ValueError('Unexpected tensorio data type: ' + data_type)
class TensorIO(object): class TensorIO(object):