mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-12-22 18:21:51 +01:00
Compare commits
2 Commits
145833ca71
...
0deed30d2c
Author | SHA1 | Date | |
---|---|---|---|
|
0deed30d2c | ||
|
b7e005427c |
@ -45,22 +45,23 @@ class _KeyRecorder(object):
|
|||||||
# convert tensorio tensor type to numpy data type.
|
# convert tensorio tensor type to numpy data type.
|
||||||
# also returns element size in bytes.
|
# also returns element size in bytes.
|
||||||
def _get_data_type(data_type):
|
def _get_data_type(data_type):
|
||||||
if data_type == 'Double':
|
match data_type:
|
||||||
return (np.float64, 8)
|
case 'Double':
|
||||||
|
return (np.float64, 8)
|
||||||
|
|
||||||
if data_type == 'Float':
|
case 'Float':
|
||||||
return (np.float32, 4)
|
return (np.float32, 4)
|
||||||
|
|
||||||
if data_type == 'Int':
|
case 'Int':
|
||||||
return (np.int32, 4)
|
return (np.int32, 4)
|
||||||
|
|
||||||
if data_type == 'Long':
|
case 'Long':
|
||||||
return (np.int64, 8)
|
return (np.int64, 8)
|
||||||
|
|
||||||
if data_type == 'Byte':
|
case 'Byte':
|
||||||
return (np.int8, 1)
|
return (np.int8, 1)
|
||||||
|
case _:
|
||||||
raise ValueError('Unexpected tensorio data type: ' + data_type)
|
raise ValueError('Unexpected tensorio data type: ' + data_type)
|
||||||
|
|
||||||
|
|
||||||
class TensorIO(object):
|
class TensorIO(object):
|
||||||
|
Loading…
Reference in New Issue
Block a user