mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-02-15 11:49:14 +01:00
Merge 3051db26297cac671994462aea9fcbd35e364103 into fb54d8b54984f89f7dba90a18e7c3048421464c3
This commit is contained in:
commit
f218429159
@ -387,15 +387,16 @@ class Trainer(object):
|
||||
fold=i,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Sure you want to do multiple fold training")
|
||||
for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df):
|
||||
self._train_single_fold(
|
||||
mb_generator=mb_generator,
|
||||
val_data=val_data,
|
||||
test_data=test_data,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
fold=i,
|
||||
)
|
||||
i += 1
|
||||
if i == 3:
|
||||
break
|
||||
a = input("Are you sure you want to do multiple fold training? (y/N)")
|
||||
if a.lower() == "y":
|
||||
for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df):
|
||||
self._train_single_fold(
|
||||
mb_generator=mb_generator,
|
||||
val_data=val_data,
|
||||
test_data=test_data,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
fold=i,
|
||||
)
|
||||
i += 1
|
||||
if i == 3:
|
||||
break
|
||||
|
Loading…
x
Reference in New Issue
Block a user