File: train_supervised.html

package info (click to toggle)
fasttext 0.9.2%2Bds-9
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,952 kB
  • sloc: cpp: 5,459; python: 2,427; javascript: 635; sh: 621; makefile: 106; xml: 81; perl: 43
file content (66 lines) | stat: -rw-r--r-- 2,210 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no">
</head>
<body>
    <script type="module">
        const printVector = function(predictions, limit) {
            limit = limit || Infinity;

            for (let i=0; i<predictions.size() && i<limit; i++){
                let prediction = predictions.get(i);
                console.log(predictions.get(i));
            }
        }

        const trainCallback = (progress, loss, wst, lr, eta) => {
            console.log([progress, loss, wst, lr, eta]);
        };

        import {FastText, addOnPostRun} from "./fasttext.js";

        addOnPostRun(() => {
            let ft = new FastText();

            ft.trainSupervised("cooking.train", {
                'lr':1.0,
                'epoch':10,
                'loss':'hs',
                'wordNgrams':2,
                'dim':50,
                'bucket':200000
            }, trainCallback).then(model => {
                console.log('Trained.');

                printVector(model.predict("Which baking dish is best to bake a banana bread ?", 5, 0.0));

                /* getInputMatrix */
                let inputMatrix = model.getInputMatrix();
                console.log(inputMatrix.cols());
                console.log(inputMatrix.rows());
                console.log(inputMatrix.at(1, 2));

                /* getOutputMatrix */
                let outputMatrix = model.getOutputMatrix();
                console.log(outputMatrix.cols());
                console.log(outputMatrix.rows());
                console.log(outputMatrix.at(1, 2));

                /* getWords */
                let wordsInformation = model.getWords();
                printVector(wordsInformation[0], 30);   // words
                printVector(wordsInformation[1], 30);   // frequencies

                /* getLabels */
                let labelsInformation = model.getLabels();
                printVector(labelsInformation[0], 30);  // labels
                printVector(labelsInformation[1], 30);  // frequencies
            });
        });

    </script>
</body>

</html>