Metadata-Version: 2.1
Name: keras-bert
Version: 0.32.0
Summary: BERT implemented in Keras
Home-page: https://github.com/CyberZHG/keras-bert
Author: CyberZHG
Author-email: CyberZHG@gmail.com
License: MIT
Description: # Keras BERT
        
        [![Travis](https://travis-ci.org/CyberZHG/keras-bert.svg)](https://travis-ci.org/CyberZHG/keras-bert)
        [![Coverage](https://coveralls.io/repos/github/CyberZHG/keras-bert/badge.svg?branch=master)](https://coveralls.io/github/CyberZHG/keras-bert)
        
        Implementation of the [BERT](https://arxiv.org/pdf/1810.04805.pdf). Official pre-trained models could be loaded for feature extraction and prediction.
        
        ## Install
        
        ```bash
        pip install keras-bert
        ```
        
        ## Usage
        
        ### Load Official Pre-trained Models
        
        In [feature extraction demo](./demo/load_model/load_and_extract.py), you should be able to get the same extraction result as the official model. And in [prediction demo](./demo/load_model/load_and_predict.py), the missing word in the sentence could be predicted.
        
        ### Train & Use
        
        ```python
        from keras_bert import get_base_dict, get_model, gen_batch_inputs
        
        
        # A toy input example
        sentence_pairs = [
            [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']],
            [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']],
            [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']],
        ]
        
        
        # Build token dictionary
        token_dict = get_base_dict()  # A dict that contains some special tokens
        for pairs in sentence_pairs:
            for token in pairs[0] + pairs[1]:
                if token not in token_dict:
                    token_dict[token] = len(token_dict)
        token_list = list(token_dict.keys())  # Used for selecting a random word
        
        
        # Build & train the model
        model = get_model(
            token_num=len(token_dict),
            head_num=5,
            transformer_num=12,
            embed_dim=25,
            feed_forward_dim=100,
            seq_len=20,
            pos_num=20,
            dropout_rate=0.05,
        )
        model.summary()
        
        def _generator():
            while True:
                yield gen_batch_inputs(
                    sentence_pairs,
                    token_dict,
                    token_list,
                    seq_len=20,
                    mask_rate=0.3,
                    swap_sentence_rate=1.0,
                )
        
        model.fit_generator(
            generator=_generator(),
            steps_per_epoch=1000,
            epochs=100,
            validation_data=_generator(),
            validation_steps=100,
            callbacks=[
                keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
            ],
        )
        
        
        # Use the trained model
        inputs, output_layer = get_model(  # `output_layer` is the last feature extraction layer (the last transformer)
            token_num=len(token_dict),
            head_num=5,
            transformer_num=12,
            embed_dim=25,
            feed_forward_dim=100,
            seq_len=20,
            pos_num=20,
            dropout_rate=0.05,
            training=False,  # The input layers and output layer will be returned if `training` is `False`
        )
        ```
        
        ### Custom Feature Extraction
        
        ```python
        def _custom_layers(x, trainable=True):
            return keras.layers.LSTM(
                units=768,
                trainable=trainable,
                return_sequences=True,
                name='LSTM',
            )(x)
        
        model = get_model(
            token_num=200,
            embed_dim=768,
            custom_layers=_custom_layers,
        )
        ```
        
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3.6
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Description-Content-Type: text/markdown
