mx gluon data dataloader
mx gluon data dataloader from Pastebin
From mxnet import gluon Pastebin
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pandas as pd
import numpy as np
from mxnet.gluon import nn, rnn
from mxnet import gluon, autograd
import gluonnlp as nlp
from mxnet import nd
import mxnet as mx
import time
import itertools
import random
# In[2]:
try:
ctx = mx.gpu()
except:
ctx = mx.cpu()
# In[3]:
bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',
dataset_name='wiki_multilingual_cased',
pretrained=False, ctx=ctx, use_pooler=True,
use_decoder=False, use_classifier=False)
# In[6]:
class BERTInferenceDataset(mx.gluon.data.Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
sent_dataset = gluon.data.SimpleDataset([[
i[sent_idx],
] for i in dataset])
self.sentences = sent_dataset.transform(transform)
def __getitem__(self, i):
return (self.sentences[i])
def __len__(self):
return (len(self.sentences))
# In[7]:
bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=False)
max_len = 64
# In[8]:
class BERTClassifier(nn.Block):
def __init__(self,
bert,
num_classes=2,
dropout=None,
prefix=None,
params=None):
super(BERTClassifier, self).__init__(prefix=prefix, params=params)
self.bert = bert
with self.name_scope():
self.classifier = nn.HybridSequential(prefix=prefix)
if dropout:
self.classifier.add(nn.Dropout(rate=dropout))
self.classifier.add(nn.Dense(units=num_classes))
def forward(self, inputs, token_types, valid_length=None):
_, pooler = self.bert(inputs, token_types, valid_length)
return self.classifier(pooler)
# In[11]:
file_name = "net.params"
new_net = BERTClassifier(bert_base, num_classes=2, dropout=0.3)
new_net.load_parameters(file_name, ctx=ctx)
# In[67]:
def inference(texts, raw=False):
def create_iter_data(texts):
if type(texts) == list:
iterable = texts
elif type(texts) == str:
iterable = [texts]
pd.DataFrame(iterable, columns=['text']).to_csv('iter_data.tsv', sep='\t')
dataset_test = nlp.data.TSVDataset("iter_data.tsv", field_indices=[1], num_discard_samples=1)
data_test = BERTInferenceDataset(dataset_test, 0, 1, bert_tokenizer, max_len, True, False)
test_dataloader = mx.gluon.data.DataLoader(data_test, batch_size=1)
return test_dataloader
x = create_iter_data(texts)
res = []
for i, (t,v,s) in enumerate(x):
token_ids = t.as_in_context(ctx)
valid_length = v.as_in_context(ctx)
segment_ids = s.as_in_context(ctx)
output = new_net(token_ids, segment_ids, valid_length.astype('float32'))
if raw:
res += nd.softmax(output, axis=1).asnumpy().tolist()
else:
res.append(
round(nd.softmax(output, axis=1).asnumpy().tolist()[0][1], 4)
)
if len(res)==1 and not raw:
return res[0]
if raw:
return np.array(res)
return res
from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer()
def limer(example):
exp = explainer.explain_instance(example, lambda x: inference(x, raw=True), top_labels=1)
exp.show_in_notebook()