mirror of
https://github.com/FIX94/hbc.git
synced 2025-11-04 07:16:13 +01:00
371 lines
8.6 KiB
Python
371 lines
8.6 KiB
Python
#!/usr/bin/python2
|
|
# Copyright 2007,2008 Segher Boessenkool <segher@kernel.crashing.org>
|
|
# Copyright 2008 Hector Martin <marcan@marcansoft.com>
|
|
# Licensed under the terms of the GNU GPL, version 2
|
|
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt
|
|
|
|
from array import array
|
|
from struct import pack, unpack
|
|
try:
|
|
from Cryptodome.Util.number import bytes_to_long, long_to_bytes
|
|
except ImportError:
|
|
from Crypto.Util.number import bytes_to_long, long_to_bytes
|
|
|
|
# y**2 + x*y = x**3 + x + b
|
|
ec_b = "\x00\x66\x64\x7e\xde\x6c\x33\x2c\x7f\x8c\x09\x23\xbb\x58\x21"+\
|
|
"\x3b\x33\x3b\x20\xe9\xce\x42\x81\xfe\x11\x5f\x7d\x8f\x90\xad"
|
|
|
|
def hexdump(s,sep=""):
|
|
return sep.join(map(lambda x: "%02x"%ord(x),s))
|
|
|
|
def bhex(s,sep=""):
|
|
return hexdump(long_to_bytes(s,30),sep)
|
|
|
|
fastelt = False
|
|
try:
|
|
import _ec
|
|
fastelt = True
|
|
except ImportError:
|
|
#print "C Elliptic Curve functions not available. EC certificate checking will be much slower."
|
|
pass
|
|
|
|
|
|
class ByteArray(array):
|
|
def __new__(cls, initializer=None):
|
|
return super(ByteArray, cls) .__new__(cls,'B',initializer)
|
|
def __init__(self,initializer=None):
|
|
array.__init__(self)
|
|
def __setitem__(self,item,value):
|
|
if isinstance(item, slice):
|
|
array.__setitem__(self, item, [x & 0xFF for x in value])
|
|
else:
|
|
array.__setitem__(self, item, value & 0xFF)
|
|
def __long__(self):
|
|
return bytes_to_long(self.tostring())
|
|
def __str__(self):
|
|
return ''.join(["%02x"%ord(x) for x in self.tostring()])
|
|
def __repr__(self):
|
|
return "ByteArray('%s')"%''.join(["\\x%02x"%ord(x) for x in self.tostring()])
|
|
|
|
class ELT_PY:
|
|
SIZEBITS=233
|
|
SIZE=(SIZEBITS+7)/8
|
|
square = ByteArray("\x00\x01\x04\x05\x10\x11\x14\x15\x40\x41\x44\x45\x50\x51\x54\x55")
|
|
def __init__(self, initializer=None):
|
|
if isinstance(initializer, long) or isinstance(initializer, int):
|
|
self.d = ByteArray(long_to_bytes(initializer,self.SIZE))
|
|
elif isinstance(initializer, str):
|
|
self.d = ByteArray(initializer)
|
|
elif isinstance(initializer, ByteArray):
|
|
self.d = ByteArray(initializer)
|
|
elif isinstance(initializer, array):
|
|
self.d = ByteArray(initializer)
|
|
elif isinstance(initializer, ELT):
|
|
self.d = ByteArray(initializer.d)
|
|
elif initializer is None:
|
|
self.d = ByteArray([0]*self.SIZE)
|
|
else:
|
|
raise TypeError("Invalid initializer type")
|
|
if len(self.d) != self.SIZE:
|
|
raise ValueError("ELT size must be 30")
|
|
|
|
def __cmp__(self, other):
|
|
if other == 0: #exception
|
|
if self:
|
|
return 1
|
|
else:
|
|
return 0
|
|
if not isinstance(other,ELT):
|
|
return NotImplemented
|
|
return cmp(self.d,other.d)
|
|
|
|
def __long__(self):
|
|
return long(self.d)
|
|
def __repr__(self):
|
|
return repr(self.d).replace("ByteArray","ELT")
|
|
def __str__(self):
|
|
return str(self.d)
|
|
def __nonzero__(self):
|
|
for x in self.d:
|
|
if x != 0:
|
|
return True
|
|
return False
|
|
def __len__(self):
|
|
return self.SIZE
|
|
def __add__(self,other):
|
|
if not isinstance(other,ELT):
|
|
return NotImplemented
|
|
new = ELT(self)
|
|
for x in range(self.SIZE):
|
|
new[x] ^= other[x]
|
|
return new
|
|
def _mul_x(self):
|
|
carry = self[0]&1
|
|
x = 0
|
|
d = ELT()
|
|
for i in range(self.SIZE-1):
|
|
y = self[i + 1]
|
|
d[i] = x ^ (y >> 7)
|
|
x = y << 1
|
|
d[29] = x ^ carry
|
|
d[20] ^= carry << 2
|
|
return d
|
|
def __mul__(self,other):
|
|
if not isinstance(other,ELT):
|
|
return NotImplemented
|
|
d = ELT()
|
|
i = 0
|
|
mask = 1
|
|
for n in range(self.SIZEBITS):
|
|
d = d._mul_x()
|
|
if (self[i] & mask) != 0:
|
|
d += other
|
|
mask >>= 1
|
|
if mask == 0:
|
|
mask = 0x80
|
|
i+=1
|
|
return d
|
|
def __pow__(self,other):
|
|
if other == -1:
|
|
return 1/self
|
|
if other < 1:
|
|
return NotImplemented
|
|
if other % 2 == 0:
|
|
return self._square()**(other/2)
|
|
x = self
|
|
for i in range(other-1):
|
|
x *= self
|
|
return x
|
|
def _square(self):
|
|
wide = ByteArray([0]*self.SIZE*2)
|
|
for i in range(self.SIZE):
|
|
wide[2*i] = self.square[self[i] >> 4]
|
|
wide[2*i + 1] = self.square[self[i] & 0xf]
|
|
for i in range(self.SIZE):
|
|
x = wide[i]
|
|
|
|
wide[i + 19] ^= x >> 7;
|
|
wide[i + 20] ^= x << 1;
|
|
|
|
wide[i + 29] ^= x >> 1;
|
|
wide[i + 30] ^= x << 7;
|
|
x = wide[30] & 0xFE;
|
|
|
|
wide[49] ^= x >> 7;
|
|
wide[50] ^= x << 1;
|
|
|
|
wide[59] ^= x >> 1;
|
|
|
|
wide[30] &= 1;
|
|
return ELT(wide[self.SIZE:])
|
|
def _itoh_tsujii(self,b,j):
|
|
t = ELT(self)
|
|
return t**(2**j) * b
|
|
def __rdiv__(self,other):
|
|
if isinstance(other,ELT):
|
|
return 1/self * other
|
|
elif other == 1:
|
|
t = self._itoh_tsujii(self, 1)
|
|
s = t._itoh_tsujii(self, 1)
|
|
t = s._itoh_tsujii(s, 3)
|
|
s = t._itoh_tsujii(self, 1)
|
|
t = s._itoh_tsujii(s, 7)
|
|
s = t._itoh_tsujii(t, 14)
|
|
t = s._itoh_tsujii(self, 1)
|
|
s = t._itoh_tsujii(t, 29)
|
|
t = s._itoh_tsujii(s, 58)
|
|
s = t._itoh_tsujii(t, 116)
|
|
return s**2
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __getitem__(self,item):
|
|
return self.d[item]
|
|
def __setitem__(self,item,value):
|
|
self.d[item] = value
|
|
def tobignum(self):
|
|
return bytes_to_long(self.d.tostring())
|
|
def tobytes(self):
|
|
return self.d.tostring()
|
|
|
|
class ELT_C(ELT_PY):
|
|
def __mul__(self,other):
|
|
if not isinstance(other,ELT):
|
|
return NotImplemented
|
|
return ELT(_ec.elt_mul(self.d.tostring(),other.d.tostring()))
|
|
def __rdiv__(self,other):
|
|
if other != 1:
|
|
return ELT_PY.__rdiv__(self,other)
|
|
return ELT(_ec.elt_inv(self.d.tostring()))
|
|
def _square(self):
|
|
return ELT(_ec.elt_square(self.d.tostring()))
|
|
|
|
if fastelt:
|
|
ELT = ELT_C
|
|
else:
|
|
ELT = ELT_PY
|
|
|
|
class Point:
|
|
def __init__(self,x,y=None):
|
|
if isinstance(x,str) and (y is None) and (len(x) == 60):
|
|
self.x = ELT(x[:30])
|
|
self.y = ELT(x[30:])
|
|
elif isinstance(x,Point):
|
|
self.x = ELT(x.x)
|
|
self.y = ELT(x.y)
|
|
else:
|
|
self.x = ELT(x)
|
|
self.y = ELT(y)
|
|
def on_curve(self):
|
|
return (self.x**3 + self.x**2 + self.y**2 + self.x*self.y + ELT(ec_b)) == 0
|
|
def __cmp__(self, other):
|
|
if other == 0:
|
|
if self.x or self.y:
|
|
return 1
|
|
else:
|
|
return 0
|
|
elif isinstance(other, Point):
|
|
ca = cmp(self.x,other.x)
|
|
if ca != 0:
|
|
return ca
|
|
return cmp(self.y,other.y)
|
|
return NotImplemented
|
|
def _double(self):
|
|
if self.x == 0:
|
|
return Point(0,0)
|
|
|
|
s = self.y/self.x + self.x
|
|
rx = s**2 + s
|
|
rx[29] ^= 1;
|
|
ry = s * rx + rx + self.x**2
|
|
return Point(rx,ry)
|
|
def __add__(self, other):
|
|
if not isinstance(other,Point):
|
|
return NotImplemented
|
|
if self == 0:
|
|
return Point(other)
|
|
if other == 0:
|
|
return Point(self)
|
|
u = self.x + other.x
|
|
if u == 0:
|
|
u = self.y + other.y
|
|
if u == 0:
|
|
return self._double()
|
|
else:
|
|
return Point(0,0)
|
|
|
|
s = (self.y + other.y) / u
|
|
t = s**2 + s + other.x
|
|
t[29] ^= 1
|
|
|
|
rx = t+self.x
|
|
ry = s*t+self.y+rx
|
|
return Point(rx,ry)
|
|
|
|
def __mul__(self, other):
|
|
bts = long_to_bytes(other,30)
|
|
d = Point(0,0)
|
|
for i in range(30):
|
|
mask = 0x80
|
|
while mask != 0:
|
|
d = d._double()
|
|
if ((ord(bts[i]) & mask) != 0):
|
|
d += self
|
|
mask >>=1
|
|
return d
|
|
#def __mul__(self, other):
|
|
#if not (isinstance(other,long) or isinstance(other,int)):
|
|
#return NotImplemented
|
|
|
|
#d = Point(0,0)
|
|
#s = Point(self)
|
|
|
|
#while other != 0:
|
|
#if other & 1:
|
|
#d += s
|
|
#s = s._double()
|
|
#other >>= 1
|
|
#return d
|
|
def __rmul__(self, other):
|
|
return self * other
|
|
def __str__(self):
|
|
return "(%s,%s)"%(str(self.x),str(self.y))
|
|
def __repr__(self):
|
|
return "Point"+str(self)
|
|
def __nonzero__(self):
|
|
return self.x or self.y
|
|
def tobytes(self):
|
|
return self.x.tobytes() + self.y.tobytes()
|
|
|
|
#only for prime N
|
|
#segher, your math makes my head hurt. But it works.
|
|
def bn_inv(a,N):
|
|
return pow(a,N-2,N)
|
|
|
|
|
|
# order of the addition group of points
|
|
ec_N = bytes_to_long(
|
|
"\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"+\
|
|
"\x13\xe9\x74\xe7\x2f\x8a\x69\x22\x03\x1d\x26\x03\xcf\xe0\xd7")
|
|
|
|
# base point
|
|
ec_G = Point(
|
|
"\x00\xfa\xc9\xdf\xcb\xac\x83\x13\xbb\x21\x39\xf1\xbb\x75\x5f"+
|
|
"\xef\x65\xbc\x39\x1f\x8b\x36\xf8\xf8\xeb\x73\x71\xfd\x55\x8b"+
|
|
"\x01\x00\x6a\x08\xa4\x19\x03\x35\x06\x78\xe5\x85\x28\xbe\xbf"+
|
|
"\x8a\x0b\xef\xf8\x67\xa7\xca\x36\x71\x6f\x7e\x01\xf8\x10\x52")
|
|
|
|
def generate_ecdsa(k, sha):
|
|
k = bytes_to_long(k)
|
|
|
|
if k >= ec_N:
|
|
raise Exception("Invalid private key")
|
|
|
|
e = bytes_to_long(sha)
|
|
|
|
m = open("/dev/random","rb").read(30)
|
|
if len(m) != 30:
|
|
raise Exception("Failed to get random data")
|
|
m = bytes_to_long(m) % ec_N
|
|
|
|
r = (m * ec_G).x.tobignum() % ec_N
|
|
|
|
kk = ((r*k)+e)%ec_N
|
|
s = (bn_inv(m,ec_N) * kk)%ec_N
|
|
|
|
r = long_to_bytes(r,30)
|
|
s = long_to_bytes(s,30)
|
|
return r,s
|
|
|
|
def check_ecdsa(q,r,s,sha):
|
|
|
|
q = Point(q)
|
|
r = bytes_to_long(r)
|
|
s = bytes_to_long(s)
|
|
e = bytes_to_long(sha)
|
|
|
|
s_inv = bn_inv(s,ec_N)
|
|
|
|
w1 = (e*s_inv)%ec_N
|
|
w2 = (r*s_inv)%ec_N
|
|
|
|
r1 = w1 * ec_G + w2 * q
|
|
|
|
rx = r1.x.tobignum()%ec_N
|
|
|
|
return rx == r
|
|
|
|
def priv_to_pub(k):
|
|
k = bytes_to_long(k)
|
|
q = k * ec_G
|
|
return q.tobytes()
|
|
|
|
def gen_priv_key():
|
|
k = open("/dev/random","rb").read(30)
|
|
if len(k) != 30:
|
|
raise Exception("Failed to get random data")
|
|
|
|
k = bytes_to_long(k)
|
|
k = k % ec_N
|
|
return long_to_bytes(k,30)
|