simplify channel validation

pull/5/head
Carlos Cardoso 5 years ago
parent a1d03c42b1
commit f94118b00e

@ -22,7 +22,7 @@ const (
type Dev struct { type Dev struct {
i2c i2c.Dev i2c i2c.Dev
state map[uint8]State state [4]State
} }
func New(bus i2c.Bus, address uint16) (*Dev, error) { func New(bus i2c.Bus, address uint16) (*Dev, error) {
@ -32,7 +32,7 @@ func New(bus i2c.Bus, address uint16) (*Dev, error) {
d := &Dev{ d := &Dev{
i2c: i2c.Dev{Bus: bus, Addr: address}, i2c: i2c.Dev{Bus: bus, Addr: address},
state: buildChannelsList(), state: [4]State{StateOff, StateOff, StateOff, StateOff},
} }
if err := d.reset(); err != nil { if err := d.reset(); err != nil {
@ -47,30 +47,30 @@ func (d *Dev) Halt() error {
} }
func (d *Dev) On(channel uint8) error { func (d *Dev) On(channel uint8) error {
if err := d.isValidChannel(channel); err != nil { if !d.isValidChannel(channel) {
return err return errInvalidChannel
} }
_, err := d.i2c.Write([]byte{channel, byte(StateOn)}) _, err := d.i2c.Write([]byte{channel, byte(StateOn)})
d.state[channel] = StateOn d.state[channel-1] = StateOn
return err return err
} }
func (d *Dev) Off(channel uint8) error { func (d *Dev) Off(channel uint8) error {
if err := d.isValidChannel(channel); err != nil { if !d.isValidChannel(channel) {
return err return errInvalidChannel
} }
_, err := d.i2c.Write([]byte{channel, byte(StateOff)}) _, err := d.i2c.Write([]byte{channel, byte(StateOff)})
d.state[channel] = StateOff d.state[channel-1] = StateOff
return err return err
} }
func (d *Dev) State(channel uint8) (State, error) { func (d *Dev) State(channel uint8) (State, error) {
if err := d.isValidChannel(channel); err != nil { if !d.isValidChannel(channel) {
return 0, err return 0, errInvalidChannel
} }
return d.state[channel], nil return d.state[channel-1], nil
} }
func (d *Dev) AvailableChannels() []uint8 { func (d *Dev) AvailableChannels() []uint8 {
@ -84,9 +84,8 @@ func (s State) String() string {
return "on" return "on"
} }
// Reset resets the registers to the default values.
func (d *Dev) reset() error { func (d *Dev) reset() error {
for channel := range d.state { for _, channel := range d.AvailableChannels() {
err := d.Off(channel) err := d.Off(channel)
if err != nil { if err != nil {
return err return err
@ -99,34 +98,14 @@ func (d *Dev) reset() error {
// Up to 4 HATs can be stacked and each one need a different address to // Up to 4 HATs can be stacked and each one need a different address to
// work. // work.
func isValidAddress(address uint16) error { func isValidAddress(address uint16) error {
validAddresses := [...]uint16{0x10, 0x11, 0x12, 0x13} switch address {
case 0x10, 0x11, 0x12, 0x13:
for _, addr := range validAddresses { return nil
if address == addr { default:
return nil return errInvalidAddress
}
} }
return errInvalidAddress
} }
func (d *Dev) isValidChannel(channel uint8) error { func (d *Dev) isValidChannel(channel uint8) bool {
if _, exists := d.state[channel]; !exists { return channel >= 1 && channel <= 4
return errInvalidChannel
}
return nil
}
// EP-0099 offers 4 channels per board
func buildChannelsList() map[uint8]State {
// Using a map instead of list since indexes of channels are not zero-based
// values. That would cause loops to have to correct channel ids while
// looping through items or reading/setting values.
// With maps, keys correspond to actual channels on the board.
return map[uint8]State{
1: StateOff,
2: StateOff,
3: StateOff,
4: StateOff,
}
} }

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"reflect"
"testing" "testing"
"periph.io/x/conn/v3/i2c/i2ctest" "periph.io/x/conn/v3/i2c/i2ctest"
@ -46,16 +47,10 @@ func TestAvailableChannels(t *testing.T) {
expected := []uint8{0x01, 0x02, 0x03, 0x04} expected := []uint8{0x01, 0x02, 0x03, 0x04}
dev, _ := New(bus, testDefaultValidAddress) dev, _ := New(bus, testDefaultValidAddress)
channels := dev.AvailableChannels() list := dev.AvailableChannels()
if len(channels) != len(expected) { if !reflect.DeepEqual(expected, list) {
t.Fatal("Available channels len should be ", len(expected), ", got ", len(channels)) t.Fatal("Available channels should be ", expected, " got ", list)
}
for i := 0; i < len(expected); i++ {
if channels[i] != expected[i] {
t.Fatal("Available channels should be ", expected, " got ", channels)
}
} }
} }
@ -63,8 +58,6 @@ func TestHalt(t *testing.T) {
bus := initTestBus() bus := initTestBus()
dev, _ := New(bus, testDefaultValidAddress) dev, _ := New(bus, testDefaultValidAddress)
resetTestBusOps(bus)
dev.Halt() dev.Halt()
checkDevReset(t, dev, bus) checkDevReset(t, dev, bus)
} }
@ -73,8 +66,6 @@ func TestOn(t *testing.T) {
bus := initTestBus() bus := initTestBus()
dev, _ := New(bus, testDefaultValidAddress) dev, _ := New(bus, testDefaultValidAddress)
resetTestBusOps(bus)
err := dev.On(3) err := dev.On(3)
if err != nil { if err != nil {
@ -89,8 +80,6 @@ func TestOff(t *testing.T) {
bus := initTestBus() bus := initTestBus()
dev, _ := New(bus, testDefaultValidAddress) dev, _ := New(bus, testDefaultValidAddress)
resetTestBusOps(bus)
err := dev.Off(4) err := dev.Off(4)
if err != nil { if err != nil {
@ -112,6 +101,20 @@ func TestReturnErrorForInvalidChannel(t *testing.T) {
if err := dev.Off(98); err != errInvalidChannel { if err := dev.Off(98); err != errInvalidChannel {
t.Fatal("Off should return invalid channel error, got ", err) t.Fatal("Off should return invalid channel error, got ", err)
} }
if err := dev.Off(98); err != errInvalidChannel {
t.Fatal("Off should return invalid channel error, got ", err)
}
}
func TestStateToString(t *testing.T) {
if s := fmt.Sprintf("%s", StateOn); s != "on" {
t.Fatal("StateOn as string should be 'on', got ", s)
}
if s := fmt.Sprintf("%s", StateOff); s != "off" {
t.Fatal("StateOn as string should be 'off', got ", s)
}
} }
func initTestBus() *i2ctest.Record { func initTestBus() *i2ctest.Record {
@ -121,10 +124,6 @@ func initTestBus() *i2ctest.Record {
} }
} }
func resetTestBusOps(bus *i2ctest.Record) {
bus.Ops = []i2ctest.IO{}
}
func checkChannelState(t *testing.T, dev *Dev, channel uint8, state State) { func checkChannelState(t *testing.T, dev *Dev, channel uint8, state State) {
if actual, _ := dev.State(channel); actual != state { if actual, _ := dev.State(channel); actual != state {
msg := fmt.Sprintf("Channel %d should have state %s, got: %s", channel, state, actual) msg := fmt.Sprintf("Channel %d should have state %s, got: %s", channel, state, actual)

Loading…
Cancel
Save