Pastebin

mx gluon data dataloader

mx gluon data dataloader from Pastebin

From mxnet import gluon Pastebin
    
        #!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd
import numpy as np
from mxnet.gluon import nn, rnn
from mxnet import gluon, autograd
import gluonnlp as nlp
from mxnet import nd 
import mxnet as mx
import time
import itertools
import random


# In[2]:


try:
    ctx = mx.gpu()
except:
    ctx = mx.cpu()


# In[3]:


bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',
                                             dataset_name='wiki_multilingual_cased',
                                             pretrained=False, ctx=ctx, use_pooler=True,
                                             use_decoder=False, use_classifier=False)


# In[6]:


class BERTInferenceDataset(mx.gluon.data.Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
        sent_dataset = gluon.data.SimpleDataset([[
            i[sent_idx],
        ] for i in dataset])
        self.sentences = sent_dataset.transform(transform)

    def __getitem__(self, i):
        return (self.sentences[i])

    def __len__(self):
        return (len(self.sentences))


# In[7]:


bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=False)
max_len = 64


# In[8]:


class BERTClassifier(nn.Block):
    def __init__(self,
                 bert,
                 num_classes=2,
                 dropout=None,
                 prefix=None,
                 params=None):
        super(BERTClassifier, self).__init__(prefix=prefix, params=params)
        self.bert = bert
        with self.name_scope():
            self.classifier = nn.HybridSequential(prefix=prefix)
            if dropout:
                self.classifier.add(nn.Dropout(rate=dropout))
            self.classifier.add(nn.Dense(units=num_classes))

    def forward(self, inputs, token_types, valid_length=None):
        _, pooler = self.bert(inputs, token_types, valid_length)
        return self.classifier(pooler)


# In[11]:


file_name = "net.params"

new_net = BERTClassifier(bert_base, num_classes=2, dropout=0.3)
new_net.load_parameters(file_name, ctx=ctx)


# In[67]:

def inference(texts, raw=False):
    def create_iter_data(texts):
        if type(texts) == list:
            iterable = texts
        elif type(texts) == str:
            iterable = [texts]
        pd.DataFrame(iterable, columns=['text']).to_csv('iter_data.tsv', sep='\t')
        dataset_test = nlp.data.TSVDataset("iter_data.tsv", field_indices=[1], num_discard_samples=1)
        data_test = BERTInferenceDataset(dataset_test, 0, 1, bert_tokenizer, max_len, True, False)
        test_dataloader = mx.gluon.data.DataLoader(data_test, batch_size=1)
        return test_dataloader
    
    x = create_iter_data(texts)
    res = []
    for i, (t,v,s) in enumerate(x):
        token_ids = t.as_in_context(ctx)
        valid_length = v.as_in_context(ctx)
        segment_ids = s.as_in_context(ctx)
        output = new_net(token_ids, segment_ids, valid_length.astype('float32'))
        if raw:
            res += nd.softmax(output, axis=1).asnumpy().tolist()
        else:
            res.append(
                round(nd.softmax(output, axis=1).asnumpy().tolist()[0][1], 4)
            )
    if len(res)==1 and not raw:
        return res[0]
    if raw:
        return np.array(res)
    return res

from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer()

def limer(example):
    exp = explainer.explain_instance(example, lambda x: inference(x, raw=True), top_labels=1)
    exp.show_in_notebook()