Skip to content
Open
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from mittens.mittens.tf_mittens import Mittens, GloVe
except ImportError:
from mittens.mittens.np_mittens import Mittens, GloVe

__version__ = "0.2.2"
11 changes: 10 additions & 1 deletion mittens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
try:
try:
from mittens.tf_mittens import Mittens, GloVe
except:
# print("Failed mittens.tf_mittens")
from mittens.mittens.tf_mittens import Mittens, GloVe
except ImportError:
# print("Failed ANY tf_mittens")
try:
from mittens.np_mittens import Mittens, GloVe
except:
# print("Failed mittens.np_mittens")
from mittens.mittens.np_mittens import Mittens, GloVe

__version__ = "0.2"
__version__ = "0.2.2"
25 changes: 23 additions & 2 deletions mittens/mittens_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from copy import copy
import random
import sys
from time import time

import numpy as np

from mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION
try:
from mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION
except:
from mittens.mittens.doc import BASE_DOC, MITTENS_PARAM_DESCRIPTION


class MittensBase(object):
Expand All @@ -31,6 +35,19 @@ def __init__(self, n=100, mittens=0.1, xmax=100, alpha=0.75,
self.max_iter = max_iter
self.errors = list()
self.test_mode = test_mode

def message(self, obj, timer=None):
if type(obj) != str:
obj = str(obj)
elapsed = 0
if timer == 'start':
self._msg_time = time()
elif timer == 'stop':
elapsed = time() - self._msg_time
if elapsed > 0:
obj = obj + ' ({:.1f}s)'.format(elapsed)
print("\r" + obj, flush=True)
return

def fit(self,
X,
Expand Down Expand Up @@ -69,14 +86,18 @@ def fit(self,
embedding of the corresponding element in `vocab`.

"""
self.message("Fitting mco {}".format(X.shape))

if fixed_initialization is not None:
assert self.test_mode, \
"Fixed initialization parameters can only be provided" \
" in test mode. Initialize {} with `test_mode=True`.". \
format(self.__class__.split(".")[-1])
self.message(" Dimensions check")
self._check_dimensions(
X, vocab, initial_embedding_dict
)
self.message(" Initializing weights and log(mco)")
weights, log_coincidence = self._initialize(X)
return self._fit(X, weights, log_coincidence,
vocab=vocab,
Expand Down Expand Up @@ -163,7 +184,7 @@ def _progressbar(self, msg, iter_num):
if self.display_progress and \
(iter_num + 1) % self.display_progress == 0:
sys.stderr.write('\r')
sys.stderr.write("Iteration {}: {}".format(iter_num + 1, msg))
sys.stderr.write("Iteration {}: {}\t\t\t".format(iter_num + 1, msg))
sys.stderr.flush()

def __repr__(self):
Expand Down
14 changes: 12 additions & 2 deletions mittens/np_mittens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
"""
import numpy as np

from mittens.mittens_base import randmatrix, noise
from mittens.mittens_base import MittensBase, GloVeBase
try:
from mittens.mittens_base import randmatrix, noise
from mittens.mittens_base import MittensBase, GloVeBase
except:
from mittens.mittens.mittens_base import randmatrix, noise
from mittens.mittens.mittens_base import MittensBase, GloVeBase



_FRAMEWORK = "NumPy"
Expand All @@ -35,6 +40,11 @@ class Mittens(MittensBase):
framework=_FRAMEWORK,
second=_DESC.format(model=MittensBase._MODEL))

def __init__(self,
**kwargs):
super().__init__(**kwargs)
self.message("NumPy Mittens initialized.")

@property
def framework(self):
return _FRAMEWORK
Expand Down
Loading