人工智能DeepMind发布神经网络

    作者:十三更新于: 2020-02-26 14:44:34

    DeepMind发布神经网络、强化学习库,网友:推动JAX发展。人工智能的定义可以分为两部分,即“ 人工”和“ 智能”。“人工”比较好理解,争议性也不大。有时我们会要考虑什么是人力所能及制造的,或者人自身的智能程度有没有高到可以创造人工智能的地步,等等。但总的来说,“人工系统”就是通常意义下的人工系统。

    DeepMind今日发布了Haiku和RLax两个库,都是基于JAX。

    JAX由谷歌提出,是TensorFlow的简化库。结合了针对线性代数的编译器XLA,和自动区分本地 Python 和 Numpy 代码的库Autograd,在高性能的机器学习研究中使用。

    而此次发布的两个库,分别针对神经网络和强化学习,大幅简化了JAX的使用。

    Haiku是基于JAX的神经网络库,允许用户使用熟悉的面向对象程序设计模型,可完全访问 JAX 的纯函数变换。

    RLax是JAX顶层的库,它提供了用于实现增强学习代理的有用构件。

    有意思的是,Reddit网友惊奇的发现Haiku这个库的名字,竟然不以“ax”结尾。

    人工智能DeepMind发布神经网络_Python开发视频_Python视频_Python课程_课课家

    当然,也有网友对这两个库表示了肯定:

    毫无疑问,对JAX起到了推动作用。

    那么,我们就来看下Haiku和RLex的庐山真面目吧。

    Haiku

    Haiku是JAX的神经网络库,它允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。

    它提供了两个核心工具:模块抽象hk.Module,和一个简单的函数转换hk.transform。

    hk.Module是Python对象,包含对其自身参数、其他模块和对用户输入应用函数方法的引用。

    hk.transform允许完全访问JAX的纯函数转换。

    其实,在JAX中有许多神经网络库,那么Haiku有什么特别之处呢?有5点。

    1、Haiku已经由DeepMind的研究人员进行了大规模测试

    DeepMind相对容易地在Haiku和JAX中复制了许多实验。其中包括图像和语言处理的大规模结果、生成模型和强化学习。

    2、Haiku是一个库,而不是一个框架

    它的设计是为了简化一些具体的事情,包括管理模型参数和其他模型状态。可以与其他库一起编写,并与JAX的其他部分一起工作。

    3、Haiku并不是另起炉灶

    它建立在Sonnet的编程模型和API之上,Sonnet是DeepMind几乎普遍采用的神经网络库。它保留了Sonnet用于状态管理的基于模块的编程模型,同时保留了对JAX函数转换的访问。

    4、过渡到Haiku是比较容易的

    通过精心的设计,从TensorFlow和Sonnet,过渡到JAX和Haiku是比较容易的。除了新的函数(如hk.transform),Haiku的目的是Sonnet 2的API。

    5、Haiku简化了JAX

    它提供了一个处理随机数的简单模型。在转换后的函数中,hk.next_rng_key()返回一个唯一的rng键。

    那么,该如何安装Haiku呢?

    Haiku是用纯Python编写的,但是通过JAX依赖于c++代码。

    首先,按照下方链接中的说明,安装带有相关加速器支持的JAX。

    httPS://github.com/google/jax#installation

    然后,只需要一句简单的pip命令就可以完成安装。

    1. $ pip install git+https://github.com/deepmind/haiku 

    接下来,是一个神经网络和损失函数的例子。

    1. import haiku as hk 
    2.  
    3. import jax.numpy as jnp 
    4.  
    5. def softmax_cross_entropy(logits, labels): 
    6.  
    7.   one_hot = hk.one_hot(labels, logits.shape[-1]) 
    8.  
    9.   return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1
    10.  
    11. def loss_fn(images, labels): 
    12.  
    13.   model = hk.Sequential([ 
    14.  
    15.       hk.Linear(1000), 
    16.  
    17.       jax.nn.relu, 
    18.  
    19.       hk.Linear(100), 
    20.  
    21.       jax.nn.relu, 
    22.  
    23.       hk.Linear(10), 
    24.  
    25.   ]) 
    26.  
    27.   logits = model(images) 
    28.  
    29.   return jnp.mean(softmax_cross_entropy(logits, labels)) 
    30.  
    31. loss_obj = hk.transform(loss_fn) 

    RLax

    RLax是JAX顶层的库,它提供了用于实现增强学习代理的有用构件。

    它所提供的操作和函数不是完整的算法,而是强化学习特定数学操作的实现。

    RLax的安装也非常简单,一个pip命令就可以搞定。

    1. pip install git+git://github.com/deepmind/rlax.git 

    使用JAX的jax.jit函数,所有的RLax代码可以不同的硬件上编译。

    RLax需要注意的是它的命名规则。

    许多函数在连续的时间步长中考虑策略、操作、奖励和值,以便计算它们的输出。在这种情况下,后缀_t和tm1通常是为了说明每个输入是在哪个步骤上生成的,例如:

    q_tm1:转换的源状态中的操作值。

    a_tm1:在源状态下选择的操作。

    r_t:在目标状态下收集的结果奖励。

    q_t:目标状态下的操作值。

    Haiku和RLax都已在GitHub上开源,有兴趣的读者可从“传送门”的链接访问。

    传送门

    Haiku:

    https://github.com/deepmind/haiku

    RLax:

    https://github.com/deepmind/rlax

    人工智能(Artificial Intelligence),英文缩写为AI。它是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。 人工智能亦称智械、机器智能,指由人制造出来的机器所表现出来的智能。通常人工智能是指通过普通计算机程序来呈现人类智能的技术。通过医学、神经科学、机器人学及统计学等的进步,有些预测则认为人类的无数职业也逐渐被人工智能取代。

课课家教育

未登录

1