Tensorflow.js 安装

2020-07-21 14:20 更新

在JavaScript项目中,TensorFlow.js的安装方法有两种:一种是通过script标签引入,另外一种就是通过npm进行安装。


如果不熟悉WEB开发的同学,我们建议使用脚本标签来获取。

使用Script Tag


将以下脚本标签添加到您的主HTML文件中:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js" rel="external nofollow" ></script>


有关脚本标签的设置,请参阅代码示例:

<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]" rel="external nofollow" > </script>

    <!-- Place your code in the script tag below. You can also use an external .js file -->
    <script>
      // Notice there is no 'import' statement. 'tf' is available on the index-page
      // because of the script tag above.

      // Define a model for linear regression.
      const model = tf.sequential();
      model.add(tf.layers.dense({units: 1, inputShape: [1]}));

      // Prepare the model for training: Specify the loss and the optimizer.
      model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

      // Generate some synthetic data for training.
      const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
      const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

      // Train the model using the data.
      model.fit(xs, ys).then(() => {
        // Use the model to do inference on a data point the model hasn't seen before:
        // Open the browser devtools to see the output
        model.predict(tf.tensor2d([5], [1, 1])).print();
      });
    </script>
  </head>

  <body>
  </body>
</html>


通过NPM(或yarn)

使用yarn或npm将TensorFlow.js添加到您的项目中。 注意:因为使用ES2017语法(如import),所以此工作流程假定您使用打包程序/转换程序将代码转换为浏览器可以理解的内容。 

yarn add @tensorflow/tfjs  
或者
npm install @tensorflow/tfjs

在NPM中输入以下代码:


import * as tf from '@tensorflow/tfjs';

//定义一个线性回归模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

// 为训练生成一些合成数据
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

// 使用数据训练模型
model.fit(xs, ys, {epochs: 10}).then(() => {
  // 在该模型从未看到过的数据点上使用模型进行推理
  model.predict(tf.tensor2d([5], [1, 1])).print();
  //  打开浏览器开发工具查看输出
});
 


以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号